示例#1
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):

    logger, final_output_path, _, tensorboard_path = create_env(
        config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    print "config.symbol", config.symbol
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    if config.TRAIN.use_dynamic:
        sym_gen = sym_instance.sym_gen(config, is_train=True)
    else:
        sym = sym_instance.get_symbol(config, is_train=True)

    # infer max shape
    scales = [(config.TRAIN.crop_size[0], config.TRAIN.crop_size[1])
              ] if config.TRAIN.enable_crop else config.SCALES
    label_stride = config.network.LABEL_STRIDE
    network_ratio = config.network.ratio
    if config.network.use_context:
        if config.network.use_crop_context:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3, 736, 736)),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3,
                                 int(config.SCALES[0][0] * network_ratio),
                                 int(config.SCALES[0][1] * network_ratio))),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
    else:
        if config.TRAIN.enable_crop:
            max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                        config.TRAIN.crop_size[0],
                                        config.TRAIN.crop_size[1]))]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3,
                  max([make_divisible(v[0], label_stride) for v in scales]),
                  max([make_divisible(v[1], label_stride) for v in scales])))
            ]

    if config.network.use_mult_label:

        if config.network.use_crop_context:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1, 736, 736))
            ]

        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1,
                                  int(config.SCALES[0][0] * network_ratio),
                                  int(config.SCALES[0][1] * network_ratio)))
            ]
    elif config.network.use_metric:
        scale_list = config.network.scale_list
        scale_name = ['a', 'b', 'c']
        if config.network.scale_list == [1, 2, 4]:
            scale_name = ['', '', '']

        if config.TRAIN.enable_crop:
            if config.TRAIN.use_mult_metric:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]

            else:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]
        else:
            if config.TRAIN.use_mult_metric:

                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

            else:
                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

    else:
        if config.TRAIN.enable_crop:
            max_label_shape = [('label',
                                (config.TRAIN.BATCH_IMAGES, 1,
                                 config.TRAIN.crop_size[0] // label_stride,
                                 config.TRAIN.crop_size[1] // label_stride))]
        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  max([make_divisible(v[0], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE,
                  max([make_divisible(v[1], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE))
            ]

    print "max_label_shapes", max_label_shape
    if config.TRAIN.use_dynamic:
        sym = sym_gen([max_data_shape])

        # setup multi-gpu
    input_batch_size = config.TRAIN.BATCH_IMAGES * len(ctx)
    NUM_GPUS = len(ctx)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=True) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx,
                                 use_context=config.network.use_context,
                                 use_mult_label=config.network.use_mult_label,
                                 use_metric=config.network.use_metric)

    # loading val data
    if config.TRAIN.eval_data_frequency > 0:
        val_image_set = config.dataset.test_image_set
        val_root_path = config.dataset.root_path
        val_dataset = config.dataset.dataset
        val_dataset_path = config.dataset.dataset_path
        val_imdb = eval(val_dataset)(val_image_set,
                                     val_root_path,
                                     val_dataset_path,
                                     result_path=final_output_path)
        val_segdb = val_imdb.gt_segdb()

        val_data = TrainDataLoader(
            sym,
            val_segdb,
            config,
            batch_size=input_batch_size,
            shuffle=config.TRAIN.SHUFFLE,
            ctx=ctx,
            use_context=config.network.use_context,
            use_mult_label=config.network.use_mult_label,
            use_metric=config.network.use_metric)
    else:
        val_data = None

    # print sym.list_arguments()
    print 'providing maximum shape', max_data_shape, max_label_shape
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    if config.TRAIN.use_dynamic:
        sym = sym_gen([train_data.provide_data_single])
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    nset = set()
    for nm in sym.list_arguments():
        if nm in nset:
            raise ValueError('Duplicate names detected, %s' % str(nm))
        nset.add(nm)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)
        preload_opt_states = load_preload_opt_states(prefix, begin_epoch)
        # preload_opt_states = None
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        preload_opt_states = None
        if not config.TRAIN.FINTUNE:
            fixed_param_names = sym_instance.init_weights(
                config, arg_params, aux_params)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)

    # check parameter shapes
    # sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(NUM_GPUS)],
        max_label_shapes=[max_label_shape for _ in xrange(NUM_GPUS)],
        fixed_param_prefix=fixed_param_prefix)

    # metric
    imagecrossentropylossmetric = metric.ImageCrossEntropyLossMetric()
    localmetric = metric.LocalImageCrossEntropyLossMetric()
    globalmetric = metric.GlobalImageCrossEntropyLossMetric()
    pixcelAccMetric = metric.PixcelAccMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()

    if config.network.use_mult_label:
        metric_list = [
            imagecrossentropylossmetric, localmetric, globalmetric,
            pixcelAccMetric
        ]
    elif config.network.use_metric:
        if config.TRAIN.use_crl_ses:
            metric_list = [
                imagecrossentropylossmetric,
                metric.SigmoidPixcelAccMetric(1),
                metric.SigmoidPixcelAccMetric(2),
                metric.SigmoidPixcelAccMetric(3),
                metric.CenterLossMetric(4),
                metric.CenterLossMetric(5),
                metric.CenterLossMetric(6), pixcelAccMetric
            ]

        elif config.network.use_sigmoid_metric:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(1),
                    metric.SigmoidPixcelAccMetric(2),
                    metric.SigmoidPixcelAccMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(), pixcelAccMetric
                ]
        else:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1),
                    metric.MetricLossMetric(2),
                    metric.MetricLossMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1), pixcelAccMetric
                ]
    elif config.network.mult_loss:
        metric_list = [
            imagecrossentropylossmetric,
            metric.MImageCrossEntropyLossMetric(1),
            metric.MImageCrossEntropyLossMetric(2), pixcelAccMetric
        ]
    elif config.TRAIN.use_center:
        if config.TRAIN.use_one_center:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1)
            ]
        else:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1),
                metric.CenterLossMetric(2),
                metric.CenterLossMetric(3)
            ]
    else:
        metric_list = [imagecrossentropylossmetric, pixcelAccMetric]
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in metric_list:
        eval_metrics.add(child_metric)

    # callback
    if False:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size,
                                 frequent=args.frequent),
            callback.TensorboardCallback(tensorboard_path,
                                         prefix="train/batch")
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        shared_tensorboard = batch_end_callback[1]

        epoch_end_metric_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="train/epoch")
        eval_end_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="val/epoch")
        lr_callback = callback.LrCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix='train/batch')
    else:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size, frequent=args.frequent)
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        epoch_end_metric_callback = None
        eval_end_callback = None

    #decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / input_batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    if config.TRAIN.lr_type == "MultiStage":
        lr_scheduler = LinearWarmupMultiStageScheduler(
            lr_iters,
            lr_factor,
            config.TRAIN.warmup,
            config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step,
            args.frequent,
            stop_lr=lr * 0.01)
    elif config.TRAIN.lr_type == "MultiFactor":
        lr_scheduler = LinearWarmupMultiFactorScheduler(
            lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step, args.frequent)

    if config.TRAIN.optimizer == "sgd":
        optimizer_params = {
            'momentum': config.TRAIN.momentum,
            'wd': config.TRAIN.wd,
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = SGD(**optimizer_params)
    elif config.TRAIN.optimizer == "adam":
        optimizer_params = {
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = Adam(**optimizer_params)
        print "optimizer adam"

    freeze_layer_pattern = config.TRAIN.FIXED_PARAMS_PATTERN

    if freeze_layer_pattern.strip():
        args_lr_mult = {}
        re_prog = re.compile(freeze_layer_pattern)
        if freeze_layer_pattern:
            fixed_param_names = [
                name for name in sym.list_arguments() if re_prog.match(name)
            ]
        print "============================"
        print "fixed_params_names:"
        print(fixed_param_names)
        for name in fixed_param_names:
            args_lr_mult[name] = config.TRAIN.FIXED_PARAMS_PATTERN_LR_MULT
        print "============================"
    else:
        args_lr_mult = {}
    optimizer.set_lr_mult(args_lr_mult)

    # data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    # if config.TRAIN.use_dynamic:
    #     sym = sym_gen([train_data.provide_data_single])
    # pprint.pprint(data_shape_dict)
    # sym_instance.infer_shape(data_shape_dict)

    if not isinstance(train_data, PrefetchingIter) and config.TRAIN.use_thread:
        train_data = PrefetchingIter(train_data)

    if val_data:
        if not isinstance(val_data, PrefetchingIter):
            val_data = PrefetchingIter(val_data)

    if Debug:
        monitor = mx.monitor.Monitor(1)
    else:
        monitor = None

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            eval_end_callback=eval_end_callback,
            epoch_end_metric_callback=epoch_end_metric_callback,
            optimizer=optimizer,
            eval_data=val_data,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            allow_missing=begin_epoch == 0,
            allow_extra=True,
            monitor=monitor,
            preload_opt_states=preload_opt_states,
            eval_data_frequency=config.TRAIN.eval_data_frequency)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    # 创建logger和对应的输出路径
    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    # 特征symbol,从网络sym中获取rpn_cls_score_output
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    # 使能多GPU训练,每一张卡训练一个batch
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    # 加载数据集同时准备训练的imdb,使用+分割不同的图像数据集,比如2007_trainval+2012_trainval
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    # load gt roidb加载gt roidb,根据数据集类型,图像集具体子类,数据集根目录和数据集路径,同时配置相关TRAIN为FLIP来增广数据
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    # 合并不同的roidb
    roidb = merge_roidb(roidbs)
    # 根据配置文件中对应的过滤规则来滤出roi
    roidb = filter_roidb(roidb, config)
    # load training data
    # 加载训练数据,anchor Loader为对应分类和回归的锚点加载,通过对应的roidb,查找对应的正负样本的锚点,该生成器需要参数锚点尺度,ratios和对应的feature的stride
    train_data = AnchorLoader(feat_sym, roidb, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE, ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in range(batch_size)],
                        max_label_shapes=[max_label_shape for _ in range(batch_size)], fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True), callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.default.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
示例#3
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    if config.dataset.dataset != 'JSONList':
        logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
        prefix = os.path.join(final_output_path, prefix)
        shutil.copy2(args.cfg, prefix+'.yaml')
    else:
        import datetime
        import logging
        final_output_path = config.output_path
        prefix = prefix + '_' + datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
        prefix = os.path.join(final_output_path, prefix)
        shutil.copy2(args.cfg, prefix+'.yaml')
        log_file = prefix + '.log'
        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(filename=log_file, format=head)
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        logger.info('prefix: %s' % prefix)
        print('prefix: %s' % prefix)

    # load symbol
    #shutil.copy2(os.path.join(curr_path, '..', 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    batch_size = config.TRAIN.IMAGES_PER_GPU * len(ctx)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    if config.network.MULTI_RPN:
        assert Fasle, 'still developing' ###
        num_layers = len(config.network.MULTI_RPN_STRIDES)
        rpn_syms = [sym.get_internals()['rpn%d_cls_score_output'% l] for l in range(num_layers)]
        train_data = PyramidAnchorLoader(rpn_syms, roidb, config, batch_size=batch_size, shuffle=config.TRAIN.SHUFFLE,
                                         ctx=ctx, feat_strides=config.network.MULTI_RPN_STRIDES, anchor_scales=config.network.ANCHOR_SCALES,
                                         anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                                         allowed_border=np.inf)
    else:
        feat_sym = sym.get_internals()['rpn_cls_score_output']
        train_data = AnchorLoader(feat_sym, roidb, config, batch_size=batch_size, shuffle=config.TRAIN.SHUFFLE, ctx=ctx,
                                  feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                                  anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    mod = MutableModule(sym,
                        train_data.data_names,
                        train_data.label_names,
                        context=ctx,
                        logger=logger,
                        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if config.network.PREDICT_KEYPOINTS:
        kps_cls_acc = metric.KeypointAccMetric(config)
        kps_cls_loss = metric.KeypointLogLossMetric(config)
        kps_pos_loss = metric.KeypointL1LossMetric(config)
        eval_metrics.add(kps_cls_acc)
        eval_metrics.add(kps_cls_loss)
        eval_metrics.add(kps_pos_loss)

    # callback
    batch_end_callback = callback.Speedometer(batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.do_checkpoint(prefix)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.TRAIN.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
    time.sleep(10)
    train_data.iters[0].terminate()
示例#4
0
def train_rcnn(cfg, dataset, image_set, root_path, dataset_path,
               frequent, kvstore, flip, shuffle, resume,
               ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
               train_shared, lr, lr_step, proposal, logger=None, output_path=None):
    mx.random.seed(0)
    np.random.seed(0)
    # set up logger
    if not logger:
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

    # load symbol
    sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
    sym = sym_instance.get_symbol_rcnn(cfg, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = cfg.TRAIN.BATCH_IMAGES * batch_size

    # print cfg
    pprint.pprint(cfg)
    logger.info('training rcnn cfg:{}\n'.format(pprint.pformat(cfg)))

    rpn_path = cfg.dataset.proposal_cache
    # load dataset and prepare imdb for training
    image_sets = [iset for iset in image_set.split('+')]
    roidbs = [load_proposal_roidb(dataset, image_set, root_path, dataset_path,
                                  proposal=proposal, append_gt=True, flip=flip, result_path=output_path,
                                  rpn_path=rpn_path, top_roi=cfg.TRAIN.TOP_ROIS)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, cfg)
    means, stds = add_bbox_regression_targets(roidb, cfg)

    # load training data
    train_data = ROIIter(roidb, cfg, batch_size=input_batch_size, shuffle=shuffle,
                         ctx=ctx, aspect_grouping=cfg.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_height = max([v[0] for v in cfg.SCALES])
    max_width = max([v[1] for v in cfg.SCALES])
    paded_max_height = max_height + cfg.network.IMAGE_STRIDE - max_height % cfg.network.IMAGE_STRIDE
    paded_max_width = max_width + cfg.network.IMAGE_STRIDE - max_width % (cfg.network.IMAGE_STRIDE)

    max_data_shape = [('data', (cfg.TRAIN.BATCH_IMAGES, 3, paded_max_height, paded_max_width))]
    # infer shape
    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    sym_instance.infer_shape(data_shape_dict)
    # print shape
    pprint.pprint(sym_instance.arg_shape_dict)
    logging.info(pprint.pformat(sym_instance.arg_shape_dict))

    max_batch_roi = cfg.TRAIN.TOP_ROIS if cfg.TRAIN.BATCH_ROIS == -1 else cfg.TRAIN.BATCH_ROIS
    num_class = 2 if cfg.CLASS_AGNOSTIC else cfg.dataset.NUM_CLASSES
    max_label_shape = [('label', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi)),
                       ('bbox_target', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi, num_class * 4)),
                       ('bbox_weight', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi, num_class * 4))]

    if cfg.network.USE_NONGT_INDEX:
        max_label_shape.append(('nongt_index', (2000,)))

    if cfg.network.ROIDispatch:
        max_data_shape.append(('rois_0', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_1', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_2', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_3', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
    else:
        max_data_shape.append(('rois', (cfg.TEST.PROPOSAL_POST_NMS_TOP_N + 30, 5)))

    #dot = mx.viz.plot_network(sym, node_attrs={'shape': 'rect', 'fixedsize': 'false'})
    #dot.render(os.path.join('./output/rcnn/network_vis', cfg.symbol + cfg.TRAIN.model_prefix))

    # load and initialize params
    if resume:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight_rcnn(cfg, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # prepare training
    # create solver
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]
    if train_shared:
        fixed_param_prefix = cfg.network.FIXED_PARAMS_SHARED
    else:
        fixed_param_prefix = cfg.network.FIXED_PARAMS

    if cfg.network.ROIDispatch:
        mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                            logger=logger, context=ctx,
                            max_data_shapes=[max_data_shape for _ in range(batch_size)],
                            max_label_shapes=[max_label_shape for _ in range(batch_size)],
                            fixed_param_prefix=fixed_param_prefix)
    else:
        mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                            logger=logger, context=ctx,
                            max_data_shapes=[max_data_shape for _ in range(batch_size)],
                            max_label_shapes=[max_label_shape for _ in range(batch_size)],
                            fixed_param_prefix=fixed_param_prefix)
    if cfg.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    eval_metric = metric.RCNNAccMetric(cfg)
    cls_metric = metric.RCNNLogLossMetric(cfg)
    bbox_metric = metric.RCNNL1LossMetric(cfg)
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if cfg.TRAIN.LEARN_NMS:
        eval_metrics.add(metric.NMSLossMetric(cfg, 'pos'))
        eval_metrics.add(metric.NMSLossMetric(cfg, 'neg'))
        eval_metrics.add(metric.NMSAccMetric(cfg))
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = cfg.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, cfg.TRAIN.warmup, cfg.TRAIN.warmup_lr,
                                              cfg.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': cfg.TRAIN.momentum,
                        'wd': cfg.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    # train

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
示例#5
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    new_args_name = args.cfg
    if args.vis:
        config.TRAIN.VISUALIZE = True
    logger, final_output_path = create_logger(config.output_path,
                                              new_args_name,
                                              config.dataset.image_set,
                                              args.temp)
    prefix = os.path.join(final_output_path, prefix)
    logger.info('called with args {}'.format(args))

    print(config.train_iter.SE3_PM_LOSS)
    if config.train_iter.SE3_PM_LOSS:
        print("SE3_PM_LOSS == True")
    else:
        print("SE3_PM_LOSS == False")

    if not config.network.STANDARD_FLOW_REP:
        print_and_log("[h, w] representation for flow is dep", logger)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    datasets = [dset for dset in config.dataset.dataset.split('+')]
    print("config.dataset.class_name: {}".format(config.dataset.class_name))
    print("image_sets: {}".format(image_sets))
    if datasets[0].startswith('ModelNet'):
        pairdbs = [
            load_gt_pairdb(config,
                           datasets[i],
                           image_sets[i] + class_name.split('/')[-1],
                           config.dataset.root_path,
                           config.dataset.dataset_path,
                           class_name=class_name,
                           result_path=final_output_path)
            for class_name in config.dataset.class_name
            for i in range(len(image_sets))
        ]
    else:
        pairdbs = [
            load_gt_pairdb(config,
                           datasets[i],
                           image_sets[i] + class_name,
                           config.dataset.root_path,
                           config.dataset.dataset_path,
                           class_name=class_name,
                           result_path=final_output_path)
            for class_name in config.dataset.class_name
            for i in range(len(image_sets))
        ]
    pairdb = merge_pairdb(pairdbs)

    if not args.temp:
        src_file = os.path.join(curr_path, 'symbols', config.symbol + '.py')
        dst_file = os.path.join(
            final_output_path,
            '{}_{}.py'.format(config.symbol, time.strftime('%Y-%m-%d-%H-%M')))
        os.popen('cp {} {}'.format(src_file, dst_file))

    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_PAIRS * batch_size

    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load training data
    train_data = TrainDataLoader(sym,
                                 pairdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    train_data.get_batch_parallel()
    max_scale = [
        max([v[0] for v in config.SCALES]),
        max(v[1] for v in config.SCALES)
    ]
    max_data_shape = [('image_observed', (config.TRAIN.BATCH_PAIRS, 3,
                                          max_scale[0], max_scale[1])),
                      ('image_rendered', (config.TRAIN.BATCH_PAIRS, 3,
                                          max_scale[0], max_scale[1])),
                      ('depth_gt_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                             max_scale[0], max_scale[1])),
                      ('src_pose', (config.TRAIN.BATCH_PAIRS, 3, 4)),
                      ('tgt_pose', (config.TRAIN.BATCH_PAIRS, 3, 4))]
    if config.network.INPUT_DEPTH:
        max_data_shape.append(('depth_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                                  max_scale[0], max_scale[1])))
        max_data_shape.append(('depth_rendered', (config.TRAIN.BATCH_PAIRS, 1,
                                                  max_scale[0], max_scale[1])))
    if config.network.INPUT_MASK:
        max_data_shape.append(('mask_observed', (config.TRAIN.BATCH_PAIRS, 1,
                                                 max_scale[0], max_scale[1])))
        max_data_shape.append(('mask_rendered', (config.TRAIN.BATCH_PAIRS, 1,
                                                 max_scale[0], max_scale[1])))

    rot_param = 3 if config.network.ROT_TYPE == "EULER" else 4
    max_label_shape = [('rot', (config.TRAIN.BATCH_PAIRS, rot_param)),
                       ('trans', (config.TRAIN.BATCH_PAIRS, 3))]
    if config.network.PRED_FLOW:
        max_label_shape.append(('flow', (config.TRAIN.BATCH_PAIRS, 2,
                                         max_scale[0], max_scale[1])))
        max_label_shape.append(('flow_weights', (config.TRAIN.BATCH_PAIRS, 2,
                                                 max_scale[0], max_scale[1])))
    if config.train_iter.SE3_PM_LOSS:
        max_label_shape.append(
            ('point_cloud_model', (config.TRAIN.BATCH_PAIRS, 3,
                                   config.train_iter.NUM_3D_SAMPLE)))
        max_label_shape.append(
            ('point_cloud_weights', (config.TRAIN.BATCH_PAIRS, 3,
                                     config.train_iter.NUM_3D_SAMPLE)))
        max_label_shape.append(
            ('point_cloud_observed', (config.TRAIN.BATCH_PAIRS, 3,
                                      config.train_iter.NUM_3D_SAMPLE)))
    if config.network.PRED_MASK:
        max_label_shape.append(
            ('mask_gt_observed', (config.TRAIN.BATCH_PAIRS, 1, max_scale[0],
                                  max_scale[1])))

    # max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape, max_label_shape)
    print_and_log(
        'providing maximum shape, {}, {}'.format(max_data_shape,
                                                 max_label_shape), logger)

    # infer max shape
    '''
    max_label_shape = [('label', (config.TRAIN.BATCH_IMAGES, 1,
                                  max([v[0] for v in max_scale]),
                                  max([v[1] for v in max_scale])))]
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print('providing maximum shape', max_data_shape, max_label_shape)
    '''
    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    print_and_log('\ndata_shape_dict: {}\n'.format(data_shape_dict), logger)
    sym_instance.infer_shape(data_shape_dict)

    print('************(wg): infering shape **************')
    internals = sym.get_internals()
    _, out_shapes, _ = internals.infer_shape(**data_shape_dict)
    print(sym.list_outputs())
    shape_dict = dict(zip(internals.list_outputs(), out_shapes))
    pprint.pprint(shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    elif pretrained == 'xavier':
        print('xavier')
        # arg_params = {}
        # aux_params = {}
        # sym_instance.init_weights(config, arg_params, aux_params)
    else:
        print(pretrained)
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        print('arg_params: ', arg_params.keys())
        print('aux_params: ', aux_params.keys())
        if not config.network.skip_initialize:
            sym_instance.init_weights(config, arg_params, aux_params)

    # check parameter shapes
    if pretrained != 'xavier':
        sym_instance.check_parameter_shapes(arg_params, aux_params,
                                            data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix,
        config=config)

    # decide training params
    # metrics
    eval_metrics = mx.metric.CompositeEvalMetric()

    metric_list = []
    iter_idx = 0
    if config.network.PRED_FLOW:
        metric_list.append(metric.Flow_L2LossMetric(config, iter_idx))
        metric_list.append(metric.Flow_CurLossMetric(config, iter_idx))
    if config.train_iter.SE3_DIST_LOSS:
        metric_list.append(metric.Rot_L2LossMetric(config, iter_idx))
        metric_list.append(metric.Trans_L2LossMetric(config, iter_idx))
    if config.train_iter.SE3_PM_LOSS:
        metric_list.append(metric.PointMatchingLossMetric(config, iter_idx))
    if config.network.PRED_MASK:
        metric_list.append(metric.MaskLossMetric(config, iter_idx))

    # Visualize Training Batches
    if config.TRAIN.VISUALIZE:
        metric_list.append(metric.SimpleVisualize(config))
        # metric_list.append(metric.MaskVisualize(config, save_dir = final_output_path))
        metric_list.append(
            metric.MinibatchVisualize(config))  # flow visualization

    for child_metric in metric_list:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(pairdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    if config.TRAIN.optimizer == 'adam':
        optimizer_params = {'learning_rate': lr}
        if pretrained == 'xavier':
            init = mx.init.Mixed(['rot_weight|trans_weight', '.*'], [
                mx.init.Zero(),
                mx.init.Xavier(
                    rnd_type='gaussian', factor_type="in", magnitude=2)
            ])
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='adam',
                    optimizer_params=optimizer_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix,
                    initializer=init,
                    force_init=True)
        else:
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='adam',
                    arg_params=arg_params,
                    aux_params=aux_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix)
    elif config.TRAIN.optimizer == 'sgd':
        # optimizer
        optimizer_params = {
            'momentum': config.TRAIN.momentum,
            'wd': config.TRAIN.wd,
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        if pretrained == 'xavier':
            init = mx.init.Mixed(['rot_weight|trans_weight', '.*'], [
                mx.init.Zero(),
                mx.init.Xavier(
                    rnd_type='gaussian', factor_type="in", magnitude=2)
            ])
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='sgd',
                    optimizer_params=optimizer_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix,
                    initializer=init,
                    force_init=True)
        else:
            mod.fit(train_data,
                    eval_metric=eval_metrics,
                    epoch_end_callback=epoch_end_callback,
                    batch_end_callback=batch_end_callback,
                    kvstore=config.default.kvstore,
                    optimizer='sgd',
                    optimizer_params=optimizer_params,
                    arg_params=arg_params,
                    aux_params=aux_params,
                    begin_epoch=begin_epoch,
                    num_epoch=end_epoch,
                    prefix=prefix)
示例#6
0
def train_feature_distance_net(args, ctx, pretrained, pretrained_flow, epoch,
                               prefix, begin_epoch, end_epoch, lr, lr_step):
    # ==============prepare logger==============
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # ==============load symbol==============
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    if config.TRAIN.G_type == 0:
        sym = sym_instance.get_train_feature_distance_symbol(config)
    elif config.TRAIN.G_type == 1:
        sym = sym_instance.get_train_feature_distance_symbol_res(config)

    # ==============setup multi-gpu==============
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # ==============print config==============
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # ==============load dataset and prepare imdb for training==============
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    train_iter = ImagenetVIDIter(roidb, input_batch_size, config,
                                 config.TRAIN.SHUFFLE, ctx)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES]))),
                      ('data_ref', (config.TRAIN.BATCH_IMAGES, 3,
                                    max([v[0] for v in config.SCALES]),
                                    max([v[1] for v in config.SCALES])))]
    print 'providing maximum shape', max_data_shape

    data_shape_dict = dict(train_iter.provide_data_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # ==============init params==============
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        # arg_params_flow, aux_params_flow = load_param(pretrained_flow, epoch, convert=True)
        # arg_params.update(arg_params_flow)
        # aux_params.update(aux_params_flow)
        sym_instance.init_weight(config, arg_params, aux_params)

    # ==============check parameter shapes==============
    # sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # ==============create solver==============
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = train_iter.data_name
    label_names = train_iter.label_name

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=None,
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # ==============optimizer==============
    optimizer_params = {
        'learning_rate': 0.00005,
    }

    if not isinstance(train_iter, PrefetchingIter):
        train_iter = PrefetchingIter(train_iter)

    batch_end_callback = callback.Speedometer(train_iter.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix)
    ]

    feature_L2_loss = metric.FeatureL2LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    eval_metrics.add(feature_L2_loss)

    mod.fit(train_iter,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='RMSprop',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Normal(0.02),
            allow_missing=True)
示例#7
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)
    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()

    sym = sym_instance.get_retina_symbol(config, is_train=True)
    feat_sym = []

    feat_sym_p4 = sym.get_internals()['box_pred/p4_output']
    feat_sym_p5 = sym.get_internals()['box_pred/p5_output']
    feat_sym_p6 = sym.get_internals()['box_pred/p6_output']
    feat_sym_p7 = sym.get_internals()['box_pred/p7_output']

    feat_sym.append(feat_sym_p4)
    feat_sym.append(feat_sym_p5)
    feat_sym.append(feat_sym_p6)
    feat_sym.append(feat_sym_p7)
    #######
    feat_stride = []
    feat_stride.append(config.network.p4_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p5_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p6_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p7_RPN_FEAT_STRIDE)
    anchor_scales = []

    anchor_scales.append(config.network.p4_ANCHOR_SCALES)
    anchor_scales.append(config.network.p5_ANCHOR_SCALES)
    anchor_scales.append(config.network.p6_ANCHOR_SCALES)
    anchor_scales.append(config.network.p7_ANCHOR_SCALES)
    anchor_ratios = []

    anchor_ratios.append(config.network.p4_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p5_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p6_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p7_ANCHOR_RATIOS)
    #############

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)

    roidb = filter_roidb(roidb, config)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              feat_stride,
                              anchor_scales,
                              anchor_ratios,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING)
    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape
    # infer max shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)
    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    Retina_toal_eval_metric = metric.RetinaToalAccMetric()
    Retina_cls_metric = metric.RetinaFocalLossMetric()
    Retina_bbox_metric = metric.RetinaL1LossMetric()

    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            Retina_toal_eval_metric, Retina_cls_metric, Retina_bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print lr_step.split(',')
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'learning_rate': lr,
        'wd': 0.0001,
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)
    # train
    initializer = mx.init.MSRAPrelu(factor_type='out', slope=0)
    # adam = mx.optimizer.AdaDelta(rho=0.09,  epsilon=1e-14)
    #optimizer_params=optimizer_params,

    print "-----------------------train--------------------------------"
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='adam',
            optimizer_params=optimizer_params,
            initializer=initializer,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    """Main train function for segmentation

    Args:
        args:
            paramenter parser
        ctx:
            GPU context
        pretrained:
            pretrained file path
        epoch:
            pretrained checkpoint epoch
        prefix:
            model save name prefix
        begin_epoch:
            which epoch start to train
        end_epoch:
            eneded epoch of training phase
        lr:
            learning rate
        lr_step:
            list of epoch number to do learning rate decay

    """
    ##########################################
    # Step 1. Create logger and set up the save prefix
    ##########################################
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    ##########################################
    # Step 2. Copy the symbols and load the symbol to build network
    ##########################################
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    #
    #sym = eval('get_' + args.network + '_train')(num_classes=config.dataset.NUM_CLASSES)

    ##########################################
    # Step 3. Setup multi-gpu and batch size
    ##########################################
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    ############################################
    # Step 4. load dataset and prepare imdb for training
    ############################################
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    ############################################
    # Step 5. Set dataloader and set the data shape
    ############################################
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 crop_height=config.TRAIN.CROP_HEIGHT,
                                 crop_width=config.TRAIN.CROP_WIDTH,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    # infer max shape
    max_scale = [(config.TRAIN.CROP_HEIGHT, config.TRAIN.CROP_WIDTH)]
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in max_scale]),
                                max([v[1] for v in max_scale])))]
    max_label_shape = [('label', (config.TRAIN.BATCH_IMAGES, 1,
                                  max([v[0] for v in max_scale]),
                                  max([v[1] for v in max_scale])))]
    # max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape, max_label_shape)
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    ##############################################
    # Step 6. load and initialize params
    ##############################################
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        print(pretrained)
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weights(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    ##############################################
    # Step 6 Create solver and set metrics
    ##############################################
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
        max_label_shapes=[max_label_shape for _ in xrange(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent *
                                              batch_size)
    eval_metrics = mx.metric.CompositeEvalMetric()

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [fcn_loss_metric]:
        eval_metrics.add(child_metric)

    ##############################################
    # Step 7. Set callback for training process
    ##############################################
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)

    ##############################################
    # Step 8. Decide learning rate and optimizers
    ##############################################
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    ##############################################
    # Step 9 Start to train
    ##############################################
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
示例#9
0
def train_rcnn(cfg,
               dataset,
               image_set,
               root_path,
               dataset_path,
               frequent,
               kvstore,
               flip,
               shuffle,
               resume,
               ctx,
               pretrained,
               epoch,
               prefix,
               begin_epoch,
               end_epoch,
               train_shared,
               lr,
               lr_step,
               proposal,
               logger=None,
               output_path=None):
    # set up logger
    # 首先需要加载日志输出对象
    if not logger:
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

    # load symbol
    # 加载对应的目标检测框架mxnet模型symbol
    sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
    # sym_instance获取对应的rfcn模型,并且设置is_train为True,表示加载的模型需要对应train的模型
    sym = sym_instance.get_symbol_rfcn(cfg, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = cfg.TRAIN.BATCH_IMAGES * batch_size

    # print cfg
    pprint.pprint(cfg)
    logger.info('training rcnn cfg:{}\n'.format(pprint.pformat(cfg)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in image_set.split('+')]
    roidbs = [
        load_proposal_roidb(dataset,
                            image_set,
                            root_path,
                            dataset_path,
                            proposal=proposal,
                            append_gt=True,
                            flip=flip,
                            result_path=output_path)
        for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, cfg)
    means, stds = add_bbox_regression_targets(roidb, cfg)

    # load training data
    train_data = ROIIter(roidb,
                         cfg,
                         batch_size=input_batch_size,
                         shuffle=shuffle,
                         ctx=ctx,
                         aspect_grouping=cfg.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (cfg.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in cfg.SCALES]),
                                max([v[1] for v in cfg.SCALES])))]

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if resume:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight_rfcn(cfg, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # prepare training
    # create solver
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]
    if train_shared:
        fixed_param_prefix = cfg.network.FIXED_PARAMS_SHARED
    else:
        fixed_param_prefix = cfg.network.FIXED_PARAMS
    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if cfg.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    eval_metric = metric.RCNNAccMetric(cfg)
    cls_metric = metric.RCNNLogLossMetric(cfg)
    bbox_metric = metric.RCNNL1LossMetric(cfg)
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=frequent)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = cfg.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              cfg.TRAIN.warmup,
                                              cfg.TRAIN.warmup_lr,
                                              cfg.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': cfg.TRAIN.momentum,
        'wd': cfg.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    # train

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
示例#10
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):

    logger, final_output_path, _, tensorboard_path = create_env(
        config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    print "config.symbol", config.symbol
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    input_batch_size = config.TRAIN.BATCH_IMAGES * len(ctx)
    NUM_GPUS = len(ctx)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path)
        for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    # loading val data
    val_image_set = config.dataset.test_image_set
    val_root_path = config.dataset.root_path
    val_dataset = config.dataset.dataset
    val_dataset_path = config.dataset.dataset_path
    val_imdb = eval(val_dataset)(val_image_set,
                                 val_root_path,
                                 val_dataset_path,
                                 result_path=final_output_path)
    val_segdb = val_imdb.gt_segdb()

    val_data = TrainDataLoader(sym,
                               val_segdb,
                               config,
                               batch_size=input_batch_size,
                               shuffle=config.TRAIN.SHUFFLE,
                               ctx=ctx)

    # infer max shape
    max_scale = [(config.TRAIN.crop_size[0], config.TRAIN.crop_size[1])]
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in max_scale]),
                                max([v[1] for v in max_scale])))]
    max_label_shape = [
        ('label',
         (config.TRAIN.BATCH_IMAGES, 1,
          max([v[0] for v in max_scale]) // config.network.LABEL_STRIDE,
          max([v[1] for v in max_scale]) // config.network.LABEL_STRIDE))
    ]
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
        preload_opt_states = load_preload_opt_states(prefix, begin_epoch)
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        preload_opt_states = None
        sym_instance.init_weights(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(NUM_GPUS)],
        max_label_shapes=[max_label_shape for _ in xrange(NUM_GPUS)],
        fixed_param_prefix=fixed_param_prefix)

    # metric
    imagecrossentropylossmetric = metric.ImageCrossEntropyLossMetric()
    pixcelAccMetric = metric.PixcelAccMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [imagecrossentropylossmetric, pixcelAccMetric]:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = [
        callback.Speedometer(train_data.batch_size, frequent=args.frequent),
        callback.TensorboardCallback(tensorboard_path, prefix="train/batch")
    ]
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)
    shared_tensorboard = batch_end_callback[1]

    epoch_end_metric_callback = callback.TensorboardCallback(
        tensorboard_path,
        shared_tensorboard=shared_tensorboard,
        prefix="train/epoch")
    eval_end_callback = callback.TensorboardCallback(
        tensorboard_path,
        shared_tensorboard=shared_tensorboard,
        prefix="val/epoch")
    lr_callback = callback.LrCallback(tensorboard_path,
                                      shared_tensorboard=shared_tensorboard,
                                      prefix='train/batch')

    #decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / input_batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    if config.TRAIN.lr_type == "MultiStage":
        lr_scheduler = LinearWarmupMultiStageScheduler(
            lr_iters,
            lr_factor,
            config.TRAIN.warmup,
            config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step,
            args.frequent,
            stop_lr=lr * 0.01)
    elif config.TRAIN.lr_type == "MultiFactor":
        lr_scheduler = LinearWarmupMultiFactorScheduler(
            lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step, args.frequent)

    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }
    optimizer = SGD(**optimizer_params)

    freeze_layer_pattern = config.TRAIN.FIXED_PARAMS_PATTERN
    if freeze_layer_pattern.strip():
        args_lr_mult = {}
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in sym.list_arguments() if re_prog.match(name)
        ]
        print "fixed_params_names:"
        print(fixed_param_names)
        for name in fixed_param_names:
            args_lr_mult[name] = config.TRAIN.FIXED_PARAMS_PATTERN_LR_MULT
    else:
        args_lr_mult = {}
    optimizer.set_lr_mult(args_lr_mult)

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    if not isinstance(val_data, PrefetchingIter):
        val_data = PrefetchingIter(val_data)

    if Debug:
        monitor = mx.monitor.Monitor(1)
    else:
        monitor = None

    initializer = mx.initializer.Xavier(magnitude=1, rnd_type="gaussian")
    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            eval_end_callback=eval_end_callback,
            epoch_end_metric_callback=epoch_end_metric_callback,
            optimizer=optimizer,
            eval_data=val_data,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            allow_missing=begin_epoch == 0,
            allow_extra=True,
            monitor=monitor,
            preload_opt_states=preload_opt_states,
            eval_data_frequency=config.TRAIN.eval_data_frequency,
            initializer=initializer)
示例#11
0
def train_net(config, output_path, logger=logging):

    # train_net(cfg_path, ctx, config.network.pretrained, config.network.pretrained_epoch,
    #           config.TRAIN.model_prefix, config.TRAIN.begin_epoch, config.TRAIN.end_epoch,
    #           config.TRAIN.lr, config.TRAIN.lr_step)

    # train parameters
    pretrained_model = config.network.pretrained
    epoch = config.network.pretrained_epoch
    prefix = config.TRAIN.model_prefix
    begin_epoch = config.TRAIN.begin_epoch
    end_epoch = config.TRAIN.end_epoch
    lr = config.TRAIN.lr
    lr_step = config.TRAIN.lr_step

    prefix = os.path.join(output_path, prefix)

    # network parameters
    BATCH_IMAGES = config.TRAIN.BATCH_IMAGES
    SCALES = config.SCALES

    # gpu stuff
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]

    # final_output_path = output_path

    # load symbol
    # shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), output_path)
    # sym_instance = eval(config.symbol)()
    network = resnet_v1_101_fcis()
    sym = network.get_symbol(config, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = BATCH_IMAGES * batch_size

    # load dataset and prepare imdb for training
    cfg_ds = config.dataset
    ds_name = cfg_ds.dataset
    image_sets = [iset for iset in cfg_ds.image_set.split('+')]
    if ds_name.lower() == "labelme":
        from utils.load_data import load_labelme_gt_sdsdb
        sdsdbs = [
            load_labelme_gt_sdsdb(image_set,
                                  cfg_ds.dataset_path,
                                  cfg_ds.root_path,
                                  flip=config.TRAIN.FLIP,
                                  mask_size=config.MASK_SIZE,
                                  binary_thresh=config.BINARY_THRESH,
                                  classes=cfg_ds.CLASSES)
            for image_set in image_sets
        ]
    else:
        sdsdbs = [
            load_gt_sdsdb(ds_name,
                          image_set,
                          cfg_ds.root_path,
                          cfg_ds.dataset_path,
                          mask_size=config.MASK_SIZE,
                          binary_thresh=config.BINARY_THRESH,
                          result_path=output_path,
                          flip=config.TRAIN.FLIP) for image_set in image_sets
        ]
    sdsdb = merge_roidb(sdsdbs)
    sdsdb = filter_roidb(sdsdb, config)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              sdsdb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              allowed_border=config.TRAIN.RPN_ALLOWED_BORDER)

    # infer max shape
    max_data_shape = [('data', (BATCH_IMAGES, 3, max([v[0] for v in SCALES]),
                                max(v[1] for v in SCALES)))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (BATCH_IMAGES, 100, 5)))
    max_data_shape.append(
        ('gt_masks', (BATCH_IMAGES, 100, max([v[0] for v in SCALES]),
                      max(v[1] for v in SCALES))))
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    print 'data shape:'
    pprint.pprint(data_shape_dict)
    network.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained_model,
                                            epoch,
                                            convert=True)
        network.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    network.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
        max_label_shapes=[max_label_shape for _ in xrange(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # decide training metric
    # RPN, classification accuracy/loss, regression loss
    rpn_acc = metric.RPNAccMetric()
    rpn_cls_loss = metric.RPNLogLossMetric()
    rpn_bbox_loss = metric.RPNL1LossMetric()

    fcis_acc = metric.FCISAccMetric(config)
    fcis_acc_fg = metric.FCISAccFGMetric(config)
    fcis_cls_loss = metric.FCISLogLossMetric(config)
    fcis_bbox_loss = metric.FCISL1LossMetric(config)
    fcis_mask_loss = metric.FCISMaskLossMetric(config)

    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [
            rpn_acc, rpn_cls_loss, rpn_bbox_loss, fcis_acc, fcis_acc_fg,
            fcis_cls_loss, fcis_bbox_loss, fcis_mask_loss
    ]:
        eval_metrics.add(child_metric)

    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=config.default.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else cfg_ds.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else cfg_ds.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)

    # print epoch, begin_epoch, end_epoch, lr_step
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(sdsdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # del sdsdb
    # a = mx.viz.plot_network(sym)
    # a.render('../example', view=True)
    # print 'prepare sds finished'

    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
示例#12
0
def train_net(args, ctx, pretrained_res, pretrained_vgg, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    mx.random.seed(3)
    np.random.seed(3)

    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    sdsdbs = [load_gt_sdsdb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            mask_size=config.MASK_SIZE, binary_thresh=config.BINARY_THRESH,
                            result_path=final_output_path, flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    sdsdb = merge_roidb(sdsdbs)
    sdsdb = filter_roidb(sdsdb, config)

    # load training data
    train_data = AnchorLoader(feat_sym, sdsdb, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx, feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              allowed_border=config.TRAIN.RPN_ALLOWED_BORDER)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]), max(v[1] for v in config.SCALES)))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    max_data_shape.append(('gt_masks', (config.TRAIN.BATCH_IMAGES, 100, max([v[0] for v in config.SCALES]), max(v[1] for v in config.SCALES))))
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    print 'data shape:'
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)
    # inshape, outshape, uaxshape = sym_instance.infer_shape(data_shape_dict)
    # print 'symbol inshape: %s ' % (str(inshape))
    # print 'symbol outshape: %s' % (str(outshape))

    '''
    internals = sym.get_internals()
    _, out_shapes, _ = internals.infer_shape(**data_shape_dict)
    print(sym.list_outputs())
    shape_dict = dict(zip(internals.list_outputs(), out_shapes))
    pprint.pprint(shape_dict)
    '''

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)

    else:
        # load vgg-16 & resnet-101 parameters
	# pretrained_res = pretrained
	# pretrained_vgg = './model/pretrained_model/VGG_FC_ILSVRC_16'
        arg_params_res, aux_params_res = load_param(pretrained_res, epoch, convert=True)
        arg_params_vgg, aux_params_vgg = load_param(pretrained_vgg, epoch, convert=True)
	# print 'params of resnet-101'
	# print arg_params_res
	# print 'params of vgg-16'
	# print arg_params_vgg
        arg_params = dict(arg_params_res, **arg_params_vgg)
        aux_params = dict(aux_params_res, **aux_params_vgg)
	# print 'arg_params: \n %s' % (str(arg_params))
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
                        max_label_shapes=[max_label_shape for _ in xrange(batch_size)], fixed_param_prefix=fixed_param_prefix)

    # decide training metric
    # RPN, classification accuracy/loss, regression loss
    rpn_acc = metric.RPNAccMetric()
    rpn_cls_loss = metric.RPNLogLossMetric()
    rpn_bbox_loss = metric.RPNL1LossMetric()

    fcis_acc = metric.FCISAccMetric(config)
    fcis_acc_fg = metric.FCISAccFGMetric(config)
    fcis_cls_loss = metric.FCISLogLossMetric(config)
    fcis_bbox_loss = metric.FCISL1LossMetric(config)
    fcis_mask_loss = metric.FCISMaskLossMetric(config)

    eval_metrics = mx.metric.CompositeEvalMetric()
    # accumulate all loss, fcn-8s loss should be added here
    for child_metric in [rpn_acc, rpn_cls_loss, rpn_bbox_loss,
                         # fcis_acc_fg, fcis_cls_loss, fcis_bbox_loss, fcis_mask_loss]:
                         fcis_acc, fcis_acc_fg, fcis_cls_loss, fcis_bbox_loss, fcis_mask_loss]:
        eval_metrics.add(child_metric)

    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)

    # print epoch, begin_epoch, end_epoch, lr_step
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(sdsdb) / batch_size) for epoch in lr_epoch_diff]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # del sdsdb
    # a = mx.viz.plot_network(sym)
    # a.render('../example', view=True)
    # print 'prepare sds finished'

    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.default.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)