Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created August 22, 2020 00:04
Show Gist options
  • Save morganmcg1/b2a26e213482d3355a3d3a64c91e94ac to your computer and use it in GitHub Desktop.
Save morganmcg1/b2a26e213482d3355a3d3a64c91e94ac to your computer and use it in GitHub Desktop.
Karpathy's minGPT in Fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Quick Demo of Andrej Karpathy's minGPT Play Char Demo \n",
"- You can find the Play Char demo in the minGPT repo here: https://github.com/karpathy/minGPT\n",
"- Goal: Generate Shakespere\n",
"\n",
"This notebook is partially based on the fastai Transformers tutorial: http://docs.fast.ai/tutorial.transformers\n",
"\n",
"**Note**:\n",
"- This needs the minGPT repo downloaded in the same folder as this notebook"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text.all import *\n",
"from minGPT.mingpt.model import GPT, GPTConfig, GPT1Config"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data\n",
"\n",
"You can download the raw text file at: https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_text = open('input.txt', 'r').read()\n",
"raw_text[:100]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Loaders"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class CharTransform(Transform):\n",
" def __init__(self, data, block_size):\n",
" chars = list(set(data))\n",
" data_size, vocab_size = len(data), len(chars)\n",
" print('data has %d characters, %d unique.' % (data_size, vocab_size))\n",
" \n",
" self.stoi = { ch:i for i,ch in enumerate(chars) }\n",
" self.itos = { i:ch for i,ch in enumerate(chars) }\n",
" self.block_size = block_size\n",
" self.vocab_size = vocab_size\n",
" self.data = data\n",
" self.n_sequences = math.ceil(len(self.data) / (self.block_size + 1))\n",
" \n",
" def encodes(self, o):\n",
" i = np.random.randint(0, len(self.data) - (self.block_size + 1))\n",
" chunk = self.data[i:i+self.block_size+1]\n",
" dix = [self.stoi[s] for s in chunk]\n",
" return torch.tensor(dix)\n",
" \n",
" def decodes(self, o):\n",
" t = ''.join([self.itos[s.item()] for s in o])\n",
" return TitledStr(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**:\n",
"- Note `block_size` in Karpathy's code is equivalent to `Sequence Length` in fastai\n",
"- We do not specify a validation set here as Karpathy does not in their notebook. Therefore we set `split_idx=0` in `TfmdLists`"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data has 1115394 characters, 65 unique.\n"
]
}
],
"source": [
"sl = 128\n",
"block_size = sl\n",
"n_samples = math.ceil(len(raw_text) / (block_size + 1))\n",
"\n",
"tls = TfmdLists(list(range(n_samples)), tfms=[CharTransform(raw_text, 128)], split_idx=0, dl_type=LMDataLoader)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We specify `dl_type=LMDataLoader` for when we will convert this `TfmdLists` to `DataLoaders`: we will use an `LMDataLoader` since we have a language modeling problem, not the usual fastai `TfmdDL`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"e them?\n",
"\n",
"LADY GREY:\n",
"What you command, that rests in me to do.\n",
"\n",
"KING EDWARD IV:\n",
"But you will take exceptions to my boon.\n",
"\n",
"LADY GRE\n"
]
}
],
"source": [
"show_at(tls.train, 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The fastai library expects the data to be assembled in a `DataLoaders` object (something that has a training and validation dataloader). We can get one by using the `dataloaders` method. We just have to specify a batch size and a sequence length. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"bs = 256\n",
"dls = tls.dataloaders(bs=bs, seq_len=sl)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"#o = dls.one_batch(); len(o), o[0].size(), o[1].size(), o"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>text_</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>que time would lie unswept,\\nAnd mountainous error be too highly heapt\\nFor truth to o'er-peer. Rather than fool it so,\\nLet the hi</td>\n",
" <td>ue time would lie unswept,\\nAnd mountainous error be too highly heapt\\nFor truth to o'er-peer. Rather than fool it so,\\nLet the hig</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>an action and capacity,\\nOf notill stand there,\\nRemembering how I love thy company.\\n\\nROMEO:\\nAnd I'll still stay, to have thee sti</td>\n",
" <td>n action and capacity,\\nOf notill stand there,\\nRemembering how I love thy company.\\n\\nROMEO:\\nAnd I'll still stay, to have thee stil</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls.show_batch(max_n=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Callback to Grab First Output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we need to write the event `after_pred` and replace `self.learn.pred` (which contains the predictions that will be passed to the loss function) by just its first element. This is because Karpathy's model actually also calculates the loss in the forward pass, and returns (logits, loss). We only need the logits.\n",
"\n",
"In callbacks, there is a shortcut that lets you access any of the underlying `Learner` attribute so we can write `self.pred[0]` instead of `self.learn.pred[0]`. That shorcut only works for read access, not write, so we have to write `self.learn.pred` on the right side (otherwise we would set a `pred` attribute in the `Callback`)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"class DropOutput(Callback):\n",
" def after_pred(self): self.learn.pred = self.pred[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GPT Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I had to use 6 layers instead of 8 as my 2080 GPU has 13GB of ram"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"mconf = GPTConfig(dls.char_transform.vocab_size, sl, n_layer=6, n_head=8, n_embd=512)\n",
"model = GPT(mconf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learner"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(), opt_func=partial(Adam, sqr_mom=0.95, wd=0.1), \n",
" cbs=[DropOutput]) #.to_fp16()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/morgan/anaconda3/envs/fastai2_me/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.\n",
" warn(\"Your generator is empty.\")\n"
]
},
{
"data": {
"text/plain": [
"SuggestedLRs(lr_min=0.0019054606556892395, lr_steep=2.511886486900039e-05)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>3.338925</td>\n",
" <td>None</td>\n",
" <td>00:18</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3.062790</td>\n",
" <td>None</td>\n",
" <td>00:18</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.832788</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>2.687037</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>2.595154</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>2.532446</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>2.486231</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2.451136</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>2.417365</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>2.380837</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>2.334122</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>2.279138</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>2.206434</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>2.130157</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2.056702</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>1.991350</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>1.929794</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>1.873611</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>1.829401</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>1.782693</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>1.759206</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>1.719195</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>1.684518</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>1.652220</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>1.630679</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>1.604731</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>1.581347</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>1.560486</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>1.543253</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>1.525474</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>1.508890</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>1.493199</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32</td>\n",
" <td>1.478913</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>1.466304</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>1.452765</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>1.441874</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>1.430259</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37</td>\n",
" <td>1.420458</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38</td>\n",
" <td>1.408505</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39</td>\n",
" <td>1.398455</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>1.389998</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41</td>\n",
" <td>1.382348</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42</td>\n",
" <td>1.371678</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43</td>\n",
" <td>1.360793</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>44</td>\n",
" <td>1.351078</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>45</td>\n",
" <td>1.342533</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>46</td>\n",
" <td>1.337409</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>47</td>\n",
" <td>1.329052</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>48</td>\n",
" <td>1.319077</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>49</td>\n",
" <td>1.310749</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>1.302483</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>51</td>\n",
" <td>1.294455</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>52</td>\n",
" <td>1.287185</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>53</td>\n",
" <td>1.277416</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>54</td>\n",
" <td>1.269452</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>55</td>\n",
" <td>1.259615</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>56</td>\n",
" <td>1.252440</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>57</td>\n",
" <td>1.243651</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>58</td>\n",
" <td>1.234919</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>59</td>\n",
" <td>1.227910</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>60</td>\n",
" <td>1.217211</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>61</td>\n",
" <td>1.209070</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>62</td>\n",
" <td>1.199655</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>63</td>\n",
" <td>1.191217</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>64</td>\n",
" <td>1.185001</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>65</td>\n",
" <td>1.175376</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>66</td>\n",
" <td>1.167852</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>67</td>\n",
" <td>1.159078</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>68</td>\n",
" <td>1.151620</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>69</td>\n",
" <td>1.142666</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>70</td>\n",
" <td>1.133764</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>71</td>\n",
" <td>1.127385</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>72</td>\n",
" <td>1.119594</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>73</td>\n",
" <td>1.112152</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>74</td>\n",
" <td>1.104161</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>1.097133</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>76</td>\n",
" <td>1.089942</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>77</td>\n",
" <td>1.083349</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>78</td>\n",
" <td>1.074918</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>79</td>\n",
" <td>1.068063</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>80</td>\n",
" <td>1.061971</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>81</td>\n",
" <td>1.056730</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>82</td>\n",
" <td>1.051151</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>83</td>\n",
" <td>1.044623</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>84</td>\n",
" <td>1.039983</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>85</td>\n",
" <td>1.036456</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>86</td>\n",
" <td>1.030712</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>87</td>\n",
" <td>1.023932</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>88</td>\n",
" <td>1.019215</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>89</td>\n",
" <td>1.013263</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>90</td>\n",
" <td>1.009832</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>91</td>\n",
" <td>1.005648</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>92</td>\n",
" <td>1.004165</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>93</td>\n",
" <td>1.001175</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>94</td>\n",
" <td>0.997694</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>95</td>\n",
" <td>0.995172</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>96</td>\n",
" <td>0.990947</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>97</td>\n",
" <td>0.989438</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>98</td>\n",
" <td>0.986149</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" <tr>\n",
" <td>99</td>\n",
" <td>0.982865</td>\n",
" <td>None</td>\n",
" <td>00:19</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/morgan/anaconda3/envs/fastai2_me/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.\n",
" warn(\"Your generator is empty.\")\n"
]
}
],
"source": [
"learn.fit_one_cycle(100, 6e-4, div_final=10) \n",
"\n",
"# div_final=10 will ensure we finish at the same lr as Karpathy"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot_loss()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note there is no validation loss or validation perplexity as we we didn't specify a validation set (as per Karpathy's notebook)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"O God, O God! to our bloody supply the horsemen are stay and forth. But wherefore stand you to the crown?\n",
"\n",
"DUKE VINCENTIO:\n",
"These are the people's voices,\n",
"To make the varlet concerning me to town.\n",
"I was no brother out the man that wilt thou have made forth thyself.\n",
"\n",
"GLOUCESTER:\n",
"This is a man, and thy supposed king,\n",
"Like men cracking in the world,\n",
"Which were instruction mortal tiden let us hence;\n",
"And, for the dangerous seats of the earth\n",
"Would be the sheep that was thus fined\n",
"An an unusual good deeds, that does not.\n",
"\n",
"CORIOLANUS:\n",
"Why that is this?\n",
"\n",
"FRANCISCA:\n",
"What was? much of you, if I show much, we have said\n",
"Shall be your first with your shame, take you into\n",
"the court: back up my packet again of his silk,\n",
"And that you might repossess him well\n",
"Then when he did show more wantons,\n",
"Than the proudest hollow can find it off.\n",
"\n",
"MARCIUS:\n",
"Though they shall feel, they change put us on you.\n",
"\n",
"CLARENCE:\n",
"Or the duke's death, the love been done:\n",
"I am so broad to live in thee and thy looks.\n",
"\n",
"RICHARD:\n",
"I will do well, I see thee here.\n",
"\n",
"GRUMIO:\n",
"Ay, sir, the man take note upon your gates: arry this good\n",
"lady, by you.\n",
"\n",
"VINCENTIO:\n",
"Thou art perfect these are at the last I see thee\n",
"In such a disguised, though they be not taught but they are not; but they\n",
"are full of vanity. They say, the duke was too much,\n",
"To make a schoolar, a silken posterity,\n",
"That it may call you now?\n",
"\n",
"BIANCA:\n",
"What, my gracious lady?\n",
"\n",
"PETRUCHIO:\n",
"Why, then, 'tis no less than this, but a sharp-pointed match.\n",
"What is the bloody which this story rich,\n",
"To make thee stranger about the stone,\n",
"And both thy speech. What is it thou,\n",
"To be the man to be thus apprehended;\n",
"And so, I trust I. What if it be so,\n",
"I will content thee with thy wisdoms: better for thy death,\n",
"By this time I had rather had been so\n",
"fidiused for a maid: there is no less less extend.\n",
"\n",
"KING RICHARD II:\n",
"Is it even that so strict me?\n",
"\n",
"GLOUCESTER:\n",
"And teach you, my lord, is dead?\n",
"\n",
"GLOUCESTER:\n",
"And why the king's daughter is too slander'd?\n",
"\n",
"LADY GREY:\n",
"To tell you plain, I pray you,\n"
]
}
],
"source": [
"from minGPT.mingpt.utils import sample\n",
"\n",
"context = \"O God, O God!\"\n",
"x = torch.tensor([dls.char_transform.stoi[s] for s in context], dtype=torch.long)[None,...].to(dls.device)\n",
"y = sample(model, x, 2000, temperature=0.9, sample=True, top_k=5)[0]\n",
"completion = ''.join([dls.char_transform.itos[int(i)] for i in y])\n",
"print(completion)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment