def get_data(name, data_dir, meta_dir, gpu_nums): isTrain = True if 'train' in name else False ds = Camvid(data_dir, meta_dir, name, shuffle=True) if isTrain: ds = MapData(ds, RandomResize) if isTrain: shape_aug = [ RandomCropWithPadding(args.crop_size,IGNORE_LABEL), Flip(horiz=True), ] else: shape_aug = [] ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False) def f(ds): image, label = ds m = np.array([104, 116, 122]) const_arr = np.resize(m, (1,1,3)) # NCHW image = image - const_arr return image, label ds = MapData(ds, f) if isTrain: ds = BatchData(ds, args.batch_size*gpu_nums) ds = PrefetchDataZMQ(ds, 1) else: ds = BatchData(ds, 1) return ds
def train_net(args, ctx): logger.auto_set_dir() sym_instance = resnet101_deeplab_new() sym = sym_instance.get_symbol(NUM_CLASSES, is_train=True, use_global_stats=False) eval_sym_instance = resnet101_deeplab_new() eval_sym = eval_sym_instance.get_symbol(NUM_CLASSES, is_train=False, use_global_stats=True) # setup multi-gpu gpu_nums = len(ctx) input_batch_size = args.batch_size * gpu_nums train_data = get_data("train", DATA_DIR, LIST_DIR, len(ctx)) test_data = get_data("val", DATA_DIR, LIST_DIR, len(ctx)) # infer shape data_shape_dict = {'data':(args.batch_size, 3, args.crop_size[0],args.crop_size[1]) ,'label':(args.batch_size, 1, args.crop_size[0],args.crop_size[1])} pprint.pprint(data_shape_dict) sym_instance.infer_shape(data_shape_dict) # load and initialize params epoch_string = args.load.rsplit("-",2)[1] begin_epoch = 1 if not args.scratch: begin_epoch = int(epoch_string) logger.info('continue training from {}'.format(begin_epoch)) arg_params, aux_params = load_init_param(args.load, convert=True) else: logger.info(args.load) arg_params, aux_params = load_init_param(args.load, convert=True) sym_instance.init_weights(arg_params, aux_params) # check parameter shapes sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict) data_names = ['data'] label_names = ['label'] mod = MutableModule(sym, data_names=data_names, label_names=label_names,context=ctx, fixed_param_prefix=fixed_param_prefix) # decide training params # metric fcn_loss_metric = metric.FCNLogLossMetric(args.frequent,Camvid.class_num()) eval_metrics = mx.metric.CompositeEvalMetric() for child_metric in [fcn_loss_metric]: eval_metrics.add(child_metric) # callback batch_end_callbacks = [callback.Speedometer(input_batch_size, frequent=args.frequent)] #batch_end_callbacks = [mx.callback.ProgressBar(total=train_data.size/train_data.batch_size)] epoch_end_callbacks = \ [mx.callback.module_checkpoint(mod, os.path.join(logger.get_logger_dir(),"mxnetgo"), period=1, save_optimizer_states=True), ] lr_scheduler = StepScheduler(train_data.size()*EPOCH_SCALE,lr_step_list) # optimizer optimizer_params = {'momentum': 0.9, 'wd': 0.0005, 'learning_rate': 2.5e-4, 'lr_scheduler': lr_scheduler, 'rescale_grad': 1.0, 'clip_gradient': None} logger.info("epoch scale = {}".format(EPOCH_SCALE)) mod.fit(train_data=train_data, args = args,eval_sym=eval_sym, eval_sym_instance=eval_sym_instance, eval_data=test_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callbacks, batch_end_callback=batch_end_callbacks, kvstore=kvstore, optimizer='sgd', optimizer_params=optimizer_params, arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch,epoch_scale=EPOCH_SCALE, validation_on_last=validation_on_last)
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' os.environ['MXNET_ENABLE_GPU_P2P'] = '0' IGNORE_LABEL = 255 CROP_HEIGHT = 320 CROP_WIDTH = 320 tile_height = 321 tile_width = 321 batch_size = 23 EPOCH_SCALE = 90 end_epoch = 9 lr_step_list = [(6, 1e-3), (9, 1e-4)] NUM_CLASSES = Camvid.class_num() validation_on_last = end_epoch kvstore = "device" fixed_param_prefix = [] symbol_str = "symbol_resnet_deeplabv1" from symbols.symbol_resnet_deeplabv1 import resnet101_deeplab_new def parse_args(): parser = argparse.ArgumentParser(description='Train deeplab network') # training parser.add_argument("--gpu", default="4") parser.add_argument('--frequent', help='frequency of logging', default=10, type=int) parser.add_argument('--view', action='store_true') parser.add_argument("--validation", action="store_true")