Skip to content

Instantly share code, notes, and snippets.

Last active January 4, 2019 13:26
Show Gist options
  • Save axel-angel/b2af7d980eb217a0af07 to your computer and use it in GitHub Desktop.
Save axel-angel/b2af7d980eb217a0af07 to your computer and use it in GitHub Desktop.
Caffe script to compute accuracy and confusion matrix
# -*- coding: utf-8 -*-
# Author: Axel Angel, copyright 2015, license GPLv3.
import sys
import caffe
import numpy as np
import lmdb
import argparse
from collections import defaultdict
def flat_shape(x):
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
return x.reshape(filter(lambda s: s > 1, x.shape))
def lmdb_reader(fpath):
import lmdb
lmdb_env =
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
label = int(datum.label)
image =
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
label = int(datum.label)
image =
yield (key, flat_shape(image), label)
def npz_reader(fpath):
npz = np.load(fpath)
xs = npz['arr_0']
ls = npz['arr_1']
for i, (x, l) in enumerate(np.array([ xs, ls ]).T):
yield (i, x, l)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', type=str, required=True)
parser.add_argument('--model', type=str, required=True)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--lmdb', type=str, default=None)
group.add_argument('--leveldb', type=str, default=None)
group.add_argument('--npz', type=str, default=None)
args = parser.parse_args()
count = 0
correct = 0
matrix = defaultdict(int) # (real,pred) -> int
labels_set = set()
net = caffe.Net(args.proto, args.model, caffe.TEST)
print "args", vars(args)
if args.lmdb != None:
reader = lmdb_reader(args.lmdb)
if args.leveldb != None:
reader = leveldb_reader(args.leveldb)
if args.npz != None:
reader = npz_reader(args.npz)
for i, image, label in reader:
image_caffe = image.reshape(1, *image.shape)
out = net.forward_all(data=np.asarray([ image_caffe ]))
plabel = int(out['prob'][0].argmax(axis=0))
count += 1
iscorrect = label == plabel
correct += (1 if iscorrect else 0)
matrix[(label, plabel)] += 1
labels_set.update([label, plabel])
if not iscorrect:
print("\rError: i=%s, expected %i but predicted %i" \
% (i, label, plabel))
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count))
print(", %i/%i corrects" % (correct, count))
print ""
print "Confusion matrix:"
print "(r , p) | count"
for l in labels_set:
for pl in labels_set:
print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)])
Copy link

mtngld commented Jun 28, 2015


Where is flat_shape defined?

Copy link

same question !

File "", line 24, in lmdb_reader
yield (key, flat_shape(image), label)
NameError: global name 'flat_shape' is not defined

Copy link

try my modified script if you dont use the lmdb but use the training images txt file.

Copy link

How to use it for hdf5 files?? Pls help

Copy link

I've added flat_shape, it's just to remove empty dimensions.

Copy link

If you implement it, I'll happily update the gist above. Try something along the lines:

import h5py
def hdf5_reader(file_name):
 file = h5py.File(file_name, 'r') # open read-only
 group_name = file.keys[0] # try to find the first group
 group = file[group_name]
 for key, value in dict(group).iteritems():
        datum = caffe.proto.caffe_pb2.Datum()
        label = int(datum.label)
        image =
        yield key, flat_shape(image), label

Copy link

To cope with encoded images I extended the code above like this:

def getImage(datum):
    if datum.encoded:
        from cStringIO import StringIO
        import PIL
        s = StringIO(
        image = np.array(
        image =
    return image

def lmdb_reader(fpath):
    lmdb_env =
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()

    for key, value in lmdb_cursor:
        datum = caffe.proto.caffe_pb2.Datum()
        label = int(datum.label)
        image = getImage(datum)
        yield (key, flat_shape(image), label)

def leveldb_reader(fpath):
    import leveldb
    db = leveldb.LevelDB(fpath)

    for key, value in db.RangeIter():
        datum = caffe.proto.caffe_pb2.Datum()
        label = int(datum.label)
        image = getImage(datum)
        yield (key, flat_shape(image), label)

Copy link

@axel-angel, Don't you need to subtract the image mean like in the example here:
Also what about channel swap?

Copy link

@alex-angel, In addition images would need channel swap as well correct?

Copy link

I am working on LMDB database. when I am running this code I am getting error: argument --proto is required error. Please help.

Copy link

@monjoybme You need launch like this:
python ../src/ --proto lenet.prototxt --model snapshots/lenet_mnist_v3-id_iter_1000.caffemodel --lmdb ../caffe/examples/mnist/mnist_test_lmdb/

accord @axel-angel


Copy link

tringn commented Nov 19, 2018

@axel-angle, thanks for your amazing work.
I am using your script and I got stuck at:

I1119 17:07:53.463573 12920 net.cpp:283] Network initialization done.
args{'proto': 'test.prototxt', 'model': 'models/caffenet_age_train_iter_50000.caffemodel', 'lmdb': 'lmdb_full/age_test_lmdb/', 'leveldb': None, 'npz': None}
Traceback (most recent call last):
  File "", line 75, in <module>
    for i, image, label in reader:
  File "", line 28, in lmdb_reader
    yield (key, flat_shape(image), label)
  File "", line 15, in flat_shape
    return x.reshape(filter(lambda s: s > 1, x.shape))
TypeError: expected sequence object with len >= 0 or a single integer

I used python3 to run. Can u suggest me a solution? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment