Exemple #1
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
              lr=0.001, lr_step='5'):
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # setup config
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True

    # load symbol
    sym = eval('get_' + args.network + '_train')(num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in args.image_set.split('+')]
    roidbs = [load_gt_roidb(args.dataset, image_set, args.root_path, args.dataset_path,
                            flip=not args.no_flip)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb)

    # load training data
    train_data = AnchorLoader(feat_sym, roidb, batch_size=input_batch_size, shuffle=not args.no_shuffle,
                              ctx=ctx, work_load_list=args.work_load_list,
                              feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
                              anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (input_batch_size, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
    arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
    out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
    aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
    print('output shape')
    pprint.pprint(out_shape_dict)

    # load and initialize params
    if args.resume:
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params['rpn_conv_3x3_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_conv_3x3_bias'])
        arg_params['rpn_cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        arg_params['rpn_cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_cls_score_bias'])
        arg_params['rpn_bbox_pred_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['rpn_bbox_pred_bias'])
        arg_params['cls_score_weight'] = mx.random.normal(0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        arg_params['cls_score_bias'] = mx.nd.zeros(shape=arg_shape_dict['cls_score_bias'])
        arg_params['bbox_pred_weight'] = mx.random.normal(0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        arg_params['bbox_pred_bias'] = mx.nd.zeros(shape=arg_shape_dict['bbox_pred_bias'])

    # check parameter shapes
    for k in sym.list_arguments():
        if k in data_shape_dict:
            continue
        assert k in arg_params, k + ' not initialized'
        assert arg_params[k].shape == arg_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
    for k in sym.list_auxiliary_states():
        assert k in aux_params, k + ' not initialized'
        assert aux_params[k].shape == aux_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)

    # create solver
    fixed_param_prefix = config.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    # optimizer
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.0005,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': (1.0 / batch_size),
                        'clip_gradient': 5}

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
Exemple #2
0
def train_net(args,
              ctx,
              pretrained,
              epoch,
              prefix,
              begin_epoch,
              end_epoch,
              lr=0.001,
              lr_step='5'):
    # setup config
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True

    # load symbol
    sym = eval('get_' + args.network + '_train')(
        num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    logger.info(pprint.pformat(config))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in args.image_set.split('+')]
    roidbs = [
        load_gt_roidb(args.dataset,
                      image_set,
                      args.root_path,
                      args.dataset_path,
                      flip=not args.no_flip) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              batch_size=input_batch_size,
                              shuffle=not args.no_shuffle,
                              ctx=ctx,
                              work_load_list=args.work_load_list,
                              feat_stride=config.RPN_FEAT_STRIDE,
                              anchor_scales=config.ANCHOR_SCALES,
                              anchor_ratios=config.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (input_batch_size, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5)))
    logger.info('providing maximum shape %s %s' %
                (max_data_shape, max_label_shape))

    # infer shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
    arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
    out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
    aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
    logger.info('output shape %s' % pprint.pformat(out_shape_dict))

    # load and initialize params
    if args.resume:
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params['rpn_conv_3x3_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_conv_3x3_bias'])
        arg_params['rpn_cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        arg_params['rpn_cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_cls_score_bias'])
        arg_params['rpn_bbox_pred_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_bbox_pred_bias'])
        arg_params['cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        arg_params['cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['cls_score_bias'])
        arg_params['bbox_pred_weight'] = mx.random.normal(
            0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        arg_params['bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['bbox_pred_bias'])

    # check parameter shapes
    for k in sym.list_arguments():
        if k in data_shape_dict:
            continue
        assert k in arg_params, k + ' not initialized'
        assert arg_params[k].shape == arg_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
    for k in sym.list_auxiliary_states():
        assert k in aux_params, k + ' not initialized'
        assert aux_params[k].shape == aux_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)

    # create solver
    fixed_param_prefix = config.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape,
                        max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = mx.callback.Speedometer(train_data.batch_size,
                                                 frequent=args.frequent,
                                                 auto_reset=False)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    logger.info('lr %f lr_epoch_diff %s lr_iters %s' %
                (lr, lr_epoch_diff, lr_iters))
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    # optimizer
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': (1.0 / batch_size),
        'clip_gradient': 5
    }

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=args.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Exemple #3
0
def train_net(args,
              ctx,
              pretrained,
              epoch,
              prefix,
              begin_epoch,
              end_epoch,
              lr=0.001,
              lr_step='5'):
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # setup config
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True

    # load symbol

    if args.use_global_context or args.use_roi_align:
        sym = eval('get_' + args.network + '_train')(
            num_classes=config.NUM_CLASSES,
            num_anchors=config.NUM_ANCHORS,
            use_global_context=args.use_global_context,
            use_roi_align=args.use_roi_align)
    else:
        sym = eval('get_' + args.network + '_train')(
            num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)

    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)

    if not args.use_ava_recordio:
        # load dataset and prepare imdb for training
        image_sets = [iset for iset in args.image_set.split('+')]
        roidbs = [
            load_gt_roidb(args.dataset,
                          image_set,
                          args.root_path,
                          args.dataset_path,
                          flip=not args.no_flip) for image_set in image_sets
        ]
        roidb = merge_roidb(roidbs)
        roidb = filter_roidb(roidb)

        # load training data
        train_data = AnchorLoader(
            feat_sym,
            roidb,
            batch_size=input_batch_size,
            shuffle=not args.no_shuffle,
            ctx=ctx,
            work_load_list=args.work_load_list,
            feat_stride=config.RPN_FEAT_STRIDE,
            anchor_scales=config.ANCHOR_SCALES,
            anchor_ratios=config.ANCHOR_RATIOS,
            aspect_grouping=config.TRAIN.ASPECT_GROUPING,
            use_data_augmentation=args.use_data_augmentation)
    else:
        f = open(args.classes_names)
        classes = ['__background__']
        for line in f.readlines():
            classes.append(line.strip().split(' ')[0])

        path_imgidx = args.ava_recordio_name + '.idx'
        path_imgrec = args.ava_recordio_name + '.rec'

        record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')  # pylint: disable=redefined-variable-type

        train_data = AnchorLoaderAvaRecordIO(
            feat_sym,
            record,
            classes,
            batch_size=input_batch_size,
            shuffle=not args.no_shuffle,
            ctx=ctx,
            work_load_list=args.work_load_list,
            feat_stride=config.RPN_FEAT_STRIDE,
            anchor_scales=config.ANCHOR_SCALES,
            anchor_ratios=config.ANCHOR_RATIOS,
            aspect_grouping=config.TRAIN.ASPECT_GROUPING,
            use_data_augmentation=args.use_data_augmentation)

    # infer max shape
    max_data_shape = [('data', (input_batch_size, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
    arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
    out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
    aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
    print('output shape')
    pprint.pprint(out_shape_dict)
    print('arg shape')
    #  pprint.pprint(arg_shape_dict)
    # load and initialize params
    if args.resume:
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params['rpn_conv_3x3_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_conv_3x3_bias'])
        arg_params['rpn_cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        arg_params['rpn_cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_cls_score_bias'])
        arg_params['rpn_bbox_pred_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_bbox_pred_bias'])
        arg_params['cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        arg_params['cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['cls_score_bias'])
        arg_params['bbox_pred_weight'] = mx.random.normal(
            0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        arg_params['bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['bbox_pred_bias'])

    if args.use_global_context:
        # additional params for using global context
        """
        for arg_param_name in sym.list_arguments():
            if 'stage5' in arg_param_name:
                # print(arg_param_name, arg_param_name.replace('stage5', 'stage4'))
                arg_params[arg_param_name] = arg_params[arg_param_name.replace('stage5', 'stage4')].copy()  # params of stage5 is initialized from stage4
        arg_params['bn2_gamma'] = arg_params['bn1_gamma'].copy()
        arg_params['bn2_beta'] = arg_params['bn1_beta'].copy()
        """
        for aux_param_name in sym.list_auxiliary_states():
            if 'stage5' in aux_param_name:
                # print(aux_param_name, aux_param_name.replace('stage5', 'stage4'))
                aux_params[aux_param_name] = aux_params[aux_param_name.replace(
                    'stage5', 'stage4')].copy(
                    )  # params of stage5 is initialized from stage4
        aux_params['bn2_moving_mean'] = aux_params['bn1_moving_mean'].copy()
        aux_params['bn2_moving_var'] = aux_params['bn1_moving_var'].copy()

    # check parameter shapes
    for k in sym.list_arguments():
        if k in data_shape_dict:
            continue
        assert k in arg_params, k + ' not initialized'
        assert arg_params[k].shape == arg_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
    for k in sym.list_auxiliary_states():
        assert k in aux_params, k + ' not initialized'
        assert aux_params[k].shape == aux_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)

    # create solver
    fixed_param_prefix = config.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape,
                        max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    if not args.use_ava_recordio:
        lr_iters = [
            int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
        ]
    else:
        lr_iters = [
            int(epoch * train_data.provide_size() / batch_size)
            for epoch in lr_epoch_diff
        ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    # optimizer
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': (1.0 / batch_size),
        'clip_gradient': 5
    }

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=args.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Exemple #4
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
              lr=0.001, lr_step='5'):
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # setup config
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True

    # load symbol
    sym_instance = eval('symbol_' + args.network)()
    sym_gen = sym_instance.get_symbol
    sym = sym_gen(46,config,is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)

    # load dataset and prepare imdb for training
    dataset = Dataset(args.root_path,args.dataset,args.subset,split = args.split)
    roidb = dataset.gt_roidb()
    W = dataset.W

    # load training data
    train_data = AnchorLoader(feat_sym, roidb, batch_size=input_batch_size, shuffle=not args.no_shuffle,
                              ctx=ctx, work_load_list=args.work_load_list,
                              feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
                              anchor_ratios=config.ANCHOR_RATIOS)

    # infer max shape
    max_data_shape = [('data', (input_batch_size, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (input_batch_size, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    #get a new symbol
    bucket_key = train_data.bucket_key
    print(train_data.provide_data)
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    sym_instance.infer_shape(data_shape_dict)
    #arg_shape, out_shape, aux_shape = curr_sym.infer_shape(**data_shape_dict)
    #arg_shape_dict = dict(zip(curr_sym.list_arguments(), arg_shape))
    #out_shape_dict = dict(zip(curr_sym.list_outputs(), out_shape))
    #aux_shape_dict = dict(zip(curr_sym.list_auxiliary_states(), aux_shape))
    #del arg_shape_dict['lstm_parameters']
    #print(curr_sym.list_arguments())
    #print(aux_shape_dict)

    # load and initialize params
    if args.resume:
        print("continue training from epoch {}".format(begin_epoch))
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        if config.RNN.USE_W2V:
            arg_params['embed_weight'] = mx.nd.array(W)
        else:
            arg_params['embed_weight'] = mx.random.uniform(0,0.01,shape=arg_shape_dict['embed_weight'])
        sym_instance.init_weight(config,arg_params,aux_params)
    #no checking
    #for k in arg_shape_dict.iterkeys():
     #   if k in data_shape_dict:
      #      continue
       # assert k in arg_params, k + ' not initialized'
        #assert arg_params[k].shape == arg_shape_dict[k], \
         #   'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
    #for k in sym.list_auxiliary_states():
     #   assert k in aux_params, k + ' not initialized'
      #  assert aux_params[k].shape == aux_shape_dict[k], \
       #     'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)

    # create solver
    fixed_param_prefix = config.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym_gen,config, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(config.ENCODER_CELL,prefix, means, stds)
    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    # optimizer
    optimizer_params = {'momentum': 0.9,
                        'wd': 0.0005,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': (1.0 / batch_size),
                        'clip_gradient': 5}
    #initializer for fused RNN
    #TODO:not successfully added,try ask it on github issues.
    initializer = mx.initializer.FusedRNN(init=mx.init.Xavier(factor_type='in', magnitude=2.34),
                                          num_hidden = 1024,num_layers=2,mode='lstm')
    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,allow_missing=True,initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
Exemple #5
0
def train_net(args):

    if args.rand_seed > 0:
        np.random.seed(args.rand_seed)
        mx.random.seed(args.rand_seed)
        random.seed(args.rand_seed)

    # print config
    logger.info(pprint.pformat(config))
    logger.info(pprint.pformat(args))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in args.image_set.split('+')]
    roidbs = [load_gt_roidb(args.dataset, image_set, args.dataset_path,
                            flip=args.flip)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb)

    samplepcnt = args.begin_sample

    if samplepcnt == 100:
        sroidb = roidb
    else:
        sroidb = sample_roidb(roidb, samplepcnt)  # Sample by percentage of all images
    logger.info('Sampling %d pcnt : %d training slices' % (samplepcnt, len(sroidb)))

    # Debug to see if we can concatenate ROIDB's
    #print(sroidb)
    #dir(sroidb)
    #newroidb = sroidb + roidb
    #newroidb = append_roidb(sroidb, roidb)
    #print( "--Append test: " + str(len(sroidb)) +" " + str(len(roidb)) + " = " + str(len(newroidb)) ) 

    # load symbol
    sym = eval('get_' + args.network)(is_train=True, num_classes=config.NUM_CLASSES, num_anchors=config.NUM_ANCHORS)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.SAMPLES_PER_BATCH * batch_size

    # load training data
    train_data = AnchorLoader(feat_sym, sroidb, batch_size=input_batch_size, shuffle=args.shuffle,
                              ctx=ctx, work_load_list=args.work_load_list,
                              feat_stride=config.RPN_FEAT_STRIDE, anchor_scales=config.ANCHOR_SCALES,
                              anchor_ratios=config.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              nThreads=default.prefetch_thread_num)

    # infer max shape
    max_data_shape = [('data', (input_batch_size*config.NUM_IMAGES_3DCE, config.NUM_SLICES, config.MAX_SIZE, config.MAX_SIZE))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (input_batch_size*config.NUM_IMAGES_3DCE, 5, 5)))
    logger.info('providing maximum shape %s %s' % (max_data_shape, max_label_shape))

    # load and initialize and check params
    arg_params, aux_params = init_params(args, sym, train_data)

    # create solver
    fixed_param_prefix = config.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape, max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    # rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    # eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [rpn_cls_metric, rpn_bbox_metric,  cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), config.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), config.NUM_CLASSES)
    epoch_end_callback = (callback.do_checkpoint(args.e2e_prefix, means, stds),
                          callback.do_validate(args.e2e_prefix))

    arg_names = [x for x in sym.list_arguments() if x not in data_names+label_names]
    opt = get_optimizer(args, arg_names, len(sroidb) / input_batch_size, args.iter_size)

    # train
    default.testing = False
    mod.fit(train_data, roidb, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=args.kvstore,
            optimizer=opt, iter_size=args.iter_size,
            arg_params=arg_params, aux_params=aux_params,
            begin_epoch=args.begin_epoch, num_epoch=args.e2e_epoch)
Exemple #6
0
    def fit(
            self,
            train_data,
            ogdb,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(
                ('learning_rate',
                 0.01), ),  #,('rescale_grad', 1.0/8.0),), #8 gpu attempt
            eval_end_callback=None,
            iter_size=1,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        """Ke's revision: add iter_size. Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind,
                  grad_req='add')
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        annealing_steps = 0  # number of current annealing steps in current epoch
        redo_training = 0  # Flag to redo training / resample
        val_list = []  # list of validation results per annealing step
        cur_val = 0
        target_prec = 50
        #Note: we want to identify the best cluster of images / training sets with a low percentage
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            if redo_training:
                annealing_steps = annealing_steps + 1
                self.logger.info('Redoing training to meet criteria = %d',
                                 annealing_steps)
                #sroidb = train_data.roidb #passthrough test

                atick = time.time()

                iterdiff = 1.0
                # Check if we've stagnated
                if len(val_list) > 2:
                    itermean = (val_list[-1] + val_list[-2] + val_list[-3]) / 3
                    iterdiff = abs(itermean - val_list[-1])
                    self.logger.info('Last 3 samples have diff of: %f',
                                     iterdiff)

                if iterdiff < 0.01:
                    self.logger.info(
                        'Reached a stagnated annealing criteria, dumping current samples'
                    )
                    # Do something drastic
                    # Lets try to instantly use the original db
                    sroidb = ogdb

                    # Try to read in another random subset
                    #sroidb = sample_roidb(ogdb, 25) # Sample with removal
                else:
                    # Continue as usual
                    # Select a new random subset
                    newroidb = sample_roidb(ogdb,
                                            15)  # Without removal, this is 10%

                    # Append old with new
                    sroidb = append_roidb(train_data.roidb, newroidb)

                # Create new training data instance by passing most of previous arguments and new random db
                train_data2 = AnchorLoader(
                    train_data.feat_sym,
                    sroidb,
                    train_data.batch_size,
                    train_data.shuffle,
                    train_data.ctx,
                    train_data.work_load_list,
                    train_data.feat_stride,
                    train_data.anchor_scales,
                    train_data.anchor_ratios,
                    train_data.aspect_grouping,
                    nThreads=default.prefetch_thread_num)

                # Overwrite old train_data with the new one
                train_data = train_data2
                data_iter = iter(train_data)

                atock = time.time()
                self.logger.info('Annealing[%d] Time cost=%.3f',
                                 annealing_steps, (atock - atick))
            else:
                data_iter = iter(train_data)
                annealing_steps = 0
                val_list = []
                #target_prec=cur_val+5
                target_prec = target_prec + 5
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # self.forward_backward(data_batch)
                self.forward(data_batch, is_train=True, grad_req='add')
                self.backward()
                if nbatch % iter_size == 0:  # update every iter_size batches
                    self.update()
                    for g in self._curr_module._exec_group.grad_arrays:
                        for g1 in g:
                            if g1 is not None:
                                g1[:] = 0.

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
            #print('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    cur_val = callback(epoch, self.symbol, arg_params,
                                       aux_params)

            self.logger.info('Returned Validation=%f', val)
            val_list.append(val)
            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                self.logger.info('Evaluating data')
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            #----------
            # Check epoch if it falls within the validation threshold
            if cur_val < target_prec:
                # Evaluate list of precision/validation results first
                #val_list
                print(eval_data)

                #else
                redo_training = 1
                self.logger.info('Retraining data=%f', val)
            else:
                redo_training = 0

            self.logger.info('Annealing steps=%f', annealing_steps)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Exemple #7
0
def train_net(args,
              ctx,
              pretrained,
              epoch,
              prefix,
              begin_epoch,
              end_epoch,
              lr=0.001,
              lr_step=50000):
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # setup config
    config.TRAIN.HAS_RPN = True
    config.TRAIN.BATCH_SIZE = 1
    config.TRAIN.BATCH_IMAGES = 1
    config.TRAIN.BATCH_ROIS = 128
    config.TRAIN.END2END = True
    config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED = True
    config.TRAIN.BG_THRESH_LO = 0.0

    # load symbol
    sym = eval('get_' + args.network + '_train')()
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    config.TRAIN.BATCH_IMAGES *= len(ctx)
    config.TRAIN.BATCH_SIZE *= len(ctx)

    # print config
    pprint.pprint(config)

    # load dataset and prepare imdb for training
    imdb = eval(args.dataset)(args.image_set, args.root_path,
                              args.dataset_path)
    roidb = imdb.gt_roidb()
    if args.flip:
        roidb = imdb.append_flipped_images(roidb)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              batch_size=config.TRAIN.BATCH_SIZE,
                              shuffle=True,
                              ctx=ctx,
                              work_load_list=args.work_load_list)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_SIZE, 3, 1000, 1000))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_SIZE, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape

    # load pretrained
    arg_params, aux_params = load_param(pretrained, epoch, convert=True)

    # infer shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    arg_shape, out_shape, aux_shape = sym.infer_shape(**data_shape_dict)
    arg_shape_dict = dict(zip(sym.list_arguments(), arg_shape))
    out_shape_dict = dict(zip(sym.list_outputs(), out_shape))
    aux_shape_dict = dict(zip(sym.list_auxiliary_states(), aux_shape))
    print 'output shape'
    pprint.pprint(out_shape_dict)

    # initialize params
    if not args.resume:
        arg_params['rpn_conv_3x3_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_conv_3x3_weight'])
        arg_params['rpn_conv_3x3_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_conv_3x3_bias'])
        arg_params['rpn_cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_cls_score_weight'])
        arg_params['rpn_cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_cls_score_bias'])
        arg_params['rpn_bbox_pred_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['rpn_bbox_pred_weight'])
        arg_params['rpn_bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['rpn_bbox_pred_bias'])
        arg_params['cls_score_weight'] = mx.random.normal(
            0, 0.01, shape=arg_shape_dict['cls_score_weight'])
        arg_params['cls_score_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['cls_score_bias'])
        arg_params['bbox_pred_weight'] = mx.random.normal(
            0, 0.001, shape=arg_shape_dict['bbox_pred_weight'])
        arg_params['bbox_pred_bias'] = mx.nd.zeros(
            shape=arg_shape_dict['bbox_pred_bias'])

    # check parameter shapes
    for k in sym.list_arguments():
        if k in data_shape_dict:
            continue
        assert arg_params[k].shape == arg_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(arg_shape_dict[k]) + ' provided ' + str(arg_params[k].shape)
    for k in sym.list_auxiliary_states():
        assert aux_params[k].shape == aux_shape_dict[k], \
            'shape inconsistent for ' + k + ' inferred ' + str(aux_shape_dict[k]) + ' provided ' + str(aux_params[k].shape)

    # create solver
    fixed_param_prefix = ['conv1', 'conv2']
    data_names = [k[0] for k in train_data.provide_data]
    label_names = [k[0] for k in train_data.provide_label]
    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        work_load_list=args.work_load_list,
                        max_data_shapes=max_data_shape,
                        max_label_shapes=max_label_shape,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric()
    cls_metric = metric.RCNNLogLossMetric()
    bbox_metric = metric.RCNNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), imdb.num_classes)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), imdb.num_classes)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)
    # optimizer
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': lr,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(lr_step, 0.1),
        'rescale_grad': (1.0 / config.TRAIN.BATCH_SIZE)
    }

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=args.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)