예제 #1
0
def main():
    logging.info('########## TRAIN FASTER-RCNN WITH APPROXIMATE JOINT END2END #############')
    init_config()
    if "resnet" in args.pretrained:
        sym = resnet_50(num_class=args.num_classes, bn_mom=args.bn_mom, bn_global=True, is_train=True)  # consider background
    else:
        sym = get_faster_rcnn(num_classes=args.num_classes)  # consider background

    feat_sym = sym.get_internals()['rpn_cls_score_output']
    # setup for multi-gpu
    ctx = [mx.gpu(int(i)) for i in args.gpu_ids.split(',')]
    config.TRAIN.IMS_PER_BATCH *= len(ctx)
    max_data_shape, max_label_shape = get_max_shape(feat_sym)

    # data
    # voc, roidb = load_gt_roidb_from_list(args.dataset_name, args.lst, args.dataset_root,
    #                                      args.outdata_path, flip=not args.no_flip)
    voc, roidb = load_gt_roidb(args.image_set, args.year, args.root_path, args.devkit_path, flip=not args.no_flip)
    train_data = AnchorLoader(feat_sym, roidb, batch_size=config.TRAIN.IMS_PER_BATCH, anchor_scales=(4, 8, 16, 32),
                              shuffle=not args.no_shuffle, mode='train', ctx=ctx, need_mean=args.need_mean)
    # model
    args_params, auxs_params, _ = load_param(args.pretrained, args.load_epoch, convert=True)
    if not args.resume:
        args_params, auxs_params= init_model(args_params, auxs_params, train_data, sym, args.pretrained)
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=args.frequent)
    epoch_end_callback = do_checkpoint(args.prefix)

    optimizer_params = {'momentum':         args.mom,
                        'wd':               args.wd,
                        'learning_rate':    args.lr,
                        # 'lr_scheduler':     WarmupScheduler(args.factor_step, 0.1, warmup_lr=0.1*args.lr, warmup_step=200) \
                        #                     if not args.resume else mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1),
                        'lr_scheduler':     mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1), # seems no need warm up
                        'clip_gradient':    1.0,
                        'rescale_grad':     1.0}

    if "resnet" in args.pretrained:
        # only consider resnet-50 here
        fixed_param_prefix = ['conv0', 'stage1', 'stage2', 'bn_data', 'bn0']
    else:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3']
    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data, eval_metric=metric(), epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kv_store,
            optimizer='sgd', optimizer_params=optimizer_params, arg_params=args_params, aux_params=auxs_params,
            begin_epoch=args.load_epoch, num_epoch=args.num_epoch)
예제 #2
0
def main():
    logging.info('########## TRAIN FASTER-RCNN WITH APPROXIMATE JOINT END2END #############')
    init_config()
    if "resnet" in args.pretrained:
        sym = resnet_50(num_class=args.num_classes, bn_mom=args.bn_mom, bn_global=True, is_train=True)  # consider background
    else:
        sym = get_faster_rcnn(num_classes=args.num_classes)  # consider background

    feat_sym = sym.get_internals()['rpn_cls_score_output']
    # setup for multi-gpu
    ctx = [mx.gpu(int(i)) for i in args.gpu_ids.split(',')]
    config.TRAIN.IMS_PER_BATCH *= len(ctx)
    max_data_shape, max_label_shape = get_max_shape(feat_sym)

    # data
    voc, roidb = load_gt_roidb_from_list(args.dataset_name, args.lst, args.dataset_root,
                                         args.outdata_path, flip=not args.no_flip)
    train_data = AnchorLoader(feat_sym, roidb, batch_size=config.TRAIN.IMS_PER_BATCH, anchor_scales=(4, 8, 16, 32),
                              shuffle=not args.no_shuffle, mode='train', ctx=ctx, need_mean=args.need_mean)
    # model
    args_params, auxs_params, _ = load_param(args.pretrained, args.load_epoch, convert=True)
    if not args.resume:
        args_params, auxs_params= init_model(args_params, auxs_params, train_data, sym, args.pretrained)
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=args.frequent)
    epoch_end_callback = do_checkpoint(args.prefix)

    optimizer_params = {'momentum':         args.mom,
                        'wd':               args.wd,
                        'learning_rate':    args.lr,
                        # 'lr_scheduler':     WarmupScheduler(args.factor_step, 0.1, warmup_lr=0.1*args.lr, warmup_step=200) \
                        #                     if not args.resume else mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1),
                        'lr_scheduler':     mx.lr_scheduler.FactorScheduler(args.factor_step, 0.1), # seems no need warm up
                        'clip_gradient':    1.0,
                        'rescale_grad':     1.0}

    if "resnet" in args.pretrained:
        # only consider resnet-50 here
        fixed_param_prefix = ['conv0', 'stage1', 'stage2', 'bn_data', 'bn0']
    else:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3']
    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names, logger=logger, context=ctx,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data, eval_metric=metric(), epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kv_store,
            optimizer='sgd', optimizer_params=optimizer_params, arg_params=args_params, aux_params=auxs_params,
            begin_epoch=args.load_epoch, num_epoch=args.num_epoch)
예제 #3
0
def train_rpn(image_set, year, root_path, devkit_path, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, kv_store, work_load_list=None, resume=False):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load symbol
    sym = get_vgg_rpn()
    feat_sym = get_vgg_rpn().get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    config.TRAIN.BATCH_IMAGES *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)

    # load training data
    voc, roidb = load_gt_roidb(image_set, year, root_path, devkit_path, flip=True)
    train_data = AnchorLoader(feat_sym, roidb, batch_size=config.TRAIN.BATCH_SIZE, shuffle=True, mode='train',
                              ctx=ctx, work_load_list=work_load_list)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_SIZE, 3, 1000, 1000))]
    max_data_shape_dict = {k: v for k, v in max_data_shape}
    _, feat_shape, _ = feat_sym.infer_shape(**max_data_shape_dict)
    from rcnn.minibatch import assign_anchor
    import numpy as np
    label = assign_anchor(feat_shape[0], np.zeros((0, 5)), [[1000, 1000, 1.0]])
    max_label_shape = [('label', label['label'].shape),
                       ('bbox_target', label['bbox_target'].shape),
                       ('bbox_inside_weight', label['bbox_inside_weight'].shape),
                       ('bbox_outside_weight', label['bbox_outside_weight'].shape)]
    print 'providing maximum shape', max_data_shape, max_label_shape

    # load pretrained
    args, auxs = load_param(pretrained, epoch, convert=True)

    # initialize params
    if not resume:
        input_shapes = {k: v for k, v in train_data.provide_data + train_data.provide_label}
        arg_shape, _, _ = sym.infer_shape(**input_shapes)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        args['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        args['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_conv_3x3_bias'])
        args['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        args['rpn_cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_cls_score_bias'])
        args['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        args['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_bbox_pred_bias'])

    # prepare training
    if config.TRAIN.FINETUNE:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
    else:
        fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    if config.TRAIN.HAS_RPN is True:
        eval_metric = AccuracyMetric(use_ignore=True, ignore=-1)
        cls_metric = LogLossMetric(use_ignore=True, ignore=-1)
    else:
        eval_metric = AccuracyMetric()
        cls_metric = LogLossMetric()
    bbox_metric = SmoothL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.0005,
                        'learning_rate': 0.001,
                        'lr_scheduler': mx.lr_scheduler.FactorScheduler(60000, 0.1),
                        'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)}

    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=work_load_list,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=kv_store,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)
예제 #4
0
def end2end_train(image_set, test_image_set, year, root_path, devkit_path, pretrained, epoch, prefix,
                  ctx, begin_epoch, num_epoch, frequent, kv_store, mom, wd, lr, num_classes, monitor,
                  work_load_list=None, resume=False, use_flip=True, factor_step=50000):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    mon = None
    config.TRAIN.BG_THRESH_HI = 0.5  # TODO(verify)
    config.TRAIN.BG_THRESH_LO = 0.0  # TODO(verify)
    config.TRAIN.RPN_MIN_SIZE = 16

    logging.info('########## TRAIN FASTER-RCNN WITH APPROXIMATE JOINT END2END #############')
    config.TRAIN.HAS_RPN = True
    config.END2END = 1
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True
    sym = get_faster_rcnn(num_classes=num_classes)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    config.TRAIN.IMS_PER_BATCH *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)  # no used here

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.IMS_PER_BATCH, 3, 1000, 1000))]
    max_data_shape_dict = {k: v for k, v in max_data_shape}
    _, feat_shape, _ = feat_sym.infer_shape(**max_data_shape_dict)
    from rcnn.minibatch import assign_anchor
    import numpy as np
    label = assign_anchor(feat_shape[0], np.zeros((0, 5)), [[1000, 1000, 1.0]])
    max_label_shape = [('label', label['label'].shape),
                       ('bbox_target', label['bbox_target'].shape),
                       ('bbox_inside_weight', label['bbox_inside_weight'].shape),
                       ('bbox_outside_weight', label['bbox_outside_weight'].shape),
                       ('gt_boxes', (config.TRAIN.IMS_PER_BATCH, 5*100))]  # assume at most 100 object in image
    print 'providing maximum shape', max_data_shape, max_label_shape

    # load training data
    voc, roidb = load_gt_roidb(image_set, year, root_path, devkit_path, flip=use_flip)
    train_data = AnchorLoader(feat_sym, roidb, batch_size=config.TRAIN.IMS_PER_BATCH, shuffle=True, mode='train',
                              ctx=ctx, work_load_list=work_load_list)
    # load pretrained
    args, auxs, _ = load_param(pretrained, epoch, convert=True)

    # initialize params
    if not resume:
        del args['fc8_weight']
        del args['fc8_bias']
        input_shapes = {k: (1,)+ v[1::] for k, v in train_data.provide_data + train_data.provide_label}
        arg_shape, _, _ = sym.infer_shape(**input_shapes)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))

        args['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        args['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_conv_3x3_bias'])
        args['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        args['rpn_cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_cls_score_bias'])
        args['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['rpn_bbox_pred_weight'])  # guarantee not likely explode with bbox_delta
        args['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_bbox_pred_bias'])
        args['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        args['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias'])
        args['bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['bbox_pred_weight'])
        args['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias'])

    # prepare training
    if config.TRAIN.FINETUNE:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
    else:
        fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = do_checkpoint(prefix)
    rpn_eval_metric = AccuracyMetric(use_ignore=True, ignore=-1, ex_rpn=True)
    rpn_cls_metric = LogLossMetric(use_ignore=True, ignore=-1, ex_rpn=True)
    rpn_bbox_metric = SmoothL1LossMetric(ex_rpn=True)
    eval_metric = AccuracyMetric()
    cls_metric = LogLossMetric()
    bbox_metric = SmoothL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    optimizer_params = {'momentum': mom,
                        'wd': wd,
                        'learning_rate': lr,
                        'lr_scheduler': mx.lr_scheduler.FactorScheduler(factor_step, 0.1),
                        'clip_gradient': 1.0,
                        'rescale_grad': 1.0 }
                        # 'rescale_grad': (1.0 / config.TRAIN.RPN_BATCH_SIZE)}
    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=work_load_list,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    if monitor:
        def norm_stat(d):
            return mx.nd.norm(d)/np.sqrt(d.size)
        mon = mx.mon.Monitor(100, norm_stat)

    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=kv_store,
            optimizer='sgd', optimizer_params=optimizer_params, monitor=mon,
            arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=num_epoch)
예제 #5
0
def train_rcnn(image_set,
               year,
               root_path,
               devkit_path,
               pretrained,
               epoch,
               prefix,
               ctx,
               begin_epoch,
               end_epoch,
               frequent,
               kv_store,
               work_load_list=None,
               resume=False,
               proposal='rpn'):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load symbol
    sym = get_vgg_rcnn()

    # setup multi-gpu
    config.TRAIN.BATCH_IMAGES *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)

    # load training data
    voc, roidb, means, stds = eval('load_' + proposal + '_roidb')(image_set,
                                                                  year,
                                                                  root_path,
                                                                  devkit_path,
                                                                  flip=True)
    train_data = ROIIter(roidb,
                         batch_size=config.TRAIN.BATCH_IMAGES,
                         shuffle=True,
                         mode='train',
                         ctx=ctx,
                         work_load_list=work_load_list)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, 1000, 1000))]

    # load pretrained
    args, auxs, _ = load_param(pretrained, epoch, convert=True)

    # initialize params
    if not resume:
        input_shapes = {
            k: v
            for k, v in train_data.provide_data + train_data.provide_label
        }
        arg_shape, _, _ = sym.infer_shape(**input_shapes)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        args['cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        args['cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['cls_score_bias'])
        args['bbox_pred_weight'] = mx.random.normal(
            0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        args['bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['bbox_pred_bias'])

    # prepare training
    if config.TRAIN.FINETUNE:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
    else:
        fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    if config.TRAIN.HAS_RPN is True:
        eval_metric = AccuracyMetric(use_ignore=True, ignore=-1)
        cls_metric = LogLossMetric(use_ignore=True, ignore=-1)
    else:
        eval_metric = AccuracyMetric()
        cls_metric = LogLossMetric()
    bbox_metric = SmoothL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': 0.001,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(30000, 0.1),
        'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)
    }

    # train
    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        work_load_list=work_load_list,
                        max_data_shapes=max_data_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=kv_store,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=args,
            aux_params=auxs,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)

    # edit params and save
    for epoch in range(begin_epoch + 1, end_epoch + 1):
        arg_params, aux_params = load_checkpoint(prefix, epoch)
        arg_params['bbox_pred_weight'] = (arg_params['bbox_pred_weight'].T *
                                          mx.nd.array(stds)).T
        arg_params['bbox_pred_bias'] = arg_params['bbox_pred_bias'] * mx.nd.array(stds) + \
                                       mx.nd.array(means)
        save_checkpoint(prefix, epoch, arg_params, aux_params)
예제 #6
0
def train_rcnn(image_set, year, root_path, devkit_path, pretrained, epoch,
               prefix, ctx, begin_epoch, end_epoch, frequent, kv_store,
               work_load_list=None, resume=False, proposal='rpn'):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load symbol
    sym = get_vgg_rcnn()

    # setup multi-gpu
    config.TRAIN.BATCH_IMAGES *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)

    # load training data
    voc, roidb, means, stds = eval('load_' + proposal + '_roidb')(image_set, year, root_path, devkit_path, flip=True)
    train_data = ROIIter(roidb, batch_size=config.TRAIN.BATCH_IMAGES, shuffle=True, mode='train',
                         ctx=ctx, work_load_list=work_load_list)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, 1000, 1000))]

    # load pretrained
    args, auxs = load_param(pretrained, epoch, convert=True)

    # initialize params
    if not resume:
        input_shapes = {k: v for k, v in train_data.provide_data + train_data.provide_label}
        arg_shape, _, _ = sym.infer_shape(**input_shapes)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        args['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        args['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias'])
        args['bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        args['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias'])

    # prepare training
    if config.TRAIN.FINETUNE:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
    else:
        fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    if config.TRAIN.HAS_RPN is True:
        eval_metric = AccuracyMetric(use_ignore=True, ignore=-1)
        cls_metric = LogLossMetric(use_ignore=True, ignore=-1)
    else:
        eval_metric = AccuracyMetric()
        cls_metric = LogLossMetric()
    bbox_metric = SmoothL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.0005,
                        'learning_rate': 0.001,
                        'lr_scheduler': mx.lr_scheduler.FactorScheduler(30000, 0.1),
                        'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)}

    # train
    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=work_load_list,
                        max_data_shapes=max_data_shape, fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=kv_store,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=args, aux_params=auxs, begin_epoch=begin_epoch, num_epoch=end_epoch)

    # edit params and save
    for epoch in range(begin_epoch + 1, end_epoch + 1):
        arg_params, aux_params = load_checkpoint(prefix, epoch)
        arg_params['bbox_pred_weight'] = (arg_params['bbox_pred_weight'].T * mx.nd.array(stds)).T
        arg_params['bbox_pred_bias'] = arg_params['bbox_pred_bias'] * mx.nd.array(stds) + \
                                       mx.nd.array(means)
        save_checkpoint(prefix, epoch, arg_params, aux_params)
def train_rpn(image_set,
              year,
              root_path,
              devkit_path,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              kv_store,
              work_load_list=None,
              resume=False):
    # set up logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # load symbol
    sym = get_vgg_rpn()
    feat_sym = get_vgg_rpn().get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    config.TRAIN.BATCH_IMAGES *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)

    # load training data
    voc, roidb = load_gt_roidb(image_set,
                               year,
                               root_path,
                               devkit_path,
                               flip=True)
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              batch_size=config.TRAIN.BATCH_SIZE,
                              shuffle=True,
                              mode='train',
                              ctx=ctx,
                              work_load_list=work_load_list)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_SIZE, 3, 1000, 1000))]
    max_data_shape_dict = {k: v for k, v in max_data_shape}
    _, feat_shape, _ = feat_sym.infer_shape(**max_data_shape_dict)
    from rcnn.minibatch import assign_anchor
    import numpy as np
    label = assign_anchor(feat_shape[0], np.zeros((0, 5)), [[1000, 1000, 1.0]])
    max_label_shape = [
        ('label', label['label'].shape),
        ('bbox_target', label['bbox_target'].shape),
        ('bbox_inside_weight', label['bbox_inside_weight'].shape),
        ('bbox_outside_weight', label['bbox_outside_weight'].shape)
    ]
    print 'providing maximum shape', max_data_shape, max_label_shape

    # load pretrained
    args, auxs = load_param(pretrained, epoch, convert=True)

    # initialize params
    if not resume:
        input_shapes = {
            k: v
            for k, v in train_data.provide_data + train_data.provide_label
        }
        arg_shape, _, _ = sym.infer_shape(**input_shapes)
        arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
        args['rpn_conv_3x3_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        args['rpn_conv_3x3_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_conv_3x3_bias'])
        args['rpn_cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        args['rpn_cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_cls_score_bias'])
        args['rpn_bbox_pred_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        args['rpn_bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_bbox_pred_bias'])

    # prepare training
    if config.TRAIN.FINETUNE:
        fixed_param_prefix = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']
    else:
        fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    batch_end_callback = Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    if config.TRAIN.HAS_RPN is True:
        eval_metric = AccuracyMetric(use_ignore=True, ignore=-1)
        cls_metric = LogLossMetric(use_ignore=True, ignore=-1)
    else:
        eval_metric = AccuracyMetric()
        cls_metric = LogLossMetric()
    bbox_metric = SmoothL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': 0.001,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(60000, 0.1),
        'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)
    }

    # train
    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        work_load_list=work_load_list,
                        max_data_shapes=max_data_shape,
                        max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=kv_store,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=args,
            aux_params=auxs,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)