Exemple #1
0
solver_config.gamma = 0.0001
solver_config.power = 0.75
solver_config.max_iter = int(5e4)
solver_config.snapshot = 2000
solver_config.snapshot_prefix = 'net'
solver_config.display = 1

# Set devices
# pygt.caffe.enumerate_devices(False)
pygt.caffe.set_devices((options.train_device, options.test_device))

solverstates = pygt.getSolverStates(solver_config.snapshot_prefix)

# First training method
if (len(solverstates) == 0 or solverstates[-1][0] < solver_config.max_iter):
    solver, test_net = pygt.init_solver(solver_config, options)
    if (len(solverstates) > 0):
        solver.restore(solverstates[-1][1])
    pygt.train(solver, test_net, [dataset], [test_dataset], options)

solverstates = pygt.getSolverStates(solver_config.snapshot_prefix)

# Second training method
if (solverstates[-1][0] >= solver_config.max_iter):
    # Modify some solver options
    solver_config.max_iter = int(3e5)
    solver_config.train_net = 'net_train_malis.prototxt'
    options.loss_function = 'malis'
    # Initialize and restore solver
    solver, test_net = pygt.init_solver(solver_config, options)
    if (len(solverstates) > 0):
Exemple #2
0
solver_config.weight_decay = 0.000005
solver_config.lr_policy = 'inv'
solver_config.gamma = 0.0001
solver_config.power = 0.75
solver_config.max_iter = 8000
solver_config.snapshot = 2000
solver_config.snapshot_prefix = 'net'
solver_config.type = 'Adam'
solver_config.display = 1

# Set devices
# pygt.caffe.enumerate_devices(False)
pygt.caffe.set_devices((options.train_device,))


solverstates = pygt.getSolverStates(solver_config.snapshot_prefix);

# First training method
if (len(solverstates) == 0 or solverstates[-1][0] < solver_config.max_iter):
    solver, test_net = pygt.init_solver(solver_config, options)
    if (len(solverstates) > 0):
        solver.restore(solverstates[-1][1])
    pygt.train(solver, test_net, datasets, [], options)
    






Exemple #3
0
    def train(args):
        print('training...')
        train_dataset, test_dataset = Data.get(args.data_path,
                                               args.seg_path,
                                               args.data_name,
                                               args.seg_name,
                                               augment=(args.augment == 1),
                                               transform=(args.transform == 1))

        # Set solver options
        print('Initializing solver...')
        solver_config = pygt.caffe.SolverParameter()
        solver_config.train_net = 'net.prototxt'

        solver_config.type = 'Adam'
        solver_config.base_lr = 1e-4
        solver_config.momentum = 0.99
        solver_config.momentum2 = 0.999
        solver_config.delta = 1e-8
        solver_config.weight_decay = 0.000005
        solver_config.lr_policy = 'inv'
        solver_config.gamma = 0.0001
        solver_config.power = 0.75

        solver_config.max_iter = 100000  #nt(2.0e5)
        solver_config.snapshot = int(2000)
        solver_config.snapshot_prefix = 'net'
        solver_config.display = 0

        # Set devices
        print('Setting devices...')
        print(tuple(set((args.train_device, args.test_device))))
        pygt.caffe.enumerate_devices(False)
        #pygt.caffe.set_devices(tuple(set((args.train_device, args.test_device))))
        pygt.caffe.set_devices((int(args.train_device), ))

        #pygt.caffe.set_mode_gpu()
        #pygt.caffe.set_device(args.train_device)

        print('devices set...')
        options = TrainOptions()
        options.train_device = args.train_device
        options.test_device = args.test_device

        # First training method
        solver_config.train_state.add_stage('euclid')
        solverstates = pygt.getSolverStates(solver_config.snapshot_prefix)

        print('solver_config.max_iter:', solver_config.max_iter)
        if (len(solverstates) == 0
                or solverstates[-1][0] < solver_config.max_iter):
            options.loss_function = 'euclid'
            solver, test_net = pygt.init_solver(solver_config, options)
            if (len(solverstates) > 0):
                print('restoring...', solverstates[-1][1])
                solver.restore(solverstates[-1][1])

            print('euclidean training...')
            pygt.train(solver, test_net, train_dataset, test_dataset, options)

        print('Second training method')
        # Second training method
        solver_config.train_state.set_stage(0, 'malis')
        solverstates = pygt.getSolverStates(solver_config.snapshot_prefix)
        print('solver_config.max_iter:', solver_config.max_iter)
        print('solverstates[-1][0]:', solverstates[-1][0])
        if (solverstates[-1][0] >= solver_config.max_iter):
            # Modify some solver options
            solver_config.max_iter = 300000
            options.loss_function = 'malis'
            # Initialize and restore solver
            solver, test_net = pygt.init_solver(solver_config, options)
            if (len(solverstates) > 0):
                solver.restore(solverstates[-1][1])

            print('malis training...')
            pygt.train(solver, test_net, train_dataset, test_dataset, options)