Skip to content

Instantly share code, notes, and snippets.

@erikbern
Created December 7, 2022 22:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save erikbern/0a9c6e929021f1fdb9a89d325a800cff to your computer and use it in GitHub Desktop.
Save erikbern/0a9c6e929021f1fdb9a89d325a800cff to your computer and use it in GitHub Desktop.
import sys
import modal
stub = modal.Stub(
image=modal.Image.debian_slim().pip_install(["datasets", "torch", "transformers"])
)
class Predictor:
def __enter__(self):
from transformers import pipeline
self.sentiment_pipeline = pipeline(
model="distilbert-base-uncased-finetuned-sst-2-english"
)
@stub.function(cpu=4)
def predict(self, phrase: str):
pred = self.sentiment_pipeline(phrase, truncation=True, max_length=512, top_k=2)
# pred will look like: [{'label': 'NEGATIVE', 'score': 0.99}, {'label': 'POSITIVE', 'score': 0.01}]
probs = {p["label"]: p["score"] for p in pred}
return (phrase, probs["POSITIVE"])
@stub.function
def get_data():
from datasets import load_dataset
imdb = load_dataset("imdb")
return [row["text"] for row in imdb["test"]]
if __name__ == "__main__":
with stub.run():
data = get_data()
for phrase, score in Predictor().predict.map(data):
print(f"{score:.4f} {phrase[:80]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment