def semseg_train(network, exp_dir, exp_prefix, args): if args.cpu: caffe.set_mode_cpu() else: caffe.set_mode_gpu() caffe.set_device(0) if exp_prefix: snapshot_prefix = os.path.join(exp_dir, exp_prefix) exp_prefix = exp_prefix + '_' else: snapshot_prefix = os.path.join(exp_dir, 'snapshot') exp_prefix = '' if network == 'seq': batch_norm = True conv_weight_filler = 'xavier' network = models.semseg_seq(arch_str=args.arch, skip_str=args.skips, dataset=args.dataset, dataset_params=args.dataset_params, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batch_size=args.batch_size, batchnorm=batch_norm, conv_weight_filler=conv_weight_filler, save_path=os.path.join( exp_dir, exp_prefix + 'net.prototxt')) models.semseg_seq(deploy=True, arch_str=args.arch, skip_str=args.skips, dataset=args.dataset, dataset_params=args.dataset_params, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batchnorm=batch_norm, save_path=os.path.join( exp_dir, exp_prefix + 'net_deploy.prototxt')) else: assert network.endswith( '.prototxt'), 'Please provide a valid prototxt file' print('Using network defined at {}'.format(network)) random_seed = 0 debug_info = False solver = create_solver.standard_solver(network, network, snapshot_prefix, base_lr=args.base_lr, gamma=args.lr_decay, stepsize=args.stepsize, test_iter=args.test_iter, test_interval=args.test_interval, max_iter=args.num_iter, snapshot=args.snapshot_interval, solver_type=args.solver_type, weight_decay=args.weight_decay, iter_size=args.iter_size, debug_info=debug_info, random_seed=random_seed, save_path=os.path.join( exp_dir, exp_prefix + 'solver.prototxt')) solver = caffe.get_solver(solver) if args.init_model: if args.init_model.endswith('.solverstate'): solver.restore(args.init_model) elif args.init_model.endswith('.caffemodel'): solver.net.copy_from(args.init_model) else: raise ValueError('Invalid file: {}'.format(args.init_model)) import pdb pdb.set_trace() solver.solve()
def partseg_train_single_model(network, exp_dir, args): if args.cpu: caffe.set_mode_cpu() else: caffe.set_mode_gpu() caffe.set_device(0) if network == 'seq': batch_norm = True conv_weight_filler = 'xavier' network = models.partseg_seq_combined_categories(arch_str=args.arch, skip_str=args.skips, renorm_class=args.renorm_class, dataset=args.dataset, dataset_params=args.dataset_params, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batch_size=args.batch_size, batchnorm=batch_norm, conv_weight_filler=conv_weight_filler, save_path=os.path.join(exp_dir, 'net.prototxt')) models.partseg_seq_combined_categories(deploy=True, arch_str=args.arch, skip_str=args.skips, renorm_class=args.renorm_class, dataset=args.dataset, dataset_params=args.dataset_params, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batchnorm=batch_norm, save_path=os.path.join(exp_dir, 'net_deploy.prototxt')) else: assert network.endswith('.prototxt'), 'Please provide a valid prototxt file' print('Using network defined at {}'.format(network)) random_seed = 0 debug_info = False solver = create_solver.standard_solver(network, network, os.path.join(exp_dir, 'snapshot'), base_lr=args.base_lr, gamma=args.lr_decay, stepsize=args.stepsize, test_iter=args.test_iter, test_interval=args.test_interval, max_iter=args.num_iter, snapshot=args.snapshot_interval, solver_type=args.solver_type, weight_decay=args.weight_decay, iter_size=args.iter_size, debug_info=debug_info, random_seed=random_seed, save_path=os.path.join(exp_dir, 'solver.prototxt')) solver = caffe.get_solver(solver) if args.init_model: if args.init_model.endswith('.caffemodel'): solver.net.copy_from(args.init_model) else: solver.net.copy_from(os.path.join(exp_dir, 'snapshot_iter_{}.caffemodel'.format(args.init_model))) if args.init_state: if args.init_state.endswith('.solverstate'): solver.restore(args.init_state) else: solver.restore(os.path.join(exp_dir, 'snapshot_iter_{}.solverstate'.format(args.init_state))) solver.solve()
def partseg_train(network, exp_dir, category, args): def solve2(solver, args, uid, rank): if args.cpu: caffe.set_mode_cpu() else: caffe.set_mode_gpu() caffe.set_device(args.gpus[rank]) caffe.set_solver_count(len(args.gpus)) caffe.set_solver_rank(rank) caffe.set_multiprocess(True) solver = caffe.get_solver(solver) if args.init_model: if args.init_model.endswith('.caffemodel'): solver.net.copy_from(args.init_model) else: solver.net.copy_from(os.path.join(exp_dir, '{}_iter_{}.caffemodel'.format(category, args.init_model))) if args.init_state: if args.init_state.endswith('.solverstate'): solver.restore(args.init_state) else: solver.restore(os.path.join(exp_dir, '{}_iter_{}.solverstate'.format(category, args.init_state))) nccl = caffe.NCCL(solver, uid) nccl.bcast() if solver.param.layer_wise_reduce: solver.net.after_backward(nccl) print(rank) #pdb.set_trace() solver.step(solver.param.max_iter) #solver.solve() #caffe.set_device(0) if network == 'seq': batch_norm = True conv_weight_filler = 'xavier' network = models.partseg_seq(arch_str=args.arch, skip_str=args.skips, dataset=args.dataset, dataset_params=args.dataset_params, category=category, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batch_size=args.batch_size, batchnorm=batch_norm, conv_weight_filler=conv_weight_filler, save_path=os.path.join(exp_dir, category + '_net.prototxt')) models.partseg_seq(deploy=True, arch_str=args.arch, skip_str=args.skips, dataset=args.dataset, dataset_params=args.dataset_params, category=category, feat_dims_str=args.feat, lattice_dims_str=args.lattice, sample_size=args.sample_size, batchnorm=batch_norm, save_path=os.path.join(exp_dir, category + '_net_deploy.prototxt')) else: assert network.endswith('.prototxt'), 'Please provide a valid prototxt file' print('Using network defined at {}'.format(network)) random_seed = 0 debug_info = False solver = create_solver.standard_solver(network, network, os.path.join(exp_dir, category)+'_' +args.prefix, base_lr=args.base_lr, gamma=args.lr_decay, stepsize=args.stepsize, test_iter=args.test_iter, test_interval=args.test_interval, max_iter=args.num_iter, snapshot=args.snapshot_interval, solver_type=args.solver_type, weight_decay=args.weight_decay, iter_size=args.iter_size, debug_info=debug_info, random_seed=random_seed, save_path=os.path.join(exp_dir, category+'_solver.prototxt')) ## Multiple GPUs uid = caffe.NCCL.new_uid() caffe.init_log(0, True) caffe.log('Using devices %s' % str(args.gpus)) procs = [] for rank in range(len(args.gpus)): p = Process(target=solve2, args=(solver, args, uid, rank)) p.daemon = True p.start() procs.append(p) for p in procs: p.join()