Skip to content

Instantly share code, notes, and snippets.

@tuulos
Created March 16, 2023 07:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tuulos/6c1f957cc49e44c277a4565dfebd04d7 to your computer and use it in GitHub Desktop.
Save tuulos/6c1f957cc49e44c277a4565dfebd04d7 to your computer and use it in GitHub Desktop.
import random
from metaflow import FlowSpec, step, S3, Flow, Parameter, profile, kubernetes, conda, conda_base
# change columns according to your schema (or remove column list to load all)
COLUMNS = ['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime']
# group parquet files as 1GB batches
def shard_data(src, batch_size=1_000_000_000):
with S3() as s3:
objs = s3.list_recursive([src])
random.shuffle(objs)
while objs:
size = 0
batch = []
while objs and size < batch_size:
obj = objs.pop()
batch.append(obj.url)
size += obj.size
yield batch
@conda_base(python='3.8.10')
class ShardedDataFlow(FlowSpec):
s3root = Parameter('s3root', help="S3 root for data")
@step
def start(self):
self.shards = list(shard_data(self.s3root))
self.next(self.process_shard_arrow, foreach='shards')
@kubernetes(memory=12000)
@conda(libraries={'pyarrow': '5.0.0'})
@step
def process_shard_arrow(self):
import pyarrow
from pyarrow.parquet import ParquetFile
self.shard_files = self.input
with S3() as s3:
with profile('loading data'):
objs = s3.get_many(self.shard_files)
with profile('deserializing parquet'):
table = pyarrow.concat_tables([ParquetFile(obj.path).read(columns=COLUMNS) for obj in objs])
self.arrow_table_len = len(table)
self.next(self.process_shard_polars)
@kubernetes(memory=12000)
@conda(libraries={'polars': '0.16.13'})
@step
def process_shard_polars(self):
import polars
self.shard_files = self.input
with S3() as s3:
with profile('loading data'):
objs = s3.get_many(self.shard_files)
with profile('deserializing polars'):
table = polars.concat([polars.read_parquet(obj.path, columns=COLUMNS) for obj in objs])
print('table', table)
self.polars_table_len = len(table)
self.next(self.join)
@step
def join(self, inputs):
print('total rows', sum(inp.arrow_table_len for inp in inputs))
self.next(self.end)
@step
def end(self):
pass
if __name__ == '__main__':
ShardedDataFlow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment