Exemplo n.º 1
0
class Net():
    def __init__(self,
                 net,
                 snapshot_prefix,
                 dp_params,
                 preproc,
                 iter=None,
                 test=False):
        if iter is not None:
            self.N = caffe.Net(net,
                               snapshot_prefix + str(iter) + '.caffemodel',
                               caffe.TRAIN)
            self.iter = iter
        else:
            if test:
                self.N = caffe.Net(net, snapshot_prefix, caffe.TEST)
#                self.N = caffe.Net(net, caffe.TRAIN)
            else:
                self.N = caffe.Net(net, snapshot_prefix, caffe.TRAIN)
            self.iter = 0

        # Data provider
        self.dp = DataProvider(dp_params, preproc)
        self.bsize = self.dp.batch_size

        self.prevs = {}
        self.test = test

    def forward(self):
        ind = self.iter * self.bsize
        if self.test:
            _data, _labels = self.dp.get_batch_test(ind)
        else:
            _data, _labels = self.dp.get_batch(ind)

        # set data as input
        self.N.blobs['data'].data[...] = _data
        for label_key in _labels.keys():
            self.N.blobs[label_key].data[...] = _labels[label_key].reshape(
                self.N.blobs[label_key].data.shape)

        # Forward
#        t0 = time.time()
        out = self.N.forward()
        self.iter += 1
        return out

    def backward(self):
        self.N.backward()
        # update filter parameters
        #        t0 = time.time()
        for layer_name, lay in zip(self.N._layer_names, self.N.layers):
            for blobind, blob in enumerate(lay.blobs):
                diff = blob.diff[:]
                key = (layer_name, blobind)
                if key in self.prevs:
                    previous_change = self.prevs[key]
                else:
                    previous_change = 0

                lr = 0.01
                wd = 0.0005
                momentum = 0.9
                if blobind == 1:
                    lr = 2 * lr
                    wd = 0
                if lay.type == "BatchNorm":
                    lr = 0
                    wd = 0
                change = momentum * previous_change - lr * diff - lr * wd * blob.data[:]

                blob.data[:] += change
                self.prevs[key] = change

    def empty_diff(self):
        for layer_name, lay in zip(self.N._layer_names, self.N.layers):
            for blobind, blob in enumerate(lay.blobs):
                blob.diff[:] = 0
Exemplo n.º 2
0
def main(parser):
    # retrieve options
    (options, args) = parser.parse_args()
    net = options.net
    display = options.display
    base_lr = options.learning_rate
    lr_policy = options.lr_policy
    gamma = options.gamma
    stepsize = options.stepsize
    max_iter = options.max_iter
    momentum = options.momentum
    iter_size = options.iter_size
    weight_decay = options.weight_decay
    iter_snapshot = options.iter_resume
    snapshot = options.snapshot
    snapshot_prefix = options.snapshot_prefix
    test_interval = options.test_interval
    command_for_test = options.command_for_test
    weights = options.weights
    do_test = options.do_test
    dp_params = ast.literal_eval(options.dp_params)
    preproc = ast.literal_eval(options.preproc)

    # Data provider
    dp = DataProvider(dp_params, preproc)
    isize = dp.batch_size
    iter_size = dp.iter_size
    bsize = isize * iter_size

    # Load the net
    if weights is not None:
        print "Loading weights from " + weights
        N = caffe.Net(net, weights, caffe.TRAIN)
    else:
        if iter_snapshot==0:
            N = caffe.Net(net, caffe.TRAIN)
        else:
            print "Resuming training from " + snapshot_prefix + '_iter_' + str(iter_snapshot) + '.caffemodel'
            N = caffe.Net(net, snapshot_prefix + '_iter_' +  str(iter_snapshot) + '.caffemodel', caffe.TRAIN)

    # Save the net if training gets stopped
    import signal
    def sigint_handler(signal, frame):
        print 'Training paused...'
        print 'Snapshotting for iteration ' + str(iter)
        N.save(snapshot_prefix + '_iter_' +  str(iter) + '.caffemodel')
        sys.exit(0)
    signal.signal(signal.SIGINT, sigint_handler)

    # save weights before training
    if iter_snapshot == 0:
        print 'Snapshotting for initial weights'
        N.save(snapshot_prefix + '_iter_initial.caffemodel')

    # Start training
    loss = { key: 0 for key in N.outputs }
    prevs = {}
    for iter in range(iter_snapshot, max_iter):
        # clear param diffs
        for layer_name, lay in zip(N._layer_names, N.layers):
            for blobind, blob in enumerate(lay.blobs):
                blob.diff[:] = 0
        # clear loss  
        for k in loss.keys():
	    loss[k] = 0
       
        # update weights at every <iter_size> iterations 
        for i in range(iter_size):
	    # load data batch
            t0 = time.time()
            ind = iter * bsize + i * isize
            _data, _labels = dp.get_batch(ind)
            load_time = time.time() - t0

	    # set data as input
            N.blobs['data'].data[...] = _data
            for label_key in _labels.keys():
                N.blobs[label_key].data[...] = _labels[label_key].reshape(N.blobs[label_key].data.shape)

	    # Forward
#            t0 = time.time()
            out = N.forward()
#            forward_time = time.time()
	    # Backward
            N.backward()
#            backward_time = time.time() - forward_time
#            forward_time -= t0

            for k in out.keys():
                loss[k] += np.copy(out[k])

        # learning rate schedule
        if lr_policy == "step":
            learning_rate = base_lr * (gamma**(iter/stepsize))

        # print output loss
        print "Iteration", iter, "(lr: {0:.4f})".format( learning_rate )
        for k in np.sort(out.keys()):
            loss[k] /= iter_size
            print "Iteration", iter, ", ", k, "=", loss[k]
        sys.stdout.flush()

        # update filter parameters
#        t0 = time.time()
        for layer_name, lay in zip(N._layer_names, N.layers):
            for blobind, blob in enumerate(lay.blobs):
                diff = blob.diff[:]
                key = (layer_name, blobind)
                if key in prevs:
                    previous_change = prevs[key]
                else:
                    previous_change = 0.

		lr = learning_rate
		wd = weight_decay
#                if ("SpatialCorr" in lay.type):
#                    wd = 1000. * wd
		if (blobind == 1) and ("SpatialCorr" not in lay.type):
		    lr = 2. *lr
		    wd = 0.
		if lay.type == "BatchNorm":
		    lr = 0.
		    wd = 0.
                if ("SpatialCorr" in lay.type):
                    if "corronly" in net:
                        wd = 0.
                    else:
                        wd = 1.    # for training pos
		change = momentum * previous_change - lr * diff / iter_size - lr * wd * blob.data[:]

                blob.data[:] += change
		prevs[key] = change

#                # for debugging
#                if layer_name=='loss_scl':
#                    print "pos:", np.reshape(blob.data[:],-1)
#                    print "change:", np.reshape(change,-1)
#                    print "prev_change:", np.reshape(previous_change,-1)
#                    print "wd:", wd
#                    print "diff:", np.reshape(diff,-1)
#        update_time = time.time() - t0
#        print "loading: {0:.2f}, forward: {1:.2f}, backward: {2:.2f}, update: {3:.2f}".format(load_time, forward_time, backward_time, update_time)
        
        # save weights 
        if iter % snapshot == 0:
            print 'Snapshotting for iteration ' + str(iter)
            N.save(snapshot_prefix + '_iter_' +  str(iter) + '.caffemodel')

        # test on validation set
        if do_test and (iter % test_interval == 0):
            print 'Test for iteration ' + str(iter)
            weights_test = snapshot_prefix + '_iter_' +  str(iter) + '.caffemodel'
            if not os.path.exists(weights_test):
                N.save(snapshot_prefix + '_iter_' +  str(iter) + '.caffemodel')
            command_for_test_iter = command_for_test + ' test.sh ' + net + ' ' + weights_test + ' "' + options.dp_params + '" "' + options.preproc + '"'
            print command_for_test_iter
            subprocess.call(command_for_test_iter, shell=True)