Skip to content

Instantly share code, notes, and snippets.

@standarderror
Last active September 30, 2016 05:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save standarderror/27129036f98d8e478987fba93b28f0f2 to your computer and use it in GitHub Desktop.
Save standarderror/27129036f98d8e478987fba93b28f0f2 to your computer and use it in GitHub Desktop.
# choose images & plot the first one
im = allX[102:103]
plt.axis('off')
plt.imshow(im[0].astype('uint8'))
plt.gcf().set_size_inches(2, 2)
# run images through 1st conv layer
m2 = tflearn.DNN(conv_1, session=model.session)
yhat = m2.predict(im)
# slice off outputs for first image and plot
yhat_1 = array(yhat[0])
def vis_conv(v,ix,iy,ch,cy,cx, p = 0) :
v = np.reshape(v,(iy,ix,ch))
ix += 2
iy += 2
npad = ((1,1), (1,1), (0,0))
v = np.pad(v, pad_width=npad, mode='constant', constant_values=p)
v = np.reshape(v,(iy,ix,cy,cx))
v = np.transpose(v,(2,0,3,1)) #cy,iy,cx,ix
v = np.reshape(v,(cy*iy,cx*ix))
return v
# h_conv1 - processed image
ix = 64 # img size
iy = 64
ch = 32
cy = 4 # grid from channels: 32 = 4x8
cx = 8
v = vis_conv(yhat_1,ix,iy,ch,cy,cx)
plt.figure(figsize = (12,12))
plt.imshow(v,cmap="Greys_r",interpolation='nearest')
plt.axis('off');
## Acknowledgements @rgr on Stackoverflow, http://stackoverflow.com/a/35247876
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment