PyArrow
1. Introduction
대용량 데이터를 처리하다보면 주로 Spark에서 나온 Parquet 파일을 다루는 일이 많습니다.
이런 데이터를 통해서 모델을 학습시키게 되는데, 워낙 데이터 사이즈가 크다 보니 데이터를 모두 메모리에 올려서 사용하는 방법은 적절하지 않습니다.
예를 들어 시스템에 메모리가 수천 기가바이트가 아닌 이상 Pandas 또는 Numpy를 통해 인메모리로 모두 올리는 것인 매우 비효율적입니다.
본문에서는 PyArrow 의 주요 기능들에 대해서 레퍼런스 형태로 빠르게 설명하도록 하겠습니다.
1.1 Installation
Macbook
brew install apache-arrow
Ubuntu
sudo apt update
sudo apt install -y -V ca-certificates lsb-release wget
wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
sudo apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
sudo apt update
sudo apt install -y -V libarrow-dev # For C++
sudo apt install -y -V libarrow-glib-dev # For GLib (C)
sudo apt install -y -V libarrow-dataset-dev # For Apache Arrow Dataset C++
sudo apt install -y -V libarrow-dataset-glib-dev # For Apache Arrow Dataset GLib (C)
sudo apt install -y -V libarrow-acero-dev # For Apache Arrow Acero
sudo apt install -y -V libarrow-flight-dev # For Apache Arrow Flight C++
sudo apt install -y -V libarrow-flight-glib-dev # For Apache Arrow Flight GLib (C)
sudo apt install -y -V libarrow-flight-sql-dev # For Apache Arrow Flight SQL C++
sudo apt install -y -V libarrow-flight-sql-glib-dev # For Apache Arrow Flight SQL GLib (C)
sudo apt install -y -V libgandiva-dev # For Gandiva C++
sudo apt install -y -V libgandiva-glib-dev # For Gandiva GLib (C)
sudo apt install -y -V libparquet-dev # For Apache Parquet C++
sudo apt install -y -V libparquet-glib-dev # For Apache Parquet GLib (C)
2. Dataset
2.1 Basic Iteration
import pyarrow.dataset as ds
dataset = ds.dataset('./data', format='parquet', partitioning=['dt'])
print('n_rows:', dataset.count_rows())
for batch in dataset.to_batches():
# Pandas
df = batch.to_pandas()
# Pandas는 모든 데이터를 불러오는데, batch 에서 특정 column 만 불러올 수 있습니다.
for i in range(batch.num_rows):
col0 = batch.column('col_0')[0].as_py()
3 ParquetDataset
3.1 Basic Usage
ParquetDataset 으로 로딩시, partition 이 안불러와집니다. 반드시 위의 ds.dataset 으로 데이터 처리하세요
ParquetDataset 은 다음의 특징을 갖고 있습니다.
- 모든 데이터를 인메모리로 올리지 않음 (아래 Pyarrow size 보면 64이고, Pandas는 692,478)
- 디렉토리 지정시 Partition된 데이터를 한꺼번에 가져옴 (그 안에 dt=20230101 디렉토리 모두 읽음)
- fragments 통해서 각 파일의 메타정보를 꺼낼수 있음 (row 갯수)
- schema 통해서 어떤 컬럼들이 있는지 확인 가능
주요 parameters
filters
: rows 를 필터링 할 수 있음!! 매우 유용
('x', '=', 0) ('y', 'in', ['a', 'b', 'c']) ('z', 'not in', {'a','b'})
memory_map
: source가 file path 인 경우, memory_map 사용시 performance 증가 함coerce_int96_timestamp_unit
: INT96 으로 저장된 timestamp 를 특정 resolution으로 변환 (‘ms’)metadata_nthreads
: 메타데이터 읽을때 쓰레드 갯수를 지정하며, 특히 partitioned dataset 읽을 때 속도 증가 (현재use_legacy_dataset
사용시 지원 안됨)use_legacy_dataset
: false 을 경우 새로운 Arrow Dataset API 를 사용함 ->filters
사용 가능하게 해줌
예제 코드
from pyarrow.parquet import ParquetDataset
dataset = ParquetDataset('./data',
memory_map=True,
use_legacy_dataset=False)
df = pd.read_parquet('./data')
file_rows = [frag.count_rows() for frag in dataset.fragments]
print('Pandas shape :', df.shape)
print('Pandas size :', sys.getsizeof(df))
print('Pyarrow size :', sys.getsizeof(dataset))
print('files :', dataset.files)
print('fragments :', dataset.fragments)
print('files rows :', file_rows)
print('column size :', len(dataset.schema))
Pandas shape : (1000, 14)
Pandas size : 692478
Pyarrow size : 64
files : ['./data/dt=20230101/userdata.parquet']
fragments : [<pyarrow.dataset.ParquetFileFragment path=./data/dt=20230101/userdata.parquet partition=[dt=20230101]>]
files rows : [1000]
column size : 14
3.2 Calculate total number of rows
import pyarrow.parquet as pq
dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)
total_rows = 0
for frag in dataset.fragments:
for i, row_group in enumerate(frag.row_groups):
total_rows += row_group.num_rows
print(total_rows) # 전체 rows 갯수.
3.3 Pandas Iteration
왜 이렇게 해야하냐하면.. 데이터 사이즈가 크기 때문입니다.
그냥 쉽게 pandas 로 불러올 수가 없기 때문입니다.
frag.to_batches()
사용시 batch iteration 을 할 수 있고,
이때 중요한 건 batch
는 실제로 데이터를 읽어온 객체 입니다. 따라서 해당 배치 사이즈 만큼의 데이터가 메모리에 올라가게 됩니다.
따라서 순차적 읽기에는 해당 iteration을 사용해서 처리시 효과적으로 처리 가능하나.. random access 가 필요할 시에는 문제가 발생 할 수 있습니다.
- to_pandas() 함수 사용해서 pandas 로 꺼낼 수 있습니다.
- Fragment Wiki
import pyarrow.parquet as pq
dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)
for frag in dataset.fragments:
for batch in frag.to_batches():
df = batch.to_pandas()
row = batch.take(pa.array([0])) # 특정 row 를 꺼낼 수 있음
print('num rows :', batch.num_rows)
print('Pandas shape:', df.shape)
print(row['gender'])
num rows : 1000
Pandas shape: (1000, 13)
[
"Female"
]
registration_dttm id first_name last_name email gender ip_address cc country birthdate salary title comments
0 2016-02-03 07:55:29 1 Amanda Jordan ajordan0@com.com Female 1.197.201.2 6759521864920116 Indonesia 3/8/1971 49756.53 Internal Auditor 1E+02
3.4 Efficient Iteraration
Python 으로 필요한 데이터만 꺼낼수도 있습니다.
대용량 데이터를 다룰때 특히 pandas 로 전부다 꺼내지 말고, 필요한것만 python 으로 불러오는게 더 효율적입니다.
import pyarrow.parquet as pq
dataset = pq.ParquetDataset('./data', use_legacy_dataset=False)
for frag in dataset.fragments:
for batch in frag.to_batches():
for i in range(batch.num_rows):
user_id = batch.column('user_id')[i].as_py()
name = batch.column('name')[i].as_py()
age = batch.column('age')[i].as_py()
print(user_id, name, age)
3.5 PyTorch Dataset
import random
from bisect import bisect_right
from typing import List, Iterator, Optional, Tuple
import pandas as pd
from pyarrow.dataset import ParquetFileFragment
from pyarrow.lib import RecordBatch
from pyarrow.parquet import ParquetDataset
from torch.utils.data import Dataset
class PyArrowDataset(Dataset):
"""
Restriction
- Don't shuffle in Dataloader. this is for efficiency to precess large dataset.
If you need to shuffle, do it before this custom dataset. (like in SparkSQL)
But the algorithm supports random access.
"""
def __init__(self, source: str, seed: int = 123):
self.source = source
self.seed = seed
# Pyarrow
self.dataset = ParquetDataset(source, use_legacy_dataset=False)
self.fragments: List[ParquetFileFragment] = self.dataset.fragments
self._batches: Optional[Iterator[RecordBatch]] = None
self._batch: Optional[RecordBatch] = None
# Indexing meta information to make search faster
self._cumulative_n_rows: List[int] = []
self._batch_idx: int = 0
# Index
self._fragment_idx = 0
# Initialization
self._init()
def _init(self):
random.seed(self.seed)
self._cumulative_n_rows = [frag.count_rows() for frag in self.fragments]
for i in range(1, len(self._cumulative_n_rows)):
self._cumulative_n_rows[i] += self._cumulative_n_rows[i - 1]
def _get_next(self, idx: int) -> Tuple[RecordBatch, int, int]:
def get_prev_cum_frag_size(_fragment_idx):
if _fragment_idx >= 1:
return self._cumulative_n_rows[_fragment_idx - 1]
return 0
# Calculate fragment idx
fragment_idx = self._fragment_idx
fragment_changed = False
_prev_size = get_prev_cum_frag_size(fragment_idx)
_cur_size = self._cumulative_n_rows[self._fragment_idx]
if (idx < _prev_size) or (idx >= _cur_size):
fragment_idx = bisect_right(self._cumulative_n_rows, idx)
assert fragment_idx < len(self.fragments)
# fragment_idx %= len(self.fragments)
fragment_changed = self._fragment_idx != fragment_idx
self._fragment_idx = fragment_idx
self._batch_idx = 0
del self._batches
del self._batch
self._batches = None
self._batch = None
# Calculate batch idx
_prev_size = get_prev_cum_frag_size(fragment_idx)
batch_idx = idx - _prev_size
batch_changed = batch_idx < self._batch_idx
# Calculate batches of the fragment
if self._batches is None or fragment_changed or batch_changed:
if self._batches:
self._batches.clear()
self.batches = self.fragments[fragment_idx].to_batches()
self._batch = None
if self._batch is None:
self._batch = next(self.batches)
self._batch_idx = 0
while True:
if self._batch_idx <= batch_idx < self._batch_idx + self._batch.num_rows:
break
self._batch_idx += self._batch.num_rows
self._batch = next(self.batches)
return self._batch, fragment_idx, batch_idx - self._batch_idx
def __del__(self):
del self.dataset
del self.fragments
del self._batches
del self._batch
def __len__(self):
return self._cumulative_n_rows[-1]
def __getitem__(self, idx):
batch, fragment_idx, batch_idx = self._get_next(idx)
return batch.column('idx')[batch_idx].as_py()
dataset = PyArrowDataset("./data")
print(dataset[50000])
print(dataset[5000])
print(dataset[1])
print(dataset[500000])
print(dataset[151547])
print(dataset[127])
print(dataset[878])