Skip to content

Instantly share code, notes, and snippets.

@keitakurita
Last active July 10, 2020 00:25
Show Gist options
  • Save keitakurita/0fac1cc175591971a18c456ce18c45a5 to your computer and use it in GitHub Desktop.
Save keitakurita/0fac1cc175591971a18c456ce18c45a5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from typing import *\n",
"import torch\n",
"import torch.optim as optim\n",
"import numpy as np\n",
"import pandas as pd\n",
"from functools import partial\n",
"from overrides import overrides\n",
"\n",
"from allennlp.data import Instance\n",
"from allennlp.data.token_indexers import TokenIndexer\n",
"from allennlp.data.tokenizers import Token\n",
"from allennlp.nn import util as nn_util"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Config(dict):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" for k, v in kwargs.items():\n",
" setattr(self, k, v)\n",
" \n",
" def set(self, key, val):\n",
" self[key] = val\n",
" setattr(self, key, val)\n",
" \n",
"config = Config(\n",
" testing=True,\n",
" seed=1,\n",
" batch_size=64,\n",
" lr=3e-4,\n",
" epochs=2,\n",
" hidden_sz=64,\n",
" max_seq_len=100, # necessary to limit memory usage\n",
" max_vocab_size=100000,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.common.checks import ConfigurationError"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"USE_GPU = torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../data\") / \"jigsaw\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set random seed manually to replicate results"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x1176dd710>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.manual_seed(config.seed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Data"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.data.vocabulary import Vocabulary\n",
"from allennlp.data.dataset_readers import DatasetReader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"label_cols = [\"toxic\", \"severe_toxic\", \"obscene\",\n",
" \"threat\", \"insult\", \"identity_hate\"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.data.fields import TextField, MetadataField, ArrayField\n",
"\n",
"class JigsawDatasetReader(DatasetReader):\n",
" def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),\n",
" token_indexers: Dict[str, TokenIndexer] = None,\n",
" max_seq_len: Optional[int]=config.max_seq_len) -> None:\n",
" super().__init__(lazy=False)\n",
" self.tokenizer = tokenizer\n",
" self.token_indexers = token_indexers or {\"tokens\": SingleIdTokenIndexer()}\n",
" self.max_seq_len = max_seq_len\n",
"\n",
" @overrides\n",
" def text_to_instance(self, tokens: List[Token], id: str,\n",
" labels: np.ndarray) -> Instance:\n",
" sentence_field = TextField(tokens, self.token_indexers)\n",
" fields = {\"tokens\": sentence_field}\n",
" \n",
" id_field = MetadataField(id)\n",
" fields[\"id\"] = id_field\n",
" \n",
" label_field = ArrayField(array=labels)\n",
" fields[\"label\"] = label_field\n",
"\n",
" return Instance(fields)\n",
" \n",
" @overrides\n",
" def _read(self, file_path: str) -> Iterator[Instance]:\n",
" df = pd.read_csv(file_path)\n",
" if config.testing: df = df.head(1000)\n",
" for i, row in df.iterrows():\n",
" yield self.text_to_instance(\n",
" [Token(x) for x in self.tokenizer(row[\"comment_text\"])],\n",
" row[\"id\"], row[label_cols].values,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare token handlers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use the spacy tokenizer here"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter\n",
"from allennlp.data.token_indexers import SingleIdTokenIndexer\n",
"\n",
"# the token indexer is responsible for mapping tokens to integers\n",
"token_indexer = SingleIdTokenIndexer()\n",
"\n",
"def tokenizer(x: str):\n",
" return [w.text for w in\n",
" SpacyWordSplitter(language='en_core_web_sm', \n",
" pos_tags=False).split_words(x)[:config.max_seq_len]]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"reader = JigsawDatasetReader(\n",
" tokenizer=tokenizer,\n",
" token_indexers={\"tokens\": token_indexer}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"267it [00:02, 94.93it/s]\n",
"251it [00:01, 172.26it/s]\n"
]
}
],
"source": [
"train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in [\"train.csv\", \"test_proced.csv\"])\n",
"val_ds = None"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"267"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_ds)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<allennlp.data.instance.Instance at 0x1a2b034160>,\n",
" <allennlp.data.instance.Instance at 0x1a2b016208>,\n",
" <allennlp.data.instance.Instance at 0x1a2afec748>,\n",
" <allennlp.data.instance.Instance at 0x1a2af92828>,\n",
" <allennlp.data.instance.Instance at 0x1a2af8a4a8>,\n",
" <allennlp.data.instance.Instance at 0x1a2af7e630>,\n",
" <allennlp.data.instance.Instance at 0x1a2af79710>,\n",
" <allennlp.data.instance.Instance at 0x1a2af66550>,\n",
" <allennlp.data.instance.Instance at 0x10bd9a518>,\n",
" <allennlp.data.instance.Instance at 0x1a28d5def0>]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ds[:10]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'tokens': [Explanation,\n",
" Why,\n",
" the,\n",
" edits,\n",
" made,\n",
" under,\n",
" my,\n",
" username,\n",
" Hardcore,\n",
" Metallica,\n",
" Fan,\n",
" were,\n",
" reverted,\n",
" ?,\n",
" They,\n",
" were,\n",
" n't,\n",
" vandalisms,\n",
" ,,\n",
" just,\n",
" closure,\n",
" on,\n",
" some,\n",
" GAs,\n",
" after,\n",
" I,\n",
" voted,\n",
" at,\n",
" New,\n",
" York,\n",
" Dolls,\n",
" FAC,\n",
" .,\n",
" And,\n",
" please,\n",
" do,\n",
" n't,\n",
" remove,\n",
" the,\n",
" template,\n",
" from,\n",
" the,\n",
" talk,\n",
" page,\n",
" since,\n",
" I,\n",
" 'm,\n",
" retired,\n",
" now.89.205.38.27],\n",
" '_token_indexers': {'tokens': <allennlp.data.token_indexers.single_id_token_indexer.SingleIdTokenIndexer at 0x1a27b07400>},\n",
" '_indexed_tokens': None,\n",
" '_indexer_name_to_indexed_token': None}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vars(train_ds[0].fields[\"tokens\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"01/29/2019 19:57:00 - INFO - allennlp.data.vocabulary - Fitting token dictionary from dataset.\n",
"100%|██████████| 267/267 [00:00<00:00, 11635.95it/s]\n"
]
}
],
"source": [
"vocab = Vocabulary.from_instances(train_ds, max_vocab_size=config.max_vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare iterator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.data.iterators import BucketIterator"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"iterator = BucketIterator(batch_size=config.batch_size, \n",
" sorting_keys=[(\"tokens\", \"num_tokens\")],\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful! "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"iterator.index_with(vocab)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read sample"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"batch = next(iter(iterator(train_ds)))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'tokens': {'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n",
" [ 74, 203, 24, ..., 0, 0, 0],\n",
" [ 5, 85, 26, ..., 0, 0, 0],\n",
" ...,\n",
" [ 5, 103, 1068, ..., 0, 0, 0],\n",
" [2972, 622, 1099, ..., 0, 0, 0],\n",
" [ 5, 3301, 8, ..., 0, 0, 0]])},\n",
" 'id': ['0029541a38c523a0',\n",
" '003d77a20601cec1',\n",
" '006fda507acd9769',\n",
" '00173958f46763a2',\n",
" '00a5394e626e72c6',\n",
" '005e2ae8f864f76c',\n",
" '0060c5c9030b2d14',\n",
" '0095756047a71716',\n",
" '0000997932d777bf',\n",
" '0061b075244dd234',\n",
" '002d6c9d9f85e81f',\n",
" '001735f961a23fc4',\n",
" '000113f07ec002fd',\n",
" '0029b87aa9c7dc4a',\n",
" '00070ef96486d6f9',\n",
" '007bc29766a43e3c',\n",
" '004f5608984d99f1',\n",
" '000b08c464718505',\n",
" '007bbfa4da2bc32d',\n",
" '00537730daf8c5f1',\n",
" '008f22e7b58e559b',\n",
" '009b3b15f1ada72f',\n",
" '004f981460421bdf',\n",
" '000f35deef84dc4a',\n",
" '002f0e29c60807b1',\n",
" '003dbd1b9b354c1f',\n",
" '004f6dbe69f3545d',\n",
" '00733f0a4a58cf42',\n",
" '0057b7710cb5ebb2',\n",
" '006f2c1459f3b6b1',\n",
" '00349c6325526c11',\n",
" '0030614cfd96d9d1',\n",
" '006120d209a4a46c',\n",
" '00744c2f77391702',\n",
" '007571394afafcb5',\n",
" '0005c987bdfc9d4b',\n",
" '0052a7e684beeb1a',\n",
" '0022cf8467ebc9fd',\n",
" '004de318396bbf8b',\n",
" '0053bab79133c0fc',\n",
" '001956c382006abd',\n",
" '0015f4aa35ebe9b5',\n",
" '000ffab30195c5e1',\n",
" '002a6beca33307b3',\n",
" '0028d62e8a5629aa',\n",
" '0063dd8f202a698a',\n",
" '008198c5a9d85a8e',\n",
" '007f1839ada915e6',\n",
" '0037e59caead9dab',\n",
" '00a20f187531df59',\n",
" '00585c1da10b448b',\n",
" '00961bcaadd6a278',\n",
" '0082b4b42b3f07a1',\n",
" '005b214511a69b4b',\n",
" '008e2acf5bcf4be4',\n",
" '002264ea4d5f2887',\n",
" '0013a8b1a5f26bcb',\n",
" '0074b307c2d9a100',\n",
" '008f93320e3661b8',\n",
" '008344a80c43b8c9',\n",
" '0038f191ffc93d75',\n",
" '0069e6d57a3beb51',\n",
" '0087c131ccffe160',\n",
" '00a330961879175c'],\n",
" 'label': tensor([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 1., 0., 1., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.]])}"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 131, 264, 21, ..., 0, 0, 0],\n",
" [ 74, 203, 24, ..., 0, 0, 0],\n",
" [ 5, 85, 26, ..., 0, 0, 0],\n",
" ...,\n",
" [ 5, 103, 1068, ..., 0, 0, 0],\n",
" [2972, 622, 1099, ..., 0, 0, 0],\n",
" [ 5, 3301, 8, ..., 0, 0, 0]])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch[\"tokens\"][\"tokens\"]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 84])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch[\"tokens\"][\"tokens\"].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prepare Model"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper\n",
"from allennlp.nn.util import get_text_field_mask\n",
"from allennlp.models import Model\n",
"from allennlp.modules.text_field_embedders import TextFieldEmbedder\n",
"\n",
"class BaselineModel(Model):\n",
" def __init__(self, word_embeddings: TextFieldEmbedder,\n",
" encoder: Seq2VecEncoder,\n",
" out_sz: int=len(label_cols)):\n",
" super().__init__(vocab)\n",
" self.word_embeddings = word_embeddings\n",
" self.encoder = encoder\n",
" self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)\n",
" self.loss = nn.BCEWithLogitsLoss()\n",
" \n",
" def forward(self, tokens: Dict[str, torch.Tensor],\n",
" id: Any, label: torch.Tensor) -> torch.Tensor:\n",
" mask = get_text_field_mask(tokens)\n",
" embeddings = self.word_embeddings(tokens)\n",
" state = self.encoder(embeddings, mask)\n",
" class_logits = self.projection(state)\n",
" \n",
" output = {\"class_logits\": class_logits}\n",
" output[\"loss\"] = self.loss(class_logits, label)\n",
"\n",
" return output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare embeddings"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.modules.token_embedders import Embedding\n",
"from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n",
"\n",
"token_embedding = Embedding(num_embeddings=config.max_vocab_size + 2,\n",
" embedding_dim=300, padding_index=0)\n",
"# the embedder maps the input tokens to the appropriate embedding matrix\n",
"word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({\"tokens\": token_embedding})"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper\n",
"encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),\n",
" config.hidden_sz, bidirectional=True, batch_first=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"model = BaselineModel(\n",
" word_embeddings, \n",
" encoder, \n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"if USE_GPU: model.cuda()\n",
"else: model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic sanity checks"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"tokens = batch[\"tokens\"]\n",
"labels = batch"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n",
" [ 74, 203, 24, ..., 0, 0, 0],\n",
" [ 5, 85, 26, ..., 0, 0, 0],\n",
" ...,\n",
" [ 5, 103, 1068, ..., 0, 0, 0],\n",
" [2972, 622, 1099, ..., 0, 0, 0],\n",
" [ 5, 3301, 8, ..., 0, 0, 0]])}"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mask = get_text_field_mask(tokens)\n",
"mask"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n",
" [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n",
" [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n",
" [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n",
" [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n",
" [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n",
" [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n",
" [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n",
" [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n",
" [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n",
" [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n",
" [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n",
" [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n",
" [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n",
" [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n",
" [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n",
" [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n",
" [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n",
" [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n",
" [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n",
" [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n",
" [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n",
" [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n",
" [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n",
" [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n",
" [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n",
" [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n",
" [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n",
" [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n",
" [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n",
" [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n",
" [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n",
" [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n",
" [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n",
" [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n",
" [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n",
" [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n",
" [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n",
" [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n",
" [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n",
" [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n",
" [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n",
" [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n",
" [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n",
" [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n",
" [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n",
" [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n",
" [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n",
" [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n",
" [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n",
" [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n",
" [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n",
" [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n",
" [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n",
" [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n",
" [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n",
" [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n",
" [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n",
" [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n",
" [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n",
" [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n",
" [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n",
" [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n",
" [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n",
" grad_fn=<AddmmBackward>)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embeddings = model.word_embeddings(tokens)\n",
"state = model.encoder(embeddings, mask)\n",
"class_logits = model.projection(state)\n",
"class_logits"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'class_logits': tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n",
" [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n",
" [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n",
" [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n",
" [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n",
" [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n",
" [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n",
" [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n",
" [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n",
" [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n",
" [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n",
" [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n",
" [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n",
" [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n",
" [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n",
" [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n",
" [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n",
" [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n",
" [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n",
" [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n",
" [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n",
" [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n",
" [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n",
" [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n",
" [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n",
" [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n",
" [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n",
" [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n",
" [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n",
" [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n",
" [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n",
" [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n",
" [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n",
" [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n",
" [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n",
" [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n",
" [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n",
" [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n",
" [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n",
" [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n",
" [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n",
" [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n",
" [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n",
" [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n",
" [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n",
" [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n",
" [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n",
" [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n",
" [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n",
" [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n",
" [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n",
" [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n",
" [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n",
" [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n",
" [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n",
" [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n",
" [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n",
" [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n",
" [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n",
" [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n",
" [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n",
" [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n",
" [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n",
" [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n",
" grad_fn=<AddmmBackward>),\n",
" 'loss': tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(**batch)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"loss = model(**batch)[\"loss\"]"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"loss.backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters(), lr=config.lr)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"from allennlp.training.trainer import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" optimizer=optimizer,\n",
" iterator=iterator,\n",
" train_dataset=train_ds,\n",
" cuda_device=0 if USE_GPU else -1,\n",
" num_epochs=config.epochs,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Beginning training.\n",
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Epoch 0/1\n",
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 315.695104\n",
"01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Training\n",
"loss: 0.6929 ||: 100%|██████████| 5/5 [00:08<00:00, 1.54s/it]\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training | Validation\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - loss | 0.693 | N/A\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - cpu_memory_MB | 315.695 | N/A\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Estimated training time remaining: 0:00:08\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch 1/1\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 734.257152\n",
"01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training\n",
"loss: 0.6820 ||: 100%|██████████| 5/5 [00:07<00:00, 1.76s/it]\n",
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Training | Validation\n",
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - loss | 0.682 | N/A\n",
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - cpu_memory_MB | 734.257 | N/A\n",
"01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n"
]
}
],
"source": [
"metrics = trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@vkurpad
Copy link

vkurpad commented Jun 12, 2019

Hi, what is the test_proced file here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment