Skip to content

Instantly share code, notes, and snippets.

@dmesquita
Created July 5, 2020 18:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dmesquita/4828356f6990b075e623c67cb703640e to your computer and use it in GitHub Desktop.
Save dmesquita/4828356f6990b075e623c67cb703640e to your computer and use it in GitHub Desktop.
import os
import yaml
from sklearn.datasets import fetch_20newsgroups
import pandas as pd
# read params
params = yaml.safe_load(open('params.yaml'))['prepare']
categories = params['categories']
# create folder to save file
data_path = os.path.join('data', 'prepared')
os.makedirs(data_path, exist_ok=True)
#fetch data
newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
def newsgroups_to_csv(split_name, newsgroups):
df = pd.DataFrame([newsgroups.data, newsgroups.target.tolist()]).T
df.columns = ['text', 'target']
df_target_names = pd.DataFrame(newsgroups.target_names)
df_target_names.columns = ['target_name']
out = pd.merge(df, df_target_names, left_on='target', right_index=True)
out.to_csv(os.path.join(data_path, split_name+".csv"))
# save data to file
newsgroups_to_csv("train", newsgroups_train)
newsgroups_to_csv("test", newsgroups_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment