Esempio n. 1
0
def load_network(loadfile, multiview=None, sort_layers=False):
    from convdata import DataProvider
    from python_util.gpumodel import IGPUModel

    load_dic = IGPUModel.load_checkpoint(loadfile)
    layers = load_dic['model_state']['layers']
    op = load_dic['op']

    if sort_layers:
        depths = get_depths(layers)
        layers = collections.OrderedDict(
            sorted(layers.items(), key=lambda item: depths[item[0]]))

    options = {}
    for o in load_dic['op'].get_options_list():
        options[o.name] = o.value

    dp_params = {}
    for v in ('color_noise', 'multiview_test', 'inner_size', 'scalar_mean',
              'minibatch_size'):
        dp_params[v] = options[v]

    lib_name = "cudaconvnet._ConvNet"
    print("Importing %s C++ module" % lib_name)
    libmodel = __import__(lib_name, fromlist=['_ConvNet'])
    dp_params['libmodel'] = libmodel

    if multiview is not None:
        dp_params['multiview_test'] = multiview

    dp = DataProvider.get_instance(options['data_path'],
                                   batch_range=options['test_batch_range'],
                                   type=options['dp_type'],
                                   dp_params=dp_params,
                                   test=True)

    epoch, batchnum, data = dp.get_next_batch()
    images, labels = data[:2]
    images = images.T
    images.shape = (images.shape[0], dp.num_colors, dp.inner_size,
                    dp.inner_size)
    labels.shape = (-1, )
    labels = labels.astype('int')
    assert images.shape[0] == labels.shape[0]

    if 1:
        rng = np.random.RandomState(8)
        i = rng.permutation(images.shape[0])
        images = images[i]
        labels = labels[i]

    data = [images, labels] + list(data[2:])
    # data['data_mean'] = dp.data_mean
    # data['label_names'] = dp.batch_meta['label_names']

    return layers, data, dp
Esempio n. 2
0
def plain_pickle(loadfile, savefile):
    load_dic = IGPUModel.load_checkpoint(loadfile)

    options = {}
    for o in load_dic['op'].get_options_list():
        options[o.name] = o.value
    load_dic['op'] = options

    with open(savefile, 'wb') as f:
        pickle.dump(load_dic, f, protocol=-1)
    print("Wrote %r" % savefile)
Esempio n. 3
0
""" View the options used to create a checkpoint

    python view_options.py --load-file <checkpoint>
"""
from convnet import ConvNet
from python_util.gpumodel import IGPUModel

op = ConvNet.get_options_parser()

op, load_dic = IGPUModel.parse_options(op)
model = ConvNet(op, load_dic)

model.op.print_values()
print "========================="
model.print_model_state()
Esempio n. 4
0
def loadbcp(name, shape, params=None):
    load_dic = IGPUModel.load_checkpoint(params[0])
    rows, cols = shape
    return load_dic['model_state']['layers'][params[1]]['biases'].reshape(rows,cols)
Esempio n. 5
0
def loadwcp(name, idx, shape, params=None):
    load_dic = IGPUModel.load_checkpoint(params[0])
    rows, cols = shape
    return load_dic['model_state']['layers'][params[1]]['weights'][idx].reshape(rows,cols)
Esempio n. 6
0
def loadbcp(name, shape, params=None):
    load_dic = IGPUModel.load_checkpoint(params[0])
    rows, cols = shape
    return load_dic['model_state']['layers'][params[1]]['biases'].reshape(
        rows, cols)
Esempio n. 7
0
def loadwcp(name, idx, shape, params=None):
    load_dic = IGPUModel.load_checkpoint(params[0])
    rows, cols = shape
    return load_dic['model_state']['layers'][
        params[1]]['weights'][idx].reshape(rows, cols)
Esempio n. 8
0
    if not donames:
        load_convnets = sys.argv[1:]  # just the paths
    else:
        load_convnets = sys.argv[1::2]
        legend_names = sys.argv[2::2]  # the paths and a shorter legend name

    print 'Loading convnets ' + ', '.join(load_convnets)

    for cost_idx in range(ncosts):
        pl.figure(cost_idx + basefig)
        if navg_groups > 1:
            allte = None
            alltr = None
        for cv, i in zip(load_convnets, range(len(load_convnets))):
            cp = IGPUModel.load_checkpoint(cv)
            numbatches, train_errors, test_errors = get_errors(
                cp,
                show_cost,
                cost_idx,
                norm=cost_idx == 0,
                smooth=smooth,
                interp_test=dointerp)
            #plot_errors(numbatches, train_errors, test_errors)
            tr = np.array(train_errors, dtype=np.double)
            te = np.array(test_errors, dtype=np.double)
            if navg_groups > 1:
                if allte is None:
                    allte = np.zeros(list(te.shape) + [len(load_convnets)],
                                     dtype=np.double)
                    alltr = np.zeros(list(tr.shape) + [len(load_convnets)],
Esempio n. 9
0
    def plot_predictions(self):
        epoch, batch, data = self.get_next_batch(train=False)  # get a test batch
        num_classes = self.test_data_provider.get_num_classes()
        NUM_ROWS = 2
        NUM_COLS = 4
        NUM_IMGS = NUM_ROWS * NUM_COLS if not self.save_preds else data[0].shape[1]
        NUM_TOP_CLASSES = min(num_classes, 5)  # show this many top labels
        NUM_OUTPUTS = self.model_state["layers"][self.softmax_name]["outputs"]
        PRED_IDX = 1

        label_names = [lab.split(",")[0] for lab in self.test_data_provider.batch_meta["label_names"]]
        if self.only_errors:
            preds = n.zeros((data[0].shape[1], NUM_OUTPUTS), dtype=n.single)
        else:
            preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single)
            # rand_idx = nr.permutation(n.r_[n.arange(1), n.where(data[1] == 552)[1], n.where(data[1] == 795)[1], n.where(data[1] == 449)[1], n.where(data[1] == 274)[1]])[:NUM_IMGS]
            rand_idx = nr.randint(0, data[0].shape[1], NUM_IMGS)
            if NUM_IMGS < data[0].shape[1]:
                data = [n.require(d[:, rand_idx], requirements="C") for d in data]
        #        data += [preds]
        # Run the model
        print [d.shape for d in data], preds.shape
        self.libmodel.startFeatureWriter(data, [preds], [self.softmax_name])
        IGPUModel.finish_batch(self)
        print preds
        data[0] = self.test_data_provider.get_plottable_data(data[0])

        if self.save_preds:
            if not gfile.Exists(self.save_preds):
                gfile.MakeDirs(self.save_preds)
            preds_thresh = preds > 0.5  # Binarize predictions
            data[0] = data[0] * 255.0
            data[0][data[0] < 0] = 0
            data[0][data[0] > 255] = 255
            data[0] = n.require(data[0], dtype=n.uint8)
            dir_name = "%s_predictions_batch_%d" % (os.path.basename(self.save_file), batch)
            tar_name = os.path.join(self.save_preds, "%s.tar" % dir_name)
            tfo = gfile.GFile(tar_name, "w")
            tf = TarFile(fileobj=tfo, mode="w")
            for img_idx in xrange(NUM_IMGS):
                img = data[0][img_idx, :, :, :]
                imsave = Image.fromarray(img)
                prefix = (
                    "CORRECT"
                    if data[1][0, img_idx] == preds_thresh[img_idx, PRED_IDX]
                    else "FALSE_POS"
                    if preds_thresh[img_idx, PRED_IDX] == 1
                    else "FALSE_NEG"
                )
                file_name = "%s_%.2f_%d_%05d_%d.png" % (
                    prefix,
                    preds[img_idx, PRED_IDX],
                    batch,
                    img_idx,
                    data[1][0, img_idx],
                )
                #                gf = gfile.GFile(file_name, "w")
                file_string = StringIO()
                imsave.save(file_string, "PNG")
                tarinf = TarInfo(os.path.join(dir_name, file_name))
                tarinf.size = file_string.tell()
                file_string.seek(0)
                tf.addfile(tarinf, file_string)
            tf.close()
            tfo.close()
            #                gf.close()
            print "Wrote %d prediction PNGs to %s" % (preds.shape[0], tar_name)
        else:
            fig = pl.figure(3, figsize=(12, 9))
            fig.text(0.4, 0.95, "%s test samples" % ("Mistaken" if self.only_errors else "Random"))
            if self.only_errors:
                # what the net got wrong
                if NUM_OUTPUTS > 1:
                    err_idx = [i for i, p in enumerate(preds.argmax(axis=1)) if p not in n.where(data[2][:, i] > 0)[0]]
                else:
                    err_idx = n.where(data[1][0, :] != preds[:, 0].T)[0]
                    print err_idx
                err_idx = r.sample(err_idx, min(len(err_idx), NUM_IMGS))
                data[0], data[1], preds = data[0][:, err_idx], data[1][:, err_idx], preds[err_idx, :]

            import matplotlib.gridspec as gridspec
            import matplotlib.colors as colors

            cconv = colors.ColorConverter()
            gs = gridspec.GridSpec(NUM_ROWS * 2, NUM_COLS, width_ratios=[1] * NUM_COLS, height_ratios=[2, 1] * NUM_ROWS)
            # print data[1]
            for row in xrange(NUM_ROWS):
                for col in xrange(NUM_COLS):
                    img_idx = row * NUM_COLS + col
                    if data[0].shape[0] <= img_idx:
                        break
                    pl.subplot(gs[(row * 2) * NUM_COLS + col])
                    # pl.subplot(NUM_ROWS*2, NUM_COLS, row * 2 * NUM_COLS + col + 1)
                    pl.xticks([])
                    pl.yticks([])
                    img = data[0][img_idx, :, :, :]
                    img = img.squeeze()
                    if len(img.shape) > 2:  # more than 2 dimensions
                        if img.shape[2] is 2:  # if two channels
                            # copy 2nd to 3rd channel for visualization
                            a1 = img
                            a2 = img[:, :, 1]
                            a2 = a2[:, :, n.newaxis]
                            img = n.concatenate((a1, a2), axis=2)
                        pl.imshow(img, interpolation="lanczos")
                    else:
                        pl.imshow(img, interpolation="lanczos", cmap=pl.gray())
                    show_title = data[1].shape[0] == 1
                    true_label = [int(data[1][0, img_idx])] if show_title else n.where(data[1][:, img_idx] == 1)[0]
                    # print true_label
                    # print preds[img_idx,:].shape
                    # print preds[img_idx,:].max()
                    true_label_names = [label_names[i] for i in true_label]
                    img_labels = sorted(zip(preds[img_idx, :], label_names), key=lambda x: x[0])[-NUM_TOP_CLASSES:]
                    # print img_labels
                    axes = pl.subplot(gs[(row * 2 + 1) * NUM_COLS + col])
                    height = 0.5
                    ylocs = n.array(range(NUM_TOP_CLASSES)) * height
                    pl.barh(
                        ylocs,
                        [l[0] for l in img_labels],
                        height=height,
                        color=["#ffaaaa" if l[1] in true_label_names else "#aaaaff" for l in img_labels],
                    )
                    # pl.title(", ".join(true_labels))
                    if show_title:
                        pl.title(", ".join(true_label_names), fontsize=15, fontweight="bold")
                    else:
                        print true_label_names
                    pl.yticks(
                        ylocs + height / 2,
                        [l[1] for l in img_labels],
                        x=1,
                        backgroundcolor=cconv.to_rgba("0.65", alpha=0.5),
                        weight="bold",
                    )
                    for line in enumerate(axes.get_yticklines()):
                        line[1].set_visible(False)
                    # pl.xticks([width], [''])
                    # pl.yticks([])
                    pl.xticks([])
                    pl.ylim(0, ylocs[-1] + height)
                    pl.xlim(0, 1)
Esempio n. 10
0
    def plot_predictions(self):
        epoch, batch, data = self.get_next_batch(
            train=False)  # get a test batch
        num_classes = self.test_data_provider.get_num_classes()
        NUM_ROWS = 2
        NUM_COLS = 4
        NUM_IMGS = NUM_ROWS * NUM_COLS if not self.save_preds else data[
            0].shape[1]
        NUM_TOP_CLASSES = min(num_classes, 5)  # show this many top labels
        NUM_OUTPUTS = self.model_state['layers'][self.softmax_name]['outputs']
        PRED_IDX = 1

        label_names = [
            lab.split(',')[0]
            for lab in self.test_data_provider.batch_meta['label_names']
        ]
        if self.only_errors:
            preds = n.zeros((data[0].shape[1], NUM_OUTPUTS), dtype=n.single)
        else:
            preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single)
            #rand_idx = nr.permutation(n.r_[n.arange(1), n.where(data[1] == 552)[1], n.where(data[1] == 795)[1], n.where(data[1] == 449)[1], n.where(data[1] == 274)[1]])[:NUM_IMGS]
            rand_idx = nr.randint(0, data[0].shape[1], NUM_IMGS)
            if NUM_IMGS < data[0].shape[1]:
                data = [
                    n.require(d[:, rand_idx], requirements='C') for d in data
                ]


#        data += [preds]
# Run the model
        print[d.shape for d in data], preds.shape
        self.libmodel.startFeatureWriter(data, [preds], [self.softmax_name])
        IGPUModel.finish_batch(self)
        print preds
        data[0] = self.test_data_provider.get_plottable_data(data[0])

        if self.save_preds:
            if not gfile.Exists(self.save_preds):
                gfile.MakeDirs(self.save_preds)
            preds_thresh = preds > 0.5  # Binarize predictions
            data[0] = data[0] * 255.0
            data[0][data[0] < 0] = 0
            data[0][data[0] > 255] = 255
            data[0] = n.require(data[0], dtype=n.uint8)
            dir_name = '%s_predictions_batch_%d' % (os.path.basename(
                self.save_file), batch)
            tar_name = os.path.join(self.save_preds, '%s.tar' % dir_name)
            tfo = gfile.GFile(tar_name, "w")
            tf = TarFile(fileobj=tfo, mode='w')
            for img_idx in xrange(NUM_IMGS):
                img = data[0][img_idx, :, :, :]
                imsave = Image.fromarray(img)
                prefix = "CORRECT" if data[1][0, img_idx] == preds_thresh[
                    img_idx, PRED_IDX] else "FALSE_POS" if preds_thresh[
                        img_idx, PRED_IDX] == 1 else "FALSE_NEG"
                file_name = "%s_%.2f_%d_%05d_%d.png" % (prefix, preds[
                    img_idx, PRED_IDX], batch, img_idx, data[1][0, img_idx])
                #                gf = gfile.GFile(file_name, "w")
                file_string = StringIO()
                imsave.save(file_string, "PNG")
                tarinf = TarInfo(os.path.join(dir_name, file_name))
                tarinf.size = file_string.tell()
                file_string.seek(0)
                tf.addfile(tarinf, file_string)
            tf.close()
            tfo.close()
            #                gf.close()
            print "Wrote %d prediction PNGs to %s" % (preds.shape[0], tar_name)
        else:
            fig = pl.figure(3, figsize=(12, 9))
            fig.text(
                .4, .95, '%s test samples' %
                ('Mistaken' if self.only_errors else 'Random'))
            if self.only_errors:
                # what the net got wrong
                if NUM_OUTPUTS > 1:
                    err_idx = [
                        i for i, p in enumerate(preds.argmax(axis=1))
                        if p not in n.where(data[2][:, i] > 0)[0]
                    ]
                else:
                    err_idx = n.where(data[1][0, :] != preds[:, 0].T)[0]
                    print err_idx
                err_idx = r.sample(err_idx, min(len(err_idx), NUM_IMGS))
                data[0], data[1], preds = data[0][:, err_idx], data[
                    1][:, err_idx], preds[err_idx, :]

            import matplotlib.gridspec as gridspec
            import matplotlib.colors as colors
            cconv = colors.ColorConverter()
            gs = gridspec.GridSpec(NUM_ROWS * 2,
                                   NUM_COLS,
                                   width_ratios=[1] * NUM_COLS,
                                   height_ratios=[2, 1] * NUM_ROWS)
            #print data[1]
            for row in xrange(NUM_ROWS):
                for col in xrange(NUM_COLS):
                    img_idx = row * NUM_COLS + col
                    if data[0].shape[0] <= img_idx:
                        break
                    pl.subplot(gs[(row * 2) * NUM_COLS + col])
                    #pl.subplot(NUM_ROWS*2, NUM_COLS, row * 2 * NUM_COLS + col + 1)
                    pl.xticks([])
                    pl.yticks([])
                    img = data[0][img_idx, :, :, :]
                    pl.imshow(img, interpolation='lanczos')
                    show_title = data[1].shape[0] == 1
                    true_label = [int(data[1][0, img_idx])
                                  ] if show_title else n.where(
                                      data[1][:, img_idx] == 1)[0]
                    #print true_label
                    #print preds[img_idx,:].shape
                    #print preds[img_idx,:].max()
                    true_label_names = [label_names[i] for i in true_label]
                    img_labels = sorted(zip(preds[img_idx, :], label_names),
                                        key=lambda x: x[0])[-NUM_TOP_CLASSES:]
                    #print img_labels
                    axes = pl.subplot(gs[(row * 2 + 1) * NUM_COLS + col])
                    height = 0.5
                    ylocs = n.array(range(NUM_TOP_CLASSES)) * height
                    pl.barh(ylocs, [l[0] for l in img_labels], height=height, \
                            color=['#ffaaaa' if l[1] in true_label_names else '#aaaaff' for l in img_labels])
                    #pl.title(", ".join(true_labels))
                    if show_title:
                        pl.title(", ".join(true_label_names),
                                 fontsize=15,
                                 fontweight='bold')
                    else:
                        print true_label_names
                    pl.yticks(ylocs + height / 2, [l[1] for l in img_labels],
                              x=1,
                              backgroundcolor=cconv.to_rgba('0.65', alpha=0.5),
                              weight='bold')
                    for line in enumerate(axes.get_yticklines()):
                        line[1].set_visible(False)
                    #pl.xticks([width], [''])
                    #pl.yticks([])
                    pl.xticks([])
                    pl.ylim(0, ylocs[-1] + height)
                    pl.xlim(0, 1)
Esempio n. 11
0
    avg_groups = ['warp','nowarp']
    navg_groups = len(avg_groups)
    
    if not donames:
        load_convnets = sys.argv[1:]    # just the paths
    else:
        load_convnets = sys.argv[1::2]; legend_names = sys.argv[2::2]    # the paths and a shorter legend name
    
    print 'Loading convnets ' + ', '.join(load_convnets)
    
    for cost_idx in range(ncosts):
        pl.figure(cost_idx + basefig)
        if navg_groups > 1: allte = None; alltr = None
        for cv,i in zip(load_convnets, range(len(load_convnets))):
            cp = IGPUModel.load_checkpoint(cv)
            numbatches, train_errors, test_errors = get_errors(cp, show_cost,cost_idx, norm=cost_idx == 0,
                smooth=smooth, interp_test=dointerp)
            #plot_errors(numbatches, train_errors, test_errors)
            tr = np.array(train_errors, dtype=np.double); te = np.array(test_errors, dtype=np.double)
            if navg_groups > 1:
                if allte is None:
                    allte = np.zeros(list(te.shape)+[len(load_convnets)], dtype=np.double)
                    alltr = np.zeros(list(tr.shape)+[len(load_convnets)], dtype=np.double)
                allte[:,i] = te; alltr[:,i] = tr
            else:
                if dolog: tr = np.log10(tr); te = np.log10(te)
                pl.plot(tr, te, label=os.path.basename(legend_names[i] if donames else os.path.normpath(cv)))
        if navg_groups > 1:
            groups = np.arange(len(load_convnets),dtype=np.int64)
            groups = groups.reshape([navg_groups,len(load_convnets)//navg_groups])
Esempio n. 12
0
    def plot_restored_img(self):
        epoch, batch, data = self.get_next_batch(train=False)
        NUM_ROWS = 8
        NUM_COLS = 4
        NUM_PATS = 3
        NUM_IMGS = NUM_ROWS * NUM_COLS if not self.save_preds else data[
            0].shape[1]
        NUM_OUTPUTS = self.model_state['layers'][self.output_name]['outputs']
        preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single)
        #        rand_idx = nr.randint(0, data[0].shape[1], NUM_IMGS)
        rand_idx = range(NUM_IMGS)
        if NUM_IMGS < data[0].shape[1]:
            data = [n.require(d[:, rand_idx], requirements='C') for d in data]

        print[d.shape for d in data], preds.shape
        self.libmodel.startFeatureWriter(data, [preds], [self.output_name])
        cost_outputs = IGPUModel.finish_batch(self)
        costs, num_cases = cost_outputs[0], cost_outputs[1]
        self.print_costs(cost_outputs)
        #        print preds
        #        data[0] = self.test_data_provider.get_plottable_data(data[0])
        #        preds = self.test_data_provider.get_plottable_data(preds.T)
        patch_size = int(n.sqrt(data[0].shape[0]))
        noisy_size = int(n.sqrt(data[1].shape[0]))

        if self.save_preds:
            pass
        else:
            fig = pl.figure(3, figsize=(12, 9))
            fig.text(.4, .95, '%s %d test samples' % ('First', NUM_IMGS))

            import matplotlib.gridspec as gridspec
            import matplotlib.colors as colors
            cconv = colors.ColorConverter()
            gs = gridspec.GridSpec(NUM_ROWS,
                                   NUM_COLS * NUM_PATS,
                                   width_ratios=[1] * NUM_COLS * NUM_PATS,
                                   height_ratios=[1] * NUM_ROWS)
            #print data[1]
            for row in xrange(NUM_ROWS):
                for col in xrange(NUM_COLS):
                    img_idx = row * NUM_COLS + col
                    if data[0].shape[0] <= img_idx:
                        break
                    pl.subplot(gs[(row * NUM_COLS + col) * NUM_PATS])
                    #pl.subplot(NUM_ROWS*2, NUM_COLS, row * 2 * NUM_COLS + col + 1)
                    pl.xticks([])
                    pl.yticks([])
                    img = data[0][:, img_idx]
                    pl.imshow(self.plottable_data(img, patch_size, patch_size),
                              interpolation='nearest')

                    pl.subplot(gs[(row * NUM_COLS + col) * NUM_PATS + 1])
                    pl.xticks([])
                    pl.yticks([])
                    r_img = preds[img_idx, :]
                    mse = r_img - img
                    mse = mse * mse
                    mse = mse.mean()
                    psnr = -10 * n.log10(mse)
                    pl.title(('%.4f' % psnr), fontsize=10)
                    pl.imshow(self.plottable_data(r_img, patch_size,
                                                  patch_size),
                              interpolation='nearest')

                    pl.subplot(gs[(row * NUM_COLS + col) * NUM_PATS + 2])
                    pl.xticks([])
                    pl.yticks([])
                    n_img = data[1][:, img_idx]
                    border = (noisy_size - patch_size) / 2
                    nc_img = n_img.reshape(noisy_size,
                                           noisy_size)[border:-border,
                                                       border:-border]
                    nc_img = nc_img.reshape(img.shape)
                    n_mse = nc_img - img
                    n_mse = n_mse * n_mse
                    n_mse = n_mse.mean()
                    n_psnr = -10 * n.log10(n_mse)
                    pl.title(('%.4f' % n_psnr), fontsize=10)
                    pl.imshow(self.plottable_data(n_img, noisy_size),
                              interpolation='nearest')
            pl.savefig('restored.png')
Esempio n. 13
0
 def plot_predictions(self):
     epoch, batch, data = self.get_next_batch(train=False) # get a test batch
     num_classes = self.test_data_provider.get_num_classes()
     NUM_ROWS = 2
     NUM_COLS = 4
     NUM_IMGS = NUM_ROWS * NUM_COLS if not self.save_preds else data[0].shape[1]
     NUM_TOP_CLASSES = min(num_classes, 5) # show this many top labels
     NUM_OUTPUTS = self.model_state['layers'][self.softmax_name]['outputs']
     PRED_IDX = 1
     if self.save_preds:
         # print preds
         if not os.path.exists(self.save_preds):   
             os.makedirs(self.save_preds)
         # we process all the batches
         while True:
             # some constant
             NUM_IMGS = data[0].shape[1]
             NUM_TOP_CLASSES = min(num_classes, 5) # show this many top labels
             NUM_OUTPUTS = self.model_state['layers'][self.softmax_name]['outputs']
             PRED_IDX = 1
             preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single)
             # we only save the prediction result instead of the image
             dir_name = 'predictions_batch_%d' % (batch)
             tar_name = os.path.join(self.save_preds, dir_name)
             # Run the model
             print  [d.shape for d in data], preds.shape
             self.libmodel.startFeatureWriter(data, [preds], [self.softmax_name])
             # in the mean while, prepare to load the next batch of data
             new_epoch, batch, new_data = self.get_next_batch(train=False)
             IGPUModel.finish_batch(self)
             # swap the data
             # concatenate the pred into the groud true label
             preds=n.concatenate((n.transpose(data[1]), preds), axis=1);
             tfo = open(tar_name, "wb");
             cPickle.dump(preds, tfo)
             tfo.close()
             print "Wrote %d prediction PNGs to %s" % (preds.shape[0], tar_name)
             if new_epoch!=epoch:
                 print "All batches process"
                 break;
             data=new_data
     else:
         label_names = [lab.split(',')[0] for lab in self.test_data_provider.batch_meta['label_names']]
         if self.only_errors:
             preds = n.zeros((data[0].shape[1], NUM_OUTPUTS), dtype=n.single)
         else:
             preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single)
             rand_idx = nr.randint(0, data[0].shape[1], NUM_IMGS)
             if NUM_IMGS < data[0].shape[1]:
                 data = [n.require(d[:,rand_idx], requirements='C') for d in data]
         # Run the model
         print  [d.shape for d in data], preds.shape
         self.libmodel.startFeatureWriter(data, [preds], [self.softmax_name])
         IGPUModel.finish_batch(self)
         # print preds
         data[0] = self.test_data_provider.get_plottable_data(data[0])
         fig = pl.figure(3, figsize=(12,9))
         fig.text(.4, .95, '%s test samples' % ('Mistaken' if self.only_errors else 'Random'))
         if self.only_errors:
             # what the net got wrong
             if NUM_OUTPUTS > 1:
                 err_idx = [i for i,p in enumerate(preds.argmax(axis=1)) if p not in n.where(data[2][:,i] > 0)[0]]
             else:
                 err_idx = n.where(data[1][0,:] != preds[:,0].T)[0]
                 print err_idx
             err_idx = r.sample(err_idx, min(len(err_idx), NUM_IMGS))
             data[0], data[1], preds = data[0][:,err_idx], data[1][:,err_idx], preds[err_idx,:]
             
         
         import matplotlib.gridspec as gridspec
         import matplotlib.colors as colors
         cconv = colors.ColorConverter()
         gs = gridspec.GridSpec(NUM_ROWS*2, NUM_COLS,
                                width_ratios=[1]*NUM_COLS, height_ratios=[2,1]*NUM_ROWS )
         #print data[1]
         for row in xrange(NUM_ROWS):
             for col in xrange(NUM_COLS):
                 img_idx = row * NUM_COLS + col
                 if data[0].shape[0] <= img_idx:
                     break
                 pl.subplot(gs[(row * 2) * NUM_COLS + col])
                 #pl.subplot(NUM_ROWS*2, NUM_COLS, row * 2 * NUM_COLS + col + 1)
                 pl.xticks([])
                 pl.yticks([])
                 img = data[0][img_idx,:,:,:]
                 pl.imshow(img, interpolation='lanczos')
                 show_title = data[1].shape[0] == 1
                 true_label = [int(data[1][0,img_idx])] if show_title else n.where(data[1][:,img_idx]==1)[0]
                 #print true_label
                 #print preds[img_idx,:].shape
                 #print preds[img_idx,:].max()
                 true_label_names = [label_names[i] for i in true_label]
                 img_labels = sorted(zip(preds[img_idx,:], label_names), key=lambda x: x[0])[-NUM_TOP_CLASSES:]
                 #print img_labels
                 axes = pl.subplot(gs[(row * 2 + 1) * NUM_COLS + col])
                 height = 0.5
                 ylocs = n.array(range(NUM_TOP_CLASSES))*height
                 pl.barh(ylocs, [l[0] for l in img_labels], height=height, \
                         color=['#ffaaaa' if l[1] in true_label_names else '#aaaaff' for l in img_labels])
                 #pl.title(", ".join(true_labels))
                 if show_title:
                     pl.title(", ".join(true_label_names), fontsize=15, fontweight='bold')
                 else:
                     print true_label_names
                 pl.yticks(ylocs + height/2, [l[1] for l in img_labels], x=1, backgroundcolor=cconv.to_rgba('0.65', alpha=0.5), weight='bold')
                 for line in enumerate(axes.get_yticklines()): 
                     line[1].set_visible(False) 
                 #pl.xticks([width], [''])
                 #pl.yticks([])
                 pl.xticks([])
                 pl.ylim(0, ylocs[-1] + height)
                 pl.xlim(0, 1)