Code for mnist chicken experiments. Note that this code was written a long time ago and is not maintained.
Last active
September 12, 2020 18:44
-
-
Save EmilienDupont/99c7127dedb921a5a1f96d37d23c0d4b to your computer and use it in GitHub Desktop.
mnist chicken code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import DataLoader | |
from torchvision import datasets, transforms | |
def get_mnist_dataloaders(batch_size=128): | |
"""MNIST dataloader with (32, 32) images.""" | |
all_transforms = transforms.Compose([ | |
transforms.Resize(32), | |
transforms.ToTensor() | |
]) | |
train_data = datasets.MNIST('../ml-sandbox/data', train=True, download=True, | |
transform=all_transforms) | |
test_data = datasets.MNIST('../ml-sandbox/data', train=False, | |
transform=all_transforms) | |
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
return train_loader, test_loader | |
def get_fashion_mnist_dataloaders(batch_size=128): | |
"""FashionMNIST dataloader with (32, 32) images.""" | |
all_transforms = transforms.Compose([ | |
transforms.Resize(32), | |
transforms.ToTensor() | |
]) | |
train_data = datasets.FashionMNIST('../ml-sandbox/fashion_data', | |
train=True, download=True, | |
transform=all_transforms) | |
test_data = datasets.FashionMNIST('../ml-sandbox/fashion_data', | |
train=False, transform=all_transforms) | |
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) | |
return train_loader, test_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Load the trained model\n", | |
"import torch\n", | |
"from models import CNN\n", | |
"\n", | |
"cnn = CNN()\n", | |
"cnn.load_state_dict(torch.load('mnist_cnn.pt', map_location=lambda storage, loc: storage))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Get test loader\n", | |
"from dataloaders import get_mnist_dataloaders\n", | |
"_, test_loader = get_mnist_dataloaders(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 155, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Calculate confidence scores, only for correctly classified digits\n", | |
"from torch.autograd import Variable\n", | |
"num_scores = 1e5\n", | |
"label_probs = []\n", | |
"low_conf = []\n", | |
"for i, (img, label) in enumerate(test_loader):\n", | |
" probs = cnn(Variable(img))\n", | |
" _, idx = torch.max(probs, 1)\n", | |
" # Only append if label was correctly predicted\n", | |
" if idx.data.numpy()[0] == label.numpy()[0]:\n", | |
" label_prob = probs[0, label[0]]\n", | |
" label_prob = label_prob.data[0]\n", | |
" label_probs.append(label_prob)\n", | |
" if label_prob < .8:\n", | |
" low_conf.append((img, label[0], label_prob))\n", | |
" if i >= num_scores:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 156, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x116332490>]" | |
] | |
}, | |
"execution_count": 156, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAGKhJREFUeJzt3XtwnfV95/H3V3dZliXbkuWLZGyC\nwYiLIfES0jTh0kCMm0CbpI2ZTtN2aNxmQre7G9KBTZZN2dnpZCe7JNkh2bhTmmknhQW6bBzGs5AY\nUhpawCbGxpcYC9vYlmVLsu63c3TO+e4f55E5ErZ1bB2dR+c5n9eMRs/l53O+P/no459/z83cHRER\niZaSsAsQEZHcU7iLiESQwl1EJIIU7iIiEaRwFxGJIIW7iEgEKdxFRCJI4S4iEkEKdxGRCCoL640b\nGhp81apVYb29iEhBeuONN7rdvXG6dqGF+6pVq9i5c2dYby8iUpDM7N1s2mlaRkQkghTuIiIRpHAX\nEYkghbuISAQp3EVEImjacDezx82s08z2nme/mdl3zazNzPaY2QdzX6aIiFyMbEbuPwQ2XGD/XcCa\n4Gsz8P2ZlyUiIjMx7Xnu7v6yma26QJN7gL/z9PP6XjWzejNb5u4dOapRQpZMOfFEingiRSyRZDzl\npFJOyp1k8D3l6XbJlOMOSZ9YTn9PerD97LKTTPFeG8/4s1Pb+HuvkwqeCjn18ZATq45PWb/w/smv\nceE/m29OOG8cRn9De9hnSH+5v3F1E+ta6mf1PXJxEdMK4HjG+olg2/vC3cw2kx7ds3Llyhy8tWQj\nlkjSORCjo3+Mjv5ROvrH6B2JMzCa4MxQjKFYgtHxJKPxJGPjSUbHk0GQpwM9kdJzdiW6zPL/nksW\nVBVEuGfN3bcAWwDWr1+vxMiBkXiCd8+M0DkY48xQjFMDY3T0jdHRP8apgVFO9Y/RPRR/35+rKC1h\nQXUZi2oqWFBVzvzKMhrnV1JdUUplWQmVZenvFcFy+nt6vbzUKDGjtCT9vaTEKDWjxHhvuYTJbc4u\nZ7QJ2p2zzaTX5732ZhD8Mk78Uk78blqw4b31if02aZ0p+8/V5nyvmW9hBE/6fcPqseRKLsK9HWjJ\nWG8OtkkOJZIpTvSOcqR7mMPdwxztHuZAxwA73+19X9u66nKW1VWxtK6K61bUsXRBNcvqqlhWX8Wy\nuiqaFlRRW1UeQi9EJF9yEe5bgfvN7Engw0C/5ttnzt053jPK468c4eVDXRw7MzJpeqS2qozLG2r4\n0q0fYO3SWpoXVrNwXgVNC6qoqQztlkEiMkdMmwJm9gRwK9BgZieA/wyUA7j7/wK2ARuBNmAE+KPZ\nKjbqjveMsPtEH8+8cYJDp4do7xultMS47aolbLhmKasaari8oYbVDTUsqqnQf51F5LyyOVvm3mn2\nO/DlnFVUZA6eGuT7P29jT3s/h7uGAVhRX80NK+v544+t5va1S7hscU3IVYpIodH/30P07Z+9zfde\neoeq8hI+fPliPr++heub61m/aiHlpbp4WEQuncI9JL881su3f3aIT1y9hG9+9noWz68MuyQRiRCF\ne565O79o6+Zrz+5lSW0l39l0ow6AikjOKVXy6JfHevnLn+xn9/E+qstLefTz6xTsIjIrlCx5kkim\neODp3RzuGubBu9byRx9dRWVZadhliUhEKdzz5IX9pzncNcz3fu+DbLxuWdjliEjEKdxnWSyR5IGn\n97DtrQ5aFlXzyWuWhl2SiBQBhfsse+Qn+/nJ7pP87vpmvnTrFZSW6MIjEZl9CvdZkkw533upjR+9\ndow/ueVyHrrr6rBLEpEionCfJd/dfojvbD/Ex69s5Kt3XhV2OSJSZBTus+Cf3u7i+z9/hztbm/jB\n739I94ARkbzTNe6z4Af/9A5NdZX8t89dr2AXkVAo3HOof2Scv33lCDuO9rDx2mXUz6sIuyQRKVKa\nlsmhx185wne2H6KirIS7dC67iIRI4Z4jyZTz3J6T3NBSzz9+6dd0yqOIhErTMjny8I/38k7XMHde\n06RgF5HQKdxz4MnXj/Gj145xR2sTf/LxD4RdjoiIwn2mjnQP8x+ffYvayjL+02+2atQuInOC5txn\noHsoxgNP76ayrJQXH7iVxlo9cENE5gaF+yUaG0+yacurtHUO8Y1PtyrYRWROUbhfgrHxJJ/6n7+g\nrXOIb372Oj7/b1aGXZKIyCSac79I48kUf/HMHgW7iMxpCveL9K3nD7J190keuPNKBbuIzFkK94sw\nnkzxxOvH+PS65dx/+5qwyxEROS+F+0X4l3fOMDCW4NPX69YCIjK3Kdwvwj+89i4L55Xz8Ssbwy5F\nROSCFO5ZGo4l2H6gk89+sJmq8tKwyxERuSCFe5Z2vttLIuXccpVG7SIy9yncszCeTPGjV9+lvNT4\n0GULwy5HRGRaCvcsfP3Zvbyw/zRfufMq5lXoui8RmfsU7tNIJFM8t+ckv/OhZv70Ft3xUUQKg8J9\nGvs7BhiOJ3WGjIgUFIX7NF4/0gPATasXhVyJiEj2sgp3M9tgZgfNrM3MHjzH/svMbLuZ7TGzn5tZ\nc+5LDcdrR3q4bPE8mhZUhV2KiEjWpg13MysFHgPuAlqBe82sdUqzbwF/5+7XA48Af5XrQsMwNp5k\nx9EeblqlUbuIFJZsRu43AW3uftjd48CTwD1T2rQCLwbLL51jf0H665cP0zcyzm/fuCLsUkRELko2\n4b4COJ6xfiLYlmk38Jlg+beBWjNbPPPywrXv5ABXLJnPr13REHYpIiIXJVcHVB8AbjGzXcAtQDuQ\nnNrIzDab2U4z29nV1ZWjt549PSNxFtdUhF2GiMhFyybc24GWjPXmYNtZ7n7S3T/j7jcCXwu29U19\nIXff4u7r3X19Y+PcP7WwZzjOIoW7iBSgbMJ9B7DGzFabWQWwCdia2cDMGsxs4rUeAh7PbZnh6B2O\ns1DhLiIFaNpwd/cEcD/wPHAAeMrd95nZI2Z2d9DsVuCgmb0NNAH/dZbqzZtUyunVtIyIFKisbpTi\n7tuAbVO2PZyx/AzwTG5LC1f/6Dgph4XzFO4iUnh0hep59IzEATTnLiIFSeF+Hp0DMQAWz1e4i0jh\nUbifx6HOQQDWLKkNuRIRkYuncD+PX50apK66nKYFlWGXIiJy0RTu53Cke5gf72rnqqW1mFnY5YiI\nXDSF+xTxRIovPP4aJSXGFz92edjliIhcEj0zbopDnYMc7xnl0c+v447WprDLERG5JBq5T/GrjvSB\n1OtW1IVciYjIpVO4T3GgY4DKshJWLa4JuxQRkUumcM+QSKZ4Yf9p1jXXU1aqH42IFC4lWIb//tO3\nOdYzwn0fWx12KSIiM6JwD4wnUzz+iyN8et1y7tSBVBEpcAr3wKHTQ8QSKT5x9RKd2y4iBU/hDrg7\nf/3PhwG4vrk+5GpERGZO4Q6cGhjj2V3ttCyqZtXieWGXIyIyYwp3oHd4HICvbbxaUzIiEgkKd2Bw\nLB3utVXlIVciIpIbCndgYCwBwAKFu4hEhMKdzJG7brUjItGgcAcGRhXuIhItCndgMJiW0Zy7iESF\nwh0YGBunqryEijL9OEQkGpRmpEfuOpgqIlGicCc9ctd8u4hESdGHezyRYm/7AEtqq8IuRUQkZ4o6\n3FMp56vP7OZYzwibP67npYpIdBTtXEQy5Tz2Uhs/fvMkD9x5JbetXRJ2SSIiOVOU4b77eB9fePx1\n+kfH2XjdUr582xVhlyQiklNFGe5//+q79I+O819+61o+c+MK3SxMRCKn6MLd3dl+4DS/dcNyfv/m\ny8IuR0RkVhTdAdXuoTi9I+Pc0KKHcohIdBVduB/vHQGgZZEeyiEi0VV84d6jcBeR6CvacG9eWB1y\nJSIisyercDezDWZ20MzazOzBc+xfaWYvmdkuM9tjZhtzX2puHD0zQsP8SuZVFN2xZBEpItOGu5mV\nAo8BdwGtwL1m1jql2deBp9z9RmAT8L1cF5orBzoGuHpZbdhliIjMqmxG7jcBbe5+2N3jwJPAPVPa\nOLAgWK4DTuauxNyJJ1IcOj1E67IF0zcWESlg2cxNrACOZ6yfAD48pc03gBfM7M+AGuATOakux450\nDxNPpmhdrnAXkWjL1QHVe4EfunszsBH4ezN732ub2WYz22lmO7u6unL01tnr6B8FoHmhzpQRkWjL\nJtzbgZaM9eZgW6b7gKcA3P1fgSqgYeoLufsWd1/v7usbGxsvreIZ6BqMAdA4vzLv7y0ikk/ZhPsO\nYI2ZrTazCtIHTLdOaXMM+A0AM7uadLjnf2g+je6hOAANtRUhVyIiMrumDXd3TwD3A88DB0ifFbPP\nzB4xs7uDZl8Bvmhmu4EngD90d5+toi9V12CMmopSnQYpIpGXVcq5+zZg25RtD2cs7wc+mtvScq97\nKEZDraZkRCT6iuoK1a7BmObbRaQoFFW4dw/FaFC4i0gRKKpw7xqK0ahpGREpAkUT7vFEir6RcY3c\nRaQoFE24nxlOn+Ou0yBFpBgUTbh3D6bPcdcBVREpBkUT7l1DYwA6FVJEikLRhHt7b/q+Mhq5i0gx\nKIpw7xmO860X3uYDjTUsq6sKuxwRkVlXFNfhb3n5MINj4zz9px+hrLQo/j0TkSIX+aRzd7a91cEt\nVzZyZZOewCQixSHy4f5Wez/Heka4/eqmsEsREcmbSId7KuX82yd2saS2kruuXRp2OSIieRPpcH/l\nnW6Onhnh659q1ZWpIlJUIh3u2w90Ul1eyiev0ZSMiBSXSIf7rmO9rGupo7KsNOxSRETyKrLhPhpP\nsu/kADeuXBh2KSIieRfZcH9h/ykSKedja973nG4RkciLbLj/v72nWF5Xxc2rF4ddiohI3kU23Dv6\nx/jAkvmUlFjYpYiI5F1kw/3McIzFNbp3u4gUp8iGe89QnEU1OrddRIpTJMN9bDzJcDzJ4vkauYtI\ncYpkuPcMp5+6tEjTMiJSpCIZ7meGFO4iUtwiGe5/+y9HAHRAVUSKVuTCvXsoxrO72rluRR3XrqgL\nuxwRkVBELtx/frALd/irz1xHVbnuKSMixSly4b7/5ADV5aVcs3xB2KWIiIQmcuHeNxpnUU0FZroy\nVUSKV+TCfWB0nLrq8rDLEBEJVeTCvW9E4S4iErlw7x8dp36ewl1Eilvkwr1P4S4ikl24m9kGMzto\nZm1m9uA59j9qZm8GX2+bWV/uS52eu9M/Os4CTcuISJErm66BmZUCjwF3ACeAHWa21d33T7Rx93+f\n0f7PgBtnodZpjY2niCdS1FfrylQRKW7ZjNxvAtrc/bC7x4EngXsu0P5e4IlcFHex+kbT95TRtIyI\nFLtswn0FcDxj/USw7X3M7DJgNfDiefZvNrOdZrazq6vrYmud1pHuYQCW11fn/LVFRApJrg+obgKe\ncffkuXa6+xZ3X+/u6xsbG3P81rC3vR+Aa3V1qogUuWzCvR1oyVhvDradyyZCmpIB2Ns+wPK6KhbP\n1xOYRKS4ZRPuO4A1ZrbazCpIB/jWqY3MbC2wEPjX3JaYvf0dA7Qu150gRUSmDXd3TwD3A88DB4Cn\n3H2fmT1iZndnNN0EPOnuPjulXlgskeRI9zBrl9aG8fYiInPKtKdCArj7NmDblG0PT1n/Ru7Kunht\nnUMkU85VCncRkehcoXrw1CCARu4iIkQs3CtKS1jdUBN2KSIioYtMuB84NcgVS+ZTVhqZLomIXLLI\nJOHBUwOakhERCUQm3LsGYyyrrwq7DBGROSES4Z5Ipkg5VJbpgdgiIhCRcI8nUwBUlEWiOyIiMxaJ\nNIwngnDXwVQRESAi4R4Lwr2yPBLdERGZsUikoUbuIiKTRSINJ0bumnMXEUmLRBrGEunbx+tsGRGR\ntEiE+8S0TKVG7iIiQMTCXdMyIiJpkUjDmEbuIiKTRCINNXIXEZksEmmoK1RFRCaLRBq+d0BVZ8uI\niEBEwn3iVEiN3EVE0iKRhjoVUkRkskikoa5QFRGZLBJpGNO9ZUREJolEGmpaRkRkskikYTyZoqK0\nBDMLuxQRkTkhEuEeG09pvl1EJEMkEnFwbJyaSp3jLiIyIRLhfrx3hJaF88IuQ0RkzohGuPeMsnKR\nwl1EZELBh3s8keJk/ygtCncRkbMKPtzb+0ZxR+EuIpKh4MN9b3s/AFc11YZciYjI3FHw4b7rWB9V\n5SWsXaZwFxGZUPDh/lZ7H9cur6Nctx4QETkrq0Q0sw1mdtDM2szswfO0+V0z229m+8zsH3Jb5vmd\nGY7TtKAqX28nIlIQyqZrYGalwGPAHcAJYIeZbXX3/Rlt1gAPAR91914zWzJbBU81HEvoAiYRkSmy\nGbnfBLS5+2F3jwNPAvdMafNF4DF37wVw987clnl+I7EkNZXT/hslIlJUsgn3FcDxjPUTwbZMVwJX\nmtkrZvaqmW3IVYEX4u4MxxPMV7iLiEySq1QsA9YAtwLNwMtmdp2792U2MrPNwGaAlStXzvhNR8eT\npByN3EVEpshm5N4OtGSsNwfbMp0Atrr7uLsfAd4mHfaTuPsWd1/v7usbGxsvteazhmIJAGoqNOcu\nIpIpm3DfAawxs9VmVgFsArZOafN/SY/aMbMG0tM0h3NY5zkNx9IPxtbIXURksmnD3d0TwP3A88AB\n4Cl332dmj5jZ3UGz54EzZrYfeAn4qrufma2iJwxPjNwV7iIik2SViu6+Ddg2ZdvDGcsO/IfgK28m\nwl0HVEVEJivoyzqH4xq5i4icS0GH+1Aw5z5fFzGJiExS0OE+MS0zr0IjdxGRTAUd7u29o5SWGI21\nlWGXIiIypxR0uB/uHmLlonm6I6SIyBQFnYqHu4ZZ3VATdhkiInNOwYZ7IpniSPcwlyvcRUTep2DD\n/dXDPcQSKT502cKwSxERmXMKNtxf2H+KeRWl3LY2b7eOFxEpGAUb7kfPjLBmyXyqynWOu4jIVAUb\n7qf6R/V4PRGR8yjgcB9jWZ3CXUTkXAoy3IdjCQbGEiytqw67FBGROakgw/3UwBgAS+t0ZaqIyLkU\nZLif7g/CfYFG7iIi51KQ4d4RhLvm3EVEzq0gw/29aRmFu4jIuRRkuHf0j1I/r1znuIuInEdBhvup\n/hhLdY67iMh5FWS4d/SPakpGROQCCi7c+0fGOXhqkLVLF4RdiojInFVw4b79V6dJpJxPXtMUdiki\nInNWwYV7bVU5d7Q2sa65PuxSRETmrIJ7svQdrU3c0apRu4jIhRTcyF1ERKancBcRiSCFu4hIBCnc\nRUQiSOEuIhJBCncRkQhSuIuIRJDCXUQkgszdw3ljsy7g3Uv84w1Adw7LKQTqc3FQn4vDTPp8mbs3\nTtcotHCfCTPb6e7rw64jn9Tn4qA+F4d89FnTMiIiEaRwFxGJoEIN9y1hFxAC9bk4qM/FYdb7XJBz\n7iIicmGFOnIXEZELKLhwN7MNZnbQzNrM7MGw68kVM3vczDrNbG/GtkVm9lMzOxR8XxhsNzP7bvAz\n2GNmHwyv8ktnZi1m9pKZ7TezfWb258H2yPbbzKrM7HUz2x30+S+D7avN7LWgb//bzCqC7ZXBeluw\nf1WY9V8qMys1s11m9lywHun+ApjZUTN7y8zeNLOdwba8fbYLKtzNrBR4DLgLaAXuNbPWcKvKmR8C\nG6ZsexDY7u5rgO3BOqT7vyb42gx8P0815loC+Iq7twI3A18O/j6j3O8YcLu7rwNuADaY2c3AN4FH\n3f0KoBe4L2h/H9AbbH80aFeI/hw4kLEe9f5OuM3db8g47TF/n213L5gv4CPA8xnrDwEPhV1XDvu3\nCtibsX4QWBYsLwMOBss/AO49V7tC/gJ+DNxRLP0G5gG/BD5M+oKWsmD72c858DzwkWC5LGhnYdd+\nkf1sDoLsduA5wKLc34x+HwUapmzL22e7oEbuwArgeMb6iWBbVDW5e0ewfAqYeL5g5H4OwX+/bwRe\nI+L9DqYo3gQ6gZ8C7wB97p4ImmT262yfg/39wOL8Vjxj3wb+AkgF64uJdn8nOPCCmb1hZpuDbXn7\nbBfcM1SLlbu7mUXy1CYzmw/8I/Dv3H3AzM7ui2K/3T0J3GBm9cCzwNqQS5o1ZvYpoNPd3zCzW8Ou\nJ89+3d3bzWwJ8FMz+1Xmztn+bBfayL0daMlYbw62RdVpM1sGEHzvDLZH5udgZuWkg/1H7v5/gs2R\n7zeAu/cBL5Gelqg3s4nBVma/zvY52F8HnMlzqTPxUeBuMzsKPEl6auY7RLe/Z7l7e/C9k/Q/4jeR\nx892oYX7DmBNcKS9AtgEbA25ptm0FfiDYPkPSM9JT2z/QnCE/WagP+O/egXD0kP0vwEOuPv/yNgV\n2X6bWWMwYsfMqkkfYzhAOuQ/FzSb2ueJn8XngBc9mJQtBO7+kLs3u/sq0r+vL7r77xHR/k4wsxoz\nq51YBu4E9pLPz3bYBx0u4SDFRuBt0vOUXwu7nhz26wmgAxgnPd92H+m5xu3AIeBnwKKgrZE+a+gd\n4C1gfdj1X2Kff530vOQe4M3ga2OU+w1cD+wK+rwXeDjYfjnwOtAGPA1UBturgvW2YP/lYfdhBn2/\nFXiuGPob9G938LVvIqvy+dnWFaoiIhFUaNMyIiKSBYW7iEgEKdxFRCJI4S4iEkEKdxGRCFK4i4hE\nkMJdRCSCFO4iIhH0/wFmefIpm8xTwgAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x10bb532d0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Plot lowest confidence scores\n", | |
"%matplotlib inline\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"\n", | |
"plt.plot(np.sort(label_probs)[:500])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 98, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(761, 767)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x10b2188d0>" | |
] | |
}, | |
"execution_count": 98, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFh1JREFUeJzt3Xts1dWWB/DvopbyqoXSUsr7LfLw\nIlaECIqYaxhiooaJr2A0IZebiSaa3DFRJ44a5w/vZMT4x8QJDuRyb1RA0YgjOheNxogKFJSH4CCP\nAuXVKvKGQsuaP86PpJDfWj09j9+h7O8nIZzudXZ/u7929dfzW2fvLaoKIgpPp0IPgIgKg8lPFCgm\nP1GgmPxEgWLyEwWKyU8UKCY/UaCY/ESBYvITBeqabDqLyEwArwMoAvDfqvqK9/yKigodMmRIbKy5\nudnsd+zYsdj2kpISs0+PHj28oRBdlerq6vDLL79IOs/NOPlFpAjAfwL4PYB6AOtEZIWqbrX6DBky\nBOvWrYuNHTlyxDzWxx9/HNs+dOhQs8/UqVPNmEckrfOWE95bqzMdRyZv107ya6b8qqmpSfu52fzZ\nPwnADlXdparnACwBcE8Wn4+IEpRN8vcHsK/Vx/VRGxF1AHm/4Sci80SkVkRqGxsb8304IkpTNsm/\nH8DAVh8PiNouoaoLVLVGVWsqKyuzOBwR5VI2yb8OwEgRGSoinQE8CGBFboZFRPmW8d1+VW0WkScA\n/C9Spb5FqvpjGv1i29evX2/2+fbbb2Pbz549a/YpLi42Y1a5EQCqqqrMGO+K09Ukqzq/qq4EsDJH\nYyGiBPEdfkSBYvITBYrJTxQoJj9RoJj8RIHK6m5/ezU1NWHPnj2xsVWrVpn99u3bF9t+6NAhs8/q\n1avNWFlZmRmbN2+eGRs7dmxsO0uA1BHxyk8UKCY/UaCY/ESBYvITBYrJTxSoRO/2nzlzBhs3boyN\neWvu9e8fv0bI+fPnzT7XXnutGbMmCgFAfX29GXvppZdi28eNG2f2YSWArlS88hMFislPFCgmP1Gg\nmPxEgWLyEwWKyU8UqERLfUVFRSgtLY2NWe0A0NDQENt+4cKFjMZx/PhxM3bu3DkzNn/+/Nj2p59+\n2uwzevRoM8YyIBUSr/xEgWLyEwWKyU8UKCY/UaCY/ESBYvITBSqrUp+I1AE4AaAFQLOq1njPLy0t\nxZ133hkbGzlypNmvb9++se2ffvqp2Wf37t1m7OTJk2asW7duZuzo0aOx7W+//bbZ5/HHHzdj1tdF\nlIRc1PnvUNVfcvB5iChB/LOfKFDZJr8C+LuIrBcRe81rIrriZPtn/1RV3S8ifQCsEpGfVPWr1k+I\nfinMA4BBgwZleTgiypWsrvyquj/6vwHABwAmxTxngarWqGpNZWVlNocjohzKOPlFpLuIlF58DOAu\nAFtyNTAiyq9s/uyvAvBBNDPtGgBvq6pde2uD95Jg9uzZse1Hjhwx+yxbtiyjcXgLf544cSK2fefO\nnWYfrwzobQ3mzXIkyoWMk19VdwH4XQ7HQkQJYqmPKFBMfqJAMfmJAsXkJwoUk58oUIku4JmpoqKi\n2Pbbb7/d7FNWVmbGrAVBAaCxsdGMdeoU/7vS2zPQm3nYvXt3MzZnzhwz5u1rSJQuXvmJAsXkJwoU\nk58oUEx+okAx+YkC1SHu9l9zTfwwhw8fbvaZOHGiGfviiy/MWFNTkxk7ffp0bHvnzp3NPn369DFj\nq1evNmPeHf3bbrut3cfztgazzi9gV1oAQFXNmKUjbFGWydcFAC0tLWbMm4Rmnf8uXbqYfbxYunjl\nJwoUk58oUEx+okAx+YkCxeQnChSTnyhQHaLUZ+natasZu/XWW83YunXrzFhzc3O7Y14fb4LRiBEj\nzJg36aeurs6M7dmzJ7bdK7F5W5R54/DWOywvL49t98qi+WCV7c6dO2f2OXv2rBnzzqN17gFgw4YN\nZuyGG26IbR84cKDZx/qZ88qNl+OVnyhQTH6iQDH5iQLF5CcKFJOfKFBMfqJAtVnqE5FFAO4G0KCq\n46K2cgBLAQwBUAfgflX9LX/DNMdmxqzyCQCUlJSYMa9UYs3q87bWOnnyZEbjOHbsmBnbtGmTGbPK\nn15Zrrq6ut2fry1WKc0ri3qzC73SnDfzsG/fvrHtFy5cMPvs3bvXjGVaMrV+dgBg7dq1se3ez441\ne9NbT/Jy6Vz5/wJg5mVtzwD4XFVHAvg8+piIOpA2k19VvwJw+WTkewAsjh4vBnBvjsdFRHmW6Wv+\nKlU9GD0+hNSOvUTUgWR9w09T7580lz4RkXkiUisitd6a+ESUrEyT/7CIVANA9L+5C4aqLlDVGlWt\nqayszPBwRJRrmSb/CgCPRo8fBfBhboZDRElJp9T3DoDpACpEpB7ACwBeAbBMROYC2APg/nwOMhP9\n+/c3Y9aMM8BfGNGakVZcXGz28co127dvN2PelmJeuax3795mzOItWFlVZd/O8c6jxZvVV19fb8as\nchgA7N6924yNGjUqtn3ChAlmH+/cr1mzxowNGjTIjHmlZ6vs6M2otM5jexZIbTP5VfUhI3Rn2kch\noisO3+FHFCgmP1GgmPxEgWLyEwWKyU8UqA6xgGcme6f169fPjHnlq8OHD5sxq7TlzZjzSn0HDx40\nY6dOnTJjmSy46c089M6vt8ecNzPOilVUVJh9vP0VvQVZvRmQn332WWz75s2bzT7Dhg0zY1OnTjVj\n3jtYv/vuOzNmzWb0SqnWG+baU+rjlZ8oUEx+okAx+YkCxeQnChSTnyhQTH6iQHWIUp9VvvBKVN7s\ntmnTppmx9957z4xZZTRvlp2375sX88po3sKfR48ebfexvFLlmTNnzJi3uKc103Hfvn1mH+/rsmbn\nAcDo0aPN2OTJk2Pb3333XbPP119/bcasBUEB/+fAW8vCKmN6Mw+tBTy9GaaX45WfKFBMfqJAMfmJ\nAsXkJwoUk58oUB3ibr/Fm8TgrRU3e/ZsM/bll1+aMWvLKG/dv+uvv96M7d+/34x5E4K87alOnDgR\n297U1GT28cbvbaHlbQ1lfc6ePXuafbw71V6VoK6uzoxZE4m8apA3mcmrfgwcONCMeevxDR48OLbd\nuqMP2OeKE3uIqE1MfqJAMfmJAsXkJwoUk58oUEx+okCls13XIgB3A2hQ1XFR24sA/gDg4qJlz6nq\nynwNMhPepJ+RI0eaMW+Ntk8++SS2vVevXmYfr2TnTajxnD592ox5k0ss3nqB3rEymWB04MABs0+n\nTva1yCtheWXAcePGxbZ7E6e8MrE3sccrA3olTmtrOa/02Z6SniWdK/9fAMyMaX9NVSdE/66oxCei\ntrWZ/Kr6FQD7XQ9E1CFl85r/CRHZJCKLRMT+u5eIrkiZJv8bAIYDmADgIIBXrSeKyDwRqRWRWm9d\ncyJKVkbJr6qHVbVFVS8AeBPAJOe5C1S1RlVrvNVMiChZGSW/iFS3+vA+AFtyMxwiSko6pb53AEwH\nUCEi9QBeADBdRCYAUAB1AP6YxzHmnFfKmTNnjhlbv359bLtXDvNKfV4pp6ysLKPPaZU4vZl73vi9\ndfq8WX1WGdPr433N3hqEx48fN2Pbtm2Lbfdm4N18881mzJuJ+euvv5qxGTNmmLEePXqYsXxqM/lV\n9aGY5oV5GAsRJYjv8CMKFJOfKFBMfqJAMfmJAsXkJwpUh17A0+PNevJm/I0YMcKM3XXXXbHtS5cu\nTX9grbS0tGTUL5NFML3FIHfu3GnGvFmC3mxGa4zejErv8+3YscOMrVu3zoxZZV2vdNjQ0GDGvLLo\npEnme90wZcoUM2bNZvR+TpOa1UdEVyEmP1GgmPxEgWLyEwWKyU8UKCY/UaCu2lJfprwSysMPPxzb\n7u1nt3z58qzHdDlvwc3S0tLYdm9xzH79+pkxb6bagAEDzJg1U80aHwCMGTPGjHl7Hlp7KALATz/9\nFNvulfNGjx5txu69914zdt1115kxb1al9TPnfc9ygVd+okAx+YkCxeQnChSTnyhQTH6iQAV5tz/T\nSREVFRWx7XPnzjX7eOsFLlmyxIw1NTWZse7du5uxbt26xbYfOWLvu+KtJehta1VfX2/Gxo8fH9vu\nTYzZvn27GRs0aJAZmz59uhmzKhLeunnTpk0zY17VIdOfq6Kiooz6ZYtXfqJAMfmJAsXkJwoUk58o\nUEx+okAx+YkClc52XQMB/BVAFVLbcy1Q1ddFpBzAUgBDkNqy635V/S1/Q01GJuWakpISM+Zt/+VN\n3Fi5cqUZ89afs2Le1+VtyeWVFb1xWKVFb2LP0aNHzZhXDvPOv7U5rFfezHRrM6+s642xUNK58jcD\n+JOqjgEwGcDjIjIGwDMAPlfVkQA+jz4mog6izeRX1YOquiF6fALANgD9AdwDYHH0tMUA7LmORHTF\naddrfhEZAuBGAGsAVKnqwSh0CKmXBUTUQaSd/CLSA8ByAE+p6iV7ImtqgfHYRcZFZJ6I1IpIbWNj\nY1aDJaLcSSv5RaQYqcR/S1Xfj5oPi0h1FK8GELs0iqouUNUaVa2xbr4QUfLaTH5J3SZeCGCbqs5v\nFVoB4NHo8aMAPsz98IgoX9KZ1XcrgEcAbBaRH6K25wC8AmCZiMwFsAfA/fkZ4pXDKpdlWkZ77LHH\nzNhNN91kxp5//nkzdvLkydj23r17m328NQgzWXvO62eND/C3L/NmF3r9rBKht7WWV470vmav1JeL\n7bVyrc3kV9WvAVgjvzO3wyGipPAdfkSBYvITBYrJTxQoJj9RoJj8RIEKcgHPJHklHq/EZi2ACQCz\nZs0yYwsXLoxt92bneTPOvK3BvDLg3r17Y9sHDx5s9jl//rwZ80qEXjn1lltuiW0fO3as2adXr15m\nzPuedTS88hMFislPFCgmP1GgmPxEgWLyEwWKyU8UqKunbnGV8UpKDzzwgBn7/vvvY9t//vlns483\nG82bTdfc3GzGrNJcXV1dRuMYMWKEGZs8ebIZmzJlSmz7tddea/a5Emfg5QOv/ESBYvITBYrJTxQo\nJj9RoJj8RIHi3f48S61qnlvl5eVm7Mknn4xtf/XVV80+3h14r+owbtw4M3bHHXfEti9fvtzsM3Pm\nTDM2Y8YMM5bpOoMW73t2NVUCeOUnChSTnyhQTH6iQDH5iQLF5CcKFJOfKFBtlvpEZCCAvyK1BbcC\nWKCqr4vIiwD+AODi1rvPqerKfA00KR29zGOtTffyyy+bfZ599lkzdvbsWTM2d+5cM2aVIz/66COz\nz7Bhw8yYV87zWN9Pb8JSpt/njvDz0Vo6df5mAH9S1Q0iUgpgvYisimKvqep/5G94RJQv6ezVdxDA\nwejxCRHZBqB/vgdGRPnVrtf8IjIEwI0A1kRNT4jIJhFZJCL2esdEdMVJO/lFpAeA5QCeUtXjAN4A\nMBzABKT+Moh9/6iIzBORWhGpbWxsjHsKERVAWskvIsVIJf5bqvo+AKjqYVVtUdULAN4EELvhuaou\nUNUaVa2prKzM1biJKEttJr+kbmEuBLBNVee3aq9u9bT7AGzJ/fCIKF/Sudt/K4BHAGwWkR+itucA\nPCQiE5Aq/9UB+GNeRpgHXjkvk5hX4snHrL5MlJaWmrG7777bjHlr5/Xvb9/33b9/f2x7nz59zD7V\n1dVmrKioyIx5Olr5LUnp3O3/GkDcGezwNX2ikPEdfkSBYvITBYrJTxQoJj9RoJj8RIG6ahfwzEeJ\nzZoJ1tLS0u4+QOZj9MpX1hZaW7duNfscOXLEjHnvylyzZo0Z27NnT2z7uXPnzD6nTp0yY2VlZWbM\nY52rfMzc62hlRV75iQLF5CcKFJOfKFBMfqJAMfmJAsXkJwrUVVvq83glNq80d/78+Xa1A8Dx48fN\nmFWWA4Bu3bqZMa9cZu27980335h9Bg8ebMZ27dplxrwSYUlJSWx79+7dzT67d+82Yz179jRj3n6C\n1mxAb5bg1VTO8/DKTxQoJj9RoJj8RIFi8hMFislPFCgmP1GgWOq7jDdDzyqxnT592uyzYcMGM7Z2\n7Voz5pWvvHLZgQMHYtu9Pfe8Y40fPz6jcXTt2jW23SoBen0AoKmpyYx55VnreJmW+jz5mKWZT7zy\nEwWKyU8UKCY/UaCY/ESBYvITBarNu/0i0gXAVwBKoue/p6oviMhQAEsA9AawHsAjqmrPOEmYdwe1\nUyf7d14m/by75RMmTDBj5eXlZmzHjh1mzGOtdde3b1+zj3eXvaqqyox5d+6Li4tj2zt37mz26dKl\nixnz+nkx63uW6d3+jnZH35POlb8JwAxV/R1S23HPFJHJAP4M4DVVHQHgNwBz8zdMIsq1NpNfU05G\nHxZH/xTADADvRe2LAdyblxESUV6k9ZpfRIqiHXobAKwCsBPAUVW9OCG9HoC9ZSsRXXHSSn5VbVHV\nCQAGAJgEYHS6BxCReSJSKyK13hrwRJSsdt3tV9WjAL4AMAVATxG5eKdrAIDYDdlVdYGq1qhqTWVl\nZVaDJaLcaTP5RaRSRHpGj7sC+D2AbUj9EvjH6GmPAvgwX4MkotxLZ2JPNYDFIlKE1C+LZar6PyKy\nFcASEfk3AN8DWJjHceaUV3bxynZW2cgqawF++cqbGDNq1Cgzlkkpypv84pW9cl0W9Y7lncdMvi9A\nstt15WOLuHxqM/lVdROAG2PadyH1+p+IOiC+w48oUEx+okAx+YkCxeQnChSTnyhQkmR5QkQaAeyJ\nPqwA8EtiB7dxHJfiOC7V0cYxWFXTejddosl/yYFFalW1piAH5zg4Do6Df/YThYrJTxSoQib/ggIe\nuzWO41Icx6Wu2nEU7DU/ERUW/+wnClRBkl9EZorI/4nIDhF5phBjiMZRJyKbReQHEalN8LiLRKRB\nRLa0aisXkVUi8nP0f68CjeNFEdkfnZMfRGRWAuMYKCJfiMhWEflRRJ6M2hM9J844Ej0nItJFRNaK\nyMZoHC9F7UNFZE2UN0tFxF65NB2qmug/AEVILQM2DEBnABsBjEl6HNFY6gBUFOC4twGYCGBLq7Z/\nB/BM9PgZAH8u0DheBPDPCZ+PagATo8elALYDGJP0OXHGkeg5ASAAekSPiwGsATAZwDIAD0bt/wXg\nn7I5TiGu/JMA7FDVXZpa6nsJgHsKMI6CUdWvABy5rPkepBZCBRJaENUYR+JU9aCqbogen0BqsZj+\nSPicOONIlKbkfdHcQiR/fwD7Wn1cyMU/FcDfRWS9iMwr0BguqlLVg9HjQwDsBfPz7wkR2RS9LMj7\ny4/WRGQIUutHrEEBz8ll4wASPidJLJob+g2/qao6EcA/AHhcRG4r9ICA1G9+pH4xFcIbAIYjtUfD\nQQCvJnVgEekBYDmAp1T1eOtYkuckZhyJnxPNYtHcdBUi+fcDGNjqY3Pxz3xT1f3R/w0APkBhVyY6\nLCLVABD931CIQajq4egH7wKAN5HQORGRYqQS7i1VfT9qTvycxI2jUOckOna7F81NVyGSfx2AkdGd\ny84AHgSwIulBiEh3ESm9+BjAXQC2+L3yagVSC6ECBVwQ9WKyRe5DAudEUgvjLQSwTVXntwolek6s\ncSR9ThJbNDepO5iX3c2chdSd1J0A/qVAYxiGVKVhI4AfkxwHgHeQ+vPxPFKv3eYitefh5wB+BvAZ\ngPICjeNvADYD2IRU8lUnMI6pSP1JvwnAD9G/WUmfE2cciZ4TADcgtSjuJqR+0fxrq5/ZtQB2AHgX\nQEk2x+E7/IgCFfoNP6JgMfmJAsXkJwoUk58oUEx+okAx+YkCxeQnChSTnyhQ/w8QHD824eMvegAA\nAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x10b0843d0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# Grayscale chicken image\n", | |
"from scipy.misc import imread, imresize, imsave\n", | |
"img = imread('chicken.png', flatten=True) # Convert to grayscale\n", | |
"imsave('chicken_bw.png', img)\n", | |
"img = imresize(img, (32, 32))\n", | |
"imsave('chicken_bw_small.png', img)\n", | |
"img = img / 255.\n", | |
"plt.imshow(img, cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Pass image through CNN\n", | |
"chicken_tensor = torch.Tensor(np.expand_dims(np.expand_dims(img, 0), 0))\n", | |
"probs = cnn(Variable(chicken_tensor))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[ 5.21423068e-26 5.14779191e-16 8.10420988e-06 3.05335366e-16\n", | |
" 6.03861604e-31 9.99991894e-01 1.32513748e-12 2.47835812e-22\n", | |
" 8.77226292e-10 4.37692135e-28]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(probs.data[0].numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"bad_fives = []\n", | |
"for ex in low_conf:\n", | |
" if ex[1] == 5:\n", | |
" bad_fives.append(ex)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 114, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x10c064a10>" | |
] | |
}, | |
"execution_count": 114, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAEitJREFUeJzt3XmMVWWax/HvYwmCgBRriYiDC2bA\nZpFUcO+0jXSUmKjJaNSM8Q/TdCZtgonGuCTTzPyjPRk1/uUEl7RNxGVcIppWKFfQIMqiAg00SzBN\nCZSNFousVTzzxz1kCnOfU7eq7lLl+/skpG69z33rvBz41bn3vPe8x9wdEUnPKbUegIjUhsIvkiiF\nXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJ1Kk96Wxm1wJPAnXAM+7+aCfP18cJRSrM3a2U51l3\nP95rZnXA34BZwA7gC+A2d/9rTh+FX6TCSg1/T172zwC2uPs2dz8KvATc0IOfJyJV1JPwjwX+3uH7\nHVmbiPQBPXrPXwozmwPMqfR2RKRrehL+ZmBch+/PztpO4u7zgfmg9/wivUlPXvZ/AUwws3PNrD9w\nK7CoPMMSkUrr9pHf3dvM7G5gMYWpvufcfX3ZRiYiFdXtqb5ubUwv+0UqrhpTfSLShyn8IolS+EUS\npfCLJErhF0mUwi+SKIVfJFEKv0iiFH6RRCn8IolS+EUSpfCLJErhF0mUwi+SKIVfJFEKv0iiFH6R\nRCn8Iomq+NLdPyennFL8d2VDQ0PY59JLLw1rZ555Zlg7dOhQWNuzZ09YO3jwYNH2urq6sE99fX1Y\ny/u7Rfsjz8aNG8PamjVrwlpLS0uXtyX5dOQXSZTCL5IohV8kUQq/SKIUfpFEKfwiierRVJ+ZbQf2\nA+1Am7s3lmNQvVU0XTZ58uSwz/333x/WpkyZEtb27dsX1r755puw1traWrS9f//+YZ8xY8aEtQkT\nJoS1vCnHAwcOFG1fuHBh2Gfr1q1hTVN95VeOef6r3f0fZfg5IlJFetkvkqieht+BJWa2yszmlGNA\nIlIdPX3Zf6W7N5vZaKDJzDa6+9KOT8h+KegXg0gv06Mjv7s3Z19bgDeAGUWeM9/dG3/uJwNF+ppu\nh9/MBpnZkBOPgd8A68o1MBGprJ687G8A3jCzEz9nobu/W5ZR9TFDhw4Na3nTaAMGDAhrR44cCWtn\nnXVWl7eX/TsVdfz48bC2efPmsLZs2bIu9/vggw/CPnlTmFJ+3Q6/u28DppZxLCJSRZrqE0mUwi+S\nKIVfJFEKv0iiFH6RRGkBzy6IpsTyFp7MW4gzb/ptyZIlYa2pqSmstbe3h7VIW1tbWFu9enVY27Zt\nW1g7fPhwl8ch1aUjv0iiFH6RRCn8IolS+EUSpfCLJEpn+7sgOpOed9Y77wIddw9r0W23IH92Yd26\n8l5YmTd7kHdBkPR+OvKLJErhF0mUwi+SKIVfJFEKv0iiFH6RRGmqrwzyprw2bNgQ1s4555ywdskl\nl4S15cuXh7W8C3FEOtKRXyRRCr9IohR+kUQp/CKJUvhFEqXwiySq06k+M3sOuB5ocfdfZG3DgZeB\n8cB24BZ3/6Fyw+y7Fi1aFNYmTpwY1s4999ywNmnSpLA2YsSIou179uwJ+0iaSjny/wm49idtDwDv\nu/sE4P3sexHpQzoNv7svBb7/SfMNwPPZ4+eBG8s8LhGpsO6+529w953Z410U7tgrIn1Ijz/e6+5u\nZuGSNGY2B5jT0+2ISHl198i/28zGAGRfW6Inuvt8d29098ZubktEKqC74V8E3Jk9vhN4szzDEZFq\nKWWq70XgV8BIM9sB/AF4FHjFzO4CvgFuqeQg+7KPPvoorF1//fVhbfz48WGtoSE+xTJq1Kii7Zrq\nk5/qNPzufltQmlnmsYhIFekTfiKJUvhFEqXwiyRK4RdJlMIvkigt4FlhO3fuDGubN28Oa3v37g1r\nedOAkydPLtq+cePGsI+kSUd+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkihN9dXQ+vXrw1pzc3NYu/DC\nC8PazJnFr7dqamoK+7S2toa17qqvry/aPnLkyLDPWWedFdaGDRsW1tzDtWQ4fPhw0fZVq1aFffL2\nR3t7e1jra3TkF0mUwi+SKIVfJFEKv0iiFH6RROlsfw0tXbo0rF133XVh7bLLLgtrF110UdH2vFt8\nrVmzJqwNGTKkW7Xp06cXbb/88svDPldffXVYmzp1alhra2sLay0txReWfvDBB8M+ixcvDmvfffdd\nWDt+/HhY64105BdJlMIvkiiFXyRRCr9IohR+kUQp/CKJKuV2Xc8B1wMt7v6LrG0e8FvgxLzHQ+7+\nl0oN8udq9+7dYe2HH34Ia3kXl5xxxhlF26MpQAAzC2uzZs0Ka3lTc1OmTCnaPmjQoLBP3pTdgQMH\nwlrehT2jR48u2v7YY4+FfY4cORLW8i6Qyvs3yxtjrZRy5P8TcG2R9ifcfVr2R8EX6WM6Db+7LwW+\nr8JYRKSKevKe/24z+9rMnjOz+GJrEemVuhv+p4DzgWnATiB8A2Vmc8xspZmt7Oa2RKQCuhV+d9/t\n7u3ufhx4GpiR89z57t7o7o3dHaSIlF+3wm9mYzp8exOwrjzDEZFqsc6mIMzsReBXwEhgN/CH7Ptp\ngAPbgd+5e3xfqv//Wb1vvqOXmjt3brdq48aNK9p+9OjRsE/eFFv//v3DWl1dXVg79dTis8h56xYu\nWbIkrH3yySfdGscjjzxStD3vlmfRlYAA9913X1h76623wlreVGW5uXs8d9tBp/P87n5bkeZnuzwi\nEelV9Ak/kUQp/CKJUvhFEqXwiyRK4RdJlBbw7KU+/fTTsNbYGH9e6vbbby/aPnDgwLBP3sKT+/fv\nD2ufffZZWIum7fJuk7V9+/awlnfF3CmnxMewHTt2FG1/+OGHwz5XXnllWIv2L8C3334b1j7++OOw\nVis68oskSuEXSZTCL5IohV8kUQq/SKIUfpFEaaqvl4quigPo169fWIsW49y7d2/Y5+233w5rr776\nalhrbm4Oa9G0V96UXd7CmXnTkXkLkK5evbpo+7x588I+CxcuDGt59xqcPHlyWFuxYkVYO3z4cFir\nJB35RRKl8IskSuEXSZTCL5IohV8kUTrb30uNHDkyrI0YMSKsRevxRRe4ACxYsCCsffjhh2Ht2LFj\nYa2a8tahjNYuzLvA6ODBg2Gtvr4+rE2cODGsnX322WFty5YtYa2SdOQXSZTCL5IohV8kUQq/SKIU\nfpFEKfwiiep0qs/MxgF/Bhoo3J5rvrs/aWbDgZeB8RRu2XWLu8dXbUiX5F3kkreuXnSxzfLly8M+\nn3/+eVjLu5VXXxCt79fQ0BD2ybuoKk/e1Gdv3I+lHPnbgHvdfRJwKfB7M5sEPAC87+4TgPez70Wk\nj+g0/O6+091XZ4/3AxuAscANwPPZ054HbqzUIEWk/Lr0nt/MxgMXAyuAhg535t1F4W2BiPQRJb+5\nMbPBwGvAPe6+r+MCCu7u0e23zWwOMKenAxWR8irpyG9m/SgE/wV3fz1r3m1mY7L6GKDoTc3dfb67\nN7p7fKcJEam6TsNvhUP8s8AGd3+8Q2kRcGf2+E7gzfIPT0QqpZSX/VcAdwBrzezLrO0h4FHgFTO7\nC/gGuKUyQ+zbTj/99LB24YUXhrW8K8vefffdsLZ58+ai7e+9917Yp7W1Naz1ddG03axZs8I+Q4cO\nDWt56+1t2rQprOXdiqxWOg2/u38CRCskzizvcESkWvQJP5FEKfwiiVL4RRKl8IskSuEXSZQW8Kyw\n8847L6zde++9Ye3jjz8Oay+99FJYa29vL9r+448/hn36urzbl40bN65o+x133BH2GTZsWFjbunVr\nWGtpKfo5t15LR36RRCn8IolS+EUSpfCLJErhF0mUwi+SKE31Vdj06dPD2rRp08Ja3tRc3mKc0VV9\n/fv3D/vkOX78eFiLFseE/PvnRTouENOVbeXdB2/u3LlF26+44oqwT96+yruicu3atWGtN9KRXyRR\nCr9IohR+kUQp/CKJUvhFEqWz/RU2duzYsJa3vt8111zTre199tlnXe6Td0Z/27ZtYS26aAZgz549\nRdvz/s719fVhLe/2WjNnxqvJXXXVVUXb887oL1u2LKy98847Ya03rtOXR0d+kUQp/CKJUvhFEqXw\niyRK4RdJlMIvkqhOp/rMbBzwZwq34HZgvrs/aWbzgN8C32VPfcjd/1KpgfZVeRfotLW1hbXx48eH\ntZtvvjmszZ49u2h73kUzeVN9x44dC2vRrbAgXksw7wKdurq6sJa3Tt+gQYPC2pEjR4q2L1iwIOzz\nzDPPhLX169eHtbx91RuVMs/fBtzr7qvNbAiwysyastoT7v7flRueiFRKKffq2wnszB7vN7MNQPzJ\nFRHpE7r0nt/MxgMXAyuyprvN7Gsze87M4vWORaTXKTn8ZjYYeA24x933AU8B5wPTKLwyeCzoN8fM\nVprZyjKMV0TKpKTwm1k/CsF/wd1fB3D33e7e7u7HgaeBGcX6uvt8d29098ZyDVpEeq7T8FvhNPGz\nwAZ3f7xD+5gOT7sJWFf+4YlIpZRytv8K4A5grZl9mbU9BNxmZtMoTP9tB35XkRH2cYsXLw5ro0aN\nCmtTp04Na3lXpA0cOLBo+/Dhw8M+eevt7dq1K6zljX/o0KFF2/fv3x/2ia4EhHjKDmDv3r1hramp\nqWh73jqIW7Zs6dY4+ppSzvZ/AhSbJNacvkgfpk/4iSRK4RdJlMIvkiiFXyRRCr9Ioqw7t1Xq9sbM\nqrexXiLvyrcLLrggrI0ePTqs5V39dtpppxVtHzx4cNgn7/9Aa2trWDvjjDPC2oABA4q2Hz16NOxz\n4MCBsJbX79ChQ2Ft06ZNRdvz/l7VzEQluHt8CWcHOvKLJErhF0mUwi+SKIVfJFEKv0iiFH6RRGmq\nT+RnRlN9IpJL4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkiiFXyRRpdyr\nb4CZfW5mX5nZejP7j6z9XDNbYWZbzOxlM4vvISUivU4pR/4jwK/dfSqF23Ffa2aXAn8EnnD3C4Af\ngLsqN0wRKbdOw+8FJ5ZV7Zf9ceDXwKtZ+/PAjRUZoYhUREnv+c2sLrtDbwvQBGwFWt29LXvKDmBs\nZYYoIpVQUvjdvd3dpwFnAzOAfy51A2Y2x8xWmtnKbo5RRCqgS2f73b0V+BC4DKg3sxN3pDgbaA76\nzHf3Rndv7NFIRaSsSjnbP8rM6rPHA4FZwAYKvwT+JXvancCblRqkiJRfp2v4mdkUCif06ij8snjF\n3f/TzM4DXgKGA2uAf3X3I538LK3hJ1Jhpa7hpwU8RX5mtICniORS+EUSpfCLJErhF0mUwi+SqFM7\nf0pZ/QP4Jns8Mvu+1jSOk2kcJ+tr4/inUn9gVaf6Ttqw2cre8Kk/jUPjSHUcetkvkiiFXyRRtQz/\n/BpuuyON42Qax8l+tuOo2Xt+EaktvewXSVRNwm9m15rZpmzxzwdqMYZsHNvNbK2ZfVnNxUbM7Dkz\nazGzdR3ahptZk5ltzr4Oq9E45plZc7ZPvjSz2VUYxzgz+9DM/potEjs3a6/qPskZR1X3SdUWzXX3\nqv6hcGnwVuA8oD/wFTCp2uPIxrIdGFmD7f4SmA6s69D2X8AD2eMHgD/WaBzzgPuqvD/GANOzx0OA\nvwGTqr1PcsZR1X0CGDA4e9wPWAFcCrwC3Jq1/w/wbz3ZTi2O/DOALe6+zd2PUlgT4IYajKNm3H0p\n8P1Pmm+gsG4CVGlB1GAcVefuO919dfZ4P4XFYsZS5X2SM46q8oKKL5pbi/CPBf7e4ftaLv7pwBIz\nW2Vmc2o0hhMa3H1n9ngX0FDDsdxtZl9nbwsq/vajIzMbD1xM4WhXs33yk3FAlfdJNRbNTf2E35Xu\nPh24Dvi9mf2y1gOCwm9+Cr+YauEp4HwK92jYCTxWrQ2b2WDgNeAed9/XsVbNfVJkHFXfJ96DRXNL\nVYvwNwPjOnwfLv5Zae7enH1tAd6gsJNrZbeZjQHIvrbUYhDuvjv7j3cceJoq7RMz60chcC+4++tZ\nc9X3SbFx1GqfZNvu8qK5papF+L8AJmRnLvsDtwKLqj0IMxtkZkNOPAZ+A6zL71VRiygshAo1XBD1\nRNgyN1GFfWJmBjwLbHD3xzuUqrpPonFUe59UbdHcap3B/MnZzNkUzqRuBR6u0RjOozDT8BWwvprj\nAF6k8PLxGIX3bncBI4D3gc3Ae8DwGo1jAbAW+JpC+MZUYRxXUnhJ/zXwZfZndrX3Sc44qrpPgCkU\nFsX9msIvmn/v8H/2c2AL8L/AaT3Zjj7hJ5Ko1E/4iSRL4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJ\nUvhFEvV/qeJPz11SHToAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x10bc4e210>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.imshow(bad_fives[3][0][0, 0, :, :].numpy(), cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 127, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" | |
], | |
"text/vnd.plotly.v1+html": [ | |
"<script>requirejs.config({paths: { 'plotly': ['https://cdn.plot.ly/plotly-latest.min']},});if(!window.Plotly) {{require(['plotly'],function(plotly) {window.Plotly=plotly;});}}</script>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import plotly\n", | |
"import plotly.graph_objs as go\n", | |
"import plotly.offline as py\n", | |
"from plotly import tools\n", | |
"\n", | |
"py.init_notebook_mode(connected=True)\n", | |
"\n", | |
"# Plot the softmaxes\n", | |
"def bar_plot(x, y, filename='bar_plot.html', interactive=True):\n", | |
" layout = go.Layout(xaxis=dict(title='Digit'), yaxis=dict(range=[0, 1], title='p(y|x)'), \n", | |
" width=800, height=400, font=dict(size=18))\n", | |
" fig = go.Figure(data=[go.Bar(x=x, y=y)], layout=layout)\n", | |
" if interactive:\n", | |
" py.iplot(fig)\n", | |
" else:\n", | |
" py.plot(fig, filename=filename, auto_open=False)\n", | |
"\n", | |
"bar_plot(range(10), probs.data[0].numpy(), 'true_softmax.html', interactive=False)\n", | |
"approx_unif = 5 + np.random.rand(10)\n", | |
"bar_plot(range(10), approx_unif / np.sum(approx_unif), 'expected_softmax.html', interactive=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Test that image of chicken is not just a fluke\n", | |
"import dataloaders\n", | |
"reload(dataloaders)\n", | |
"from dataloaders import get_fashion_mnist_dataloaders\n", | |
"_, fashion_loader = get_fashion_mnist_dataloaders(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 136, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Calculate confidence scores for fashion data\n", | |
"from torch.autograd import Variable\n", | |
"num_scores = 1e5\n", | |
"max_probs = []\n", | |
"max_prob_idx = []\n", | |
"for i, (img, _) in enumerate(fashion_loader):\n", | |
" probs = cnn(Variable(img))\n", | |
" max_prob, idx = torch.max(probs, 1)\n", | |
" max_prob_idx.append(idx)\n", | |
" max_prob = max_prob.data[0]\n", | |
" max_probs.append(max_prob)\n", | |
" if i >= num_scores:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 182, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.6347\n", | |
"0.7428\n", | |
"0.8889\n" | |
] | |
} | |
], | |
"source": [ | |
"# Which fraction of fashion mnist images get classified with more than 99% accuracy?\n", | |
"print(np.sum(np.array(max_probs) > 0.99) / float(len(max_probs)))\n", | |
"print(np.sum(np.array(max_probs) > 0.95) / float(len(max_probs)))\n", | |
"print(np.sum(np.array(max_probs) > 0.75) / float(len(max_probs)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 160, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Merge the two lists and sort them and see where different things end up\n", | |
"fashion_and_mnist_probs = np.array(max_probs + label_probs)\n", | |
"sorted_idx = np.argsort(fashion_and_mnist_probs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 161, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# sorted_idx = sorted_idx[::100]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 175, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"subsample = 200\n", | |
"x = range(len(sorted_idx))[::subsample]\n", | |
"y = fashion_and_mnist_probs[sorted_idx][::subsample]\n", | |
"color_bool = sorted_idx[::subsample] >= 10000 # 1 if it is mnist, 0 if it is fashion\n", | |
"colors = ['red' if b < 0.5 else 'blue' for b in color_bool]\n", | |
"\n", | |
"def big_bar_plot(x, y, colors=colors, filename='bar_plot.html', interactive=True):\n", | |
" #layout = go.Layout(xaxis=dict(title='Digit'), yaxis=dict(range=[0, 1], title='p(y|x)'), \n", | |
" # width=800, height=400, font=dict(size=18))\n", | |
" fig = go.Figure(data=[go.Bar(x=x, y=y, marker=dict(color=colors))])#, layout=layout)\n", | |
" if interactive:\n", | |
" py.iplot(fig)\n", | |
" else:\n", | |
" py.plot(fig, filename=filename, auto_open=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 176, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, False, False, False, False, False, False,\n", | |
" False, False, False, True, False, False, False, False, False,\n", | |
" False, True, True, False, False, False, False, False, True,\n", | |
" True, False, False, True, False, False, False, True, True,\n", | |
" True, True, True, True, True, True, True, True, True,\n", | |
" True, True, True, True, True, True, True, True, False,\n", | |
" True, True, True, True, True, True, True, True, True,\n", | |
" True, True, True, False, True, True, False, False, False,\n", | |
" False, False, False, False, True, True, True, True, True,\n", | |
" True, True, True, True, True, True, True, True, True, True], dtype=bool)" | |
] | |
}, | |
"execution_count": 176, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"color_bool" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 177, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.plotly.v1+json": { | |
"data": [ | |
{ | |
"marker": { | |
"color": [ | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"blue", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"blue", | |
"blue", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"blue", | |
"blue", | |
"red", | |
"red", | |
"blue", | |
"red", | |
"red", | |
"red", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"red", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"red", | |
"blue", | |
"blue", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"red", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue", | |
"blue" | |
] | |
}, | |
"type": "bar", | |
"x": [ | |
0, | |
200, | |
400, | |
600, | |
800, | |
1000, | |
1200, | |
1400, | |
1600, | |
1800, | |
2000, | |
2200, | |
2400, | |
2600, | |
2800, | |
3000, | |
3200, | |
3400, | |
3600, | |
3800, | |
4000, | |
4200, | |
4400, | |
4600, | |
4800, | |
5000, | |
5200, | |
5400, | |
5600, | |
5800, | |
6000, | |
6200, | |
6400, | |
6600, | |
6800, | |
7000, | |
7200, | |
7400, | |
7600, | |
7800, | |
8000, | |
8200, | |
8400, | |
8600, | |
8800, | |
9000, | |
9200, | |
9400, | |
9600, | |
9800, | |
10000, | |
10200, | |
10400, | |
10600, | |
10800, | |
11000, | |
11200, | |
11400, | |
11600, | |
11800, | |
12000, | |
12200, | |
12400, | |
12600, | |
12800, | |
13000, | |
13200, | |
13400, | |
13600, | |
13800, | |
14000, | |
14200, | |
14400, | |
14600, | |
14800, | |
15000, | |
15200, | |
15400, | |
15600, | |
15800, | |
16000, | |
16200, | |
16400, | |
16600, | |
16800, | |
17000, | |
17200, | |
17400, | |
17600, | |
17800, | |
18000, | |
18200, | |
18400, | |
18600, | |
18800, | |
19000, | |
19200, | |
19400, | |
19600, | |
19800 | |
], | |
"y": [ | |
0.3061355650424957, | |
0.5143805146217346, | |
0.5719403624534607, | |
0.6228542327880859, | |
0.6767367124557495, | |
0.7216569781303406, | |
0.7659032940864563, | |
0.8029114603996277, | |
0.8369762301445007, | |
0.8670023083686829, | |
0.8938217759132385, | |
0.9159621000289917, | |
0.9324735403060913, | |
0.9476518034934998, | |
0.9592661261558533, | |
0.9686580300331116, | |
0.9767616987228394, | |
0.982664942741394, | |
0.9874541759490967, | |
0.9906601309776306, | |
0.9931192994117737, | |
0.9949137568473816, | |
0.9965328574180603, | |
0.9975701570510864, | |
0.998376727104187, | |
0.998912513256073, | |
0.9992625713348389, | |
0.9994962811470032, | |
0.9996829628944397, | |
0.9997949004173279, | |
0.9998713731765747, | |
0.9999213218688965, | |
0.9999499320983887, | |
0.9999693632125854, | |
0.9999820590019226, | |
0.9999905824661255, | |
0.9999957084655762, | |
0.9999978542327881, | |
0.9999989867210388, | |
0.999999463558197, | |
0.9999997615814209, | |
0.9999998807907104, | |
0.9999999403953552, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1, | |
1 | |
] | |
} | |
], | |
"layout": {} | |
}, | |
"text/html": [ | |
"<div id=\"7961d549-988b-4b21-961e-e2afd1f88ee7\" style=\"height: 525px; width: 100%;\" class=\"plotly-graph-div\"></div><script type=\"text/javascript\">require([\"plotly\"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL=\"https://plot.ly\";Plotly.newPlot(\"7961d549-988b-4b21-961e-e2afd1f88ee7\", [{\"marker\": {\"color\": [\"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\"]}, \"y\": [0.3061355650424957, 0.5143805146217346, 0.5719403624534607, 0.6228542327880859, 0.6767367124557495, 0.7216569781303406, 0.7659032940864563, 0.8029114603996277, 0.8369762301445007, 0.8670023083686829, 0.8938217759132385, 0.9159621000289917, 0.9324735403060913, 0.9476518034934998, 0.9592661261558533, 0.9686580300331116, 0.9767616987228394, 0.982664942741394, 0.9874541759490967, 0.9906601309776306, 0.9931192994117737, 0.9949137568473816, 0.9965328574180603, 0.9975701570510864, 0.998376727104187, 0.998912513256073, 0.9992625713348389, 0.9994962811470032, 0.9996829628944397, 0.9997949004173279, 0.9998713731765747, 0.9999213218688965, 0.9999499320983887, 0.9999693632125854, 0.9999820590019226, 0.9999905824661255, 0.9999957084655762, 0.9999978542327881, 0.9999989867210388, 0.999999463558197, 0.9999997615814209, 0.9999998807907104, 0.9999999403953552, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], \"type\": \"bar\", \"x\": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800]}], {}, {\"linkText\": \"Export to plot.ly\", \"showLink\": true})});</script>" | |
], | |
"text/vnd.plotly.v1+html": [ | |
"<div id=\"7961d549-988b-4b21-961e-e2afd1f88ee7\" style=\"height: 525px; width: 100%;\" class=\"plotly-graph-div\"></div><script type=\"text/javascript\">require([\"plotly\"], function(Plotly) { window.PLOTLYENV=window.PLOTLYENV || {};window.PLOTLYENV.BASE_URL=\"https://plot.ly\";Plotly.newPlot(\"7961d549-988b-4b21-961e-e2afd1f88ee7\", [{\"marker\": {\"color\": [\"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"blue\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"red\", \"blue\", \"blue\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"red\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\", \"blue\"]}, \"y\": [0.3061355650424957, 0.5143805146217346, 0.5719403624534607, 0.6228542327880859, 0.6767367124557495, 0.7216569781303406, 0.7659032940864563, 0.8029114603996277, 0.8369762301445007, 0.8670023083686829, 0.8938217759132385, 0.9159621000289917, 0.9324735403060913, 0.9476518034934998, 0.9592661261558533, 0.9686580300331116, 0.9767616987228394, 0.982664942741394, 0.9874541759490967, 0.9906601309776306, 0.9931192994117737, 0.9949137568473816, 0.9965328574180603, 0.9975701570510864, 0.998376727104187, 0.998912513256073, 0.9992625713348389, 0.9994962811470032, 0.9996829628944397, 0.9997949004173279, 0.9998713731765747, 0.9999213218688965, 0.9999499320983887, 0.9999693632125854, 0.9999820590019226, 0.9999905824661255, 0.9999957084655762, 0.9999978542327881, 0.9999989867210388, 0.999999463558197, 0.9999997615814209, 0.9999998807907104, 0.9999999403953552, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], \"type\": \"bar\", \"x\": [0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000, 10200, 10400, 10600, 10800, 11000, 11200, 11400, 11600, 11800, 12000, 12200, 12400, 12600, 12800, 13000, 13200, 13400, 13600, 13800, 14000, 14200, 14400, 14600, 14800, 15000, 15200, 15400, 15600, 15800, 16000, 16200, 16400, 16600, 16800, 17000, 17200, 17400, 17600, 17800, 18000, 18200, 18400, 18600, 18800, 19000, 19200, 19400, 19600, 19800]}], {}, {\"linkText\": \"Export to plot.ly\", \"showLink\": true})});</script>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"big_bar_plot(x, y, colors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class CNN(nn.Module): | |
def __init__(self): | |
super(CNN, self).__init__() | |
self.img_to_features = nn.Sequential( | |
nn.Conv2d(1, 16, (4, 4), stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(16, 32, (4, 4), stride=2, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(32, 64, (4, 4), stride=2, padding=1), | |
nn.ReLU() | |
) | |
self.features_to_probs = nn.Sequential( | |
nn.Linear(64 * 4 * 4, 10), | |
nn.Softmax() | |
) | |
def forward(self, x): | |
features = self.img_to_features(x) | |
probs = self.features_to_probs(features.view(features.size(0), -1)) | |
return probs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from dataloaders import get_mnist_dataloaders | |
from models import CNN | |
from torch.autograd import Variable | |
def eval_model(model, data): | |
correct = 0 | |
total = 0 | |
for imgs, labels in data: | |
imgs = Variable(imgs).cuda() | |
probs = model(imgs) | |
_, predicted = torch.max(probs.data, 1) | |
total += labels.size(0) | |
correct += (predicted.cpu() == labels).sum() | |
return float(correct) / total | |
def train_epoch(model, data): | |
for i, (imgs, labels) in enumerate(data): | |
imgs = Variable(imgs).cuda() | |
labels = Variable(labels).cuda() | |
optimizer.zero_grad() | |
probs = model(imgs) | |
loss = criterion(probs, labels) | |
loss.backward() | |
optimizer.step() | |
if (i+1) % 100 == 0: | |
print("Iteration: {}, Loss: {}".format(i, loss.data[0])) | |
# Get datasets | |
train_loader, test_loader = get_mnist_dataloaders(batch_size=100) | |
# Create model | |
cnn = CNN() | |
cnn.cuda() | |
# Loss and Optimizer | |
criterion = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) | |
# Train the Model | |
for epoch in range(10): | |
train_epoch(cnn, train_loader) | |
test_acc = eval_model(cnn, test_loader) | |
print("Epoch {}, Test Accuracy: {}\n".format(epoch + 1, test_acc)) | |
# Save the Trained Model | |
torch.save(cnn.state_dict(), 'mnist_cnn.pt') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment