Ejemplo n.º 1
0
    def test_solver_debug(self):

        solver = bct.CaffeSolver(debug=True)
        self.assertEqual(solver.sp['test_iter'], '1')

        solver = bct.CaffeSolver(debug=True, onlytrain=True)
        self.assertFalse('test_iter' in solver.sp.keys())
Ejemplo n.º 2
0
    def test_solver_copy(self):

        solver = bct.CaffeSolver()

        solver.write(osp.join(self.workdir, 'solver.prototxt'))
        solver2 = bct.CaffeSolver(empty=True)
        solver2.add_from_file(osp.join(self.workdir, 'solver.prototxt'))

        for k1, k2 in zip(sorted(solver.sp.keys()), sorted(solver2.sp.keys())):
            self.assertEqual(k1, k2)
            self.assertEqual(solver.sp[k1], solver2.sp[k2])

        # try again with the onlytrain mode
        solver = bct.CaffeSolver(onlytrain=True)

        solver.write(osp.join(self.workdir, 'solver.prototxt'))
        solver2 = bct.CaffeSolver(empty=True)
        solver2.add_from_file(osp.join(self.workdir, 'solver.prototxt'))

        for k1, k2 in zip(sorted(solver.sp.keys()), sorted(solver2.sp.keys())):
            self.assertEqual(k1, k2)
            self.assertEqual(solver.sp[k1], solver2.sp[k2])
Ejemplo n.º 3
0
    def test_solver_onlytrain(self):

        solver = bct.CaffeSolver(onlytrain=True)
        for k in solver.sp.keys():
            self.assertEqual(k.find('test'), -1)
Ejemplo n.º 4
0
    def test_solver_empty(self):

        solver = bct.CaffeSolver(empty=True)

        self.assertEqual(len(solver.sp.keys()), 0)
Ejemplo n.º 5
0
    def test_workdir_setup(self):

        solver = bct.CaffeSolver(debug=True)

        solver.write(osp.join(self.workdir, 'solver.prototxt'))

        n = caffe.NetSpec()
        n.data, n.label = L.ImageData(transform_param=dict(crop_size=224,
                                                           mean_value=128),
                                      source='../static/imlist.txt',
                                      batch_size=50,
                                      ntop=2)
        net = vgg_core(n, learn=True)

        net.score = L.InnerProduct(net.fc7,
                                   num_output=2,
                                   param=[
                                       dict(lr_mult=5, decay_mult=1),
                                       dict(lr_mult=10, decay_mult=0)
                                   ])
        net.loss = L.SoftmaxWithLoss(net.score, n.label)

        with open(osp.join(self.workdir, 'trainnet.prototxt'), 'w') as w:
            w.write(str(net.to_proto()))

        with open(osp.join(self.workdir, 'testnet.prototxt'), 'w') as w:
            w.write(str(net.to_proto()))

        caffefile = '/runs/templates/VGG_ILSVRC_16_layers_initial.caffemodel'
        if osp.isfile(caffefile):
            shutil.copyfile(caffefile,
                            osp.join(self.workdir, 'initial.caffemodel'))

        bct.run(self.workdir, nbr_iters=3)

        self.assertTrue(osp.isfile(osp.join(self.workdir, 'train.log')))
        self.assertTrue(
            osp.isfile(osp.join(self.workdir, 'snapshot_iter_3.caffemodel')))

        caffemodel, iter_ = bct.find_latest_caffemodel(self.workdir)

        self.assertEqual(iter_, 3)
        net = bct.load_model(self.workdir,
                             caffemodel,
                             gpuid=0,
                             net_prototxt='testnet.prototxt',
                             phase=caffe.TEST)
        estlist, scorelist = bct.classify_from_datalayer(net,
                                                         n_testinstances=3,
                                                         batch_size=50,
                                                         scorelayer='score')

        self.assertEqual(len(scorelist), 3)
        self.assertEqual(len(estlist), 3)
        self.assertEqual(len(scorelist[0]), 2)

        img = np.asarray(Image.open('../static/bbc.jpg'))[:224, :224, :]
        imglist = []
        for itt in range(6):
            imglist.append(img)

        estlist, scorelist = bct.classify_from_imlist(imglist, net,
                                                      bct.Transformer(), 4)

        self.assertEqual(len(scorelist), 6)
        self.assertEqual(len(estlist), 6)
        self.assertEqual(len(scorelist[0]), 2)