Skip to content

Instantly share code, notes, and snippets.

@tuulos
Created February 4, 2023 00:52
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tuulos/8e22d5e64bc93fb143d774a146f9170b to your computer and use it in GitHub Desktop.
Save tuulos/8e22d5e64bc93fb143d774a146f9170b to your computer and use it in GitHub Desktop.
Train a model with a config file using Metaflow
from metaflow import FlowSpec, step, IncludeFile
def dataset_wine():
from sklearn import datasets
return datasets.load_wine(return_X_y=True)
def model_knn(train_data, train_labels):
from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier()
model.fit(train_data, train_labels)
return model
def model_svm(train_data, train_labels):
from sklearn import svm
model = svm.SVC(kernel='poly')
model.fit(train_data, train_labels)
return model
MODELS = {'knn': model_knn,
'svm': model_svm}
DATASETS = {'wine': dataset_wine}
class TrainWithConfigFlow(FlowSpec):
config_file = IncludeFile('config', default='config.json')
@step
def start(self):
import json
self.config = json.loads(self.config_file)
self.next(self.load_data)
@step
def load_data(self):
from sklearn.model_selection import train_test_split
print('Loading dataset', self.config['dataset'])
X, y = DATASETS[self.config['dataset']]()
self.train_data,\
self.test_data,\
self.train_labels,\
self.test_labels = train_test_split(X, y, test_size=0.2, random_state=0)
self.next(self.train)
@step
def train(self):
print("Training model", self.config['model'])
self.model = MODELS[self.config['model']](self.train_data,
self.train_labels)
self.next(self.end)
@step
def end(self):
self.score = self.model.score(self.test_data, self.test_labels)
print('Eval score', self.score)
if __name__ == '__main__':
TrainWithConfigFlow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment