def train_net(args, ctx, pretrained, pretrained_base, pretrained_ec, epoch, prefix, begin_epoch, end_epoch, lr, lr_step): logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set) prefix = os.path.join(final_output_path, prefix) # load symbol shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path) sym_instance = eval(config.symbol + '.' + config.symbol)() sym = sym_instance.get_train_symbol(config) # setup multi-gpu batch_size = len(ctx) input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size # print config pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) # load dataset and prepare imdb for training image_sets = [iset for iset in config.dataset.image_set.split('+')] segdbs = [ load_gt_segdb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, result_path=final_output_path, flip=config.TRAIN.FLIP) for image_set in image_sets ] segdb = merge_segdb(segdbs) # load training data train_data = TrainDataLoader(sym, segdb, config, batch_size=input_batch_size, crop_height=config.TRAIN.CROP_HEIGHT, crop_width=config.TRAIN.CROP_WIDTH, shuffle=config.TRAIN.SHUFFLE, ctx=ctx) # infer max shape max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES]))), ('data_ref', (config.TRAIN.KEY_INTERVAL - 1, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES]))), ('eq_flag', (1, ))] max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape) print 'providing maximum shape', max_data_shape, max_label_shape data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single) pprint.pprint(data_shape_dict) sym_instance.infer_shape(data_shape_dict) # load and initialize params if config.TRAIN.RESUME: print('continue training from ', begin_epoch) arg_params, aux_params = load_param(prefix, begin_epoch, convert=True) else: print pretrained arg_params, aux_params = load_param(pretrained, epoch, convert=True) arg_params_base, aux_params_base = load_param(pretrained_base, epoch, convert=True) arg_params.update(arg_params_base) aux_params.update(aux_params_base) arg_params_ec, aux_params_ec = load_param( pretrained_ec, epoch, convert=True, argprefix=config.TRAIN.arg_prefix) arg_params.update(arg_params_ec) aux_params.update(aux_params_ec) sym_instance.init_weight(config, arg_params, aux_params) # check parameter shapes sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict) # create solver fixed_param_prefix = config.network.FIXED_PARAMS data_names = [k[0] for k in train_data.provide_data_single] label_names = [k[0] for k in train_data.provide_label_single] mod = MutableModule( sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in range(batch_size)], max_label_shapes=[max_label_shape for _ in range(batch_size)], fixed_param_prefix=fixed_param_prefix) if config.TRAIN.RESUME: mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch) # decide training params # metric fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent * batch_size) eval_metrics = mx.metric.CompositeEvalMetric() for child_metric in [fcn_loss_metric]: eval_metrics.add(child_metric) # callback batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent) epoch_end_callback = mx.callback.module_checkpoint( mod, prefix, period=1, save_optimizer_states=True) # decide learning rate base_lr = lr lr_factor = 0.1 lr_epoch = [float(epoch) for epoch in lr_step.split(',')] lr_epoch_diff = [ epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch ] lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff))) lr_iters = [ int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff ] print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step) # optimizer optimizer_params = { 'momentum': config.TRAIN.momentum, 'wd': config.TRAIN.wd, 'learning_rate': lr, 'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0, 'clip_gradient': None } if not isinstance(train_data, PrefetchingIter): train_data = PrefetchingIter(train_data) # train mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback, batch_end_callback=batch_end_callback, kvstore=config.default.kvstore, optimizer='sgd', optimizer_params=optimizer_params, arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step): """Main train function for segmentation Args: args: paramenter parser ctx: GPU context pretrained: pretrained file path epoch: pretrained checkpoint epoch prefix: model save name prefix begin_epoch: which epoch start to train end_epoch: eneded epoch of training phase lr: learning rate lr_step: list of epoch number to do learning rate decay """ ########################################## # Step 1. Create logger and set up the save prefix ########################################## logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set) prefix = os.path.join(final_output_path, prefix) ########################################## # Step 2. Copy the symbols and load the symbol to build network ########################################## shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path) sym_instance = eval(config.symbol + '.' + config.symbol)() sym = sym_instance.get_symbol(config, is_train=True) # #sym = eval('get_' + args.network + '_train')(num_classes=config.dataset.NUM_CLASSES) ########################################## # Step 3. Setup multi-gpu and batch size ########################################## batch_size = len(ctx) input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size # print config pprint.pprint(config) logger.info('training config:{}\n'.format(pprint.pformat(config))) ############################################ # Step 4. load dataset and prepare imdb for training ############################################ image_sets = [iset for iset in config.dataset.image_set.split('+')] segdbs = [ load_gt_segdb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path, result_path=final_output_path, flip=config.TRAIN.FLIP) for image_set in image_sets ] segdb = merge_segdb(segdbs) ############################################ # Step 5. Set dataloader and set the data shape ############################################ train_data = TrainDataLoader(sym, segdb, config, batch_size=input_batch_size, crop_height=config.TRAIN.CROP_HEIGHT, crop_width=config.TRAIN.CROP_WIDTH, shuffle=config.TRAIN.SHUFFLE, ctx=ctx) # infer max shape max_scale = [(config.TRAIN.CROP_HEIGHT, config.TRAIN.CROP_WIDTH)] max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in max_scale]), max([v[1] for v in max_scale])))] max_label_shape = [('label', (config.TRAIN.BATCH_IMAGES, 1, max([v[0] for v in max_scale]), max([v[1] for v in max_scale])))] # max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape, max_label_shape) print('providing maximum shape', max_data_shape, max_label_shape) # infer shape data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single) pprint.pprint(data_shape_dict) sym_instance.infer_shape(data_shape_dict) ############################################## # Step 6. load and initialize params ############################################## if config.TRAIN.RESUME: print('continue training from ', begin_epoch) arg_params, aux_params = load_param(prefix, begin_epoch, convert=True) else: print(pretrained) arg_params, aux_params = load_param(pretrained, epoch, convert=True) sym_instance.init_weights(config, arg_params, aux_params) # check parameter shapes sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict) ############################################## # Step 6 Create solver and set metrics ############################################## fixed_param_prefix = config.network.FIXED_PARAMS data_names = [k[0] for k in train_data.provide_data_single] label_names = [k[0] for k in train_data.provide_label_single] mod = MutableModule( sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in xrange(batch_size)], max_label_shapes=[max_label_shape for _ in xrange(batch_size)], fixed_param_prefix=fixed_param_prefix) # decide training params # metric fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent * batch_size) eval_metrics = mx.metric.CompositeEvalMetric() # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric for child_metric in [fcn_loss_metric]: eval_metrics.add(child_metric) ############################################## # Step 7. Set callback for training process ############################################## batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent) epoch_end_callback = mx.callback.module_checkpoint( mod, prefix, period=1, save_optimizer_states=True) ############################################## # Step 8. Decide learning rate and optimizers ############################################## base_lr = lr lr_factor = 0.1 lr_epoch = [float(epoch) for epoch in lr_step.split(',')] lr_epoch_diff = [ epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch ] lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff))) lr_iters = [ int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff ] print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters) lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step) # optimizer optimizer_params = { 'momentum': config.TRAIN.momentum, 'wd': config.TRAIN.wd, 'learning_rate': lr, 'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0, 'clip_gradient': None } if not isinstance(train_data, PrefetchingIter): train_data = PrefetchingIter(train_data) ############################################## # Step 9 Start to train ############################################## mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback, batch_end_callback=batch_end_callback, kvstore=config.default.kvstore, optimizer='sgd', optimizer_params=optimizer_params, arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)