Skip to content

Instantly share code, notes, and snippets.

@EmilienDupont
Last active September 12, 2020 18:44
Show Gist options
  • Save EmilienDupont/99c7127dedb921a5a1f96d37d23c0d4b to your computer and use it in GitHub Desktop.
Save EmilienDupont/99c7127dedb921a5a1f96d37d23c0d4b to your computer and use it in GitHub Desktop.
mnist chicken code

Code for mnist chicken experiments. Note that this code was written a long time ago and is not maintained.

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def get_mnist_dataloaders(batch_size=128):
"""MNIST dataloader with (32, 32) images."""
all_transforms = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()
])
train_data = datasets.MNIST('../ml-sandbox/data', train=True, download=True,
transform=all_transforms)
test_data = datasets.MNIST('../ml-sandbox/data', train=False,
transform=all_transforms)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
def get_fashion_mnist_dataloaders(batch_size=128):
"""FashionMNIST dataloader with (32, 32) images."""
all_transforms = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor()
])
train_data = datasets.FashionMNIST('../ml-sandbox/fashion_data',
train=True, download=True,
transform=all_transforms)
test_data = datasets.FashionMNIST('../ml-sandbox/fashion_data',
train=False, transform=all_transforms)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Load the trained model\n",
"import torch\n",
"from models import CNN\n",
"\n",
"cnn = CNN()\n",
"cnn.load_state_dict(torch.load('mnist_cnn.pt', map_location=lambda storage, loc: storage))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Get test loader\n",
"from dataloaders import get_mnist_dataloaders\n",
"_, test_loader = get_mnist_dataloaders(1)"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [],
"source": [
"# Calculate confidence scores, only for correctly classified digits\n",
"from torch.autograd import Variable\n",
"num_scores = 1e5\n",
"label_probs = []\n",
"low_conf = []\n",
"for i, (img, label) in enumerate(test_loader):\n",
" probs = cnn(Variable(img))\n",
" _, idx = torch.max(probs, 1)\n",
" # Only append if label was correctly predicted\n",
" if idx.data.numpy()[0] == label.numpy()[0]:\n",
" label_prob = probs[0, label[0]]\n",
" label_prob = label_prob.data[0]\n",
" label_probs.append(label_prob)\n",
" if label_prob < .8:\n",
" low_conf.append((img, label[0], label_prob))\n",
" if i >= num_scores:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x116332490>]"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAGKhJREFUeJzt3XtwnfV95/H3V3dZliXbkuWLZGyC\nwYiLIfES0jTh0kCMm0CbpI2ZTtN2aNxmQre7G9KBTZZN2dnpZCe7JNkh2bhTmmknhQW6bBzGs5AY\nUhpawCbGxpcYC9vYlmVLsu63c3TO+e4f55E5ErZ1bB2dR+c5n9eMRs/l53O+P/no459/z83cHRER\niZaSsAsQEZHcU7iLiESQwl1EJIIU7iIiEaRwFxGJIIW7iEgEKdxFRCJI4S4iEkEKdxGRCCoL640b\nGhp81apVYb29iEhBeuONN7rdvXG6dqGF+6pVq9i5c2dYby8iUpDM7N1s2mlaRkQkghTuIiIRpHAX\nEYkghbuISAQp3EVEImjacDezx82s08z2nme/mdl3zazNzPaY2QdzX6aIiFyMbEbuPwQ2XGD/XcCa\n4Gsz8P2ZlyUiIjMx7Xnu7v6yma26QJN7gL/z9PP6XjWzejNb5u4dOapRQpZMOfFEingiRSyRZDzl\npFJOyp1k8D3l6XbJlOMOSZ9YTn9PerD97LKTTPFeG8/4s1Pb+HuvkwqeCjn18ZATq45PWb/w/smv\nceE/m29OOG8cRn9De9hnSH+5v3F1E+ta6mf1PXJxEdMK4HjG+olg2/vC3cw2kx7ds3Llyhy8tWQj\nlkjSORCjo3+Mjv5ROvrH6B2JMzCa4MxQjKFYgtHxJKPxJGPjSUbHk0GQpwM9kdJzdiW6zPL/nksW\nVBVEuGfN3bcAWwDWr1+vxMiBkXiCd8+M0DkY48xQjFMDY3T0jdHRP8apgVFO9Y/RPRR/35+rKC1h\nQXUZi2oqWFBVzvzKMhrnV1JdUUplWQmVZenvFcFy+nt6vbzUKDGjtCT9vaTEKDWjxHhvuYTJbc4u\nZ7QJ2p2zzaTX5732ZhD8Mk78Uk78blqw4b31if02aZ0p+8/V5nyvmW9hBE/6fcPqseRKLsK9HWjJ\nWG8OtkkOJZIpTvSOcqR7mMPdwxztHuZAxwA73+19X9u66nKW1VWxtK6K61bUsXRBNcvqqlhWX8Wy\nuiqaFlRRW1UeQi9EJF9yEe5bgfvN7Engw0C/5ttnzt053jPK468c4eVDXRw7MzJpeqS2qozLG2r4\n0q0fYO3SWpoXVrNwXgVNC6qoqQztlkEiMkdMmwJm9gRwK9BgZieA/wyUA7j7/wK2ARuBNmAE+KPZ\nKjbqjveMsPtEH8+8cYJDp4do7xultMS47aolbLhmKasaari8oYbVDTUsqqnQf51F5LyyOVvm3mn2\nO/DlnFVUZA6eGuT7P29jT3s/h7uGAVhRX80NK+v544+t5va1S7hscU3IVYpIodH/30P07Z+9zfde\neoeq8hI+fPliPr++heub61m/aiHlpbp4WEQuncI9JL881su3f3aIT1y9hG9+9noWz68MuyQRiRCF\ne565O79o6+Zrz+5lSW0l39l0ow6AikjOKVXy6JfHevnLn+xn9/E+qstLefTz6xTsIjIrlCx5kkim\neODp3RzuGubBu9byRx9dRWVZadhliUhEKdzz5IX9pzncNcz3fu+DbLxuWdjliEjEKdxnWSyR5IGn\n97DtrQ5aFlXzyWuWhl2SiBQBhfsse+Qn+/nJ7pP87vpmvnTrFZSW6MIjEZl9CvdZkkw533upjR+9\ndow/ueVyHrrr6rBLEpEionCfJd/dfojvbD/Ex69s5Kt3XhV2OSJSZBTus+Cf3u7i+z9/hztbm/jB\n739I94ARkbzTNe6z4Af/9A5NdZX8t89dr2AXkVAo3HOof2Scv33lCDuO9rDx2mXUz6sIuyQRKVKa\nlsmhx185wne2H6KirIS7dC67iIRI4Z4jyZTz3J6T3NBSzz9+6dd0yqOIhErTMjny8I/38k7XMHde\n06RgF5HQKdxz4MnXj/Gj145xR2sTf/LxD4RdjoiIwn2mjnQP8x+ffYvayjL+02+2atQuInOC5txn\noHsoxgNP76ayrJQXH7iVxlo9cENE5gaF+yUaG0+yacurtHUO8Y1PtyrYRWROUbhfgrHxJJ/6n7+g\nrXOIb372Oj7/b1aGXZKIyCSac79I48kUf/HMHgW7iMxpCveL9K3nD7J190keuPNKBbuIzFkK94sw\nnkzxxOvH+PS65dx/+5qwyxEROS+F+0X4l3fOMDCW4NPX69YCIjK3Kdwvwj+89i4L55Xz8Ssbwy5F\nROSCFO5ZGo4l2H6gk89+sJmq8tKwyxERuSCFe5Z2vttLIuXccpVG7SIy9yncszCeTPGjV9+lvNT4\n0GULwy5HRGRaCvcsfP3Zvbyw/zRfufMq5lXoui8RmfsU7tNIJFM8t+ckv/OhZv70Ft3xUUQKg8J9\nGvs7BhiOJ3WGjIgUFIX7NF4/0gPATasXhVyJiEj2sgp3M9tgZgfNrM3MHjzH/svMbLuZ7TGzn5tZ\nc+5LDcdrR3q4bPE8mhZUhV2KiEjWpg13MysFHgPuAlqBe82sdUqzbwF/5+7XA48Af5XrQsMwNp5k\nx9EeblqlUbuIFJZsRu43AW3uftjd48CTwD1T2rQCLwbLL51jf0H665cP0zcyzm/fuCLsUkRELko2\n4b4COJ6xfiLYlmk38Jlg+beBWjNbPPPywrXv5ABXLJnPr13REHYpIiIXJVcHVB8AbjGzXcAtQDuQ\nnNrIzDab2U4z29nV1ZWjt549PSNxFtdUhF2GiMhFyybc24GWjPXmYNtZ7n7S3T/j7jcCXwu29U19\nIXff4u7r3X19Y+PcP7WwZzjOIoW7iBSgbMJ9B7DGzFabWQWwCdia2cDMGsxs4rUeAh7PbZnh6B2O\ns1DhLiIFaNpwd/cEcD/wPHAAeMrd95nZI2Z2d9DsVuCgmb0NNAH/dZbqzZtUyunVtIyIFKisbpTi\n7tuAbVO2PZyx/AzwTG5LC1f/6Dgph4XzFO4iUnh0hep59IzEATTnLiIFSeF+Hp0DMQAWz1e4i0jh\nUbifx6HOQQDWLKkNuRIRkYuncD+PX50apK66nKYFlWGXIiJy0RTu53Cke5gf72rnqqW1mFnY5YiI\nXDSF+xTxRIovPP4aJSXGFz92edjliIhcEj0zbopDnYMc7xnl0c+v447WprDLERG5JBq5T/GrjvSB\n1OtW1IVciYjIpVO4T3GgY4DKshJWLa4JuxQRkUumcM+QSKZ4Yf9p1jXXU1aqH42IFC4lWIb//tO3\nOdYzwn0fWx12KSIiM6JwD4wnUzz+iyN8et1y7tSBVBEpcAr3wKHTQ8QSKT5x9RKd2y4iBU/hDrg7\nf/3PhwG4vrk+5GpERGZO4Q6cGhjj2V3ttCyqZtXieWGXIyIyYwp3oHd4HICvbbxaUzIiEgkKd2Bw\nLB3utVXlIVciIpIbCndgYCwBwAKFu4hEhMKdzJG7brUjItGgcAcGRhXuIhItCndgMJiW0Zy7iESF\nwh0YGBunqryEijL9OEQkGpRmpEfuOpgqIlGicCc9ctd8u4hESdGHezyRYm/7AEtqq8IuRUQkZ4o6\n3FMp56vP7OZYzwibP67npYpIdBTtXEQy5Tz2Uhs/fvMkD9x5JbetXRJ2SSIiOVOU4b77eB9fePx1\n+kfH2XjdUr582xVhlyQiklNFGe5//+q79I+O819+61o+c+MK3SxMRCKn6MLd3dl+4DS/dcNyfv/m\ny8IuR0RkVhTdAdXuoTi9I+Pc0KKHcohIdBVduB/vHQGgZZEeyiEi0VV84d6jcBeR6CvacG9eWB1y\nJSIisyercDezDWZ20MzazOzBc+xfaWYvmdkuM9tjZhtzX2puHD0zQsP8SuZVFN2xZBEpItOGu5mV\nAo8BdwGtwL1m1jql2deBp9z9RmAT8L1cF5orBzoGuHpZbdhliIjMqmxG7jcBbe5+2N3jwJPAPVPa\nOLAgWK4DTuauxNyJJ1IcOj1E67IF0zcWESlg2cxNrACOZ6yfAD48pc03gBfM7M+AGuATOakux450\nDxNPpmhdrnAXkWjL1QHVe4EfunszsBH4ezN732ub2WYz22lmO7u6unL01tnr6B8FoHmhzpQRkWjL\nJtzbgZaM9eZgW6b7gKcA3P1fgSqgYeoLufsWd1/v7usbGxsvreIZ6BqMAdA4vzLv7y0ikk/ZhPsO\nYI2ZrTazCtIHTLdOaXMM+A0AM7uadLjnf2g+je6hOAANtRUhVyIiMrumDXd3TwD3A88DB0ifFbPP\nzB4xs7uDZl8Bvmhmu4EngD90d5+toi9V12CMmopSnQYpIpGXVcq5+zZg25RtD2cs7wc+mtvScq97\nKEZDraZkRCT6iuoK1a7BmObbRaQoFFW4dw/FaFC4i0gRKKpw7xqK0ahpGREpAkUT7vFEir6RcY3c\nRaQoFE24nxlOn+Ou0yBFpBgUTbh3D6bPcdcBVREpBkUT7l1DYwA6FVJEikLRhHt7b/q+Mhq5i0gx\nKIpw7xmO860X3uYDjTUsq6sKuxwRkVlXFNfhb3n5MINj4zz9px+hrLQo/j0TkSIX+aRzd7a91cEt\nVzZyZZOewCQixSHy4f5Wez/Heka4/eqmsEsREcmbSId7KuX82yd2saS2kruuXRp2OSIieRPpcH/l\nnW6Onhnh659q1ZWpIlJUIh3u2w90Ul1eyiev0ZSMiBSXSIf7rmO9rGupo7KsNOxSRETyKrLhPhpP\nsu/kADeuXBh2KSIieRfZcH9h/ykSKedja973nG4RkciLbLj/v72nWF5Xxc2rF4ddiohI3kU23Dv6\nx/jAkvmUlFjYpYiI5F1kw/3McIzFNbp3u4gUp8iGe89QnEU1OrddRIpTJMN9bDzJcDzJ4vkauYtI\ncYpkuPcMp5+6tEjTMiJSpCIZ7meGFO4iUtwiGe5/+y9HAHRAVUSKVuTCvXsoxrO72rluRR3XrqgL\nuxwRkVBELtx/frALd/irz1xHVbnuKSMixSly4b7/5ADV5aVcs3xB2KWIiIQmcuHeNxpnUU0FZroy\nVUSKV+TCfWB0nLrq8rDLEBEJVeTCvW9E4S4iErlw7x8dp36ewl1Eilvkwr1P4S4ikl24m9kGMzto\nZm1m9uA59j9qZm8GX2+bWV/uS52eu9M/Os4CTcuISJErm66BmZUCjwF3ACeAHWa21d33T7Rx93+f\n0f7PgBtnodZpjY2niCdS1FfrylQRKW7ZjNxvAtrc/bC7x4EngXsu0P5e4IlcFHex+kbT95TRtIyI\nFLtswn0FcDxj/USw7X3M7DJgNfDiefZvNrOdZrazq6vrYmud1pHuYQCW11fn/LVFRApJrg+obgKe\ncffkuXa6+xZ3X+/u6xsbG3P81rC3vR+Aa3V1qogUuWzCvR1oyVhvDradyyZCmpIB2Ns+wPK6KhbP\n1xOYRKS4ZRPuO4A1ZrbazCpIB/jWqY3MbC2wEPjX3JaYvf0dA7Qu150gRUSmDXd3TwD3A88DB4Cn\n3H2fmT1iZndnNN0EPOnuPjulXlgskeRI9zBrl9aG8fYiInPKtKdCArj7NmDblG0PT1n/Ru7Kunht\nnUMkU85VCncRkehcoXrw1CCARu4iIkQs3CtKS1jdUBN2KSIioYtMuB84NcgVS+ZTVhqZLomIXLLI\nJOHBUwOakhERCUQm3LsGYyyrrwq7DBGROSES4Z5Ipkg5VJbpgdgiIhCRcI8nUwBUlEWiOyIiMxaJ\nNIwngnDXwVQRESAi4R4Lwr2yPBLdERGZsUikoUbuIiKTRSINJ0bumnMXEUmLRBrGEunbx+tsGRGR\ntEiE+8S0TKVG7iIiQMTCXdMyIiJpkUjDmEbuIiKTRCINNXIXEZksEmmoK1RFRCaLRBq+d0BVZ8uI\niEBEwn3iVEiN3EVE0iKRhjoVUkRkskikoa5QFRGZLBJpGNO9ZUREJolEGmpaRkRkskikYTyZoqK0\nBDMLuxQRkTkhEuEeG09pvl1EJEMkEnFwbJyaSp3jLiIyIRLhfrx3hJaF88IuQ0RkzohGuPeMsnKR\nwl1EZELBh3s8keJk/ygtCncRkbMKPtzb+0ZxR+EuIpKh4MN9b3s/AFc11YZciYjI3FHw4b7rWB9V\n5SWsXaZwFxGZUPDh/lZ7H9cur6Nctx4QETkrq0Q0sw1mdtDM2szswfO0+V0z229m+8zsH3Jb5vmd\nGY7TtKAqX28nIlIQyqZrYGalwGPAHcAJYIeZbXX3/Rlt1gAPAR91914zWzJbBU81HEvoAiYRkSmy\nGbnfBLS5+2F3jwNPAvdMafNF4DF37wVw987clnl+I7EkNZXT/hslIlJUsgn3FcDxjPUTwbZMVwJX\nmtkrZvaqmW3IVYEX4u4MxxPMV7iLiEySq1QsA9YAtwLNwMtmdp2792U2MrPNwGaAlStXzvhNR8eT\npByN3EVEpshm5N4OtGSsNwfbMp0Atrr7uLsfAd4mHfaTuPsWd1/v7usbGxsvteazhmIJAGoqNOcu\nIpIpm3DfAawxs9VmVgFsArZOafN/SY/aMbMG0tM0h3NY5zkNx9IPxtbIXURksmnD3d0TwP3A88AB\n4Cl332dmj5jZ3UGz54EzZrYfeAn4qrufma2iJwxPjNwV7iIik2SViu6+Ddg2ZdvDGcsO/IfgK28m\nwl0HVEVEJivoyzqH4xq5i4icS0GH+1Aw5z5fFzGJiExS0OE+MS0zr0IjdxGRTAUd7u29o5SWGI21\nlWGXIiIypxR0uB/uHmLlonm6I6SIyBQFnYqHu4ZZ3VATdhkiInNOwYZ7IpniSPcwlyvcRUTep2DD\n/dXDPcQSKT502cKwSxERmXMKNtxf2H+KeRWl3LY2b7eOFxEpGAUb7kfPjLBmyXyqynWOu4jIVAUb\n7qf6R/V4PRGR8yjgcB9jWZ3CXUTkXAoy3IdjCQbGEiytqw67FBGROakgw/3UwBgAS+t0ZaqIyLkU\nZLif7g/CfYFG7iIi51KQ4d4RhLvm3EVEzq0gw/29aRmFu4jIuRRkuHf0j1I/r1znuIuInEdBhvup\n/hhLdY67iMh5FWS4d/SPakpGROQCCi7c+0fGOXhqkLVLF4RdiojInFVw4b79V6dJpJxPXtMUdiki\nInNWwYV7bVU5d7Q2sa65PuxSRETmrIJ7svQdrU3c0apRu4jIhRTcyF1ERKancBcRiSCFu4hIBCnc\nRUQiSOEuIhJBCncRkQhSuIuIRJDCXUQkgszdw3ljsy7g3Uv84w1Adw7LKQTqc3FQn4vDTPp8mbs3\nTtcotHCfCTPb6e7rw64jn9Tn4qA+F4d89FnTMiIiEaRwFxGJoEIN9y1hFxAC9bk4qM/FYdb7XJBz\n7iIicmGFOnIXEZELKLhwN7MNZnbQzNrM7MGw68kVM3vczDrNbG/GtkVm9lMzOxR8XxhsNzP7bvAz\n2GNmHwyv8ktnZi1m9pKZ7TezfWb258H2yPbbzKrM7HUz2x30+S+D7avN7LWgb//bzCqC7ZXBeluw\nf1WY9V8qMys1s11m9lywHun+ApjZUTN7y8zeNLOdwba8fbYLKtzNrBR4DLgLaAXuNbPWcKvKmR8C\nG6ZsexDY7u5rgO3BOqT7vyb42gx8P0815loC+Iq7twI3A18O/j6j3O8YcLu7rwNuADaY2c3AN4FH\n3f0KoBe4L2h/H9AbbH80aFeI/hw4kLEe9f5OuM3db8g47TF/n213L5gv4CPA8xnrDwEPhV1XDvu3\nCtibsX4QWBYsLwMOBss/AO49V7tC/gJ+DNxRLP0G5gG/BD5M+oKWsmD72c858DzwkWC5LGhnYdd+\nkf1sDoLsduA5wKLc34x+HwUapmzL22e7oEbuwArgeMb6iWBbVDW5e0ewfAqYeL5g5H4OwX+/bwRe\nI+L9DqYo3gQ6gZ8C7wB97p4ImmT262yfg/39wOL8Vjxj3wb+AkgF64uJdn8nOPCCmb1hZpuDbXn7\nbBfcM1SLlbu7mUXy1CYzmw/8I/Dv3H3AzM7ui2K/3T0J3GBm9cCzwNqQS5o1ZvYpoNPd3zCzW8Ou\nJ89+3d3bzWwJ8FMz+1Xmztn+bBfayL0daMlYbw62RdVpM1sGEHzvDLZH5udgZuWkg/1H7v5/gs2R\n7zeAu/cBL5Gelqg3s4nBVma/zvY52F8HnMlzqTPxUeBuMzsKPEl6auY7RLe/Z7l7e/C9k/Q/4jeR\nx892oYX7DmBNcKS9AtgEbA25ptm0FfiDYPkPSM9JT2z/QnCE/WagP+O/egXD0kP0vwEOuPv/yNgV\n2X6bWWMwYsfMqkkfYzhAOuQ/FzSb2ueJn8XngBc9mJQtBO7+kLs3u/sq0r+vL7r77xHR/k4wsxoz\nq51YBu4E9pLPz3bYBx0u4SDFRuBt0vOUXwu7nhz26wmgAxgnPd92H+m5xu3AIeBnwKKgrZE+a+gd\n4C1gfdj1X2Kff530vOQe4M3ga2OU+w1cD+wK+rwXeDjYfjnwOtAGPA1UBturgvW2YP/lYfdhBn2/\nFXiuGPob9G938LVvIqvy+dnWFaoiIhFUaNMyIiKSBYW7iEgEKdxFRCJI4S4iEkEKdxGRCFK4i4hE\nkMJdRCSCFO4iIhH0/wFmefIpm8xTwgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10bb532d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Plot lowest confidence scores\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"plt.plot(np.sort(label_probs)[:500])"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(761, 767)\n"
]
},
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x10b2188d0>"
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFh1JREFUeJzt3Xts1dWWB/DvopbyqoXSUsr7LfLw\nIlaECIqYaxhiooaJr2A0IZebiSaa3DFRJ44a5w/vZMT4x8QJDuRyb1RA0YgjOheNxogKFJSH4CCP\nAuXVKvKGQsuaP86PpJDfWj09j9+h7O8nIZzudXZ/u7929dfzW2fvLaoKIgpPp0IPgIgKg8lPFCgm\nP1GgmPxEgWLyEwWKyU8UKCY/UaCY/ESBYvITBeqabDqLyEwArwMoAvDfqvqK9/yKigodMmRIbKy5\nudnsd+zYsdj2kpISs0+PHj28oRBdlerq6vDLL79IOs/NOPlFpAjAfwL4PYB6AOtEZIWqbrX6DBky\nBOvWrYuNHTlyxDzWxx9/HNs+dOhQs8/UqVPNmEckrfOWE95bqzMdRyZv107ya6b8qqmpSfu52fzZ\nPwnADlXdparnACwBcE8Wn4+IEpRN8vcHsK/Vx/VRGxF1AHm/4Sci80SkVkRqGxsb8304IkpTNsm/\nH8DAVh8PiNouoaoLVLVGVWsqKyuzOBwR5VI2yb8OwEgRGSoinQE8CGBFboZFRPmW8d1+VW0WkScA\n/C9Spb5FqvpjGv1i29evX2/2+fbbb2Pbz549a/YpLi42Y1a5EQCqqqrMGO+K09Ukqzq/qq4EsDJH\nYyGiBPEdfkSBYvITBYrJTxQoJj9RoJj8RIHK6m5/ezU1NWHPnj2xsVWrVpn99u3bF9t+6NAhs8/q\n1avNWFlZmRmbN2+eGRs7dmxsO0uA1BHxyk8UKCY/UaCY/ESBYvITBYrJTxSoRO/2nzlzBhs3boyN\neWvu9e8fv0bI+fPnzT7XXnutGbMmCgFAfX29GXvppZdi28eNG2f2YSWArlS88hMFislPFCgmP1Gg\nmPxEgWLyEwWKyU8UqERLfUVFRSgtLY2NWe0A0NDQENt+4cKFjMZx/PhxM3bu3DkzNn/+/Nj2p59+\n2uwzevRoM8YyIBUSr/xEgWLyEwWKyU8UKCY/UaCY/ESBYvITBSqrUp+I1AE4AaAFQLOq1njPLy0t\nxZ133hkbGzlypNmvb9++se2ffvqp2Wf37t1m7OTJk2asW7duZuzo0aOx7W+//bbZ5/HHHzdj1tdF\nlIRc1PnvUNVfcvB5iChB/LOfKFDZJr8C+LuIrBcRe81rIrriZPtn/1RV3S8ifQCsEpGfVPWr1k+I\nfinMA4BBgwZleTgiypWsrvyquj/6vwHABwAmxTxngarWqGpNZWVlNocjohzKOPlFpLuIlF58DOAu\nAFtyNTAiyq9s/uyvAvBBNDPtGgBvq6pde2uD95Jg9uzZse1Hjhwx+yxbtiyjcXgLf544cSK2fefO\nnWYfrwzobQ3mzXIkyoWMk19VdwH4XQ7HQkQJYqmPKFBMfqJAMfmJAsXkJwoUk58oUIku4JmpoqKi\n2Pbbb7/d7FNWVmbGrAVBAaCxsdGMdeoU/7vS2zPQm3nYvXt3MzZnzhwz5u1rSJQuXvmJAsXkJwoU\nk58oUEx+okAx+YkC1SHu9l9zTfwwhw8fbvaZOHGiGfviiy/MWFNTkxk7ffp0bHvnzp3NPn369DFj\nq1evNmPeHf3bbrut3cfztgazzi9gV1oAQFXNmKUjbFGWydcFAC0tLWbMm4Rmnf8uXbqYfbxYunjl\nJwoUk58oUEx+okAx+YkCxeQnChSTnyhQHaLUZ+natasZu/XWW83YunXrzFhzc3O7Y14fb4LRiBEj\nzJg36aeurs6M7dmzJ7bdK7F5W5R54/DWOywvL49t98qi+WCV7c6dO2f2OXv2rBnzzqN17gFgw4YN\nZuyGG26IbR84cKDZx/qZ88qNl+OVnyhQTH6iQDH5iQLF5CcKFJOfKFBMfqJAtVnqE5FFAO4G0KCq\n46K2cgBLAQwBUAfgflX9LX/DNMdmxqzyCQCUlJSYMa9UYs3q87bWOnnyZEbjOHbsmBnbtGmTGbPK\nn15Zrrq6ut2fry1WKc0ri3qzC73SnDfzsG/fvrHtFy5cMPvs3bvXjGVaMrV+dgBg7dq1se3ez441\ne9NbT/Jy6Vz5/wJg5mVtzwD4XFVHAvg8+piIOpA2k19VvwJw+WTkewAsjh4vBnBvjsdFRHmW6Wv+\nKlU9GD0+hNSOvUTUgWR9w09T7580lz4RkXkiUisitd6a+ESUrEyT/7CIVANA9L+5C4aqLlDVGlWt\nqayszPBwRJRrmSb/CgCPRo8fBfBhboZDRElJp9T3DoDpACpEpB7ACwBeAbBMROYC2APg/nwOMhP9\n+/c3Y9aMM8BfGNGakVZcXGz28co127dvN2PelmJeuax3795mzOItWFlVZd/O8c6jxZvVV19fb8as\nchgA7N6924yNGjUqtn3ChAlmH+/cr1mzxowNGjTIjHmlZ6vs6M2otM5jexZIbTP5VfUhI3Rn2kch\noisO3+FHFCgmP1GgmPxEgWLyEwWKyU8UqA6xgGcme6f169fPjHnlq8OHD5sxq7TlzZjzSn0HDx40\nY6dOnTJjmSy46c089M6vt8ecNzPOilVUVJh9vP0VvQVZvRmQn332WWz75s2bzT7Dhg0zY1OnTjVj\n3jtYv/vuOzNmzWb0SqnWG+baU+rjlZ8oUEx+okAx+YkCxeQnChSTnyhQTH6iQHWIUp9VvvBKVN7s\ntmnTppmx9957z4xZZTRvlp2375sX88po3sKfR48ebfexvFLlmTNnzJi3uKc103Hfvn1mH+/rsmbn\nAcDo0aPN2OTJk2Pb3333XbPP119/bcasBUEB/+fAW8vCKmN6Mw+tBTy9GaaX45WfKFBMfqJAMfmJ\nAsXkJwoUk58oUB3ibr/Fm8TgrRU3e/ZsM/bll1+aMWvLKG/dv+uvv96M7d+/34x5E4K87alOnDgR\n297U1GT28cbvbaHlbQ1lfc6ePXuafbw71V6VoK6uzoxZE4m8apA3mcmrfgwcONCMeevxDR48OLbd\nuqMP2OeKE3uIqE1MfqJAMfmJAsXkJwoUk58oUEx+okCls13XIgB3A2hQ1XFR24sA/gDg4qJlz6nq\nynwNMhPepJ+RI0eaMW+Ntk8++SS2vVevXmYfr2TnTajxnD592ox5k0ss3nqB3rEymWB04MABs0+n\nTva1yCtheWXAcePGxbZ7E6e8MrE3sccrA3olTmtrOa/02Z6SniWdK/9fAMyMaX9NVSdE/66oxCei\ntrWZ/Kr6FQD7XQ9E1CFl85r/CRHZJCKLRMT+u5eIrkiZJv8bAIYDmADgIIBXrSeKyDwRqRWRWm9d\ncyJKVkbJr6qHVbVFVS8AeBPAJOe5C1S1RlVrvNVMiChZGSW/iFS3+vA+AFtyMxwiSko6pb53AEwH\nUCEi9QBeADBdRCYAUAB1AP6YxzHmnFfKmTNnjhlbv359bLtXDvNKfV4pp6ysLKPPaZU4vZl73vi9\ndfq8WX1WGdPr433N3hqEx48fN2Pbtm2Lbfdm4N18881mzJuJ+euvv5qxGTNmmLEePXqYsXxqM/lV\n9aGY5oV5GAsRJYjv8CMKFJOfKFBMfqJAMfmJAsXkJwpUh17A0+PNevJm/I0YMcKM3XXXXbHtS5cu\nTX9grbS0tGTUL5NFML3FIHfu3GnGvFmC3mxGa4zejErv8+3YscOMrVu3zoxZZV2vdNjQ0GDGvLLo\npEnme90wZcoUM2bNZvR+TpOa1UdEVyEmP1GgmPxEgWLyEwWKyU8UKCY/UaCu2lJfprwSysMPPxzb\n7u1nt3z58qzHdDlvwc3S0tLYdm9xzH79+pkxb6bagAEDzJg1U80aHwCMGTPGjHl7Hlp7KALATz/9\nFNvulfNGjx5txu69914zdt1115kxb1al9TPnfc9ygVd+okAx+YkCxeQnChSTnyhQTH6iQAV5tz/T\nSREVFRWx7XPnzjX7eOsFLlmyxIw1NTWZse7du5uxbt26xbYfOWLvu+KtJehta1VfX2/Gxo8fH9vu\nTYzZvn27GRs0aJAZmz59uhmzKhLeunnTpk0zY17VIdOfq6Kiooz6ZYtXfqJAMfmJAsXkJwoUk58o\nUEx+okAx+YkClc52XQMB/BVAFVLbcy1Q1ddFpBzAUgBDkNqy635V/S1/Q01GJuWakpISM+Zt/+VN\n3Fi5cqUZ89afs2Le1+VtyeWVFb1xWKVFb2LP0aNHzZhXDvPOv7U5rFfezHRrM6+s642xUNK58jcD\n+JOqjgEwGcDjIjIGwDMAPlfVkQA+jz4mog6izeRX1YOquiF6fALANgD9AdwDYHH0tMUA7LmORHTF\naddrfhEZAuBGAGsAVKnqwSh0CKmXBUTUQaSd/CLSA8ByAE+p6iV7ImtqgfHYRcZFZJ6I1IpIbWNj\nY1aDJaLcSSv5RaQYqcR/S1Xfj5oPi0h1FK8GELs0iqouUNUaVa2xbr4QUfLaTH5J3SZeCGCbqs5v\nFVoB4NHo8aMAPsz98IgoX9KZ1XcrgEcAbBaRH6K25wC8AmCZiMwFsAfA/fkZ4pXDKpdlWkZ77LHH\nzNhNN91kxp5//nkzdvLkydj23r17m328NQgzWXvO62eND/C3L/NmF3r9rBKht7WWV470vmav1JeL\n7bVyrc3kV9WvAVgjvzO3wyGipPAdfkSBYvITBYrJTxQoJj9RoJj8RIEKcgHPJHklHq/EZi2ACQCz\nZs0yYwsXLoxt92bneTPOvK3BvDLg3r17Y9sHDx5s9jl//rwZ80qEXjn1lltuiW0fO3as2adXr15m\nzPuedTS88hMFislPFCgmP1GgmPxEgWLyEwWKyU8UqKunbnGV8UpKDzzwgBn7/vvvY9t//vlns483\nG82bTdfc3GzGrNJcXV1dRuMYMWKEGZs8ebIZmzJlSmz7tddea/a5Emfg5QOv/ESBYvITBYrJTxQo\nJj9RoJj8RIHi3f48S61qnlvl5eVm7Mknn4xtf/XVV80+3h14r+owbtw4M3bHHXfEti9fvtzsM3Pm\nTDM2Y8YMM5bpOoMW73t2NVUCeOUnChSTnyhQTH6iQDH5iQLF5CcKFJOfKFBtlvpEZCCAvyK1BbcC\nWKCqr4vIiwD+AODi1rvPqerKfA00KR29zGOtTffyyy+bfZ599lkzdvbsWTM2d+5cM2aVIz/66COz\nz7Bhw8yYV87zWN9Pb8JSpt/njvDz0Vo6df5mAH9S1Q0iUgpgvYisimKvqep/5G94RJQv6ezVdxDA\nwejxCRHZBqB/vgdGRPnVrtf8IjIEwI0A1kRNT4jIJhFZJCL2esdEdMVJO/lFpAeA5QCeUtXjAN4A\nMBzABKT+Moh9/6iIzBORWhGpbWxsjHsKERVAWskvIsVIJf5bqvo+AKjqYVVtUdULAN4EELvhuaou\nUNUaVa2prKzM1biJKEttJr+kbmEuBLBNVee3aq9u9bT7AGzJ/fCIKF/Sudt/K4BHAGwWkR+itucA\nPCQiE5Aq/9UB+GNeRpgHXjkvk5hX4snHrL5MlJaWmrG7777bjHlr5/Xvb9/33b9/f2x7nz59zD7V\n1dVmrKioyIx5Olr5LUnp3O3/GkDcGezwNX2ikPEdfkSBYvITBYrJTxQoJj9RoJj8RIG6ahfwzEeJ\nzZoJ1tLS0u4+QOZj9MpX1hZaW7duNfscOXLEjHnvylyzZo0Z27NnT2z7uXPnzD6nTp0yY2VlZWbM\nY52rfMzc62hlRV75iQLF5CcKFJOfKFBMfqJAMfmJAsXkJwrUVVvq83glNq80d/78+Xa1A8Dx48fN\nmFWWA4Bu3bqZMa9cZu27980335h9Bg8ebMZ27dplxrwSYUlJSWx79+7dzT67d+82Yz179jRj3n6C\n1mxAb5bg1VTO8/DKTxQoJj9RoJj8RIFi8hMFislPFCgmP1GgWOq7jDdDzyqxnT592uyzYcMGM7Z2\n7Voz5pWvvHLZgQMHYtu9Pfe8Y40fPz6jcXTt2jW23SoBen0AoKmpyYx55VnreJmW+jz5mKWZT7zy\nEwWKyU8UKCY/UaCY/ESBYvITBarNu/0i0gXAVwBKoue/p6oviMhQAEsA9AawHsAjqmrPOEmYdwe1\nUyf7d14m/by75RMmTDBj5eXlZmzHjh1mzGOtdde3b1+zj3eXvaqqyox5d+6Li4tj2zt37mz26dKl\nixnz+nkx63uW6d3+jnZH35POlb8JwAxV/R1S23HPFJHJAP4M4DVVHQHgNwBz8zdMIsq1NpNfU05G\nHxZH/xTADADvRe2LAdyblxESUV6k9ZpfRIqiHXobAKwCsBPAUVW9OCG9HoC9ZSsRXXHSSn5VbVHV\nCQAGAJgEYHS6BxCReSJSKyK13hrwRJSsdt3tV9WjAL4AMAVATxG5eKdrAIDYDdlVdYGq1qhqTWVl\nZVaDJaLcaTP5RaRSRHpGj7sC+D2AbUj9EvjH6GmPAvgwX4MkotxLZ2JPNYDFIlKE1C+LZar6PyKy\nFcASEfk3AN8DWJjHceaUV3bxynZW2cgqawF++cqbGDNq1Cgzlkkpypv84pW9cl0W9Y7lncdMvi9A\nstt15WOLuHxqM/lVdROAG2PadyH1+p+IOiC+w48oUEx+okAx+YkCxeQnChSTnyhQkmR5QkQaAeyJ\nPqwA8EtiB7dxHJfiOC7V0cYxWFXTejddosl/yYFFalW1piAH5zg4Do6Df/YThYrJTxSoQib/ggIe\nuzWO41Icx6Wu2nEU7DU/ERUW/+wnClRBkl9EZorI/4nIDhF5phBjiMZRJyKbReQHEalN8LiLRKRB\nRLa0aisXkVUi8nP0f68CjeNFEdkfnZMfRGRWAuMYKCJfiMhWEflRRJ6M2hM9J844Ej0nItJFRNaK\nyMZoHC9F7UNFZE2UN0tFxF65NB2qmug/AEVILQM2DEBnABsBjEl6HNFY6gBUFOC4twGYCGBLq7Z/\nB/BM9PgZAH8u0DheBPDPCZ+PagATo8elALYDGJP0OXHGkeg5ASAAekSPiwGsATAZwDIAD0bt/wXg\nn7I5TiGu/JMA7FDVXZpa6nsJgHsKMI6CUdWvABy5rPkepBZCBRJaENUYR+JU9aCqbogen0BqsZj+\nSPicOONIlKbkfdHcQiR/fwD7Wn1cyMU/FcDfRWS9iMwr0BguqlLVg9HjQwDsBfPz7wkR2RS9LMj7\ny4/WRGQIUutHrEEBz8ll4wASPidJLJob+g2/qao6EcA/AHhcRG4r9ICA1G9+pH4xFcIbAIYjtUfD\nQQCvJnVgEekBYDmAp1T1eOtYkuckZhyJnxPNYtHcdBUi+fcDGNjqY3Pxz3xT1f3R/w0APkBhVyY6\nLCLVABD931CIQajq4egH7wKAN5HQORGRYqQS7i1VfT9qTvycxI2jUOckOna7F81NVyGSfx2AkdGd\ny84AHgSwIulBiEh3ESm9+BjAXQC2+L3yagVSC6ECBVwQ9WKyRe5DAudEUgvjLQSwTVXntwolek6s\ncSR9ThJbNDepO5iX3c2chdSd1J0A/qVAYxiGVKVhI4AfkxwHgHeQ+vPxPFKv3eYitefh5wB+BvAZ\ngPICjeNvADYD2IRU8lUnMI6pSP1JvwnAD9G/WUmfE2cciZ4TADcgtSjuJqR+0fxrq5/ZtQB2AHgX\nQEk2x+E7/IgCFfoNP6JgMfmJAsXkJwoUk58oUEx+okAx+YkCxeQnChSTnyhQ/w8QHD824eMvegAA\nAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10b0843d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Grayscale chicken image\n",
"from scipy.misc import imread, imresize, imsave\n",
"img = imread('chicken.png', flatten=True) # Convert to grayscale\n",
"imsave('chicken_bw.png', img)\n",
"img = imresize(img, (32, 32))\n",
"imsave('chicken_bw_small.png', img)\n",
"img = img / 255.\n",
"plt.imshow(img, cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"# Pass image through CNN\n",
"chicken_tensor = torch.Tensor(np.expand_dims(np.expand_dims(img, 0), 0))\n",
"probs = cnn(Variable(chicken_tensor))"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 5.21423068e-26 5.14779191e-16 8.10420988e-06 3.05335366e-16\n",
" 6.03861604e-31 9.99991894e-01 1.32513748e-12 2.47835812e-22\n",
" 8.77226292e-10 4.37692135e-28]\n"
]
}
],
"source": [
"print(probs.data[0].numpy())"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"bad_fives = []\n",
"for ex in low_conf:\n",
" if ex[1] == 5:\n",
" bad_fives.append(ex)"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x10c064a10>"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEitJREFUeJzt3XmMVWWax/HvYwmCgBRriYiDC2bA\nZpFUcO+0jXSUmKjJaNSM8Q/TdCZtgonGuCTTzPyjPRk1/uUEl7RNxGVcIppWKFfQIMqiAg00SzBN\nCZSNFousVTzzxz1kCnOfU7eq7lLl+/skpG69z33rvBz41bn3vPe8x9wdEUnPKbUegIjUhsIvkiiF\nXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJ1Kk96Wxm1wJPAnXAM+7+aCfP18cJRSrM3a2U51l3\nP95rZnXA34BZwA7gC+A2d/9rTh+FX6TCSg1/T172zwC2uPs2dz8KvATc0IOfJyJV1JPwjwX+3uH7\nHVmbiPQBPXrPXwozmwPMqfR2RKRrehL+ZmBch+/PztpO4u7zgfmg9/wivUlPXvZ/AUwws3PNrD9w\nK7CoPMMSkUrr9pHf3dvM7G5gMYWpvufcfX3ZRiYiFdXtqb5ubUwv+0UqrhpTfSLShyn8IolS+EUS\npfCLJErhF0mUwi+SKIVfJFEKv0iiFH6RRCn8IolS+EUSpfCLJErhF0mUwi+SKIVfJFEKv0iiFH6R\nRCn8Iomq+NLdPyennFL8d2VDQ0PY59JLLw1rZ555Zlg7dOhQWNuzZ09YO3jwYNH2urq6sE99fX1Y\ny/u7Rfsjz8aNG8PamjVrwlpLS0uXtyX5dOQXSZTCL5IohV8kUQq/SKIUfpFEKfwiierRVJ+ZbQf2\nA+1Am7s3lmNQvVU0XTZ58uSwz/333x/WpkyZEtb27dsX1r755puw1traWrS9f//+YZ8xY8aEtQkT\nJoS1vCnHAwcOFG1fuHBh2Gfr1q1hTVN95VeOef6r3f0fZfg5IlJFetkvkqieht+BJWa2yszmlGNA\nIlIdPX3Zf6W7N5vZaKDJzDa6+9KOT8h+KegXg0gv06Mjv7s3Z19bgDeAGUWeM9/dG3/uJwNF+ppu\nh9/MBpnZkBOPgd8A68o1MBGprJ687G8A3jCzEz9nobu/W5ZR9TFDhw4Na3nTaAMGDAhrR44cCWtn\nnXVWl7eX/TsVdfz48bC2efPmsLZs2bIu9/vggw/CPnlTmFJ+3Q6/u28DppZxLCJSRZrqE0mUwi+S\nKIVfJFEKv0iiFH6RRGkBzy6IpsTyFp7MW4gzb/ptyZIlYa2pqSmstbe3h7VIW1tbWFu9enVY27Zt\nW1g7fPhwl8ch1aUjv0iiFH6RRCn8IolS+EUSpfCLJEpn+7sgOpOed9Y77wIddw9r0W23IH92Yd26\n8l5YmTd7kHdBkPR+OvKLJErhF0mUwi+SKIVfJFEKv0iiFH6RRGmqrwzyprw2bNgQ1s4555ywdskl\nl4S15cuXh7W8C3FEOtKRXyRRCr9IohR+kUQp/CKJUvhFEqXwiySq06k+M3sOuB5ocfdfZG3DgZeB\n8cB24BZ3/6Fyw+y7Fi1aFNYmTpwY1s4999ywNmnSpLA2YsSIou179uwJ+0iaSjny/wm49idtDwDv\nu/sE4P3sexHpQzoNv7svBb7/SfMNwPPZ4+eBG8s8LhGpsO6+529w953Z410U7tgrIn1Ijz/e6+5u\nZuGSNGY2B5jT0+2ISHl198i/28zGAGRfW6Inuvt8d29098ZubktEKqC74V8E3Jk9vhN4szzDEZFq\nKWWq70XgV8BIM9sB/AF4FHjFzO4CvgFuqeQg+7KPPvoorF1//fVhbfz48WGtoSE+xTJq1Kii7Zrq\nk5/qNPzufltQmlnmsYhIFekTfiKJUvhFEqXwiyRK4RdJlMIvkigt4FlhO3fuDGubN28Oa3v37g1r\nedOAkydPLtq+cePGsI+kSUd+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkihN9dXQ+vXrw1pzc3NYu/DC\nC8PazJnFr7dqamoK+7S2toa17qqvry/aPnLkyLDPWWedFdaGDRsW1tzDtWQ4fPhw0fZVq1aFffL2\nR3t7e1jra3TkF0mUwi+SKIVfJFEKv0iiFH6RROlsfw0tXbo0rF133XVh7bLLLgtrF110UdH2vFt8\nrVmzJqwNGTKkW7Xp06cXbb/88svDPldffXVYmzp1alhra2sLay0txReWfvDBB8M+ixcvDmvfffdd\nWDt+/HhY64105BdJlMIvkiiFXyRRCr9IohR+kUQp/CKJKuV2Xc8B1wMt7v6LrG0e8FvgxLzHQ+7+\nl0oN8udq9+7dYe2HH34Ia3kXl5xxxhlF26MpQAAzC2uzZs0Ka3lTc1OmTCnaPmjQoLBP3pTdgQMH\nwlrehT2jR48u2v7YY4+FfY4cORLW8i6Qyvs3yxtjrZRy5P8TcG2R9ifcfVr2R8EX6WM6Db+7LwW+\nr8JYRKSKevKe/24z+9rMnjOz+GJrEemVuhv+p4DzgWnATiB8A2Vmc8xspZmt7Oa2RKQCuhV+d9/t\n7u3ufhx4GpiR89z57t7o7o3dHaSIlF+3wm9mYzp8exOwrjzDEZFqsc6mIMzsReBXwEhgN/CH7Ptp\ngAPbgd+5e3xfqv//Wb1vvqOXmjt3brdq48aNK9p+9OjRsE/eFFv//v3DWl1dXVg79dTis8h56xYu\nWbIkrH3yySfdGscjjzxStD3vlmfRlYAA9913X1h76623wlreVGW5uXs8d9tBp/P87n5bkeZnuzwi\nEelV9Ak/kUQp/CKJUvhFEqXwiyRK4RdJlBbw7KU+/fTTsNbYGH9e6vbbby/aPnDgwLBP3sKT+/fv\nD2ufffZZWIum7fJuk7V9+/awlnfF3CmnxMewHTt2FG1/+OGHwz5XXnllWIv2L8C3334b1j7++OOw\nVis68oskSuEXSZTCL5IohV8kUQq/SKIUfpFEaaqvl4quigPo169fWIsW49y7d2/Y5+233w5rr776\nalhrbm4Oa9G0V96UXd7CmXnTkXkLkK5evbpo+7x588I+CxcuDGt59xqcPHlyWFuxYkVYO3z4cFir\nJB35RRKl8IskSuEXSZTCL5IohV8kUTrb30uNHDkyrI0YMSKsRevxRRe4ACxYsCCsffjhh2Ht2LFj\nYa2a8tahjNYuzLvA6ODBg2Gtvr4+rE2cODGsnX322WFty5YtYa2SdOQXSZTCL5IohV8kUQq/SKIU\nfpFEKfwiiep0qs/MxgF/Bhoo3J5rvrs/aWbDgZeB8RRu2XWLu8dXbUiX5F3kkreuXnSxzfLly8M+\nn3/+eVjLu5VXXxCt79fQ0BD2ybuoKk/e1Gdv3I+lHPnbgHvdfRJwKfB7M5sEPAC87+4TgPez70Wk\nj+g0/O6+091XZ4/3AxuAscANwPPZ054HbqzUIEWk/Lr0nt/MxgMXAyuAhg535t1F4W2BiPQRJb+5\nMbPBwGvAPe6+r+MCCu7u0e23zWwOMKenAxWR8irpyG9m/SgE/wV3fz1r3m1mY7L6GKDoTc3dfb67\nN7p7fKcJEam6TsNvhUP8s8AGd3+8Q2kRcGf2+E7gzfIPT0QqpZSX/VcAdwBrzezLrO0h4FHgFTO7\nC/gGuKUyQ+zbTj/99LB24YUXhrW8K8vefffdsLZ58+ai7e+9917Yp7W1Naz1ddG03axZs8I+Q4cO\nDWt56+1t2rQprOXdiqxWOg2/u38CRCskzizvcESkWvQJP5FEKfwiiVL4RRKl8IskSuEXSZQW8Kyw\n8847L6zde++9Ye3jjz8Oay+99FJYa29vL9r+448/hn36urzbl40bN65o+x133BH2GTZsWFjbunVr\nWGtpKfo5t15LR36RRCn8IolS+EUSpfCLJErhF0mUwi+SKE31Vdj06dPD2rRp08Ja3tRc3mKc0VV9\n/fv3D/vkOX78eFiLFseE/PvnRTouENOVbeXdB2/u3LlF26+44oqwT96+yruicu3atWGtN9KRXyRR\nCr9IohR+kUQp/CKJUvhFEqWz/RU2duzYsJa3vt8111zTre199tlnXe6Td0Z/27ZtYS26aAZgz549\nRdvz/s719fVhLe/2WjNnxqvJXXXVVUXb887oL1u2LKy98847Ya03rtOXR0d+kUQp/CKJUvhFEqXw\niyRK4RdJlMIvkqhOp/rMbBzwZwq34HZgvrs/aWbzgN8C32VPfcjd/1KpgfZVeRfotLW1hbXx48eH\ntZtvvjmszZ49u2h73kUzeVN9x44dC2vRrbAgXksw7wKdurq6sJa3Tt+gQYPC2pEjR4q2L1iwIOzz\nzDPPhLX169eHtbx91RuVMs/fBtzr7qvNbAiwysyastoT7v7flRueiFRKKffq2wnszB7vN7MNQPzJ\nFRHpE7r0nt/MxgMXAyuyprvN7Gsze87M4vWORaTXKTn8ZjYYeA24x933AU8B5wPTKLwyeCzoN8fM\nVprZyjKMV0TKpKTwm1k/CsF/wd1fB3D33e7e7u7HgaeBGcX6uvt8d29098ZyDVpEeq7T8FvhNPGz\nwAZ3f7xD+5gOT7sJWFf+4YlIpZRytv8K4A5grZl9mbU9BNxmZtMoTP9tB35XkRH2cYsXLw5ro0aN\nCmtTp04Na3lXpA0cOLBo+/Dhw8M+eevt7dq1K6zljX/o0KFF2/fv3x/2ia4EhHjKDmDv3r1hramp\nqWh73jqIW7Zs6dY4+ppSzvZ/AhSbJNacvkgfpk/4iSRK4RdJlMIvkiiFXyRRCr9Ioqw7t1Xq9sbM\nqrexXiLvyrcLLrggrI0ePTqs5V39dtpppxVtHzx4cNgn7/9Aa2trWDvjjDPC2oABA4q2Hz16NOxz\n4MCBsJbX79ChQ2Ft06ZNRdvz/l7VzEQluHt8CWcHOvKLJErhF0mUwi+SKIVfJFEKv0iiFH6RRGmq\nT+RnRlN9IpJL4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkiiFXyRRpdyr\nb4CZfW5mX5nZejP7j6z9XDNbYWZbzOxlM4vvISUivU4pR/4jwK/dfSqF23Ffa2aXAn8EnnD3C4Af\ngLsqN0wRKbdOw+8FJ5ZV7Zf9ceDXwKtZ+/PAjRUZoYhUREnv+c2sLrtDbwvQBGwFWt29LXvKDmBs\nZYYoIpVQUvjdvd3dpwFnAzOAfy51A2Y2x8xWmtnKbo5RRCqgS2f73b0V+BC4DKg3sxN3pDgbaA76\nzHf3Rndv7NFIRaSsSjnbP8rM6rPHA4FZwAYKvwT+JXvancCblRqkiJRfp2v4mdkUCif06ij8snjF\n3f/TzM4DXgKGA2uAf3X3I538LK3hJ1Jhpa7hpwU8RX5mtICniORS+EUSpfCLJErhF0mUwi+SqFM7\nf0pZ/QP4Jns8Mvu+1jSOk2kcJ+tr4/inUn9gVaf6Ttqw2cre8Kk/jUPjSHUcetkvkiiFXyRRtQz/\n/BpuuyON42Qax8l+tuOo2Xt+EaktvewXSVRNwm9m15rZpmzxzwdqMYZsHNvNbK2ZfVnNxUbM7Dkz\nazGzdR3ahptZk5ltzr4Oq9E45plZc7ZPvjSz2VUYxzgz+9DM/potEjs3a6/qPskZR1X3SdUWzXX3\nqv6hcGnwVuA8oD/wFTCp2uPIxrIdGFmD7f4SmA6s69D2X8AD2eMHgD/WaBzzgPuqvD/GANOzx0OA\nvwGTqr1PcsZR1X0CGDA4e9wPWAFcCrwC3Jq1/w/wbz3ZTi2O/DOALe6+zd2PUlgT4IYajKNm3H0p\n8P1Pmm+gsG4CVGlB1GAcVefuO919dfZ4P4XFYsZS5X2SM46q8oKKL5pbi/CPBf7e4ftaLv7pwBIz\nW2Vmc2o0hhMa3H1n9ngX0FDDsdxtZl9nbwsq/vajIzMbD1xM4WhXs33yk3FAlfdJNRbNTf2E35Xu\nPh24Dvi9mf2y1gOCwm9+Cr+YauEp4HwK92jYCTxWrQ2b2WDgNeAed9/XsVbNfVJkHFXfJ96DRXNL\nVYvwNwPjOnwfLv5Zae7enH1tAd6gsJNrZbeZjQHIvrbUYhDuvjv7j3cceJoq7RMz60chcC+4++tZ\nc9X3SbFx1GqfZNvu8qK5papF+L8AJmRnLvsDtwKLqj0IMxtkZkNOPAZ+A6zL71VRiygshAo1XBD1\nRNgyN1GFfWJmBjwLbHD3xzuUqrpPonFUe59UbdHcap3B/MnZzNkUzqRuBR6u0RjOozDT8BWwvprj\nAF6k8PLxGIX3bncBI4D3gc3Ae8DwGo1jAbAW+JpC+MZUYRxXUnhJ/zXwZfZndrX3Sc44qrpPgCkU\nFsX9msIvmn/v8H/2c2AL8L/AaT3Zjj7hJ5Ko1E/4iSRL4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJ\nUvhFEvV/qeJPz11SHToAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10bc4e210>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(bad_fives[3][0][0, 0, :, :].numpy(), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
],
"text/vnd.plotly.v1+html": [
"<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import plotly\n",
"import plotly.graph_objs as go\n",
"import plotly.offline as py\n",
"from plotly import tools\n",
"\n",
"py.init_notebook_mode(connected=True)\n",
"\n",
"# Plot the softmaxes\n",
"def bar_plot(x, y, filename='bar_plot.html', interactive=True):\n",
" layout = go.Layout(xaxis=dict(title='Digit'), yaxis=dict(range=[0, 1], title='p(y|x)'), \n",
" width=800, height=400, font=dict(size=18))\n",
" fig = go.Figure(data=[go.Bar(x=x, y=y)], layout=layout)\n",
" if interactive:\n",
" py.iplot(fig)\n",
" else:\n",
" py.plot(fig, filename=filename, auto_open=False)\n",
"\n",
"bar_plot(range(10), probs.data[0].numpy(), 'true_softmax.html', interactive=False)\n",
"approx_unif = 5 + np.random.rand(10)\n",
"bar_plot(range(10), approx_unif / np.sum(approx_unif), 'expected_softmax.html', interactive=False)"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"# Test that image of chicken is not just a fluke\n",
"import dataloaders\n",
"reload(dataloaders)\n",
"from dataloaders import get_fashion_mnist_dataloaders\n",
"_, fashion_loader = get_fashion_mnist_dataloaders(1)"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"# Calculate confidence scores for fashion data\n",
"from torch.autograd import Variable\n",
"num_scores = 1e5\n",
"max_probs = []\n",
"max_prob_idx = []\n",
"for i, (img, _) in enumerate(fashion_loader):\n",
" probs = cnn(Variable(img))\n",
" max_prob, idx = torch.max(probs, 1)\n",
" max_prob_idx.append(idx)\n",
" max_prob = max_prob.data[0]\n",
" max_probs.append(max_prob)\n",
" if i >= num_scores:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 182,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.6347\n",
"0.7428\n",
"0.8889\n"
]
}
],
"source": [
"# Which fraction of fashion mnist images get classified with more than 99% accuracy?\n",
"print(np.sum(np.array(max_probs) > 0.99) / float(len(max_probs)))\n",
"print(np.sum(np.array(max_probs) > 0.95) / float(len(max_probs)))\n",
"print(np.sum(np.array(max_probs) > 0.75) / float(len(max_probs)))"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {},
"outputs": [],
"source": [
"# Merge the two lists and sort them and see where different things end up\n",
"fashion_and_mnist_probs = np.array(max_probs + label_probs)\n",
"sorted_idx = np.argsort(fashion_and_mnist_probs)"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [],
"source": [
"# sorted_idx = sorted_idx[::100]"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [],
"source": [
"subsample = 200\n",
"x = range(len(sorted_idx))[::subsample]\n",
"y = fashion_and_mnist_probs[sorted_idx][::subsample]\n",
"color_bool = sorted_idx[::subsample] >= 10000 # 1 if it is mnist, 0 if it is fashion\n",
"colors = ['red' if b < 0.5 else 'blue' for b in color_bool]\n",
"\n",
"def big_bar_plot(x, y, colors=colors, filename='bar_plot.html', interactive=True):\n",
" #layout = go.Layout(xaxis=dict(title='Digit'), yaxis=dict(range=[0, 1], title='p(y|x)'), \n",
" # width=800, height=400, font=dict(size=18))\n",
" fig = go.Figure(data=[go.Bar(x=x, y=y, marker=dict(color=colors))])#, layout=layout)\n",
" if interactive:\n",
" py.iplot(fig)\n",
" else:\n",
" py.plot(fig, filename=filename, auto_open=False)"
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, True, False, False, False, False, False,\n",
" False, True, True, False, False, False, False, False, True,\n",
" True, False, False, True, False, False, False, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, False,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, False, True, True, False, False, False,\n",
" False, False, False, False, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True, True], dtype=bool)"
]
},
"execution_count": 176,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"color_bool"
]
},
{
"cell_type": "code",
"execution_count": 177,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"data": [
{
"marker": {
"color": [
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"blue",
"red",
"red",
"red",
"red",
"red",
"red",
"blue",
"blue",
"red",
"red",
"red",
"red",
"red",
"blue",
"blue",
"red",
"red",
"blue",
"red",
"red",
"red",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"red",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"red",
"blue",
"blue",
"red",
"red",
"red",
"red",
"red",
"red",
"red",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue",
"blue"
]
},
"type": "bar",
"x": [
0,
200,
400,
600,
800,
1000,
1200,
1400,
1600,
1800,
2000,
2200,
2400,
2600,
2800,
3000,
3200,
3400,
3600,
3800,
4000,
4200,
4400,
4600,
4800,
5000,
5200,
5400,
5600,
5800,
6000,
6200,
6400,
6600,
6800,
7000,
7200,
7400,
7600,
7800,
8000,
8200,
8400,
8600,
8800,
9000,
9200,
9400,
9600,
9800,
10000,
10200,
10400,
10600,
10800,
11000,
11200,
11400,
11600,
11800,
12000,
12200,
12400,
12600,
12800,
13000,
13200,
13400,
13600,
13800,
14000,
14200,
14400,
14600,
14800,
15000,
15200,
15400,
15600,
15800,
16000,
16200,
16400,
16600,
16800,
17000,
17200,
17400,
17600,
17800,
18000,
18200,
18400,
18600,
18800,
19000,
19200,
19400,
19600,
19800
],
"y": [
0.3061355650424957,
0.5143805146217346,
0.5719403624534607,
0.6228542327880859,
0.6767367124557495,
0.7216569781303406,
0.7659032940864563,
0.8029114603996277,
0.8369762301445007,
0.8670023083686829,
0.8938217759132385,
0.9159621000289917,
0.9324735403060913,
0.9476518034934998,
0.9592661261558533,
0.9686580300331116,
0.9767616987228394,
0.982664942741394,
0.9874541759490967,
0.9906601309776306,
0.9931192994117737,
0.9949137568473816,
0.9965328574180603,
0.9975701570510864,
0.998376727104187,
0.998912513256073,
0.9992625713348389,
0.9994962811470032,
0.9996829628944397,
0.9997949004173279,
0.9998713731765747,
0.9999213218688965,
0.9999499320983887,
0.9999693632125854,
0.9999820590019226,
0.9999905824661255,
0.9999957084655762,
0.9999978542327881,
0.9999989867210388,
0.999999463558197,
0.9999997615814209,
0.9999998807907104,
0.9999999403953552,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1
]
}
],
"layout": {}
},
"text/html": [
"<div id=\"7961d549-988b-4b21-961e-e2afd1f88ee7\" style=\"height: 525px; width: 100%;\" class=\"plotly-graph-div\"></div><script type=\"text/javascript\">require([\"plotly\"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL=\"https://plot.ly\";Plotly.newPlot(\"7961d549-988b-4b21-961e-e2afd1f88ee7\", [{\"marker\": {\"color\": [\"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\"]}, \"y\": [0.3061355650424957, 0.5143805146217346, 0.5719403624534607, 0.6228542327880859, 0.6767367124557495, 0.7216569781303406, 0.7659032940864563, 0.8029114603996277, 0.8369762301445007, 0.8670023083686829, 0.8938217759132385, 0.9159621000289917, 0.9324735403060913, 0.9476518034934998, 0.9592661261558533, 0.9686580300331116, 0.9767616987228394, 0.982664942741394, 0.9874541759490967, 0.9906601309776306, 0.9931192994117737, 0.9949137568473816, 0.9965328574180603, 0.9975701570510864, 0.998376727104187, 0.998912513256073, 0.9992625713348389, 0.9994962811470032, 0.9996829628944397, 0.9997949004173279, 0.9998713731765747, 0.9999213218688965, 0.9999499320983887, 0.9999693632125854, 0.9999820590019226, 0.9999905824661255, 0.9999957084655762, 0.9999978542327881, 0.9999989867210388, 0.999999463558197, 0.9999997615814209, 0.9999998807907104, 0.9999999403953552, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], \"type\": \"bar\", \"x\": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800]}], {}, {\"linkText\": \"Export to plot.ly\", \"showLink\": true})});</script>"
],
"text/vnd.plotly.v1+html": [
"<div id=\"7961d549-988b-4b21-961e-e2afd1f88ee7\" style=\"height: 525px; width: 100%;\" class=\"plotly-graph-div\"></div><script type=\"text/javascript\">require([\"plotly\"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL=\"https://plot.ly\";Plotly.newPlot(\"7961d549-988b-4b21-961e-e2afd1f88ee7\", [{\"marker\": {\"color\": [\"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\"]}, \"y\": [0.3061355650424957, 0.5143805146217346, 0.5719403624534607, 0.6228542327880859, 0.6767367124557495, 0.7216569781303406, 0.7659032940864563, 0.8029114603996277, 0.8369762301445007, 0.8670023083686829, 0.8938217759132385, 0.9159621000289917, 0.9324735403060913, 0.9476518034934998, 0.9592661261558533, 0.9686580300331116, 0.9767616987228394, 0.982664942741394, 0.9874541759490967, 0.9906601309776306, 0.9931192994117737, 0.9949137568473816, 0.9965328574180603, 0.9975701570510864, 0.998376727104187, 0.998912513256073, 0.9992625713348389, 0.9994962811470032, 0.9996829628944397, 0.9997949004173279, 0.9998713731765747, 0.9999213218688965, 0.9999499320983887, 0.9999693632125854, 0.9999820590019226, 0.9999905824661255, 0.9999957084655762, 0.9999978542327881, 0.9999989867210388, 0.999999463558197, 0.9999997615814209, 0.9999998807907104, 0.9999999403953552, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], \"type\": \"bar\", \"x\": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800]}], {}, {\"linkText\": \"Export to plot.ly\", \"showLink\": true})});</script>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"big_bar_plot(x, y, colors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.img_to_features = nn.Sequential(
nn.Conv2d(1, 16, (4, 4), stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, (4, 4), stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, (4, 4), stride=2, padding=1),
nn.ReLU()
)
self.features_to_probs = nn.Sequential(
nn.Linear(64 * 4 * 4, 10),
nn.Softmax()
)
def forward(self, x):
features = self.img_to_features(x)
probs = self.features_to_probs(features.view(features.size(0), -1))
return probs
import torch
import torch.nn as nn
from dataloaders import get_mnist_dataloaders
from models import CNN
from torch.autograd import Variable
def eval_model(model, data):
correct = 0
total = 0
for imgs, labels in data:
imgs = Variable(imgs).cuda()
probs = model(imgs)
_, predicted = torch.max(probs.data, 1)
total += labels.size(0)
correct += (predicted.cpu() == labels).sum()
return float(correct) / total
def train_epoch(model, data):
for i, (imgs, labels) in enumerate(data):
imgs = Variable(imgs).cuda()
labels = Variable(labels).cuda()
optimizer.zero_grad()
probs = model(imgs)
loss = criterion(probs, labels)
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print("Iteration: {}, Loss: {}".format(i, loss.data[0]))
# Get datasets
train_loader, test_loader = get_mnist_dataloaders(batch_size=100)
# Create model
cnn = CNN()
cnn.cuda()
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# Train the Model
for epoch in range(10):
train_epoch(cnn, train_loader)
test_acc = eval_model(cnn, test_loader)
print("Epoch {}, Test Accuracy: {}\n".format(epoch + 1, test_acc))
# Save the Trained Model
torch.save(cnn.state_dict(), 'mnist_cnn.pt')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment