Esempio n. 1
0
def semseg_train(network, exp_dir, exp_prefix, args):

    if args.cpu:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(0)

    if exp_prefix:
        snapshot_prefix = os.path.join(exp_dir, exp_prefix)
        exp_prefix = exp_prefix + '_'
    else:
        snapshot_prefix = os.path.join(exp_dir, 'snapshot')
        exp_prefix = ''

    if network == 'seq':
        batch_norm = True
        conv_weight_filler = 'xavier'
        network = models.semseg_seq(arch_str=args.arch,
                                    skip_str=args.skips,
                                    dataset=args.dataset,
                                    dataset_params=args.dataset_params,
                                    feat_dims_str=args.feat,
                                    lattice_dims_str=args.lattice,
                                    sample_size=args.sample_size,
                                    batch_size=args.batch_size,
                                    batchnorm=batch_norm,
                                    conv_weight_filler=conv_weight_filler,
                                    save_path=os.path.join(
                                        exp_dir, exp_prefix + 'net.prototxt'))

        models.semseg_seq(deploy=True,
                          arch_str=args.arch,
                          skip_str=args.skips,
                          dataset=args.dataset,
                          dataset_params=args.dataset_params,
                          feat_dims_str=args.feat,
                          lattice_dims_str=args.lattice,
                          sample_size=args.sample_size,
                          batchnorm=batch_norm,
                          save_path=os.path.join(
                              exp_dir, exp_prefix + 'net_deploy.prototxt'))
    else:
        assert network.endswith(
            '.prototxt'), 'Please provide a valid prototxt file'
        print('Using network defined at {}'.format(network))

    random_seed = 0
    debug_info = False
    solver = create_solver.standard_solver(network,
                                           network,
                                           snapshot_prefix,
                                           base_lr=args.base_lr,
                                           gamma=args.lr_decay,
                                           stepsize=args.stepsize,
                                           test_iter=args.test_iter,
                                           test_interval=args.test_interval,
                                           max_iter=args.num_iter,
                                           snapshot=args.snapshot_interval,
                                           solver_type=args.solver_type,
                                           weight_decay=args.weight_decay,
                                           iter_size=args.iter_size,
                                           debug_info=debug_info,
                                           random_seed=random_seed,
                                           save_path=os.path.join(
                                               exp_dir,
                                               exp_prefix + 'solver.prototxt'))
    solver = caffe.get_solver(solver)

    if args.init_model:
        if args.init_model.endswith('.solverstate'):
            solver.restore(args.init_model)
        elif args.init_model.endswith('.caffemodel'):
            solver.net.copy_from(args.init_model)
        else:
            raise ValueError('Invalid file: {}'.format(args.init_model))

    import pdb
    pdb.set_trace()
    solver.solve()
Esempio n. 2
0
def partseg_train_single_model(network, exp_dir, args):

    if args.cpu:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(0)

    if network == 'seq':
        batch_norm = True
        conv_weight_filler = 'xavier'
        network = models.partseg_seq_combined_categories(arch_str=args.arch,
                                                         skip_str=args.skips,
                                                         renorm_class=args.renorm_class,
                                                         dataset=args.dataset,
                                                         dataset_params=args.dataset_params,
                                                         feat_dims_str=args.feat,
                                                         lattice_dims_str=args.lattice,
                                                         sample_size=args.sample_size,
                                                         batch_size=args.batch_size,
                                                         batchnorm=batch_norm,
                                                         conv_weight_filler=conv_weight_filler,
                                                         save_path=os.path.join(exp_dir, 'net.prototxt'))

        models.partseg_seq_combined_categories(deploy=True,
                                               arch_str=args.arch,
                                               skip_str=args.skips,
                                               renorm_class=args.renorm_class,
                                               dataset=args.dataset,
                                               dataset_params=args.dataset_params,
                                               feat_dims_str=args.feat,
                                               lattice_dims_str=args.lattice,
                                               sample_size=args.sample_size,
                                               batchnorm=batch_norm,
                                               save_path=os.path.join(exp_dir, 'net_deploy.prototxt'))
    else:
        assert network.endswith('.prototxt'), 'Please provide a valid prototxt file'
        print('Using network defined at {}'.format(network))

    random_seed = 0
    debug_info = False
    solver = create_solver.standard_solver(network,
                                           network,
                                           os.path.join(exp_dir, 'snapshot'),
                                           base_lr=args.base_lr,
                                           gamma=args.lr_decay,
                                           stepsize=args.stepsize,
                                           test_iter=args.test_iter,
                                           test_interval=args.test_interval,
                                           max_iter=args.num_iter,
                                           snapshot=args.snapshot_interval,
                                           solver_type=args.solver_type,
                                           weight_decay=args.weight_decay,
                                           iter_size=args.iter_size,
                                           debug_info=debug_info,
                                           random_seed=random_seed,
                                           save_path=os.path.join(exp_dir, 'solver.prototxt'))
    solver = caffe.get_solver(solver)

    if args.init_model:
        if args.init_model.endswith('.caffemodel'):
            solver.net.copy_from(args.init_model)
        else:
            solver.net.copy_from(os.path.join(exp_dir, 'snapshot_iter_{}.caffemodel'.format(args.init_model)))

    if args.init_state:
        if args.init_state.endswith('.solverstate'):
            solver.restore(args.init_state)
        else:
            solver.restore(os.path.join(exp_dir, 'snapshot_iter_{}.solverstate'.format(args.init_state)))

    solver.solve()
Esempio n. 3
0
def partseg_train(network, exp_dir, category, args):
    def solve2(solver, args, uid, rank):
        if args.cpu:
            caffe.set_mode_cpu()
        else:
            caffe.set_mode_gpu()
        caffe.set_device(args.gpus[rank])
        caffe.set_solver_count(len(args.gpus))
        caffe.set_solver_rank(rank)
        caffe.set_multiprocess(True)
        
        solver = caffe.get_solver(solver)

        if args.init_model:
            if args.init_model.endswith('.caffemodel'):
                solver.net.copy_from(args.init_model)
            else:
                solver.net.copy_from(os.path.join(exp_dir, '{}_iter_{}.caffemodel'.format(category, args.init_model)))

        if args.init_state:
            if args.init_state.endswith('.solverstate'):
                solver.restore(args.init_state)
            else:
                solver.restore(os.path.join(exp_dir, '{}_iter_{}.solverstate'.format(category, args.init_state)))

        nccl = caffe.NCCL(solver, uid)
        nccl.bcast()
        if solver.param.layer_wise_reduce:
            solver.net.after_backward(nccl)
        print(rank)
        #pdb.set_trace()
        solver.step(solver.param.max_iter)
        #solver.solve()

        #caffe.set_device(0)

    if network == 'seq':
        batch_norm = True
        conv_weight_filler = 'xavier'
        network = models.partseg_seq(arch_str=args.arch,
                                     skip_str=args.skips,
                                     dataset=args.dataset,
                                     dataset_params=args.dataset_params,
                                     category=category,
                                     feat_dims_str=args.feat,
                                     lattice_dims_str=args.lattice,
                                     sample_size=args.sample_size,
                                     batch_size=args.batch_size,
                                     batchnorm=batch_norm,
                                     conv_weight_filler=conv_weight_filler,
                                     save_path=os.path.join(exp_dir, category + '_net.prototxt'))

        models.partseg_seq(deploy=True,
                           arch_str=args.arch,
                           skip_str=args.skips,
                           dataset=args.dataset,
                           dataset_params=args.dataset_params,
                           category=category,
                           feat_dims_str=args.feat,
                           lattice_dims_str=args.lattice,
                           sample_size=args.sample_size,
                           batchnorm=batch_norm,
                           save_path=os.path.join(exp_dir, category + '_net_deploy.prototxt'))
    else:
        assert network.endswith('.prototxt'), 'Please provide a valid prototxt file'
        print('Using network defined at {}'.format(network))

    random_seed = 0
    debug_info = False
    
    solver = create_solver.standard_solver(network,
                                           network,
                                           os.path.join(exp_dir, category)+'_' +args.prefix,
                                           base_lr=args.base_lr,
                                           gamma=args.lr_decay,
                                           stepsize=args.stepsize,
                                           test_iter=args.test_iter,
                                           test_interval=args.test_interval,
                                           max_iter=args.num_iter,
                                           snapshot=args.snapshot_interval,
                                           solver_type=args.solver_type,
                                           weight_decay=args.weight_decay,
                                           iter_size=args.iter_size,
                                           debug_info=debug_info,
                                           random_seed=random_seed,
                                           save_path=os.path.join(exp_dir, category+'_solver.prototxt'))    
    ## Multiple GPUs
    uid = caffe.NCCL.new_uid()
    
    caffe.init_log(0, True)
    caffe.log('Using devices %s' % str(args.gpus))
    procs = []
    
    for rank in range(len(args.gpus)):
        p = Process(target=solve2,
                    args=(solver, args, uid, rank))
        p.daemon = True
        p.start()
        procs.append(p)
    for p in procs:
        p.join()