Skip to content

Instantly share code, notes, and snippets.

@puzzler10
Created August 26, 2017 03:06
Show Gist options
  • Save puzzler10/d118731cd4698eb73a2b7563823fa0ed to your computer and use it in GitHub Desktop.
Save puzzler10/d118731cd4698eb73a2b7563823fa0ed to your computer and use it in GitHub Desktop.
Nearest Neighbour classifier
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Nearest Neighbour Classification\n",
"The nearest neighbour classifier is a very simple algorithm for image classification. While not used much in practice, it is simple to implement and it helps to gain a deeper understanding of the problems in image classification. \n",
"\n",
"Just like other classifiers, if we give the nearest neighbour classifier an image, it'll try and find its closest match. We \"train\" the classifier by giving it a large collection of images that it allowed to search through to find matches. The bigger the number of images that we give it, the closer the match that the classifier is able to find, and conversely the longer it will take to find that match.\n",
"\n",
"How do you compare images? An image is basically a large matrix filled with pixels. Colours are made up of different combinations of the three primary colours; correspondingly, each pixel is made up of a red, green and blue component. A 32x32 image will be represented as three 32x32 matrices; one for red, green and blue. \n",
"\n",
"Once this is understood it is easy to find ways to compare images. One way is for an image pair, we could find the differences between the two in the red, green and blue layers, and then sum up their absolute values. This is called $L^1$ distance and is probably the simplest method to compare two images. \n",
"\n",
"The Nearest Neighbour classifier is not a particulary efficient or useful classifier. It will do better than chance, but will achieve nowhere near the performance of other classifiers. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"We are following along with the [Stanford tutorial](http://cs231n.github.io/classification/). We will use a dataset called [CIFAR-10](http://www.cs.toronto.edu/~kriz/cifar.html); it comprises of 60,000 labelled tiny images (32x32 pixels) across ten classes, split into training and test batches. Each class has 5,000 training images available. \n",
"\n",
"The first thing to do is download the [dataset](http://www.cs.toronto.edu/~kriz/cifar.html) and unzip it to a folder. CIFAR-10 comes in 'pickled' form, meaning that it is stored as a byte stream. This not a convenient format for us and we will have to convert the byte stream into something useable (depickling) before we can use it. \n",
"\n",
"First let's import the libraries we'll want:"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import pickle\n",
"import os\n",
"import sys\n",
"import glob\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The data comes in a number of batches - one for the test set and a few for the training set. Each batch file contains a dictionary with the following:\n",
"\n",
"* data: a 10000x3072 numpy array of uint8s, where each row of the array stores a 32x32 colour image. The first 1024 entries contain the red channel values, the next 1024 green, and the final 1024 blue.\n",
"* labels: a list of 10000 numbers in the range 0-9. The number at index i indicates the label of the ith image in the array data.\n",
"\n",
"We will create a helper function for use on unpickling our batches, which will return a dictonary for each batch. Dictonaries aren't that helpful for us to work with, so we'll change the data into a numpy array instead. \n",
"\n",
"We'll create three functions to unpack the data: \n",
"\n",
"* `unpickle`: a function for unpickling each batch \n",
"* `load_CIFAR_batch`: a function for extracting the data and labels from each batch\n",
"* `load_CIFAR10`: a function for loading the data and labels of the entire dataset (separating train and test batches)\n",
"\n",
"Each row of the returned data will contain one image. This format will be hard for us to visualise - so we'll reshape each row to 32x32 blocks to allow us to visualise the images. "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def unpickle(file):\n",
" \"\"\"\n",
" Unpickles a file stream\n",
" \n",
" :param file: file path\n",
" :return: dict with keys {batch_label, labels, data, filename}\n",
" \"\"\"\n",
" import pickle\n",
" with open(file, 'rb') as fo: \n",
" dict = pickle.load(fo, encoding='latin1')\n",
" return dict"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_CIFAR_batch(file): \n",
" \"\"\"\n",
" Load data and labels from a batch file\n",
" \n",
" :param file: file path\n",
" :return: \n",
" X: ndarray of shape (number of images in batch, (3x32x32)) \n",
" Y: ndarray of shape (number of images in batch) \n",
" label_names: vector of label_names\n",
" \"\"\"\n",
" file_dict = unpickle(file)\n",
" X = file_dict['data']\n",
" Y = file_dict['labels']\n",
" return X, Y"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def load_CIFAR10(root):\n",
" \"\"\"\n",
" Load the CIFAR10 dataset. \n",
" \n",
" Currently handles multiple train batches, but only one test batch.\n",
" \n",
" :param root: path to the root folder containing the CIFAR10 dataset\n",
" :return: \n",
" x_train: ndarray of shape (number of train images, 32, 32, 3) holding RGB values of images\n",
" y_train: ndarray of shape (number of train images, ) holding labels for the training set \n",
" x_test: ndarray of shape (number of test images, 32, 32, 3) holding RGB values of images\n",
" y_test: ndarray of shape (number of test images, ) holding labels for the testing set \n",
" \"\"\"\n",
" os.chdir(root)\n",
" train_batch_list = glob.glob('data_batch*')\n",
" test_batch = glob.glob('test_batch*')\n",
" \n",
" # Training set \n",
" x_train = ''\n",
" y_train = ''\n",
" for file in train_batch_list:\n",
" x_batch, y_batch = load_CIFAR_batch(file)\n",
" if (x_train == ''):\n",
" x_train = x_batch\n",
" y_train = y_batch\n",
" else:\n",
" x_train = np.concatenate((x_train, x_batch))\n",
" y_train = np.concatenate((y_train, y_batch))\n",
" \n",
" # Change x_train from n_image * col format to a n_image * 3x32x32 format\n",
" x_train = x_train.reshape((x_train.shape[0], 3, 32, 32)).transpose(0,2,3,1)\n",
" \n",
" # Test Set \n",
" x_test, y_test = load_CIFAR_batch(test_batch[0])\n",
" # change x_test from n_image * col format to a n_image * 3x32x32 format\n",
" x_test = x_test.reshape((x_test.shape[0], 3, 32, 32)).transpose(0,2,3,1)\n",
" \n",
" return x_train, y_train, x_test, y_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can go ahead and load up our data now."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/lib/python3.6/site-packages/ipykernel_launcher.py:23: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n"
]
}
],
"source": [
"# Get filenames of batches \n",
"path_data = '/Users/tomroth/Documents/deeplearning_courses/cs231n_exercises/cifar-10-batches-py/'\n",
"x_train, y_train, x_test, y_test = load_CIFAR10(path_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at an image to see if it loaded correctly."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x18739a780>"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGgRJREFUeJztnW2MXGd1x/9nZmd3be8m9q5jx6zf8uIiggGHbi0QCNEi\nUIqQAqWN4APKhwjzgaIiUVVRKpX0G1QFxIcKyTQRpkqBtIESVWmrECECUhWyMY7jxIE4wSF+iTfx\nC7t+29mZOf0w19LGuec/s3dn79h9/j/J8u4989znzDP3zN15/nPOMXeHECI9Kv12QAjRHxT8QiSK\ngl+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlEGljLYzG4D8E0AVQD/7O5fYY9fMzbuE5u2\n5Np6/U1Ddraic0Wj2OmceULdiI18vgJ+FHOjk3HRQ1oFx3EXI6v1+HxLoMg5gyEnXzuGszNn4ie3\ngMLBb2ZVAP8E4MMAjgB40swedvfnojETm7bgwf/6Wa6t1aIvfS5NsmaNZnw+NhezzQfzzbdiR5rN\nZkE/4nOypZpvNnKPN8gl3fL4hEb8cOJI9AbL3njrjfgP0Sbzg5wzWn93EvxkfYtcpwDg5Hq0+fga\nWawfX/2bz3R9jqX82b8TwCF3f8nd6wC+D+D2JZxPCFEiSwn+CQCvLPj9SHZMCHEVsOwbfma2y8ym\nzGzq1MnXl3s6IUSXLCX4jwLYtOD3jdmxN+Duu9190t0nx8bXLmE6IUQvWUrwPwlgm5ndYGaDAD4F\n4OHeuCWEWG4K7/a7e8PM/hLA/6At9d3v7s92GAULdp2j4wwjco0RsaNCjGQTOHynZHNRG3nrrTBH\nyFpFz61KHDGygW1GVALiYiQtMmWhWulKoXrzOZkyElrIGlaq8agCCkdmDE3GLpKASrRWizjVknR+\nd38EwCNLOYcQoj/oG35CJIqCX4hEUfALkSgKfiESRcEvRKIsabe/CJVQfFm8KFMhWhl7V2MqGhMc\nI3WlQuQwJzaaqcbGEWkoVMuYjEbXkUhU8RnRChNZ4lFVdsICclj7nEGCEblCmJzHlpHJmCCvJ5NT\nI3qRBas7vxCJouAXIlEU/EIkioJfiERR8AuRKKXv9rMqcxHRBivb/2W71C2iLLCd+0pgosoCsdE6\nfQVr7oU75kyRYEkudJUXXwePJlWRmcLFB2jSTLwBT0p/kfMNkBebKQjsmitS8izOxur+utGdX4hE\nUfALkSgKfiESRcEvRKIo+IVIFAW/EIlSutQX1SsrUrGOyXJUzmOyInEkkoBo6TlidFJ7jkk2rNZd\nu5HSm2kFnXwAoEKTRIol1ESjmHxFazKScUw+rAa3t2aD1BKMBoH73yqYvBPlVbHahGENv0W8Xrrz\nC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlGWJPWZ2WEAswCaABruPkkfj06yUj5xHiCTf4gkQ32I\nJaCopRh/ByWSHZUqWRYe8THMPCxWZ5ALR0Uy/lhWH/OD2Kh8mA/NxKTnK5odufj1Z/Jg7GP38dUL\nnf+P3V29t4W4ytCf/UIkylKD3wH8xMyeMrNdvXBICFEOS/2z//3uftTM1gF41Myed/fHFz4ge1PY\nBQBvmdi4xOmEEL1iSXd+dz+a/T8N4EcAduY8Zre7T7r75Nj42qVMJ4ToIYWD38xWmdnopZ8BfATA\ngV45JoRYXpbyZ/96AD/KsvQGAPyru/83H+KxXEYLTObTIrJRpcIys5rEFppCSYl1W4rbk3Uo0sna\nQsUmIGg1RZaDFLmMszDbNuJHZCxYwNOpHwUy/vKTH9tzkfWg7dyIscVus8F8pItaeNdeTFezwsHv\n7i8BeFfR8UKI/iKpT4hEUfALkSgKfiESRcEvRKIo+IVIlCumVx+VlKIzFSzqyGQj2h8tHlVgTCcZ\njWWxEU8CnYpJny0iDzEfmRwZFmol5+OZb/FcPLszMNDrg8xFa50W62sYyYDGXpjw5VQBTyFEBxT8\nQiSKgl+IRFHwC5EoCn4hEqUPu/35u5GFdtlpXkxva74xG89voVvihWwtqhJEGStkl53JB3SRF9/2\njKkYdJudPufFtxSj9fGWYT3o9R2tVeFWb92hO78QiaLgFyJRFPxCJIqCX4hEUfALkSgKfiESpWSp\nz+CBrtEMas9dSVQL1KVjypATKWeeFZKrxC9bJXg/Zy3KqsTJhs/HfhAMUZ1EUtMwlCmBlpP7VJXU\nawyuqxZ5Xi0jNR4Ltj1rhesRy7PGivhF18ciVE/d+YVIFAW/EImi4BciURT8QiSKgl+IRFHwC5Eo\nHaU+M7sfwMcATLv79uzYGIAfANgK4DCAO9z99FIcYQrF0vOXekOkzLGsshaRMFtEzmNZibwFWFDD\nj2ZNFpSvyHOLMgVpxlxBH9nVE5fw622WYHsgec1Yhl70vAtmn3ZLN3f+7wC47bJjdwN4zN23AXgs\n+10IcRXRMfjd/XEApy47fDuAPdnPewB8vMd+CSGWmaKf+de7+/Hs51fR7tgrhLiKWPKGn7c/sIQf\nTsxsl5lNmdnUqZOvL3U6IUSPKBr8J8xsAwBk/09HD3T33e4+6e6TY+NrC04nhOg1RYP/YQB3Zj/f\nCeDHvXFHCFEW3Uh93wPwQQBrzewIgC8D+AqAB83sLgAvA7iju+kcFslUtD1Vb8W+0IcONi/wXllY\nNmKtyFjGX2BrsedFlpc9Y5bhFmlRVdaBijwvJlWyNY6kTyYdslesxWRRdk4qB+fbmJRaDbxczNXW\nMfjd/dOB6UOLmEcIcYWhb/gJkSgKfiESRcEvRKIo+IVIFAW/EIlSfq++QLIxKl+V40NHWyAbUdeL\ntXYrlLnXni/I6iOyEX/GxNpqhKZqJSgkSnyvsqmYDEgKXUZZhKxXH3vOTeIHlz5jEa4ZjPNWXPSz\nWg2KfsYevAnd+YVIFAW/EImi4BciURT8QiSKgl+IRFHwC5EopUt9kYTFkt8imScsfNgBJisyiQ2e\n76QHxzMj8YRkgRGZZ4As1kDQ7i6SkwDeE26AFJ6sk6Vqeb7/bO2rTLJjbetYIdRg/T3wDwAqBbPz\nmAzI649GlWHJmHCu7mNCd34hEkXBL0SiKPiFSBQFvxCJouAXIlFK3e03eNiGirUzQit/DN1dZRRt\nkxXsyrKEjiJ17oBQWAAAnDv7+9B2MiiPPj8/T/yIJxtaORqPI4ysGsk93mySXfaB4dDGVIdGI04w\nihQhdtejyUy0jRo5J1WY8kdaNT4jq+/XLbrzC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlG6add1\nP4CPAZh29+3ZsXsBfBbAa9nD7nH3R7qZMJLSWEuucEzB4n583OJr+NH2TiRBh01VsVjKefHXz4a2\nJ598Mvf43NxcOKZej2XAeQ8yhQC869ZbQ9s7tm/PPc6kvlVrhkJbM5B7AdBiiJHExhJ05oks1ySy\nYlS3EODXd5RkxBKugg5fPa/h9x0At+Uc/4a778j+dRX4Qogrh47B7+6PAzhVgi9CiBJZymf+L5jZ\nfjO738zW9MwjIUQpFA3+bwG4EcAOAMcBfC16oJntMrMpM5s6dfJkwemEEL2mUPC7+wl3b3q7hMq3\nAewkj93t7pPuPjk2Pl7UTyFEjykU/Ga2YcGvnwBwoDfuCCHKohup73sAPghgrZkdAfBlAB80sx1o\ni1WHAXyuq9kcqEQyCpFeIpkkPFdHP1i7KyIbBdILa7tVVI70ZiwprV87Ftq2bHxL7vEKkaFOnor3\nc+utWOobIE/8+efy7wc337yNnC80gdY7ZFJfYGOSI2sbViGZduylbjIfA92OJbrG8nf3dAx+d/90\nzuH7FjGHEOIKRN/wEyJRFPxCJIqCX4hEUfALkSgKfiESpfR2XRFcoigml5UFazVWIZlZxIT6xTjT\nbmgwftneuu2m3OOjo3Ehzqee2hvaBkfib26fu3AhtEWS6diaa8MxtDgmk72IjBm18nKWJUig1ym9\nDhYjwrVpETkyKuC5mA52uvMLkSgKfiESRcEvRKIo+IVIFAW/EImi4BciUUqX+iLBgxVGDDPtiMRD\nCz4y6TAopggAhnwbywSMpCYAaBEfp6ePh7Znnv5VaLt48WLu8Vd+97twTHUgvgxuuDm2HTt6LLS9\n973vyz3OsgubpJ9gtRJnFzrpW9cKrqsayc5rksuD9shjlxW7rgJXWNFPtKJ46V7r051fiERR8AuR\nKAp+IRJFwS9Eoij4hUiUknf7Hc1gt5TuogZJES2SSeEsyYK95ZHd+UYzfzeazcXyOZqkTt/4daQV\nQi1+2arIb3k1Sionj4/HNQHrzXpoO3Y83u1ft/763ONm8a49rXfI1BuyKx691C22k05etFbQsq09\njFyPZJwHz5uOqUS1MLXbL4TogIJfiERR8AuRKAp+IRJFwS9Eoij4hUiUbtp1bQLwXQDr0c4a2O3u\n3zSzMQA/ALAV7ZZdd7j7aXYu97hNEpdJ8mm24mQJ1o5pIEjQAbjcVAmSS5gKxRJZrr3mmtD26xde\nCG3rNmwMbefOncs9Pro6lvrOnj0b2l49Fst5hw6/HNq+/+8P5R7/iz//VDhmaHA4tDEpmKnE9fmg\n1h0pCshsLGGMlukj10FUq6/B5lpUY67ApS4e0wDwJXe/BcB7AHzezG4BcDeAx9x9G4DHst+FEFcJ\nHYPf3Y+7+97s51kABwFMALgdwJ7sYXsAfHy5nBRC9J5FfeY3s60AbgXwBID17n4p6fxVtD8WCCGu\nEroOfjMbAfAQgC+6+8xCm7erWeR+QDGzXWY2ZWZTp0graCFEuXQV/GZWQzvwH3D3H2aHT5jZhsy+\nAcB03lh33+3uk+4+OTYWf4dcCFEuHYPf2tvw9wE46O5fX2B6GMCd2c93Avhx790TQiwX3WT1vQ/A\nZwA8Y2b7smP3APgKgAfN7C4ALwO4o9OJ3B0X5+NMNjYujwrJbgPJiGqG9c+ARj2/Bh4AVKuDwUzx\ne+jLRA6bnn4ttJ09fz601VnWWaB7NYj0WRlaEdqun9gU2jZuzW8NBgArRvJlzMGVq8IxTVYej2QD\nNjx+PeeCa2eoWovnYvX2mCRNazmGplAOrhCpj9WG7JaOwe/uv0Bcd/NDS/ZACNEX9A0/IRJFwS9E\noij4hUgUBb8QiaLgFyJRSi3gef7CBex9en+ujRWzjDL0aoOx+0M1UiiyFbeFWrUivwAmAFQq+VKf\nV+Ixe/fuC2379j0d2s7Mzoa29Vu2hraNG/Mz/g4dOhSOGSfFPTdv3hzabtr21tC2NZABT7x2Mhwz\nF2TgAVxim6vPhbZK0AtrgLTrqhiT0Ug2HdHz5kk7uihvlcmDEU2ml16G7vxCJIqCX4hEUfALkSgK\nfiESRcEvRKIo+IVIlFKlvkazgVO/P5NrW7EiziwbGMh3c4Bk9VnUywzAViJfrb5mNLQNrxjJPf7i\nb4/E51t9bWi76aYbQtvpmbio5jXr8vvgAcATT/wy9/grR2IfG/Ox9PnJT/5ZaFuzJq7P8PzB53OP\nn3g1lvrqTKYiBTDPkwzIWi3I3iNVP6uk3x2T0owV/iRSnwVyJJO/Ixnw3Ll4LS5Hd34hEkXBL0Si\nKPiFSBQFvxCJouAXIlFK3e13B6LcjXmyS7lmzZrc40PD+Yk2ALB+bf4YAKgRlWBmJl+NAIDZs/mt\nsGBxzbc/eGtc525iIt61PzMb7/afPl8PbTv/6A9zj7/zHW+P5zoTP+dhssarV8ftxi6cu5B7/NzZ\nmdzjAICBuK5ek9SsI0IAms38tXJSH4+pDkVq8QFAo8BuPxsT1QtkdQQvR3d+IRJFwS9Eoij4hUgU\nBb8QiaLgFyJRFPxCJEpHqc/MNgH4LtotuB3Abnf/ppndC+CzAC71nLrH3R/hJ6ugEsg5J0/GCR+z\ngWz04oXT4Zihaix5rF0TS1QsqQOBJDO8Mk4GYslHzUYsETKZh71jb964Ifd4tRrXNIwSp4C4fiIA\n1OfihKC3XH9d7vFXXjkWjhlaFSd3MT1vZiaWD+v1QOrz+Hx1UkuwOhCvI0vemSdt6iKpj5QthAe1\nBBdT9q8bnb8B4EvuvtfMRgE8ZWaPZrZvuPs/LmI+IcQVQje9+o4DOJ79PGtmBwFMLLdjQojlZVGf\n+c1sK4BbATyRHfqCme03s/vNLP5KnRDiiqPr4DezEQAPAfiiu88A+BaAGwHsQPsvg68F43aZ2ZSZ\nTdGvdgohSqWr4DezGtqB/4C7/xAA3P2Euze93Sj82wB25o11993uPunuk6uCnu1CiPLpGPzWbpVy\nH4CD7v71BccXbit/AsCB3rsnhFguutntfx+AzwB4xswu9Z66B8CnzWwH2urCYQCf62ZCD2SNsbX5\n0hAAzAc15ppzv4/n8ViGWrFiOLRVQLLHghZPTcRznTsfZAICmK/H4+bqpH1ZK85+qwdaD5P6WCbY\nAJG2qtXYj8GgtdlNWzaFYyLfAaBBau416xdDmzfz15gobzCyVpEsBwBN4mMkzQFAI5B8mQTbIlmO\n3dLNbv8vkN+gjGv6QogrGn3DT4hEUfALkSgKfiESRcEvRKIo+IVIlFILeLZarVD6YrKGBelNrICk\nNWL5p1qJpZz63FxoGx4Yyj1eo3JY/hiAF56kklIjnq8VyE0sQyxfzLk0F5EjyVqdnc1f/wEiDw5f\nE7+eddK6at346tDWms/PCJ0l56sRH43mzcUZkFaJx83P5a9V0+PXOcoSdCI3Xo7u/EIkioJfiERR\n8AuRKAp+IRJFwS9Eoij4hUiUkqW+Ji4GUt/4mrFwXCR4RNIbAGzcvDG0DQ3GUs7Bg8+FtqPHTuQe\nXzGyKhwzPj4e2mrVuGClDZLCmSApacH7eYv0n4uyFQFggEiOXonPaSvybXNBQU0A8Pm4P2GF9Nar\nDsRS5epVK3OPXzz/ejimVZ8NbUzWHR+JX8/r168LbR7IhydejX1sNvPnGhzo/n6uO78QiaLgFyJR\nFPxCJIqCX4hEUfALkSgKfiESpVSpr1arYf11+ZLHhXNxoctKkPG3ffvbwzGbN14f2mZnYiln5cqR\n0Hb+Yn6G2KHfvhSOeeE3L4Y2lsm4Zk3cA2XVqtjHqBjnykDyAoBa0D8RACxWHGmvwRXD+VLUxYtx\ntuWF+djWIhlzM6fjno3r1uX3Lhwh8uzIaLxWmzasD20TG2I5b7BGMjE9/7m9/npcoHZ2Jv9a/I9/\neyAcczm68wuRKAp+IRJFwS9Eoij4hUgUBb8QidJxt9/MhgE8DmAoe/y/u/uXzWwMwA8AbEW7Xdcd\n7h5vuwLwlqMeJHawhI+5C/k7m/v2/Soc8+wzsR8VUjxvoBYvyZatW3OPv+1tbwvHnD0bJ6scOBC3\nN3zppVhBOH36TGgbGgrqDNbiHX1mW1GLk6cGa/ktuQBgcDDfxuZq0lZp8etSrcZ+bA5as22+fks4\nZtOWOCns2lVx8s4w2dE38tzm6vm1EIeGRsMxMyPnc4/XyGtyOd3c+ecA/Im7vwvtdty3mdl7ANwN\n4DF33wbgsex3IcRVQsfg9zaXbl+17J8DuB3Anuz4HgAfXxYPhRDLQlef+c2smnXonQbwqLs/AWC9\nux/PHvIqgPjbD0KIK46ugt/dm+6+A8BGADvNbPtldgfyKxKY2S4zmzKzqbNn42/WCSHKZVG7/e5+\nBsBPAdwG4ISZbQCA7P/pYMxud59098mRkXgDQwhRLh2D38yuM7PV2c8rAHwYwPMAHgZwZ/awOwH8\neLmcFEL0nm4SezYA2GNmVbTfLB509/80s/8F8KCZ3QXgZQB3dDqRw9HyfMnjmtH4r4K58/lS37Hj\nr4Rjzs/GchiT32qBRAUAP/v5z3OPDwbyGsClrUgOA4CJiYnQVq//JrRVq/ly08hInAw0EIwBgFbQ\nFgqIE1IAYCZYf9aGjLXkunAxloJvvOHm0HY6SPqJkrQAoDYYr8fojbFEWKnE4dRsxFLfqZP5azU8\nHCcYjY/nJ34NkBqDb3pspwe4+34At+YcPwngQ13PJIS4otA3/IRIFAW/EImi4BciURT8QiSKgl+I\nRLGo5tuyTGb2GtqyIACsBRD3IyoP+fFG5Mcbudr82OLu13VzwlKD/w0Tm025+2RfJpcf8kN+6M9+\nIVJFwS9EovQz+Hf3ce6FyI83Ij/eyP9bP/r2mV8I0V/0Z78QidKX4Dez28zs12Z2yMz6VvvPzA6b\n2TNmts/Mpkqc934zmzazAwuOjZnZo2b2QvZ/3K9ref2418yOZmuyz8w+WoIfm8zsp2b2nJk9a2Z/\nlR0vdU2IH6WuiZkNm9kvzezpzI+/z473dj3cvdR/AKoAXgRwI4BBAE8DuKVsPzJfDgNY24d5PwDg\n3QAOLDj2DwDuzn6+G8BX++THvQD+uuT12ADg3dnPowB+A+CWsteE+FHqmgAwACPZzzUATwB4T6/X\nox93/p0ADrn7S+5eB/B9tIuBJoO7Pw7g1GWHSy+IGvhROu5+3N33Zj/PAjgIYAIlrwnxo1S8zbIX\nze1H8E8AWFiF4wj6sMAZDuAnZvaUme3qkw+XuJIKon7BzPZnHwuW/ePHQsxsK9r1I/paJPYyP4CS\n16SMormpb/i939uFSf8UwOfN7AP9dgjgBVFL4FtofyTbAeA4gK+VNbGZjQB4CMAX3X1moa3MNcnx\no/Q18SUUze2WfgT/UQCbFvy+MTtWOu5+NPt/GsCP0P5I0i+6Koi63Lj7iezCawH4NkpaEzOroR1w\nD7j7D7PDpa9Jnh/9WpNs7kUXze2WfgT/kwC2mdkNZjYI4FNoFwMtFTNbZWajl34G8BEAcf+s5eeK\nKIh66eLK+ARKWBMzMwD3ATjo7l9fYCp1TSI/yl6T0ormlrWDedlu5kfR3kl9EcDf9smHG9FWGp4G\n8GyZfgD4Htp/Ps6jvedxF4BxtNuevQDgJwDG+uTHvwB4BsD+7GLbUIIf70f7T9j9APZl/z5a9poQ\nP0pdEwDvBPCrbL4DAP4uO97T9dA3/IRIlNQ3/IRIFgW/EImi4BciURT8QiSKgl+IRFHwC5EoCn4h\nEkXBL0Si/B/Nk6vAWWQDvwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x17dce3128>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"i=100\n",
"plt.imshow(x_train[i])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's worth noting that if we change x_train to a `float` type, then the image will turn out negative, which can really mess with you if you're not expecting it. "
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x128271668>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGeZJREFUeJztnVuMXWd1x//r3GbGnjHxJTG24+I4MZdgiBOGgMSltAiU\nIqSAWkXwgPIQYR4oKhJ9iFKppG+0KiAeKiTTRJiKAikJIqqiVsFCimhRiBMSx8GE2MF2xtgzycSO\nZ+y5nXNWH86OOnb2+p8ze2b2jPn+P8nymW+db+91vrPXuXz/s9Yyd4cQIj0qK+2AEGJlUPALkSgK\nfiESRcEvRKIo+IVIFAW/EImi4BciURT8QiSKgl+IRKktZrKZ3QbgWwCqAP7V3b9GT1Yxr9fyX28M\nVsCBQibAYiubF9nI4egR+bwiRwQsOCg/VdGFLHjMgIqRX5sWXKxwVuG1X5YnbcmON3FhBlMzzZ4O\nWDj4zawK4F8AfAzACIAnzOxhd/9NNKdeq+D6awZzbZXKwj+EVMlDrFXj47FzMVs9OF+9EjtSrVYL\n+hEfky1VvZr/lNbI1Vex+IBO/DDiSPgiRC70Rq0d2qrMD3LMaP2NvNCQ5Sh0nQKAkevR6/E1slA/\nHjgQht8bj7Hgs/4/twI46u4vuvssgB8CuH0RxxNClMhign8bgJfm/T2SjQkhrgAW9Z2/F8xsL4C9\nAFBnn9OFEKWymHf+UwC2z/v72mzsEtx9n7sPu/sw+94mhCiXxQT/EwB2mdl1ZtYA8BkADy+NW0KI\n5abwx353b5rZXwP4b3Skvvvd/Tk+y+DBrnM0Tn1AvGPLapS0iZGpTdFeNDsXtcWb22gzR8haRY+t\nRRxx8hbgTlQCpswFPjJlodUuVliGKiOhhaxhuxXPKqBwZMbQVKSgTjtaqwUcalHf+d39EQCPLOYY\nQoiVQb/wEyJRFPxCJIqCX4hEUfALkSgKfiESZdl/4Xc57fD1ZuGiTJtoZURFo3IeExwjdaVN5DAj\nNppMx+YRaShUy5iMRteRSFTxEVEJE1niWS12wIL9JVrBOhq5Qpicx5aRyZggzyeTUyOorNgjeucX\nIlEU/EIkioJfiERR8AuRKAp+IRKl9N1+UgkvnBFtsLL9X7ZLXSGveWznvh2YqLJAbLS+X8FaceGO\nOVMkWJILXWWSIBUdjyVVkTOFiw/wmozhU01Kf5HjNcmTzRQEds0VKXlGs7F6RO/8QiSKgl+IRFHw\nC5EoCn4hEkXBL0SiKPiFSJRSpT5HXK+sSMU6JstROY/JisSRSAKipeeI0Wg149jGat2559efqwSd\nfACgTZNEiiXURLOYfEVrMpJ5TD5sBepbNWgbBwCtaBK4/5WCyTtRXhWrTRjV8FtI/pPe+YVIFAW/\nEImi4BciURT8QiSKgl+IRFHwC5Eoi5L6zOw4gAkALQBNdx/uNofLSsF5wnEm/xBJhvoQvx5GLcVY\nVh+V7KhUybLwiI9h5mGxOoNcOSqS8cey+pgfxEblw3xoJiY9XtHsyIWvP5MHl6KG31Lo/H/m7q8s\nwXGEECWij/1CJMpig98B/MzMnjSzvUvhkBCiHBb7sf+D7n7KzK4B8KiZ/dbdH5t/h+xFYS8A1MJa\n7kKIsllUNLr7qez/MQA/AXBrzn32ufuwuw9X6W/ZhRBlUjj4zWytmQ29fhvAxwEcXirHhBDLy2I+\n9m8G8JNMcqgB+Hd3/69uk0K5jBaYzKdCZKN2m2VmVYktNIWSEuu2FLcn61Kkk7WFik1A0GqKLAcp\nchlnYXZsxI/IWLCAp1E/CmT85Sc/ds5F1oO2cyPGCtMWg/ORLmpd5OXeKBz87v4igJuWwAchxAqg\nHTghEkXBL0SiKPiFSBQFvxCJouAXIlFK7tVniMQSKilFRytY1JHJRsWypYr9eInLaCyLjXgS6FRM\n+qwQyY75yOTIsFArOR7PfIvPxbM7AwO9Psi5aK3TYn0NIxnQ2ROzBFqf3vmFSBQFvxCJouAXIlEU\n/EIkioJfiEQpebffEe16Ftplp3kxS1vzjdl4fgvdEi9kq1CVIMpYIbvsTD6gi7zwtmdMxaDb7PQx\nL7ylGK2PtwzrQa/vaK2KtHpTuy4hRDcU/EIkioJfiERR8AuRKAp+IRJFwS9EopQs9QEWaBHVoPbc\naqJVoC4dU4aMSDl1Vkiu3YxNQcYHa1HWIk7WrB77QXBEdRJJTcNQpgQqRjJZWqReY3BdVcjjqjip\n8Viw7VklXI9YnnVWxI9dHz2y+iNOCLEsKPiFSBQFvxCJouAXIlEU/EIkioJfiETpKvWZ2f0APglg\nzN13Z2MbAPwIwA4AxwHc4e5nF+MIS0ZaLe09I2WOZZVViIRZIXINy0rkLcCCGn40a7KgfEUeW5Qp\nSDPmCvrIrp64hN/SZgl2JpLnjGXoRY+7YPZpr/Tyzv9dALddNnY3gAPuvgvAgexvIcQVRNfgd/fH\nALx62fDtAPZnt/cD+NQS+yWEWGaKfuff7O6ns9tn0OnYK4S4glj0z3vd3c3iEixmthfAXgCoVbW/\nKMRqoWg0jprZFgDI/h+L7uju+9x92N2Hq1HpISFE6RQN/ocB3JndvhPAT5fGHSFEWfQi9f0AwEcA\nbDKzEQBfBfA1AA+Y2V0ATgC4o7fTGTx6vaHtqZb2E0PoQxebFeiRVFg2Yq3IWEZXYKuwx0WWlz1i\nluEWaVEt1oGKPC4mVbI1jqRPJh2yZ6zCZFF2TCoH59uYlNoKvFzI1dY1+N39s4Hpows4jxBilaEd\nOCESRcEvRKIo+IVIFAW/EImi4BciUUov4BnpSk7lq3J86GoLXiup68VauxXK3OucL8jqI7IRf8TE\nWokvn1Y7KCRKfG+xUzEZkBS6jLIIWa8+9pirxA8ufcYiXDWYZ5W46GerFRT9jD14A3rnFyJRFPxC\nJIqCX4hEUfALkSgKfiESRcEvRKKUKvU5YgmLJb9FMk9Y+LCbH0w2Yq+HQc0SUsukcFFKJzJPkyxW\nM2h3F8lJAO8J1ySFJxtkqSqW7z9b+xaT7FjbOlYINVh/C/wDgHbB7DwmA/L6o1FlWDJHvfqEEEVR\n8AuRKAp+IRJFwS9Eoij4hUiU0hN7ojZUrJ0Rgqq/dHeVUbRNVrAryxI6itS5A0JhAQCwdvBNoW3j\nxk254/V6nfgRn2zm4kQ8jzB5YTJ3vFolu+zN6dDGVIdaLb6MI0WI1SakyUy0jRo5JlWY8md6Kz5i\nXN+vdwVM7/xCJIqCX4hEUfALkSgKfiESRcEvRKIo+IVIlF7add0P4JMAxtx9dzZ2L4DPA3g5u9s9\n7v5I99NZKKWxllzhnILF/fi8hdfwo+2dSIIOO1Xb49fl69/2ztD23ve+N3e8r68vnNNoxDJg3YJM\nIQDP/PrXoe3Zw4dzx5nUd+HsTGijTV5ZPb5AYmMJOnUiy1WJrBjVLQT49R0lGbGEq6DD15LX8Psu\ngNtyxr/p7nuyfz0EvhBiNdE1+N39MQCvluCLEKJEFvOd/0tmdsjM7jez9UvmkRCiFIoG/7cB7ASw\nB8BpAF+P7mhme83soJkdZN+JhBDlUij43X3U3Vve2ZH4DoBbyX33ufuwuw9XSeMIIUS5FIpGM9sy\n789PA8jf2hVCrFp6kfp+AOAjADaZ2QiArwL4iJntQUdZOA7gC72esB3JKER6iWSS8FjdYC2XiCRT\nCaQX1narqBxp1fipGX0l3n89MfKH3PE2+cq1ccOG0NaoxFJfkzzwt9+4O3f86NEXyPFCE2i9Qyb1\nBTYmObK2YW2Sacee6irzMdDtWKJrLH/3Ttfgd/fP5gzft4BzCCFWIfoSLkSiKPiFSBQFvxCJouAX\nIlEU/EIkSukFPCO4RFFMLisL1mqsTTKziAmN/jjTbma2Gdqef+FY7vjERFyI8z3vuSW0zU6eDW1r\nBwZCWySZvnr2tXAOLY7JZC8iY0atvIxlCRLodUqvg4WIcB0qRI6MCngupIOd3vmFSBQFvxCJouAX\nIlEU/EIkioJfiERR8AuRKKVKfY5YDWGFEcNMOyLx0IKPTDoMiikCgCPfxjIBI6kJACrEx2uu2RLa\n3nXTzaGtv78/d3z7n/xJOKfVjKXD3x+NbVu3bQ1tv/zl/+SOs+zCKukn2GrH2YVG6kRUgutqjmTn\nVcnlEffIA7+s2HUVuMKKfqISxYt69QkhuqDgFyJRFPxCJIqCX4hEUfALkSil7vabAVEFX7qLGiRF\nVEgmhbEkC1ZBnOzO16r5u9HsXCyfo0rq9I2/HCfUYC7egW8hv+XVxPh4fK7xuCZgo9oIbVu3xLv9\nY6Nncsfd4117Wu+QqTdkVzx6qitsJ508aRXyfslagBmZZ8HjpnPaUS1M7fYLIbqg4BciURT8QiSK\ngl+IRFHwC5EoCn4hEqWXdl3bAXwPwGZ08nL2ufu3zGwDgB8B2IFOy6473J3oUx1JI2qTxGWSfKqV\nOFmCtWNqBgk6AJeb2kFyCVOhWCLLa+fPh7a37doV2sZOj4S2tWvX5o5PnIulvsHBwdD25q2xnHfD\njreEts/81V/mjv/Hj38YzpmZnQ5tTApmKnGjHtS6I0UBmY0ljNEyfeQ6iGr11di5Qpm791qBvbzz\nNwF8xd1vBPB+AF80sxsB3A3ggLvvAnAg+1sIcYXQNfjd/bS7P5XdngBwBMA2ALcD2J/dbT+ATy2X\nk0KIpWdB3/nNbAeAmwE8DmCzu5/OTGfQ+VoghLhC6Dn4zWwQwIMAvuzul3xZ9c4X9twvG2a218wO\nmtlBVpBBCFEuPQW/mdXRCfzvu/tD2fComW3J7FsAjOXNdfd97j7s7sNsg04IUS5dg986NaruA3DE\n3b8xz/QwgDuz23cC+OnSuyeEWC56yer7AIDPAXjWzJ7Oxu4B8DUAD5jZXQBOALij24HMDP31hScS\nRjXy2iS7jb2uVcP6Z0CtkV8DDwBardngTLGM8xYih11zzdWhbXDNmtDWYFlnge5VI9Jne2YqtJ05\n9VJoGzme3xoMAKYm82XM2YsXwjlVVh6PZAPWLH4++4JrZ6Y1F5+L1dtjkjSt5RiaQjm4TaQ+Vhuy\nV7pGorv/ArHU/tFFeyCEWBH0Cz8hEkXBL0SiKPiFSBQFvxCJouAXIlFKLeC5Zs0Abrnp3bk2Vswy\nytCbm42lvpk5UiiyEreFujCVXwATANrtfKnP2vGcW27ZE9r27LkptF01NBTaRk8cD20jI/kZfzfc\ncEM4Z5wU9zx58mRoO/bC86HteCADbr56YzinL8jAA7jE1tfoC23toBdWk7Trajt7TyRZc0TPq5N2\ndJGYxuTBCJbNejl65xciURT8QiSKgl+IRFHwC5EoCn4hEkXBL0SilCr11apVbHjTVbm2qak4s6zZ\nzJf0miSrz6NeZgCOE/nq3PmJ0DY9NZk7fv1118bHO/daaDt27Pehbf26uKjm+bH8PngA8L733Zo7\nvv3a2MdaPZY+H3zwodB29mzc4+/t73h77vjmN8dSX4Ol9ZECmGtIBuTcXJC9R6p+tki/uyrx0Vnh\nTyL1eSBHMvk7kgEf/d+nwjmXo3d+IRJFwS9Eoij4hUgUBb8QiaLgFyJRSt3tNwBR7kZ9bbxje/Zs\nfhewmen8RBsAGH0l7hw2R1SCdevy1QgAGBrMb4UFj5Mpfvd8XOfu1Kl41/6qoXi3f/2aRmj71RNP\n5o4feva5+FxXxY95mqzxuXNxu7GBtQO542sH14Vz0Izr6lVJzToiBKBazV8rI/XxmOpQpBYfANQK\n7PazOVG9QFZH8HL0zi9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hE6Sr1mdl2AN9DpwW3A9jn7t8y\ns3sBfB7Ay9ld73H3R/jRHO1Aztm4MU74GApko+sH1odzZlqx5PHK2ViiYkkdCCSZ6YtxMhBLPqrW\nYomQyTxE2cLJkdO5461WXNMwSpwCeE24Rl+cEPSHMy/njm/fvjWcM3MhTu5iet66dbF82GgEUp/F\nx2uQWoKtZryOLHmnTtrURVIfKVsIC2oJLqTsXy86fxPAV9z9KTMbAvCkmT2a2b7p7v+8gPMJIVYJ\nvfTqOw3gdHZ7wsyOANi23I4JIZaXBX3nN7MdAG4G8Hg29CUzO2Rm95tZ/BlcCLHq6Dn4zWwQwIMA\nvuzu5wF8G8BOAHvQ+WTw9WDeXjM7aGYHL5Ka+EKIcukp+M2sjk7gf9/dHwIAdx9195Z3diu+AyC3\nhIy773P3YXcfXjMQN1cQQpRL1+C3TqbAfQCOuPs35o1vmXe3TwM4vPTuCSGWi152+z8A4HMAnjWz\np7OxewB81sz2oCP/HQfwha5HcsACWePVV/KlIQCoBzXmqn1vCud0PqzkMzU1Hdra5PWwHbR4qiI+\n19o1QSYggHojntfXIO3LKnH2WyPQepjUxzLBmkTaarViP2aD1mbHTrwUzol8B4AaqblXbfSHNqvm\nrzFR3uBkrSJZDgCqxMdImgOAWiD5Mgm2QrIce6WX3f5fIF8+7KLpCyFWM/qFnxCJouAXIlEU/EIk\nioJfiERR8AuRKKUW8KxUKqH0xWQND9KbWAFJr8XyT6sdSzmNvviHSNPN/F8ozlE5LP5VIys8SSWl\nWny+SiA3sQwxEBmqWiNyJFmrwaH89W8SeXD6fPx8NkjrqrHxc6GtUs/PCB0ix5sjPjrNm4szIL0d\nz6v35a9V1eLnOcoSNCI3Xo7e+YVIFAW/EImi4BciURT8QiSKgl+IRFHwC5EopUt9/YHUN3721XBe\nJHhE0hsAjJwcCW0zs7GU84533Bjatm3dnDs+NXkhnDM+Ph7a5lpxwUqfJYUz6dOWLxFWSP+5KFsR\nAJpEcrR2fEyfyrf1BQU1AcDqcX/CNumt12rGUuW5Cxdzx/vXbArnVBpDoY3JuuOT8fN5ZnQstFkg\nH25+c+xjtZp/rtkmK+96KXrnFyJRFPxCJIqCX4hEUfALkSgKfiESRcEvRKKUKvXNNZsYfTlf8hhY\nGxe6bAcZf4cPPxfOOTlyJrQNrYulnIsXJ0Pbmv78DLEbrtsZztn11utDG8tkPHv2bGi7cCH2MSrG\neTGQvABgLuifCAAeK4601+DUdL4U1d8fZ1sO1GNbhWTMrVsf94sZG8vvXThJ5NnJiXitXjo9GtpO\nnY7lvNk5kolp+Y9t06a4QO3Quvxr8fxEfG284bw931MI8UeFgl+IRFHwC5EoCn4hEkXBL0SidN3t\nN7N+AI8B6Mvu/2N3/6qZbQDwIwA70GnXdYe7x1vUnWOhESR2sISPvoH8nc09e24O57zzXbEfbVI8\nrzkX78CfOH48d/zIkSPhnMHBOFll9+7doW3nzlhBWL/+qtA2MxPUGZyLd/SZbWouTp6anctvyQUA\ns7P5NnauKm2VFj8vrVbsx8mgNdvJMyfCOS+diJPCXrsQJ+9Mkx19J4+tr5FfC3FmZiKcs25yTe44\nW9/L6eWdfwbAn7v7Tei0477NzN4P4G4AB9x9F4AD2d9CiCuErsHvHV4XD+vZPwdwO4D92fh+AJ9a\nFg+FEMtCT9/5zayadegdA/Couz8OYLO7v/4LijMA8pPdhRCrkp6C391b7r4HwLUAbjWz3ZfZHUHx\ndzPba2YHzezg1HTcGlsIUS4L2u1393MAfg7gNgCjZrYFALL/c3/b6O773H3Y3YcHyE87hRDl0jX4\nzexqM7squz0A4GMAfgvgYQB3Zne7E8BPl8tJIcTS00tizxYA+82sis6LxQPu/p9m9ksAD5jZXQBO\nALij24HMDBXLf705PxHLGn1r8qW+rVu2h3PWDMVyGJPf5gKJCgD+9EMfyh2fDeQ1gEsvkRwGAKdO\nnQptjcZbQ1urlS83TU7GCR/NYA4AVIK2UECckAIA64L1Z23IWEuugf5YCn7x90dD2/og6SdK0gKA\nudl4PSZejCXCdjuWI6u1+H12w8b8tZqejhOMxsfzVfUmqTF4OV2D390PAXiDoO7u4wA+2vOZhBCr\nCv3CT4hEUfALkSgKfiESRcEvRKIo+IVIFOv8OK+kk5m9jI4sCACbALxS2slj5MelyI9LudL8eIu7\nX93LAUsN/ktObHbQ3YdX5OTyQ37ID33sFyJVFPxCJMpKBv++FTz3fOTHpciPS/mj9WPFvvMLIVYW\nfewXIlFWJPjN7DYze97MjprZitX+M7PjZvasmT1tZgdLPO/9ZjZmZofnjW0ws0fN7IXs/7gH1fL6\nca+ZncrW5Gkz+0QJfmw3s5+b2W/M7Dkz+5tsvNQ1IX6UuiZm1m9mvzKzZzI//iEbX9r1cPdS/wGo\nAjgGYCeABoBnANxYth+ZL8cBbFqB834YwC0ADs8b+ycAd2e37wbwjyvkx70A/rbk9dgC4Jbs9hCA\n3wG4sew1IX6UuiYADMBgdrsO4HEA71/q9ViJd/5bARx19xfdfRbAD9EpBpoM7v4YgFcvGy69IGrg\nR+m4+2l3fyq7PQHgCIBtKHlNiB+l4h2WvWjuSgT/NgAvzft7BCuwwBkO4Gdm9qSZ7V0hH15nNRVE\n/ZKZHcq+Fiz714/5mNkOdOpHrGiR2Mv8AEpekzKK5qa+4fdB7xQm/QsAXzSzD6+0QwAviFoC30bn\nK9keAKcBfL2sE5vZIIAHAXzZ3c/Pt5W5Jjl+lL4mvoiiub2yEsF/CsD8+lvXZmOl4+6nsv/HAPwE\nna8kK0VPBVGXG3cfzS68NoDvoKQ1MbM6OgH3fXd/KBsufU3y/FipNcnOveCiub2yEsH/BIBdZnad\nmTUAfAadYqClYmZrzWzo9dsAPg7gMJ+1rKyKgqivX1wZn0YJa2JmBuA+AEfc/RvzTKWuSeRH2WtS\nWtHcsnYwL9vN/AQ6O6nHAPzdCvmwEx2l4RkAz5XpB4AfoPPxcQ6dPY+7AGxEp+3ZCwB+BmDDCvnx\nbwCeBXAou9i2lODHB9H5CHsIwNPZv0+UvSbEj1LXBMC7Afw6O99hAH+fjS/peugXfkIkSuobfkIk\ni4JfiERR8AuRKAp+IRJFwS9Eoij4hUgUBb8QiaLgFyJR/g9zitVMwdRzBQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x1873ce9b0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"x_train_float = x_train.astype(float)\n",
"plt.imshow(x_train_float[i])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building the Nearest Neighbour classifier\n",
"We are now ready to build our classifier. We do this through creating a class that contains:\n",
"\n",
"* a constructor (the `__init__` method) \n",
"* A `train` method\n",
"* A `predict` method \n",
"\n",
"After creating this it is easy to use the classifier. We will create a NearestNeighbour object and then train and predict with the object's `train` and `predict` methods."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class NearestNeighbour(object):\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def train(self, x_train, y_train):\n",
" \"\"\" x_train is shape NxD, where N is the number of images rows and D=3x32x32. \n",
" Asssumes x_train is flattened out - a 2d array with one row per image \n",
" This function remembers the training data, and that is all. \"\"\"\n",
" self.x_train = x_train\n",
" self.y_train = y_train\n",
" \n",
" def predict(self, x_test):\n",
" \"\"\"\n",
" Compare each image in the test set with every image in the training set.\n",
" Asssumes x_test is flattened out - a 2d array with one row per image \n",
" \"\"\"\n",
" y_predicted_classes = []\n",
" for i in x_test:\n",
" differences = abs(self.x_train - i)\n",
" differences_rowsums = np.sum(differences, axis = 1) # 1D array\n",
" closest_image_index = differences_rowsums.argmin()\n",
" y_predicted_classes.append(self.y_train[closest_image_index])\n",
" \n",
" # track progress\n",
" progress = len(y_predicted_classes)\n",
" if (progress % 100 == 0):\n",
" print('Progress: %f' % (progress / x_test.shape[0]))\n",
" \n",
" \n",
" return y_predicted_classes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The Nearest Neighbour classifier requires the data is flattened out - one row per image in a 2D array. "
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x_train_rows = x_train.reshape((x_train.shape[0], 3*32*32))\n",
"x_test_rows = x_test.reshape((x_test.shape[0], 3*32*32))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can train our model and make some predictions for our test set. It takes a while to compute predictions on the test set, so we'll only test half the observations in the test set to get a general idea of the classification accuracy. "
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"nn = NearestNeighbour()\n",
"nn.train(x_train_rows, y_train)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": true
},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-29-ad2c4f5aabd3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpredictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test_rows\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5000\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-27-58ca4c3f2ec1>\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x_test)\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0my_predicted_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx_test\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mdifferences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx_train\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0mdifferences_rowsums\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdifferences\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# 1D array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0mclosest_image_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdifferences_rowsums\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"predictions = nn.predict(x_test_rows[0:5000])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can evaluate our accuracy by comparing our predictions against the labelled test set. "
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.25280000000000002"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred_array = np.array(predictions)\n",
"sum(pred_array == y_test[0:5000]) / 5000"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"We scored 25% accuracy using this classifier, which is better than guessing randomly (10%), but not by much! We could improve things a little by generalising to a k-nearest neighbours classifier, which instead of finding the closest image to our test set, instead finds the k closest images and looks for a consensus in their labels. That would certainly improve our accuracy, but in reality there are classifiers that are much more powerful than this one and we really should just use them. "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment