Last active
July 10, 2020 00:25
-
-
Save keitakurita/0fac1cc175591971a18c456ce18c45a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"from typing import *\n", | |
"import torch\n", | |
"import torch.optim as optim\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"from functools import partial\n", | |
"from overrides import overrides\n", | |
"\n", | |
"from allennlp.data import Instance\n", | |
"from allennlp.data.token_indexers import TokenIndexer\n", | |
"from allennlp.data.tokenizers import Token\n", | |
"from allennlp.nn import util as nn_util" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Config(dict):\n", | |
" def __init__(self, **kwargs):\n", | |
" super().__init__(**kwargs)\n", | |
" for k, v in kwargs.items():\n", | |
" setattr(self, k, v)\n", | |
" \n", | |
" def set(self, key, val):\n", | |
" self[key] = val\n", | |
" setattr(self, key, val)\n", | |
" \n", | |
"config = Config(\n", | |
" testing=True,\n", | |
" seed=1,\n", | |
" batch_size=64,\n", | |
" lr=3e-4,\n", | |
" epochs=2,\n", | |
" hidden_sz=64,\n", | |
" max_seq_len=100, # necessary to limit memory usage\n", | |
" max_vocab_size=100000,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.common.checks import ConfigurationError" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"USE_GPU = torch.cuda.is_available()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"DATA_ROOT = Path(\"../data\") / \"jigsaw\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Set random seed manually to replicate results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch._C.Generator at 0x1176dd710>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.manual_seed(config.seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Load Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.data.vocabulary import Vocabulary\n", | |
"from allennlp.data.dataset_readers import DatasetReader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Prepare dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"label_cols = [\"toxic\", \"severe_toxic\", \"obscene\",\n", | |
" \"threat\", \"insult\", \"identity_hate\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.data.fields import TextField, MetadataField, ArrayField\n", | |
"\n", | |
"class JigsawDatasetReader(DatasetReader):\n", | |
" def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),\n", | |
" token_indexers: Dict[str, TokenIndexer] = None,\n", | |
" max_seq_len: Optional[int]=config.max_seq_len) -> None:\n", | |
" super().__init__(lazy=False)\n", | |
" self.tokenizer = tokenizer\n", | |
" self.token_indexers = token_indexers or {\"tokens\": SingleIdTokenIndexer()}\n", | |
" self.max_seq_len = max_seq_len\n", | |
"\n", | |
" @overrides\n", | |
" def text_to_instance(self, tokens: List[Token], id: str,\n", | |
" labels: np.ndarray) -> Instance:\n", | |
" sentence_field = TextField(tokens, self.token_indexers)\n", | |
" fields = {\"tokens\": sentence_field}\n", | |
" \n", | |
" id_field = MetadataField(id)\n", | |
" fields[\"id\"] = id_field\n", | |
" \n", | |
" label_field = ArrayField(array=labels)\n", | |
" fields[\"label\"] = label_field\n", | |
"\n", | |
" return Instance(fields)\n", | |
" \n", | |
" @overrides\n", | |
" def _read(self, file_path: str) -> Iterator[Instance]:\n", | |
" df = pd.read_csv(file_path)\n", | |
" if config.testing: df = df.head(1000)\n", | |
" for i, row in df.iterrows():\n", | |
" yield self.text_to_instance(\n", | |
" [Token(x) for x in self.tokenizer(row[\"comment_text\"])],\n", | |
" row[\"id\"], row[label_cols].values,\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Prepare token handlers" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We will use the spacy tokenizer here" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter\n", | |
"from allennlp.data.token_indexers import SingleIdTokenIndexer\n", | |
"\n", | |
"# the token indexer is responsible for mapping tokens to integers\n", | |
"token_indexer = SingleIdTokenIndexer()\n", | |
"\n", | |
"def tokenizer(x: str):\n", | |
" return [w.text for w in\n", | |
" SpacyWordSplitter(language='en_core_web_sm', \n", | |
" pos_tags=False).split_words(x)[:config.max_seq_len]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"reader = JigsawDatasetReader(\n", | |
" tokenizer=tokenizer,\n", | |
" token_indexers={\"tokens\": token_indexer}\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"267it [00:02, 94.93it/s]\n", | |
"251it [00:01, 172.26it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in [\"train.csv\", \"test_proced.csv\"])\n", | |
"val_ds = None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"267" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(train_ds)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<allennlp.data.instance.Instance at 0x1a2b034160>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2b016208>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2afec748>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2af92828>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2af8a4a8>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2af7e630>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2af79710>,\n", | |
" <allennlp.data.instance.Instance at 0x1a2af66550>,\n", | |
" <allennlp.data.instance.Instance at 0x10bd9a518>,\n", | |
" <allennlp.data.instance.Instance at 0x1a28d5def0>]" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"train_ds[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'tokens': [Explanation,\n", | |
" Why,\n", | |
" the,\n", | |
" edits,\n", | |
" made,\n", | |
" under,\n", | |
" my,\n", | |
" username,\n", | |
" Hardcore,\n", | |
" Metallica,\n", | |
" Fan,\n", | |
" were,\n", | |
" reverted,\n", | |
" ?,\n", | |
" They,\n", | |
" were,\n", | |
" n't,\n", | |
" vandalisms,\n", | |
" ,,\n", | |
" just,\n", | |
" closure,\n", | |
" on,\n", | |
" some,\n", | |
" GAs,\n", | |
" after,\n", | |
" I,\n", | |
" voted,\n", | |
" at,\n", | |
" New,\n", | |
" York,\n", | |
" Dolls,\n", | |
" FAC,\n", | |
" .,\n", | |
" And,\n", | |
" please,\n", | |
" do,\n", | |
" n't,\n", | |
" remove,\n", | |
" the,\n", | |
" template,\n", | |
" from,\n", | |
" the,\n", | |
" talk,\n", | |
" page,\n", | |
" since,\n", | |
" I,\n", | |
" 'm,\n", | |
" retired,\n", | |
" now.89.205.38.27],\n", | |
" '_token_indexers': {'tokens': <allennlp.data.token_indexers.single_id_token_indexer.SingleIdTokenIndexer at 0x1a27b07400>},\n", | |
" '_indexed_tokens': None,\n", | |
" '_indexer_name_to_indexed_token': None}" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"vars(train_ds[0].fields[\"tokens\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Prepare vocabulary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"01/29/2019 19:57:00 - INFO - allennlp.data.vocabulary - Fitting token dictionary from dataset.\n", | |
"100%|██████████| 267/267 [00:00<00:00, 11635.95it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"vocab = Vocabulary.from_instances(train_ds, max_vocab_size=config.max_vocab_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Prepare iterator" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.data.iterators import BucketIterator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iterator = BucketIterator(batch_size=config.batch_size, \n", | |
" sorting_keys=[(\"tokens\", \"num_tokens\")],\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful! " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"iterator.index_with(vocab)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Read sample" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch = next(iter(iterator(train_ds)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'tokens': {'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
" [ 74, 203, 24, ..., 0, 0, 0],\n", | |
" [ 5, 85, 26, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
" [2972, 622, 1099, ..., 0, 0, 0],\n", | |
" [ 5, 3301, 8, ..., 0, 0, 0]])},\n", | |
" 'id': ['0029541a38c523a0',\n", | |
" '003d77a20601cec1',\n", | |
" '006fda507acd9769',\n", | |
" '00173958f46763a2',\n", | |
" '00a5394e626e72c6',\n", | |
" '005e2ae8f864f76c',\n", | |
" '0060c5c9030b2d14',\n", | |
" '0095756047a71716',\n", | |
" '0000997932d777bf',\n", | |
" '0061b075244dd234',\n", | |
" '002d6c9d9f85e81f',\n", | |
" '001735f961a23fc4',\n", | |
" '000113f07ec002fd',\n", | |
" '0029b87aa9c7dc4a',\n", | |
" '00070ef96486d6f9',\n", | |
" '007bc29766a43e3c',\n", | |
" '004f5608984d99f1',\n", | |
" '000b08c464718505',\n", | |
" '007bbfa4da2bc32d',\n", | |
" '00537730daf8c5f1',\n", | |
" '008f22e7b58e559b',\n", | |
" '009b3b15f1ada72f',\n", | |
" '004f981460421bdf',\n", | |
" '000f35deef84dc4a',\n", | |
" '002f0e29c60807b1',\n", | |
" '003dbd1b9b354c1f',\n", | |
" '004f6dbe69f3545d',\n", | |
" '00733f0a4a58cf42',\n", | |
" '0057b7710cb5ebb2',\n", | |
" '006f2c1459f3b6b1',\n", | |
" '00349c6325526c11',\n", | |
" '0030614cfd96d9d1',\n", | |
" '006120d209a4a46c',\n", | |
" '00744c2f77391702',\n", | |
" '007571394afafcb5',\n", | |
" '0005c987bdfc9d4b',\n", | |
" '0052a7e684beeb1a',\n", | |
" '0022cf8467ebc9fd',\n", | |
" '004de318396bbf8b',\n", | |
" '0053bab79133c0fc',\n", | |
" '001956c382006abd',\n", | |
" '0015f4aa35ebe9b5',\n", | |
" '000ffab30195c5e1',\n", | |
" '002a6beca33307b3',\n", | |
" '0028d62e8a5629aa',\n", | |
" '0063dd8f202a698a',\n", | |
" '008198c5a9d85a8e',\n", | |
" '007f1839ada915e6',\n", | |
" '0037e59caead9dab',\n", | |
" '00a20f187531df59',\n", | |
" '00585c1da10b448b',\n", | |
" '00961bcaadd6a278',\n", | |
" '0082b4b42b3f07a1',\n", | |
" '005b214511a69b4b',\n", | |
" '008e2acf5bcf4be4',\n", | |
" '002264ea4d5f2887',\n", | |
" '0013a8b1a5f26bcb',\n", | |
" '0074b307c2d9a100',\n", | |
" '008f93320e3661b8',\n", | |
" '008344a80c43b8c9',\n", | |
" '0038f191ffc93d75',\n", | |
" '0069e6d57a3beb51',\n", | |
" '0087c131ccffe160',\n", | |
" '00a330961879175c'],\n", | |
" 'label': tensor([[0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 0., 1., 0., 1., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 0., 1., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [1., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0., 0.]])}" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
" [ 74, 203, 24, ..., 0, 0, 0],\n", | |
" [ 5, 85, 26, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
" [2972, 622, 1099, ..., 0, 0, 0],\n", | |
" [ 5, 3301, 8, ..., 0, 0, 0]])" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch[\"tokens\"][\"tokens\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([64, 84])" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch[\"tokens\"][\"tokens\"].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Prepare Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.optim as optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper\n", | |
"from allennlp.nn.util import get_text_field_mask\n", | |
"from allennlp.models import Model\n", | |
"from allennlp.modules.text_field_embedders import TextFieldEmbedder\n", | |
"\n", | |
"class BaselineModel(Model):\n", | |
" def __init__(self, word_embeddings: TextFieldEmbedder,\n", | |
" encoder: Seq2VecEncoder,\n", | |
" out_sz: int=len(label_cols)):\n", | |
" super().__init__(vocab)\n", | |
" self.word_embeddings = word_embeddings\n", | |
" self.encoder = encoder\n", | |
" self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)\n", | |
" self.loss = nn.BCEWithLogitsLoss()\n", | |
" \n", | |
" def forward(self, tokens: Dict[str, torch.Tensor],\n", | |
" id: Any, label: torch.Tensor) -> torch.Tensor:\n", | |
" mask = get_text_field_mask(tokens)\n", | |
" embeddings = self.word_embeddings(tokens)\n", | |
" state = self.encoder(embeddings, mask)\n", | |
" class_logits = self.projection(state)\n", | |
" \n", | |
" output = {\"class_logits\": class_logits}\n", | |
" output[\"loss\"] = self.loss(class_logits, label)\n", | |
"\n", | |
" return output" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Prepare embeddings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.modules.token_embedders import Embedding\n", | |
"from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n", | |
"\n", | |
"token_embedding = Embedding(num_embeddings=config.max_vocab_size + 2,\n", | |
" embedding_dim=300, padding_index=0)\n", | |
"# the embedder maps the input tokens to the appropriate embedding matrix\n", | |
"word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({\"tokens\": token_embedding})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper\n", | |
"encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),\n", | |
" config.hidden_sz, bidirectional=True, batch_first=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = BaselineModel(\n", | |
" word_embeddings, \n", | |
" encoder, \n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"if USE_GPU: model.cuda()\n", | |
"else: model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Basic sanity checks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tokens = batch[\"tokens\"]\n", | |
"labels = batch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
" [ 74, 203, 24, ..., 0, 0, 0],\n", | |
" [ 5, 85, 26, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
" [2972, 622, 1099, ..., 0, 0, 0],\n", | |
" [ 5, 3301, 8, ..., 0, 0, 0]])}" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tokens" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0]])" | |
] | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mask = get_text_field_mask(tokens)\n", | |
"mask" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n", | |
" [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n", | |
" [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n", | |
" [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n", | |
" [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n", | |
" [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n", | |
" [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n", | |
" [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n", | |
" [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n", | |
" [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n", | |
" [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n", | |
" [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n", | |
" [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n", | |
" [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n", | |
" [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n", | |
" [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n", | |
" [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n", | |
" [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n", | |
" [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n", | |
" [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n", | |
" [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n", | |
" [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n", | |
" [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n", | |
" [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n", | |
" [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n", | |
" [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n", | |
" [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n", | |
" [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n", | |
" [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n", | |
" [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n", | |
" [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n", | |
" [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n", | |
" [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n", | |
" [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n", | |
" [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n", | |
" [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n", | |
" [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n", | |
" [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n", | |
" [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n", | |
" [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n", | |
" [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n", | |
" [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n", | |
" [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n", | |
" [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n", | |
" [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n", | |
" [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n", | |
" [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n", | |
" [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n", | |
" [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n", | |
" [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n", | |
" [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n", | |
" [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n", | |
" [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n", | |
" [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n", | |
" [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n", | |
" [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n", | |
" [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n", | |
" [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n", | |
" [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n", | |
" [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n", | |
" [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n", | |
" [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n", | |
" [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n", | |
" [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n", | |
" grad_fn=<AddmmBackward>)" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"embeddings = model.word_embeddings(tokens)\n", | |
"state = model.encoder(embeddings, mask)\n", | |
"class_logits = model.projection(state)\n", | |
"class_logits" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'class_logits': tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n", | |
" [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n", | |
" [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n", | |
" [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n", | |
" [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n", | |
" [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n", | |
" [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n", | |
" [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n", | |
" [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n", | |
" [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n", | |
" [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n", | |
" [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n", | |
" [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n", | |
" [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n", | |
" [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n", | |
" [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n", | |
" [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n", | |
" [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n", | |
" [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n", | |
" [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n", | |
" [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n", | |
" [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n", | |
" [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n", | |
" [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n", | |
" [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n", | |
" [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n", | |
" [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n", | |
" [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n", | |
" [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n", | |
" [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n", | |
" [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n", | |
" [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n", | |
" [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n", | |
" [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n", | |
" [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n", | |
" [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n", | |
" [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n", | |
" [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n", | |
" [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n", | |
" [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n", | |
" [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n", | |
" [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n", | |
" [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n", | |
" [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n", | |
" [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n", | |
" [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n", | |
" [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n", | |
" [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n", | |
" [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n", | |
" [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n", | |
" [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n", | |
" [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n", | |
" [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n", | |
" [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n", | |
" [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n", | |
" [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n", | |
" [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n", | |
" [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n", | |
" [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n", | |
" [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n", | |
" [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n", | |
" [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n", | |
" [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n", | |
" [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n", | |
" grad_fn=<AddmmBackward>),\n", | |
" 'loss': tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model(**batch)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss = model(**batch)[\"loss\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"loss.backward()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Train" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer = optim.Adam(model.parameters(), lr=config.lr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from allennlp.training.trainer import Trainer\n", | |
"\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" optimizer=optimizer,\n", | |
" iterator=iterator,\n", | |
" train_dataset=train_ds,\n", | |
" cuda_device=0 if USE_GPU else -1,\n", | |
" num_epochs=config.epochs,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Beginning training.\n", | |
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Epoch 0/1\n", | |
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 315.695104\n", | |
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Training\n", | |
"loss: 0.6929 ||: 100%|██████████| 5/5 [00:08<00:00, 1.54s/it]\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training | Validation\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - loss | 0.693 | N/A\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - cpu_memory_MB | 315.695 | N/A\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Estimated training time remaining: 0:00:08\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch 1/1\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 734.257152\n", | |
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training\n", | |
"loss: 0.6820 ||: 100%|██████████| 5/5 [00:07<00:00, 1.76s/it]\n", | |
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Training | Validation\n", | |
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - loss | 0.682 | N/A\n", | |
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - cpu_memory_MB | 734.257 | N/A\n", | |
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n" | |
] | |
} | |
], | |
"source": [ | |
"metrics = trainer.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"celltoolbar": "Tags", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, what is the test_proced file here?