1. Introduction

대용량 데이터를 처리하다보면 주로 Spark에서 나온 Parquet 파일을 다루는 일이 많습니다.
이런 데이터를 통해서 모델을 학습시키게 되는데, 워낙 데이터 사이즈가 크다 보니 데이터를 모두 메모리에 올려서 사용하는 방법은 적절하지 않습니다.
예를 들어 시스템에 메모리가 수천 기가바이트가 아닌 이상 Pandas 또는 Numpy를 통해 인메모리로 모두 올리는 것인 매우 비효율적입니다.
본문에서는 PyArrow 의 주요 기능들에 대해서 레퍼런스 형태로 빠르게 설명하도록 하겠습니다.

1.1 Installation

Macbook

brew install apache-arrow

Ubuntu

sudo apt update
sudo apt install -y -V ca-certificates lsb-release wget
wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
sudo apt update
sudo apt install -y -V libarrow-dev # For C++
sudo apt install -y -V libarrow-glib-dev # For GLib (C)
sudo apt install -y -V libarrow-dataset-dev # For Apache Arrow Dataset C++
sudo apt install -y -V libarrow-dataset-glib-dev # For Apache Arrow Dataset GLib (C)
sudo apt install -y -V libarrow-acero-dev # For Apache Arrow Acero
sudo apt install -y -V libarrow-flight-dev # For Apache Arrow Flight C++
sudo apt install -y -V libarrow-flight-glib-dev # For Apache Arrow Flight GLib (C)
sudo apt install -y -V libarrow-flight-sql-dev # For Apache Arrow Flight SQL C++
sudo apt install -y -V libarrow-flight-sql-glib-dev # For Apache Arrow Flight SQL GLib (C)
sudo apt install -y -V libgandiva-dev # For Gandiva C++
sudo apt install -y -V libgandiva-glib-dev # For Gandiva GLib (C)
sudo apt install -y -V libparquet-dev # For Apache Parquet C++
sudo apt install -y -V libparquet-glib-dev # For Apache Parquet GLib (C)

2. Dataset

2.1 Basic Iteration

import pyarrow.dataset as ds

dataset = ds.dataset('./data', format='parquet', partitioning=['dt'])
print('n_rows:', dataset.count_rows())

for batch in dataset.to_batches():
    # Pandas
    df = batch.to_pandas()
    
    # Pandas는 모든 데이터를 불러오는데, batch 에서 특정 column 만 불러올 수 있습니다.
    for i in range(batch.num_rows):
        col0 = batch.column('col_0')[0].as_py()

        

3 ParquetDataset

3.1 Basic Usage

ParquetDataset 으로 로딩시, partition 이 안불러와집니다. 반드시 위의 ds.dataset 으로 데이터 처리하세요

ParquetDataset 은 다음의 특징을 갖고 있습니다.

  1. 모든 데이터를 인메모리로 올리지 않음 (아래 Pyarrow size 보면 64이고, Pandas는 692,478)
  2. 디렉토리 지정시 Partition된 데이터를 한꺼번에 가져옴 (그 안에 dt=20230101 디렉토리 모두 읽음)
  3. fragments 통해서 각 파일의 메타정보를 꺼낼수 있음 (row 갯수)
  4. schema 통해서 어떤 컬럼들이 있는지 확인 가능

주요 parameters

  • filters: rows 를 필터링 할 수 있음!! 매우 유용
    ('x', '=', 0)
    ('y', 'in', ['a', 'b', 'c'])
    ('z', 'not in', {'a','b'})
    
  • memory_map: source가 file path 인 경우, memory_map 사용시 performance 증가 함
  • coerce_int96_timestamp_unit: INT96 으로 저장된 timestamp 를 특정 resolution으로 변환 (‘ms’)
  • metadata_nthreads: 메타데이터 읽을때 쓰레드 갯수를 지정하며, 특히 partitioned dataset 읽을 때 속도 증가 (현재 use_legacy_dataset 사용시 지원 안됨)
  • use_legacy_dataset: false 을 경우 새로운 Arrow Dataset API 를 사용함 -> filters 사용 가능하게 해줌

예제 코드

from pyarrow.parquet import ParquetDataset

dataset = ParquetDataset('./data',
                         memory_map=True,                         
                         use_legacy_dataset=False)
df = pd.read_parquet('./data')

file_rows = [frag.count_rows() for frag in dataset.fragments]

print('Pandas shape :', df.shape)
print('Pandas  size :', sys.getsizeof(df))
print('Pyarrow size :', sys.getsizeof(dataset))
print('files        :', dataset.files)
print('fragments    :', dataset.fragments)
print('files rows   :', file_rows)
print('column size  :', len(dataset.schema))
Pandas shape : (1000, 14)
Pandas  size : 692478
Pyarrow size : 64
files        : ['./data/dt=20230101/userdata.parquet']
fragments    : [<pyarrow.dataset.ParquetFileFragment path=./data/dt=20230101/userdata.parquet partition=[dt=20230101]>]
files rows   : [1000]
column size  : 14

3.2 Calculate total number of rows

import pyarrow.parquet as pq

dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)
total_rows = 0
for frag in dataset.fragments:
    for i, row_group in enumerate(frag.row_groups):
        total_rows += row_group.num_rows

print(total_rows)  # 전체 rows 갯수.

3.3 Pandas Iteration

왜 이렇게 해야하냐하면.. 데이터 사이즈가 크기 때문입니다.
그냥 쉽게 pandas 로 불러올 수가 없기 때문입니다.
frag.to_batches() 사용시 batch iteration 을 할 수 있고, 이때 중요한 건 batch 는 실제로 데이터를 읽어온 객체 입니다. 따라서 해당 배치 사이즈 만큼의 데이터가 메모리에 올라가게 됩니다.
따라서 순차적 읽기에는 해당 iteration을 사용해서 처리시 효과적으로 처리 가능하나.. random access 가 필요할 시에는 문제가 발생 할 수 있습니다.

  • to_pandas() 함수 사용해서 pandas 로 꺼낼 수 있습니다.
  • Fragment Wiki
import pyarrow.parquet as pq

dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)

for frag in dataset.fragments:
    for batch in frag.to_batches():
        df = batch.to_pandas()
        row = batch.take(pa.array([0]))  # 특정 row 를 꺼낼 수 있음
        
        print('num rows    :', batch.num_rows)
        print('Pandas shape:', df.shape)
        print(row['gender'])
num rows    : 1000
Pandas shape: (1000, 13)
[
  "Female"
]
	registration_dttm	id	first_name	last_name	email	gender	ip_address	cc	country	birthdate	salary	title	comments
0	2016-02-03 07:55:29	1	Amanda	Jordan	ajordan0@com.com	Female	1.197.201.2	6759521864920116	Indonesia	3/8/1971	49756.53	Internal Auditor	1E+02

3.4 Efficient Iteraration

Python 으로 필요한 데이터만 꺼낼수도 있습니다.
대용량 데이터를 다룰때 특히 pandas 로 전부다 꺼내지 말고, 필요한것만 python 으로 불러오는게 더 효율적입니다.

import pyarrow.parquet as pq

dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)

for frag in dataset.fragments:
    for batch in frag.to_batches():
        for i in range(batch.num_rows):
            user_id = batch.column('user_id')[i].as_py()
            name = batch.column('name')[i].as_py()
            age = batch.column('age')[i].as_py()
            print(user_id, name, age)

3.5 PyTorch Dataset

import random
from bisect import bisect_right
from typing import List, Iterator, Optional, Tuple

import pandas as pd
from pyarrow.dataset import ParquetFileFragment
from pyarrow.lib import RecordBatch
from pyarrow.parquet import ParquetDataset
from torch.utils.data import Dataset


class PyArrowDataset(Dataset):
    """
    Restriction
     - Don't shuffle in Dataloader. this is for efficiency to precess large dataset.
       If you need to shuffle, do it before this custom dataset. (like in SparkSQL)
       But the algorithm supports random access.
    """

    def __init__(self, source: str, seed: int = 123):
        self.source = source
        self.seed = seed

        # Pyarrow
        self.dataset = ParquetDataset(source, use_legacy_dataset=False)
        self.fragments: List[ParquetFileFragment] = self.dataset.fragments
        self._batches: Optional[Iterator[RecordBatch]] = None
        self._batch: Optional[RecordBatch] = None

        # Indexing meta information to make search faster
        self._cumulative_n_rows: List[int] = []
        self._batch_idx: int = 0

        # Index
        self._fragment_idx = 0

        # Initialization
        self._init()

    def _init(self):
        random.seed(self.seed)

        self._cumulative_n_rows = [frag.count_rows() for frag in self.fragments]
        for i in range(1, len(self._cumulative_n_rows)):
            self._cumulative_n_rows[i] += self._cumulative_n_rows[i - 1]

    def _get_next(self, idx: int) -> Tuple[RecordBatch, int, int]:
        def get_prev_cum_frag_size(_fragment_idx):
            if _fragment_idx >= 1:
                return self._cumulative_n_rows[_fragment_idx - 1]
            return 0

        # Calculate fragment idx
        fragment_idx = self._fragment_idx
        fragment_changed = False
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        _cur_size = self._cumulative_n_rows[self._fragment_idx]
        if (idx < _prev_size) or (idx >= _cur_size):
            fragment_idx = bisect_right(self._cumulative_n_rows, idx)
            assert fragment_idx < len(self.fragments)
            # fragment_idx %= len(self.fragments)
            fragment_changed = self._fragment_idx != fragment_idx
            self._fragment_idx = fragment_idx

            self._batch_idx = 0
            del self._batches
            del self._batch
            self._batches = None
            self._batch = None

        # Calculate batch idx
        _prev_size = get_prev_cum_frag_size(fragment_idx)
        batch_idx = idx - _prev_size
        batch_changed = batch_idx < self._batch_idx

        # Calculate batches of the fragment
        if self._batches is None or fragment_changed or batch_changed:
            if self._batches:
                self._batches.clear()

            self.batches = self.fragments[fragment_idx].to_batches()
            self._batch = None

        if self._batch is None:
            self._batch = next(self.batches)
            self._batch_idx = 0

        while True:
            if self._batch_idx <= batch_idx < self._batch_idx + self._batch.num_rows:
                break

            self._batch_idx += self._batch.num_rows
            self._batch = next(self.batches)

        return self._batch, fragment_idx, batch_idx - self._batch_idx

    def __del__(self):
        del self.dataset
        del self.fragments
        del self._batches
        del self._batch

    def __len__(self):
        return self._cumulative_n_rows[-1]


    def __getitem__(self, idx):
        batch, fragment_idx, batch_idx = self._get_next(idx)
        return batch.column('idx')[batch_idx].as_py()
dataset = PyArrowDataset("./data")
print(dataset[50000])
print(dataset[5000])
print(dataset[1])
print(dataset[500000])
print(dataset[151547])
print(dataset[127])
print(dataset[878])