Skip to content

Instantly share code, notes, and snippets.

@tuulos
Created August 4, 2023 21:25
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tuulos/70c468bb22be3e81abce828499210d1f to your computer and use it in GitHub Desktop.
Save tuulos/70c468bb22be3e81abce828499210d1f to your computer and use it in GitHub Desktop.
demonstrates resumable processing
from metaflow import FlowSpec, step, retry, S3, current
from functools import wraps
import pickle
import random
PREFIX = "resumable-processing"
class resumable_processing:
def __init__(self, process="list", results="output"):
self.input = process
self.output = results
def __call__(self, f):
@wraps(f)
def func(s):
prefix = f"{PREFIX}/{current.pathspec}/"
index = 0
# checkpoint an item in S3
def append(obj):
nonlocal index
with S3(run=s) as s3:
s3.put("%s%0.12d" % (prefix, index), pickle.dumps(obj))
index += 1
# find checkpointed items in S3
with S3(run=s) as s3:
index = sum(1 for _ in s3.list_recursive([prefix]))
# make data and the checkpointing function available to the user code
inputs = getattr(s, self.input)
s.append_item = append
s.iter_items = iter(inputs[index:])
# call the user code
f(s)
# delete these internal attributes so they won't get saved as artifacts
delattr(s, "iter_items")
delattr(s, "append_item")
# retrieve all data
with S3(run=s) as s3:
out = [pickle.loads(obj.blob) for obj in s3.get_recursive([prefix])]
# make sure all items were processed
if len(out) != len(inputs):
raise Exception("Not all items were processed")
# save results in an artifact
setattr(s, self.output, out)
return func
class ResumingFlow(FlowSpec):
@step
def start(self):
self.list = [4, 1, 2, 3, 9, 5]
self.next(self.flaky_step)
@retry
@resumable_processing(process="list", results="output")
@step
def flaky_step(self):
print("processing", self.list)
# this iterator iterates over unprocessed items
for item in self.iter_items:
print("processing item", item)
# do processing here
processed = item + 2
# simulate random failure
if random.random() < 0.3:
raise Exception("Random failure!")
# persist processed item
self.append_item(processed)
print(item, "processed successfully")
self.next(self.end)
@step
def end(self):
print("results", self.output)
if __name__ == "__main__":
ResumingFlow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment