def train_rpn(cfg,
              dataset,
              image_set,
              root_path,
              dataset_path,
              frequent,
              kvstore,
              flip,
              shuffle,
              resume,
              ctx,
              pretrained,
              epoch,
              prefix,
              begin_epoch,
              end_epoch,
              train_shared,
              lr,
              lr_step,
              logger=None,
              output_path=None):
    # set up logger
    if not logger:
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

    # set up config
    cfg.TRAIN.BATCH_IMAGES = cfg.TRAIN.ALTERNATE.RPN_BATCH_IMAGES

    # load symbol
    sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
    sym = sym_instance.get_symbol_rpn(cfg, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = cfg.TRAIN.BATCH_IMAGES * batch_size

    # print cfg
    pprint.pprint(cfg)
    logger.info('training rpn cfg:{}\n'.format(pprint.pformat(cfg)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in image_set.split('+')]
    roidbs = [
        load_gt_roidb(dataset,
                      image_set,
                      root_path,
                      dataset_path,
                      result_path=output_path,
                      flip=flip) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, cfg)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              cfg,
                              batch_size=input_batch_size,
                              shuffle=shuffle,
                              ctx=ctx,
                              feat_stride=cfg.network.RPN_FEAT_STRIDE,
                              anchor_scales=cfg.network.ANCHOR_SCALES,
                              anchor_ratios=cfg.network.ANCHOR_RATIOS,
                              aspect_grouping=cfg.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (cfg.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in cfg.SCALES]),
                                max([v[1] for v in cfg.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if resume:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight_rpn(cfg, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]
    if train_shared:
        fixed_param_prefix = cfg.network.FIXED_PARAMS_SHARED
    else:
        fixed_param_prefix = cfg.network.FIXED_PARAMS
    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
        max_label_shapes=[max_label_shape for _ in xrange(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    eval_metric = metric.RPNAccMetric()
    cls_metric = metric.RPNLogLossMetric()
    bbox_metric = metric.RPNL1LossMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=frequent)
    # epoch_end_callback = mx.callback.do_checkpoint(prefix)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)
    # decide learning rate
    base_lr = lr
    lr_factor = cfg.TRAIN.lr_factor
    lr_epoch = [int(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              cfg.TRAIN.warmup,
                                              cfg.TRAIN.warmup_lr,
                                              cfg.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': cfg.TRAIN.momentum,
        'wd': cfg.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Example #2
0
def generate_proposals(predictor, test_data, imdb, cfg, vis=False, thresh=0.):
    """
    Generate detections results using RPN.
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffled
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: thresh for valid detections
    :return: list of detected boxes
    """
    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    idx = 0
    t = time.time()
    imdb_boxes = list()
    original_boxes = list()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_proposal(
            predictor, data_batch, data_names, scales)
        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict, scale) in enumerate(
                zip(scores_all, boxes_all, data_dict_all, scales)):
            # assemble proposals
            dets = np.hstack((boxes, scores))
            original_boxes.append(dets)

            # filter proposals
            keep = np.where(dets[:, 4:] > thresh)[0]
            dets = dets[keep, :]
            imdb_boxes.append(dets)

            if vis:
                vis_all_detection(data_dict['data'].asnumpy(), [dets], ['obj'],
                                  scale, cfg)

            print 'generating %d/%d' % (idx + 1, imdb.num_images), 'proposal %d' % (dets.shape[0]), \
                'data %.4fs net %.4fs' % (t1, t2 / test_data.batch_size)
            idx += 1

    assert len(imdb_boxes) == imdb.num_images, 'calculations not complete'

    # save results
    rpn_folder = os.path.join(imdb.result_path, 'rpn_data')
    if not os.path.exists(rpn_folder):
        os.mkdir(rpn_folder)

    rpn_file = os.path.join(rpn_folder, imdb.name + '_rpn.pkl')
    with open(rpn_file, 'wb') as f:
        cPickle.dump(imdb_boxes, f, cPickle.HIGHEST_PROTOCOL)

    if thresh > 0:
        full_rpn_file = os.path.join(rpn_folder, imdb.name + '_full_rpn.pkl')
        with open(full_rpn_file, 'wb') as f:
            cPickle.dump(original_boxes, f, cPickle.HIGHEST_PROTOCOL)

    print 'wrote rpn proposals to {}'.format(rpn_file)
    return imdb_boxes
Example #3
0
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=False):

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    seg_file = os.path.join(imdb.result_path, imdb.name + '_masks.pkl')

    if os.path.exists(det_file) and os.path.exists(seg_file) and not ignore_cache:
        with open(det_file, 'rb') as f:
            all_boxes = cPickle.load(f)
        with open(seg_file, 'rb') as f:
            all_masks = cPickle.load(f)
    else:
        assert vis or not test_data.shuffle
        data_names = [k[0] for k in test_data.provide_data[0]]

        if not isinstance(test_data, PrefetchingIter):
            test_data = PrefetchingIter(test_data)

        # function pointers
        nms = py_nms_wrapper(cfg.TEST.NMS)
        mask_voting = gpu_mask_voting if cfg.TEST.USE_GPU_MASK_MERGE else cpu_mask_voting

        max_per_image = 100 if cfg.TEST.USE_MASK_MERGE else -1
        num_images = imdb.num_images
        all_boxes = [[[] for _ in xrange(num_images)]
                     for _ in xrange(imdb.num_classes)]
        all_masks = [[[] for _ in xrange(num_images)]
                     for _ in xrange(imdb.num_classes)]

        idx = 0
        t = time.time()
        for data_batch in test_data:
            t1 = time.time() - t
            t = time.time()

            scales = [data_batch.data[i][1].asnumpy()[0, 2] for i in xrange(len(data_batch.data))]
            scores_all, boxes_all, masks_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)
            im_shapes = [data_batch.data[i][0].shape[2:4] for i in xrange(len(data_batch.data))]

            t2 = time.time() - t
            t = time.time()

            # post processing
            for delta, (scores, boxes, masks, data_dict) in enumerate(zip(scores_all, boxes_all, masks_all, data_dict_all)):

                if not cfg.TEST.USE_MASK_MERGE:
                    for j in range(1, imdb.num_classes):
                        indexes = np.where(scores[:, j] > thresh)[0]
                        cls_scores = scores[indexes, j, np.newaxis]
                        cls_masks = masks[indexes, 1, :, :]
                        try:
                            if cfg.CLASS_AGNOSTIC:
                                cls_boxes = boxes[indexes, :]
                            else:
                                raise Exception()
                        except:
                            cls_boxes = boxes[indexes, j * 4:(j + 1) * 4]

                        cls_dets = np.hstack((cls_boxes, cls_scores))
                        keep = nms(cls_dets)
                        all_boxes[j][idx + delta] = cls_dets[keep, :]
                        all_masks[j][idx + delta] = cls_masks[keep, :]
                else:
                    masks = masks[:, 1:, :, :]
                    im_height = np.round(im_shapes[delta][0] / scales[delta]).astype('int')
                    im_width = np.round(im_shapes[delta][1] / scales[delta]).astype('int')
                    boxes = clip_boxes(boxes, (im_height, im_width))
                    result_mask, result_box = mask_voting(masks, boxes, scores, imdb.num_classes,
                                                          max_per_image, im_width, im_height,
                                                          cfg.TEST.NMS, cfg.TEST.MASK_MERGE_THRESH,
                                                          cfg.BINARY_THRESH)
                    for j in xrange(1, imdb.num_classes):
                        all_boxes[j][idx+delta] = result_box[j]
                        all_masks[j][idx+delta] = result_mask[j][:,0,:,:]

                if vis:
                    boxes_this_image = [[]] + [all_boxes[j][idx + delta] for j in range(1, imdb.num_classes)]
                    masks_this_image = [[]] + [all_masks[j][idx + delta] for j in range(1, imdb.num_classes)]
                    vis_all_mask(data_dict['data'].asnumpy(), boxes_this_image, masks_this_image, imdb.classes, scales[delta], cfg)

            idx += test_data.batch_size
            t3 = time.time() - t
            t = time.time()

            print ('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, t1, t2, t3))
            if logger:
                logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, t1, t2, t3))
            
        with open(det_file, 'wb') as f:
            cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)
        with open(seg_file, 'wb') as f:
            cPickle.dump(all_masks, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_sds(all_boxes, all_masks)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Example #4
0
def train_net(args, ctx, pretrained, pretrained_base, pretrained_ec, epoch,
              prefix, begin_epoch, end_epoch, lr, lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_train_symbol(config)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 crop_height=config.TRAIN.CROP_HEIGHT,
                                 crop_width=config.TRAIN.CROP_WIDTH,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES]))),
                      ('data_ref', (config.TRAIN.KEY_INTERVAL - 1, 3,
                                    max([v[0] for v in config.SCALES]),
                                    max([v[1] for v in config.SCALES]))),
                      ('eq_flag', (1, ))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params_base, aux_params_base = load_param(pretrained_base,
                                                      epoch,
                                                      convert=True)
        arg_params.update(arg_params_base)
        aux_params.update(aux_params_base)
        arg_params_ec, aux_params_ec = load_param(
            pretrained_ec,
            epoch,
            convert=True,
            argprefix=config.TRAIN.arg_prefix)
        arg_params.update(arg_params_ec)
        aux_params.update(aux_params_ec)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent *
                                              batch_size)
    eval_metrics = mx.metric.CompositeEvalMetric()

    for child_metric in [fcn_loss_metric]:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)

    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    # 创建logger和对应的输出路径
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    # 特征symbol,从网络sym中获取rpn_cls_score_output
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    # 使能多GPU训练,每一张卡训练一个batch
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    # 加载数据集同时准备训练的imdb,使用+分割不同的图像数据集,比如2007_trainval+2012_trainval
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    # load gt roidb加载gt roidb,根据数据集类型,图像集具体子类,数据集根目录和数据集路径,同时配置相关TRAIN为FLIP来增广数据
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    # 合并不同的roidb
    roidb = merge_roidb(roidbs)
    # 根据配置文件中对应的过滤规则来滤出roi
    roidb = filter_roidb(roidb, config)
    # load training data
    # 加载训练数据,anchor Loader为对应分类和回归的锚点加载,通过对应的roidb,查找对应的正负样本的锚点,该生成器需要参数锚点尺度,ratios和对应的feature的stride
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    # 加载并且初始化参数,如果训练中是继续上次的训练,也就是RESUME这一flag设置为True
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        # 从前缀和being_epoch中加载RESUME的arg参数和aux参数,同时需要转换为GPU NDArray
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    # 检查相关参数的shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    # 创造求解器
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    # 以下主要是RPN和RCNN相关的一些评价指标
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    # mxnet中合成的评估指标,可以增加以上所有的评估指标,包括rpn_eval_metrix、rpn_cls_metric、rpn_bbox_metric和rcnn_eval_metric、rcnn_cls_metric、rcnn_bbox_metric
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)

    # callback
    # batch后的callback回调以及epoch后的callback回调
    # batch_end_callback是在训练一定batch_size后进行的相应回调,回调频率为frequent
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    # means和stds,如果BBOX是类无关的,那么means为复制means两个,否则复制数量为NUM_CLASSES
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    # epoch为一个周期结束后的回调
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    # 以下主要根据不同的学习率调整策略来决定学习率,这里如voc中默认的初始lr为0.0005
    base_lr = lr
    # 学习率调整因子
    lr_factor = config.TRAIN.lr_factor
    # 学习率调整周期,lr_step一般格式为3, 5,表示在3和5周期中进行学习率调整
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    # 如果当前周期大于begin_epoch那么lr_epoch_diff为epoch-begin_epoch
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    print('lr_epoch', lr_epoch, 'begin_epoch', begin_epoch)
    # 通过当前的epoch来计算当前应该具有的lr
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    # learning rate调整机制,warmup multi factor scheduler预训练多因子调整器
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    # 优化器参数,包含momentum、wd、lr、lr_scheduler、rescale_grad和clip_gradient
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        print('!!!train_data is not PrefetchingIter!!!')
        train_data = PrefetchingIter(train_data)

    # train
    # 模型训练过程,输入train_data,评估指标包括eval_metrics等一系列指标,每一个epoch结束后进入epoch_end_callback,每一个batch结束后进入batch_end_callback,优化器使用sgd,同时优化参数、输入参数和辅助参数以及begin周期和end周期
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Example #6
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image
    num_images = imdb.num_images

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(
            imdb.result_path,
            imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        # if os.path.exists(det_file_single_scale):
        #    continue
        cfg.SCALES = [test_scale]
        test_data.reset()

        # all detections are collected into:
        #    all_boxes[cls][image] = N x 5 array of detections in
        #    (x1, y1, x2, y2, score)
        all_boxes_single_scale = [[[] for _ in range(num_images)]
                                  for _ in range(imdb.num_classes)]

        detect_at_single_scale(predictor, data_names, imdb, test_data, cfg,
                               thresh, vis, all_boxes_single_scale, logger)

        with open(det_file_single_scale, 'wb') as f:
            cPickle.dump(all_boxes_single_scale,
                         f,
                         protocol=cPickle.HIGHEST_PROTOCOL)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(
            imdb.result_path,
            imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        if os.path.exists(det_file_single_scale):
            with open(det_file_single_scale, 'rb') as fid:
                all_boxes_single_scale = cPickle.load(fid)
            for idx_class in range(1, imdb.num_classes):
                for idx_im in range(0, num_images):
                    if len(all_boxes[idx_class][idx_im]) == 0:
                        all_boxes[idx_class][idx_im] = all_boxes_single_scale[
                            idx_class][idx_im]
                    else:
                        all_boxes[idx_class][idx_im] = np.vstack(
                            (all_boxes[idx_class][idx_im],
                             all_boxes_single_scale[idx_class][idx_im]))

    for idx_class in range(1, imdb.num_classes):
        for idx_im in range(0, num_images):
            if cfg.TEST.USE_SOFTNMS:
                soft_nms = py_softnms_wrapper(cfg.TEST.SOFTNMS_THRESH,
                                              max_dets=max_per_image)
                all_boxes[idx_class][idx_im] = soft_nms(
                    all_boxes[idx_class][idx_im])
            else:
                nms = py_nms_wrapper(cfg.TEST.NMS)
                keep = nms(all_boxes[idx_class][idx_im])
                all_boxes[idx_class][idx_im] = all_boxes[idx_class][idx_im][
                    keep, :]

    if max_per_image > 0:
        for idx_im in range(0, num_images):
            image_scores = np.hstack([
                all_boxes[j][idx_im][:, -1]
                for j in range(1, imdb.num_classes)
            ])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(
                        all_boxes[j][idx_im][:, -1] >= image_thresh)[0]
                    all_boxes[j][idx_im] = all_boxes[j][idx_im][keep, :]

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Example #7
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):

    logger, final_output_path, _, tensorboard_path = create_env(
        config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    print "config.symbol", config.symbol
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    if config.TRAIN.use_dynamic:
        sym_gen = sym_instance.sym_gen(config, is_train=True)
    else:
        sym = sym_instance.get_symbol(config, is_train=True)

    # infer max shape
    scales = [(config.TRAIN.crop_size[0], config.TRAIN.crop_size[1])
              ] if config.TRAIN.enable_crop else config.SCALES
    label_stride = config.network.LABEL_STRIDE
    network_ratio = config.network.ratio
    if config.network.use_context:
        if config.network.use_crop_context:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3, 736, 736)),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3,
                                 int(config.SCALES[0][0] * network_ratio),
                                 int(config.SCALES[0][1] * network_ratio))),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
    else:
        if config.TRAIN.enable_crop:
            max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                        config.TRAIN.crop_size[0],
                                        config.TRAIN.crop_size[1]))]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3,
                  max([make_divisible(v[0], label_stride) for v in scales]),
                  max([make_divisible(v[1], label_stride) for v in scales])))
            ]

    if config.network.use_mult_label:

        if config.network.use_crop_context:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1, 736, 736))
            ]

        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1,
                                  int(config.SCALES[0][0] * network_ratio),
                                  int(config.SCALES[0][1] * network_ratio)))
            ]
    elif config.network.use_metric:
        scale_list = config.network.scale_list
        scale_name = ['a', 'b', 'c']
        if config.network.scale_list == [1, 2, 4]:
            scale_name = ['', '', '']

        if config.TRAIN.enable_crop:
            if config.TRAIN.use_mult_metric:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]

            else:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]
        else:
            if config.TRAIN.use_mult_metric:

                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

            else:
                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

    else:
        if config.TRAIN.enable_crop:
            max_label_shape = [('label',
                                (config.TRAIN.BATCH_IMAGES, 1,
                                 config.TRAIN.crop_size[0] // label_stride,
                                 config.TRAIN.crop_size[1] // label_stride))]
        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  max([make_divisible(v[0], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE,
                  max([make_divisible(v[1], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE))
            ]

    print "max_label_shapes", max_label_shape
    if config.TRAIN.use_dynamic:
        sym = sym_gen([max_data_shape])

        # setup multi-gpu
    input_batch_size = config.TRAIN.BATCH_IMAGES * len(ctx)
    NUM_GPUS = len(ctx)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=True) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx,
                                 use_context=config.network.use_context,
                                 use_mult_label=config.network.use_mult_label,
                                 use_metric=config.network.use_metric)

    # loading val data
    if config.TRAIN.eval_data_frequency > 0:
        val_image_set = config.dataset.test_image_set
        val_root_path = config.dataset.root_path
        val_dataset = config.dataset.dataset
        val_dataset_path = config.dataset.dataset_path
        val_imdb = eval(val_dataset)(val_image_set,
                                     val_root_path,
                                     val_dataset_path,
                                     result_path=final_output_path)
        val_segdb = val_imdb.gt_segdb()

        val_data = TrainDataLoader(
            sym,
            val_segdb,
            config,
            batch_size=input_batch_size,
            shuffle=config.TRAIN.SHUFFLE,
            ctx=ctx,
            use_context=config.network.use_context,
            use_mult_label=config.network.use_mult_label,
            use_metric=config.network.use_metric)
    else:
        val_data = None

    # print sym.list_arguments()
    print 'providing maximum shape', max_data_shape, max_label_shape
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    if config.TRAIN.use_dynamic:
        sym = sym_gen([train_data.provide_data_single])
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    nset = set()
    for nm in sym.list_arguments():
        if nm in nset:
            raise ValueError('Duplicate names detected, %s' % str(nm))
        nset.add(nm)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)
        preload_opt_states = load_preload_opt_states(prefix, begin_epoch)
        # preload_opt_states = None
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        preload_opt_states = None
        if not config.TRAIN.FINTUNE:
            fixed_param_names = sym_instance.init_weights(
                config, arg_params, aux_params)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)

    # check parameter shapes
    # sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(NUM_GPUS)],
        max_label_shapes=[max_label_shape for _ in xrange(NUM_GPUS)],
        fixed_param_prefix=fixed_param_prefix)

    # metric
    imagecrossentropylossmetric = metric.ImageCrossEntropyLossMetric()
    localmetric = metric.LocalImageCrossEntropyLossMetric()
    globalmetric = metric.GlobalImageCrossEntropyLossMetric()
    pixcelAccMetric = metric.PixcelAccMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()

    if config.network.use_mult_label:
        metric_list = [
            imagecrossentropylossmetric, localmetric, globalmetric,
            pixcelAccMetric
        ]
    elif config.network.use_metric:
        if config.TRAIN.use_crl_ses:
            metric_list = [
                imagecrossentropylossmetric,
                metric.SigmoidPixcelAccMetric(1),
                metric.SigmoidPixcelAccMetric(2),
                metric.SigmoidPixcelAccMetric(3),
                metric.CenterLossMetric(4),
                metric.CenterLossMetric(5),
                metric.CenterLossMetric(6), pixcelAccMetric
            ]

        elif config.network.use_sigmoid_metric:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(1),
                    metric.SigmoidPixcelAccMetric(2),
                    metric.SigmoidPixcelAccMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(), pixcelAccMetric
                ]
        else:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1),
                    metric.MetricLossMetric(2),
                    metric.MetricLossMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1), pixcelAccMetric
                ]
    elif config.network.mult_loss:
        metric_list = [
            imagecrossentropylossmetric,
            metric.MImageCrossEntropyLossMetric(1),
            metric.MImageCrossEntropyLossMetric(2), pixcelAccMetric
        ]
    elif config.TRAIN.use_center:
        if config.TRAIN.use_one_center:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1)
            ]
        else:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1),
                metric.CenterLossMetric(2),
                metric.CenterLossMetric(3)
            ]
    else:
        metric_list = [imagecrossentropylossmetric, pixcelAccMetric]
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in metric_list:
        eval_metrics.add(child_metric)

    # callback
    if False:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size,
                                 frequent=args.frequent),
            callback.TensorboardCallback(tensorboard_path,
                                         prefix="train/batch")
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        shared_tensorboard = batch_end_callback[1]

        epoch_end_metric_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="train/epoch")
        eval_end_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="val/epoch")
        lr_callback = callback.LrCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix='train/batch')
    else:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size, frequent=args.frequent)
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        epoch_end_metric_callback = None
        eval_end_callback = None

    #decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / input_batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    if config.TRAIN.lr_type == "MultiStage":
        lr_scheduler = LinearWarmupMultiStageScheduler(
            lr_iters,
            lr_factor,
            config.TRAIN.warmup,
            config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step,
            args.frequent,
            stop_lr=lr * 0.01)
    elif config.TRAIN.lr_type == "MultiFactor":
        lr_scheduler = LinearWarmupMultiFactorScheduler(
            lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step, args.frequent)

    if config.TRAIN.optimizer == "sgd":
        optimizer_params = {
            'momentum': config.TRAIN.momentum,
            'wd': config.TRAIN.wd,
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = SGD(**optimizer_params)
    elif config.TRAIN.optimizer == "adam":
        optimizer_params = {
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = Adam(**optimizer_params)
        print "optimizer adam"

    freeze_layer_pattern = config.TRAIN.FIXED_PARAMS_PATTERN

    if freeze_layer_pattern.strip():
        args_lr_mult = {}
        re_prog = re.compile(freeze_layer_pattern)
        if freeze_layer_pattern:
            fixed_param_names = [
                name for name in sym.list_arguments() if re_prog.match(name)
            ]
        print "============================"
        print "fixed_params_names:"
        print(fixed_param_names)
        for name in fixed_param_names:
            args_lr_mult[name] = config.TRAIN.FIXED_PARAMS_PATTERN_LR_MULT
        print "============================"
    else:
        args_lr_mult = {}
    optimizer.set_lr_mult(args_lr_mult)

    # data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    # if config.TRAIN.use_dynamic:
    #     sym = sym_gen([train_data.provide_data_single])
    # pprint.pprint(data_shape_dict)
    # sym_instance.infer_shape(data_shape_dict)

    if not isinstance(train_data, PrefetchingIter) and config.TRAIN.use_thread:
        train_data = PrefetchingIter(train_data)

    if val_data:
        if not isinstance(val_data, PrefetchingIter):
            val_data = PrefetchingIter(val_data)

    if Debug:
        monitor = mx.monitor.Monitor(1)
    else:
        monitor = None

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            eval_end_callback=eval_end_callback,
            epoch_end_metric_callback=epoch_end_metric_callback,
            optimizer=optimizer,
            eval_data=val_data,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            allow_missing=begin_epoch == 0,
            allow_extra=True,
            monitor=monitor,
            preload_opt_states=preload_opt_states,
            eval_data_frequency=config.TRAIN.eval_data_frequency)
Example #8
0
def pred_eval(gpu_id,
              feat_predictors,
              aggr_predictors_feat_array,
              aggr_predictors_rfcn,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_' + str(gpu_id))
    if cfg.TEST.SEQ_NMS == True:
        det_file += '_raw'
    print 'det_file=', det_file
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes, frame_ids = cPickle.load(fid)
        return all_boxes, frame_ids

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]
    num_images = test_data.size
    roidb_frame_ids = [x['frame_id'] for x in test_data.roidb]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)
    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    frame_ids = np.zeros(num_images, dtype=np.int)

    roidb_idx = -1
    roidb_offset = -1
    idx = 0
    all_frame_interval = cfg.TEST.KEY_FRAME_INTERVAL * 2 + 1

    data_time, net_time, post_time, seq_time = 0.0, 0.0, 0.0, 0.0
    t = time.time()

    # loop through all the test data
    for im_info, key_frame_flag, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        #################################################
        # new video                                     #
        #################################################
        # empty lists and append padding images
        # do not do prediction yet
        if key_frame_flag == 0:
            roidb_idx += 1
            roidb_offset = -1
            # init data_lsit and feat_list for a new video
            data_list = deque(maxlen=all_frame_interval)
            feat_list = deque(maxlen=all_frame_interval)
            image, feat = get_resnet_output(feat_predictors, data_batch,
                                            data_names)
            # append cfg.TEST.KEY_FRAME_INTERVAL+1 padding images in the front (first frame)
            while len(data_list) < cfg.TEST.KEY_FRAME_INTERVAL + 1:
                data_list.append(image)
                preprocess(feat, idx)
                feat_list.append(feat)

            get_feature_init()

        #################################################
        # main part of the loop                         #
        #################################################
        elif key_frame_flag == 2:
            # keep appending data to the lists without doing prediction until the lists contain 2 * cfg.TEST.KEY_FRAME_INTERVAL objects
            if len(data_list) < all_frame_interval - 1:
                image, feat = get_resnet_output(feat_predictors, data_batch,
                                                data_names)
                data_list.append(image)
                preprocess(feat, idx)
                feat_list.append(feat)

            else:
                scales = [iim_info[0, 2] for iim_info in im_info]

                image, feat = get_resnet_output(feat_predictors, data_batch,
                                                data_names)
                data_list.append(image)
                preprocess(feat, idx)
                feat_list.append(feat)
                prepare_data(data_list, feat_list, data_batch)
                aggr_feat = im_detect_feat(aggr_predictors_feat_array,
                                           data_batch, data_names, scales, cfg,
                                           cfg.TEST.INTERVALS)
                if cfg.TEST.SELECT_FEATURES:
                    aggr_feat = get_feature(
                        [f[0].asnumpy() for f in aggr_feat],
                        [int(x) for x in im_info[0][0, :2]])
                    aggr_feat = [mx.nd.array(aggr_feat)]
                else:
                    aggr_feat = list(aggr_feat)[-1]
                prepare_aggregation(aggr_feat, data_batch)
                pred_result = im_detect_rfcn(aggr_predictors_rfcn, data_batch,
                                             data_names, scales, cfg)

                roidb_offset += 1
                frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset

                t2 = time.time() - t
                t = time.time()
                process_pred_result(
                    pred_result, imdb, thresh, cfg, nms, all_boxes, idx,
                    max_per_image, vis,
                    data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(), scales)
                idx += test_data.batch_size

                t3 = time.time() - t
                t = time.time()
                data_time += t1
                net_time += t2
                post_time += t3
                print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size)
                if logger:
                    logger.info(
                        'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.
                        format(idx, num_images,
                               data_time / idx * test_data.batch_size,
                               net_time / idx * test_data.batch_size,
                               post_time / idx * test_data.batch_size))
        #################################################
        # end part of a video                           #
        #################################################
        elif key_frame_flag == 1:  # last frame of a video
            end_counter = 0
            image, feat = get_resnet_output(feat_predictors, data_batch,
                                            data_names)
            while end_counter < cfg.TEST.KEY_FRAME_INTERVAL + 1:
                data_list.append(image)
                preprocess(feat, idx)
                feat_list.append(feat)
                prepare_data(data_list, feat_list, data_batch)
                aggr_feat = im_detect_feat(aggr_predictors_feat_array,
                                           data_batch, data_names, scales, cfg,
                                           cfg.TEST.INTERVALS)
                if cfg.TEST.SELECT_FEATURES:
                    aggr_feat = get_feature(
                        [f[0].asnumpy() for f in aggr_feat],
                        [int(x) for x in im_info[0][0, :2]])
                    aggr_feat = [mx.nd.array(aggr_feat)]
                else:
                    aggr_feat = list(aggr_feat)[-1]
                prepare_aggregation(aggr_feat, data_batch)
                pred_result = im_detect_rfcn(aggr_predictors_rfcn, data_batch,
                                             data_names, scales, cfg)

                roidb_offset += 1
                frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset

                t2 = time.time() - t
                t = time.time()
                process_pred_result(
                    pred_result, imdb, thresh, cfg, nms, all_boxes, idx,
                    max_per_image, vis,
                    data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(), scales)
                idx += test_data.batch_size
                t3 = time.time() - t
                t = time.time()
                data_time += t1
                net_time += t2
                post_time += t3

                print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size)
                if logger:
                    logger.info(
                        'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.
                        format(idx, num_images,
                               data_time / idx * test_data.batch_size,
                               net_time / idx * test_data.batch_size,
                               post_time / idx * test_data.batch_size))
                end_counter += 1

    get_feature_init()

    with open(det_file, 'wb') as f:
        cPickle.dump((all_boxes, frame_ids),
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)

    return all_boxes, frame_ids
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    # 创建logger和对应的输出路径
    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    # 特征symbol,从网络sym中获取rpn_cls_score_output
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    # 使能多GPU训练,每一张卡训练一个batch
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    # 加载数据集同时准备训练的imdb,使用+分割不同的图像数据集,比如2007_trainval+2012_trainval
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    # load gt roidb加载gt roidb,根据数据集类型,图像集具体子类,数据集根目录和数据集路径,同时配置相关TRAIN为FLIP来增广数据
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    # 合并不同的roidb
    roidb = merge_roidb(roidbs)
    # 根据配置文件中对应的过滤规则来滤出roi
    roidb = filter_roidb(roidb, config)
    # load training data
    # 加载训练数据,anchor Loader为对应分类和回归的锚点加载,通过对应的roidb,查找对应的正负样本的锚点,该生成器需要参数锚点尺度,ratios和对应的feature的stride
    train_data = AnchorLoader(feat_sym, roidb, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE, ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in range(batch_size)],
                        max_label_shapes=[max_label_shape for _ in range(batch_size)], fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True), callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.default.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
Example #10
0
def pred_eval_dota_rotbox_Rroi(predictor,
                               test_data,
                               imdb,
                               cfg,
                               vis=False,
                               draw=False,
                               thresh=1e-3,
                               logger=None,
                               ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    # ignore_cache = True
    # pdb.set_trace()
    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        # imdb.count_ar()
        #imdb.check_transform()
        # imdb.draw_gt_and_detections(all_boxes, thresh=0.1)
        info_str = imdb.evaluate_detections(all_boxes, ignore_cache)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    #nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 9 array of detections in
    #    (x1, y1, x2, y2, x3, y3, x4, y4, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        # scores_all, boxes_all, data_dict_all= im_detect_poly(predictor, data_batch, data_names, scales, cfg)
        scores_all, boxes_all, data_dict_all = im_detect_rotbox_Rroi(
            predictor, data_batch, data_names, scales, cfg)
        # pdb.set_trace()
        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            # idx = int(data_dict['im_index'])-1
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[
                    indexes,
                    8:16] if cfg.network.RRoI_CLASS_AGNOSTIC else boxes[
                        indexes, j * 8:(j + 1) * 8]
                cls_quadrangle_dets = np.hstack((cls_boxes, cls_scores))
                # keep = nms(cls_dets)
                # TODO: check the thresh
                keep = py_cpu_nms_poly(cls_quadrangle_dets, 0.3)
                # pdb.set_trace()
                all_boxes[j][idx + delta] = cls_quadrangle_dets[keep, :]
                # all_boxes[j][idx+delta]=cls_quadrangle_dets
            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

            if draw:
                if not os.path.isdir(cfg.TEST.save_img_path):
                    os.mkdir(cfg.TEST.save_img_path)
                path = os.path.join(cfg.TEST.save_img_path, str(idx) + '.jpg')
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                im = draw_all_poly_detection(data_dict['data'].asnumpy(),
                                             boxes_this_image,
                                             imdb.classes,
                                             scales[delta],
                                             cfg,
                                             threshold=0.2)
                print path
                cv2.imwrite(path, im)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images,
                    data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    # imdb.draw_gt_and_detections(all_boxes, thresh=0.1)
    info_str = imdb.evaluate_detections(all_boxes, ignore_cache)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Example #11
0
def pred_eval(gpu_id,
              feat_predictors,
              aggr_predictors,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_' + str(gpu_id))
    if cfg.TEST.SEQ_NMS == True:
        det_file += '_raw'
    print 'det_file=', det_file
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes, frame_ids = cPickle.load(fid)
        return all_boxes, frame_ids

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]
                  ]  # data, img_info, data_cache, feat_cache

    num_images = test_data.size  # 43859, 44033, 43812, 44422
    roidb_frame_ids = [x['frame_id'] for x in test_data.roidb]
    print 'roidb_frame_ids:{}'.format(roidb_frame_ids)

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)
    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    frame_ids = np.zeros(num_images, dtype=np.int)

    roidb_idx = -1
    roidb_offset = -1
    idx = 0
    if cfg.TRAIN.E2E_NAME == 'off':
        all_frame_interval = 2
    else:
        all_frame_interval = cfg.TEST.KEY_FRAME_INTERVAL * 2 + 1

    data_time, net_time, post_time, seq_time = 0.0, 0.0, 0.0, 0.0
    t = time.time()

    # loop through all the test data
    for im_info, key_frame_flag, data_batch in test_data:
        print 'key_frame_flag:{}'.format(key_frame_flag)
        t1 = time.time() - t
        t = time.time()

        if cfg.TRAIN.E2E_NAME == 'off':

            if key_frame_flag == 0:
                roidb_idx += 1
                roidb_offset = -1
                data_list = deque(maxlen=all_frame_interval)
                feat_list = deque(maxlen=all_frame_interval)
                scales = [iim_info[0, 2] for iim_info in im_info]

                image, feat = get_resnet_output(feat_predictors, data_batch,
                                                data_names, cfg)
                data_list.append(image)
                feat_list.append(feat)
                prepare_data(data_list, feat_list, data_batch, logger)
                pred_result = im_detect(aggr_predictors, data_batch,
                                        data_names, scales, cfg)
                roidb_offset += 1
                frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset
                print 'roidb_idx:{}'.format(roidb_idx)
                print 'roidb_frames_ids[roidb_idx]:{}'.format(
                    roidb_frame_ids[roidb_idx])
                print 'roidb_offset:{}'.format(roidb_offset)
                print 'idx:{}'.format(idx)
                print 'frame_ids[idx]:{}'.format(frame_ids[idx])

                t2 = time.time() - t
                t = time.time()
                file_name = imdb.get_result_file_template().format('bind')
                with open(file_name, 'a+') as f:
                    for scores, boxes, data_dict in pred_result:
                        for j in range(1, imdb.num_classes):
                            # indexes = np.where(scores[:, j] > thresh)[0]
                            # cls_scores = scores[indexes, j]
                            # if cls_scores is not None:
                            for m in range(len(scores[:, j])):
                                f.write(
                                    '{:d} {:d} {:f} {:.2f} {:.2f} {:.2f} {:.2f}\n'
                                    .format(frame_ids[idx], j, scores[m, j],
                                            boxes[m, 4], boxes[m, 5],
                                            boxes[m, 6], boxes[m, 7]))
                if cfg.TRAIN.E2E_NAME == 'off':
                    process_pred_result(pred_result, imdb, thresh, cfg, nms,
                                        all_boxes, idx, max_per_image, vis,
                                        image.asnumpy(), scales)
                else:
                    process_pred_result(
                        pred_result, imdb, thresh, cfg, nms, all_boxes, idx,
                        max_per_image, vis,
                        data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(),
                        scales)
                idx += test_data.batch_size  #1
                t3 = time.time() - t
                t = time.time()
                data_time += t1
                net_time += t2
                post_time += t3

                print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size)
                if logger:
                    logger.info(
                        'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.
                        format(idx, num_images,
                               data_time / idx * test_data.batch_size,
                               net_time / idx * test_data.batch_size,
                               post_time / idx * test_data.batch_size))
            elif key_frame_flag == 2:
                data_list = deque(maxlen=all_frame_interval)
                feat_list = deque(maxlen=all_frame_interval)
                scales = [iim_info[0, 2] for iim_info in im_info]

                image, feat = get_resnet_output(feat_predictors, data_batch,
                                                data_names, cfg)
                data_list.append(image)
                feat_list.append(feat)
                prepare_data(data_list, feat_list, data_batch, logger)
                pred_result = im_detect(aggr_predictors, data_batch,
                                        data_names, scales, cfg)
                roidb_offset += 1
                frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset
                print 'roidb_idx:{}'.format(roidb_idx)
                print 'roidb_frames_ids[roidb_idx]:{}'.format(roidb_idx)
                print 'roidb_offset:{}'.format(roidb_offset)
                print 'frame_ids[idx]:{}'.format(idx)

                t2 = time.time() - t
                t = time.time()
                file_name = imdb.get_result_file_template().format('bind')
                with open(file_name, 'a+') as f:
                    for scores, boxes, data_dict in pred_result:
                        for j in range(1, imdb.num_classes):
                            # indexes = np.where(scores[:, j] > thresh)[0]
                            # cls_scores = scores[indexes, j]
                            # if cls_scores is not None:
                            for m in range(len(scores[:, j])):
                                f.write(
                                    '{:d} {:d} {:f} {:.2f} {:.2f} {:.2f} {:.2f}\n'
                                    .format(frame_ids[idx], j, scores[m, j],
                                            boxes[m, 4], boxes[m, 5],
                                            boxes[m, 6], boxes[m, 7]))
                if cfg.TRAIN.E2E_NAME == 'off':
                    process_pred_result(pred_result, imdb, thresh, cfg, nms,
                                        all_boxes, idx, max_per_image, vis,
                                        image.asnumpy(), scales)
                else:
                    process_pred_result(
                        pred_result, imdb, thresh, cfg, nms, all_boxes, idx,
                        max_per_image, vis,
                        data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(),
                        scales)
                idx += test_data.batch_size  #1
                t3 = time.time() - t
                t = time.time()
                data_time += t1
                net_time += t2
                post_time += t3

                print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size)
                if logger:
                    logger.info(
                        'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.
                        format(idx, num_images,
                               data_time / idx * test_data.batch_size,
                               net_time / idx * test_data.batch_size,
                               post_time / idx * test_data.batch_size))
        else:

            #################################################
            # new video                                     #
            #################################################
            # empty lists and append padding images
            # do not do prediction yet
            if key_frame_flag == 0:
                roidb_idx += 1
                roidb_offset = -1
                # init data_lsit and feat_list for a new video
                data_list = deque(maxlen=all_frame_interval)
                feat_list = deque(maxlen=all_frame_interval)
                image, feat = get_resnet_output(feat_predictors, data_batch,
                                                data_names, cfg)
                # append cfg.TEST.KEY_FRAME_INTERVAL+1 padding images in the front (first frame)
                if cfg.TRAIN.E2E_NAME == 'off':  # changed by zy
                    key_frame_interval = 0
                elif cfg.TRAIN.E2E_NAME == 'base':
                    key_frame_interval = -1
                else:
                    key_frame_interval = cfg.TEST.KEY_FRAME_INTERVAL
                if logger:
                    logger.info(
                        'key_frame_interval:{}\n'.format(key_frame_interval))

                while len(data_list) < key_frame_interval + 1:  #????
                    print "RR, no append___________________________________-"
                    data_list.append(image)
                    feat_list.append(feat)

            #################################################
            # main part of the loop                         #
            #################################################
            elif key_frame_flag == 2:
                # keep appending data to the lists without doing prediction until the lists contain 2 * cfg.TEST.KEY_FRAME_INTERVAL objects
                if len(data_list) < all_frame_interval - 1:  # off:1
                    image, feat = get_resnet_output(feat_predictors,
                                                    data_batch, data_names,
                                                    cfg)
                    data_list.append(image)
                    feat_list.append(feat)

                else:
                    scales = [iim_info[0, 2] for iim_info in im_info]
                    image, feat = get_resnet_output(feat_predictors,
                                                    data_batch, data_names,
                                                    cfg)
                    data_list.append(image)
                    feat_list.append(feat)
                    prepare_data(data_list, feat_list, data_batch, logger)
                    print 'pred_eval,flag=2:{}'.format(
                        data_batch.data[0][-2].shape)
                    if logger:
                        logger.info('pred_eval,flag=2:{}'.format(
                            data_batch.data[0][-2].shape))
                    pred_result = im_detect(aggr_predictors, data_batch,
                                            data_names, scales, cfg)

                    roidb_offset += 1
                    frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset

                    t2 = time.time() - t
                    t = time.time()
                    file_name = imdb.get_result_file_template().format('bind')
                    with open(file_name, 'a+') as f:
                        for scores, boxes, data_dict in pred_result:
                            for j in range(1, imdb.num_classes):
                                # indexes = np.where(scores[:, j] > thresh)[0]
                                # cls_scores = scores[indexes, j]
                                # if cls_scores is not None:
                                for m in range(len(scores[:, j])):
                                    f.write(
                                        '{:d} {:d} {:f} {:.2f} {:.2f} {:.2f} {:.2f}\n'
                                        .format(frame_ids[idx], j, scores[m,
                                                                          j],
                                                boxes[m, 4], boxes[m, 5],
                                                boxes[m, 6], boxes[m, 7]))
                    if cfg.TRAIN.E2E_NAME == 'off':
                        process_pred_result(pred_result, imdb, thresh, cfg,
                                            nms, all_boxes, idx, max_per_image,
                                            vis, data_list[0].asnumpy(),
                                            scales)
                    else:
                        process_pred_result(
                            pred_result, imdb, thresh, cfg, nms, all_boxes,
                            idx, max_per_image, vis,
                            data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(),
                            scales)
                    idx += test_data.batch_size

                    t3 = time.time() - t
                    t = time.time()
                    data_time += t1
                    net_time += t2
                    post_time += t3
                    print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                        idx, num_images,
                        data_time / idx * test_data.batch_size,
                        net_time / idx * test_data.batch_size,
                        post_time / idx * test_data.batch_size)
                    if logger:
                        logger.info(
                            'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'
                            .format(idx, num_images,
                                    data_time / idx * test_data.batch_size,
                                    net_time / idx * test_data.batch_size,
                                    post_time / idx * test_data.batch_size))
            #################################################
            # end part of a video                           #
            #################################################
            elif key_frame_flag == 1:  # last frame of a video
                end_counter = 0
                # image, feat = get_resnet_output(feat_predictors, data_batch, data_names, cfg)
                if cfg.TRAIN.E2E_NAME == 'off':
                    key_frame_interval = 1
                else:
                    key_frame_interval = cfg.TEST.KEY_FRAME_INTERVAL
                if len(data_list) == all_frame_interval - 1:
                    # while end_counter < key_frame_interval + 1:
                    image, feat = get_resnet_output(feat_predictors,
                                                    data_batch, data_names,
                                                    cfg)
                    data_list.append(image)
                    feat_list.append(feat)
                    prepare_data(data_list, feat_list, data_batch, logger)
                    print 'pred_eval,flag=1:{}'.format(
                        data_batch.data[0][-2].shape)
                    if logger:
                        logger.info('pred_eval,flag=1:{}'.format(
                            data_batch.data[0][-2].shape))
                    pred_result = im_detect(aggr_predictors, data_batch,
                                            data_names, scales, cfg)

                    roidb_offset += 1
                    frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset

                    t2 = time.time() - t
                    t = time.time()
                    file_name = imdb.get_result_file_template().format('bind')
                    with open(file_name, 'a+') as f:
                        for scores, boxes, data_dict in pred_result:
                            for j in range(1, imdb.num_classes):
                                # indexes = np.where(scores[:, j] > thresh)[0]
                                # cls_scores = scores[indexes, j]
                                # if cls_scores is not None:
                                for m in range(len(scores[:, j])):
                                    f.write(
                                        '{:d} {:d} {:f} {:.2f} {:.2f} {:.2f} {:.2f}\n'
                                        .format(frame_ids[idx], j, scores[m,
                                                                          j],
                                                boxes[m, 4], boxes[m, 5],
                                                boxes[m, 6], boxes[m, 7]))
                    if cfg.TRAIN.E2E_NAME == 'off':
                        process_pred_result(pred_result, imdb, thresh, cfg,
                                            nms, all_boxes, idx, max_per_image,
                                            vis, data_list[0].asnumpy(),
                                            scales)
                    else:
                        process_pred_result(
                            pred_result, imdb, thresh, cfg, nms, all_boxes,
                            idx, max_per_image, vis,
                            data_list[cfg.TEST.KEY_FRAME_INTERVAL].asnumpy(),
                            scales)
                    idx += test_data.batch_size
                    t3 = time.time() - t
                    t = time.time()
                    data_time += t1
                    net_time += t2
                    post_time += t3

                    print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                        idx, num_images,
                        data_time / idx * test_data.batch_size,
                        net_time / idx * test_data.batch_size,
                        post_time / idx * test_data.batch_size)
                    if logger:
                        logger.info(
                            'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'
                            .format(idx, num_images,
                                    data_time / idx * test_data.batch_size,
                                    net_time / idx * test_data.batch_size,
                                    post_time / idx * test_data.batch_size))
                    # end_counter += 1

    with open(det_file, 'wb') as f:
        cPickle.dump((all_boxes, frame_ids),
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)

    return all_boxes, frame_ids
Example #12
0
def train_rcnn(cfg, dataset, image_set, root_path, dataset_path,
               frequent, kvstore, flip, shuffle, resume,
               ctx, pretrained, epoch, prefix, begin_epoch, end_epoch,
               train_shared, lr, lr_step, proposal, logger=None, output_path=None):
    mx.random.seed(0)
    np.random.seed(0)
    # set up logger
    if not logger:
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)

    # load symbol
    sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
    sym = sym_instance.get_symbol_rcnn(cfg, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = cfg.TRAIN.BATCH_IMAGES * batch_size

    # print cfg
    pprint.pprint(cfg)
    logger.info('training rcnn cfg:{}\n'.format(pprint.pformat(cfg)))

    rpn_path = cfg.dataset.proposal_cache
    # load dataset and prepare imdb for training
    image_sets = [iset for iset in image_set.split('+')]
    roidbs = [load_proposal_roidb(dataset, image_set, root_path, dataset_path,
                                  proposal=proposal, append_gt=True, flip=flip, result_path=output_path,
                                  rpn_path=rpn_path, top_roi=cfg.TRAIN.TOP_ROIS)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, cfg)
    means, stds = add_bbox_regression_targets(roidb, cfg)

    # load training data
    train_data = ROIIter(roidb, cfg, batch_size=input_batch_size, shuffle=shuffle,
                         ctx=ctx, aspect_grouping=cfg.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_height = max([v[0] for v in cfg.SCALES])
    max_width = max([v[1] for v in cfg.SCALES])
    paded_max_height = max_height + cfg.network.IMAGE_STRIDE - max_height % cfg.network.IMAGE_STRIDE
    paded_max_width = max_width + cfg.network.IMAGE_STRIDE - max_width % (cfg.network.IMAGE_STRIDE)

    max_data_shape = [('data', (cfg.TRAIN.BATCH_IMAGES, 3, paded_max_height, paded_max_width))]
    # infer shape
    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    sym_instance.infer_shape(data_shape_dict)
    # print shape
    pprint.pprint(sym_instance.arg_shape_dict)
    logging.info(pprint.pformat(sym_instance.arg_shape_dict))

    max_batch_roi = cfg.TRAIN.TOP_ROIS if cfg.TRAIN.BATCH_ROIS == -1 else cfg.TRAIN.BATCH_ROIS
    num_class = 2 if cfg.CLASS_AGNOSTIC else cfg.dataset.NUM_CLASSES
    max_label_shape = [('label', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi)),
                       ('bbox_target', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi, num_class * 4)),
                       ('bbox_weight', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi, num_class * 4))]

    if cfg.network.USE_NONGT_INDEX:
        max_label_shape.append(('nongt_index', (2000,)))

    if cfg.network.ROIDispatch:
        max_data_shape.append(('rois_0', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_1', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_2', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
        max_data_shape.append(('rois_3', (cfg.TRAIN.BATCH_IMAGES, max_batch_roi / 4, 5)))
    else:
        max_data_shape.append(('rois', (cfg.TEST.PROPOSAL_POST_NMS_TOP_N + 30, 5)))

    #dot = mx.viz.plot_network(sym, node_attrs={'shape': 'rect', 'fixedsize': 'false'})
    #dot.render(os.path.join('./output/rcnn/network_vis', cfg.symbol + cfg.TRAIN.model_prefix))

    # load and initialize params
    if resume:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight_rcnn(cfg, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # prepare training
    # create solver
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]
    if train_shared:
        fixed_param_prefix = cfg.network.FIXED_PARAMS_SHARED
    else:
        fixed_param_prefix = cfg.network.FIXED_PARAMS

    if cfg.network.ROIDispatch:
        mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                            logger=logger, context=ctx,
                            max_data_shapes=[max_data_shape for _ in range(batch_size)],
                            max_label_shapes=[max_label_shape for _ in range(batch_size)],
                            fixed_param_prefix=fixed_param_prefix)
    else:
        mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                            logger=logger, context=ctx,
                            max_data_shapes=[max_data_shape for _ in range(batch_size)],
                            max_label_shapes=[max_label_shape for _ in range(batch_size)],
                            fixed_param_prefix=fixed_param_prefix)
    if cfg.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    eval_metric = metric.RCNNAccMetric(cfg)
    cls_metric = metric.RCNNLogLossMetric(cfg)
    bbox_metric = metric.RCNNL1LossMetric(cfg)
    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if cfg.TRAIN.LEARN_NMS:
        eval_metrics.add(metric.NMSLossMetric(cfg, 'pos'))
        eval_metrics.add(metric.NMSLossMetric(cfg, 'neg'))
        eval_metrics.add(metric.NMSAccMetric(cfg))
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=frequent)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = cfg.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, cfg.TRAIN.warmup, cfg.TRAIN.warmup_lr,
                                              cfg.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': cfg.TRAIN.momentum,
                        'wd': cfg.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    # train

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
Example #13
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)
    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()

    sym = sym_instance.get_retina_symbol(config, is_train=True)
    feat_sym = []

    feat_sym_p4 = sym.get_internals()['box_pred/p4_output']
    feat_sym_p5 = sym.get_internals()['box_pred/p5_output']
    feat_sym_p6 = sym.get_internals()['box_pred/p6_output']
    feat_sym_p7 = sym.get_internals()['box_pred/p7_output']

    feat_sym.append(feat_sym_p4)
    feat_sym.append(feat_sym_p5)
    feat_sym.append(feat_sym_p6)
    feat_sym.append(feat_sym_p7)
    #######
    feat_stride = []
    feat_stride.append(config.network.p4_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p5_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p6_RPN_FEAT_STRIDE)
    feat_stride.append(config.network.p7_RPN_FEAT_STRIDE)
    anchor_scales = []

    anchor_scales.append(config.network.p4_ANCHOR_SCALES)
    anchor_scales.append(config.network.p5_ANCHOR_SCALES)
    anchor_scales.append(config.network.p6_ANCHOR_SCALES)
    anchor_scales.append(config.network.p7_ANCHOR_SCALES)
    anchor_ratios = []

    anchor_ratios.append(config.network.p4_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p5_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p6_ANCHOR_RATIOS)
    anchor_ratios.append(config.network.p7_ANCHOR_RATIOS)
    #############

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)

    roidb = filter_roidb(roidb, config)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              feat_stride,
                              anchor_scales,
                              anchor_ratios,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING)
    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape
    # infer max shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)
    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    Retina_toal_eval_metric = metric.RetinaToalAccMetric()
    Retina_cls_metric = metric.RetinaFocalLossMetric()
    Retina_bbox_metric = metric.RetinaL1LossMetric()

    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            Retina_toal_eval_metric, Retina_cls_metric, Retina_bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print lr_step.split(',')
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'learning_rate': lr,
        'wd': 0.0001,
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)
    # train
    initializer = mx.init.MSRAPrelu(factor_type='out', slope=0)
    # adam = mx.optimizer.AdaDelta(rho=0.09,  epsilon=1e-14)
    #optimizer_params=optimizer_params,

    print "-----------------------train--------------------------------"
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='adam',
            optimizer_params=optimizer_params,
            initializer=initializer,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    if config.dataset.dataset != 'JSONList':
        logger, final_output_path = create_logger(config.output_path, args.cfg,
                                                  config.dataset.image_set)
        prefix = os.path.join(final_output_path, prefix)
    else:
        import datetime
        import logging
        final_output_path = config.output_path
        prefix = prefix + '_' + datetime.datetime.now().strftime(
            "%Y-%m-%d_%H_%M_%S")
        prefix = os.path.join(final_output_path, prefix)
        shutil.copy2(args.cfg, prefix + '.yaml')
        log_file = prefix + '.log'
        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(filename=log_file, format=head)
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        logger.info('prefix: %s' % prefix)
        print('prefix: %s' % prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    if config.network.MULTI_RPN:
        num_layers = len(config.network.MULTI_RPN_STRIDES)
        rpn_syms = [
            sym.get_internals()['rpn%d_cls_score_output' % l]
            for l in range(num_layers)
        ]
        train_data = PyramidAnchorLoader(
            rpn_syms,
            roidb,
            config,
            batch_size=input_batch_size,
            shuffle=config.TRAIN.SHUFFLE,
            ctx=ctx,
            feat_strides=config.network.MULTI_RPN_STRIDES,
            anchor_scales=config.network.ANCHOR_SCALES,
            anchor_ratios=config.network.ANCHOR_RATIOS,
            aspect_grouping=config.TRAIN.ASPECT_GROUPING,
            allowed_border=np.inf)
    else:
        feat_sym = sym.get_internals()['rpn_cls_score_output']
        train_data = AnchorLoader(feat_sym,
                                  roidb,
                                  config,
                                  batch_size=input_batch_size,
                                  shuffle=config.TRAIN.SHUFFLE,
                                  ctx=ctx,
                                  feat_stride=config.network.RPN_FEAT_STRIDE,
                                  anchor_scales=config.network.ANCHOR_SCALES,
                                  anchor_ratios=config.network.ANCHOR_RATIOS,
                                  aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    if config.network.PREDICT_KEYPOINTS:
        kps_cls_acc = metric.KeypointAccMetric(config)
        kps_cls_loss = metric.KeypointLogLossMetric(config)
        kps_pos_loss = metric.KeypointL1LossMetric(config)
        eval_metrics.add(kps_cls_acc)
        eval_metrics.add(kps_cls_loss)
        eval_metrics.add(kps_pos_loss)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Example #15
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    mx.random.seed(3)
    np.random.seed(3)
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    feat_pyramid_level = np.log2(config.network.RPN_FEAT_STRIDE).astype(int)
    feat_sym = [
        sym.get_internals()['rpn_cls_score_p' + str(x) + '_output']
        for x in feat_pyramid_level
    ]
    print('load symbol END')
    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    print('Start load dataset and prepare imdb for training')
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb_poly(config.dataset.dataset,
                           image_set,
                           config.dataset.root_path,
                           config.dataset.dataset_path,
                           flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    print('Start load training data')
    # load training data

    train_data = PyramidAnchorIterator_poly(
        feat_sym,
        roidb,
        config,
        batch_size=input_batch_size,
        shuffle=config.TRAIN.SHUFFLE,
        ctx=ctx,
        feat_strides=config.network.RPN_FEAT_STRIDE,
        anchor_scales=config.network.ANCHOR_SCALES,
        anchor_ratios=config.network.ANCHOR_RATIOS,
        aspect_grouping=config.TRAIN.ASPECT_GROUPING,
        allowed_border=np.inf)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 300, 9)))
    print 'providing maximum shape', max_data_shape, max_label_shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    rpn_fg_metric = metric.RPNFGFraction(config)
    eval_fg_metric = metric.RCNNFGAccuracy(config)
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    # add Rroi loss here
    RCNN_proposal_fraction_metric = metric.RCNNFGFraction(config)
    Rroi_fg_accuracy = metric.RRoIRCNNFGAccuracy(config)
    Rroi_accuracy = metric.RRoIAccMetric(config)
    Rroi_cls_metric = metric.RRoIRCNNLogLossMetric(config)
    Rroi_bbox_metric = metric.RRoIRCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, rpn_fg_metric,
            eval_fg_metric, eval_metric, cls_metric, bbox_metric,
            RCNN_proposal_fraction_metric, Rroi_fg_accuracy, Rroi_accuracy,
            Rroi_cls_metric, Rroi_bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    Rroi_means = np.tile(
        np.array(config.TRAIN.RRoI_BBOX_MEANS), 2
        if config.network.RRoI_CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    Rroi_stds = np.tile(
        np.array(config.TRAIN.RRoI_BBOX_STDS), 2
        if config.network.RRoI_CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)

    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint_Rroi(prefix, means, stds, Rroi_means, Rroi_stds)
    ]

    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None
    }
    #
    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    #if config.TRAIN.RESUME:
    #    print('continue training from ', begin_epoch)
    #    arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    #else:
    #    arg_params, aux_params = load_param(pretrained, epoch, convert=True)
    #    sym_instance.init_weight(config, arg_params, aux_params)

    print('transfer learning...')

    # Choose the initialization weights (COCO or UADETRAC or pretrained)
    #arg_params, aux_params = load_param('/raid10/home_ext/Deformable-ConvNets/output/rfcn_dcn_Shuo_UADTRAC/resnet_v1_101_voc0712_rfcn_dcn_Shuo_UADETRAC/trainlist_full/rfcn_UADTRAC', 5, convert=True)
    #arg_params, aux_params = load_param('/raid10/home_ext/Deformable-ConvNets/model/rfcn_dcn_coco', 0, convert=True)
    arg_params, aux_params = load_param(
        '/raid10/home_ext/Deformable-ConvNets/output/rfcn_dcn_Shuo_AICity/resnet_v1_101_voc0712_rfcn_dcn_Shuo_AICityVOC1080_FreezeCOCO_rpnOnly_all/1080_all/rfcn_AICityVOC1080_FreezeCOCO_rpnOnly_all',
        4,
        convert=True)

    sym_instance.init_weight_Shuo(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    #freeze parameters using fixed_param_names:list of str
    para_file = open(
        '/raid10/home_ext/Deformable-ConvNets/rfcn/symbols/arg_params.txt')
    para_list = [line.split('<')[0] for line in para_file.readlines()]
    #    para_list.remove('rfcn_cls_weight')
    #    para_list.remove('rfcn_cls_bias')
    #    para_list.remove('rfcn_cls_offset_t_weight')
    #    para_list.remove('rfcn_cls_offset_t_bias')
    #
    para_list.remove('res5a_branch2b_offset_weight')
    para_list.remove('res5a_branch2b_offset_bias')
    para_list.remove('res5b_branch2b_offset_weight')
    para_list.remove('res5b_branch2b_offset_bias')
    para_list.remove('res5c_branch2b_offset_weight')
    para_list.remove('res5c_branch2b_offset_bias')
    para_list.remove('conv_new_1_weight')
    para_list.remove('conv_new_1_bias')
    para_list.remove('rfcn_bbox_weight')
    para_list.remove('rfcn_bbox_bias')
    para_list.remove('rfcn_bbox_offset_t_weight')
    para_list.remove('rfcn_bbox_offset_t_bias')

    mod = MutableModule_Shuo(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix,
        fixed_param_names=para_list)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Example #17
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = pickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    rois = []

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, roi_score_all, rois_all, roi_feat_all, data_dict_all = im_detect(
            predictor, data_batch, data_names, scales, cfg)
        assert len(roi_score_all) == len(rois_all) == len(roi_feat_all) == 1
        nms_input = np.hstack((rois_all[0], roi_score_all[0]))
        roi_nms = psoft(nms_input, -1.0)
        roi_bbox = roi_nms[:, :4]
        roi_score = roi_nms[:, 4]
        roi_feat = np.zeros((roi_nms.shape[0], roi_feat_all[0].shape[1]),
                            dtype=roi_feat_all[0].dtype)
        for i in range(roi_nms.shape[0]):
            for j in range(nms_input.shape[0]):
                if np.all(roi_nms[i] == nms_input[j]):
                    roi_feat[i] = roi_feat_all[0][j]
        rois.append({'score': roi_score, 'bbox': roi_bbox, 'feat': roi_feat})

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            # scores: 300 x 81 numpy.array
            nms_dets = []
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes,
                                  4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                      indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))
                nms_dets.append(psoft(cls_dets, thresh))
            for j in range(1, imdb.num_classes):
                all_boxes[j][idx + delta] = nms_dets[j - 1]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size))
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images,
                    data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, protocol=pickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes, rois)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Example #18
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    #if cfg.TEST.SOFTNMS:
    #    nms = py_softnms_wrapper(cfg.TEST.NMS)
    #else:
    #    nms = py_nms_wrapper(cfg.TEST.NMS)

    if cfg.TEST.SOFTNMS:
        nms = py_softnms_wrapper(cfg.TEST.NMS)
    else:
        nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    class_lut = [[] for _ in range(imdb.num_classes)]
    valid_tally = 0
    valid_sum = 0

    idx = 0
    t = time.time()
    inference_count = 0
    all_inference_time = []
    post_processing_time = []
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(
            predictor, data_batch, data_names, scales, cfg)

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            if cfg.TEST.LEARN_NMS:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j - 1] > thresh)[0]
                    cls_scores = scores[indexes, j - 1:j]
                    cls_boxes = boxes[indexes, j - 1, :]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)
                    all_boxes[j][idx + delta] = cls_dets
            else:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j] > thresh)[0]
                    if cfg.TEST.FIRST_N > 0:
                        # todo: check whether the order affects the result
                        sort_indices = np.argsort(
                            scores[:, j])[-cfg.TEST.FIRST_N:]
                        # sort_indices = np.argsort(-scores[:, j])[0:cfg.TEST.FIRST_N]
                        indexes = np.intersect1d(sort_indices, indexes)

                    cls_scores = scores[indexes, j, np.newaxis]
                    cls_boxes = boxes[indexes,
                                      4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                          indexes, j * 4:(j + 1) * 4]
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)
                        # print np.min(cls_scores), valid_tally, valid_sum
                        # cls_scores = scores[:, j, np.newaxis]
                        # cls_scores[cls_scores <= thresh] = thresh
                        # cls_boxes = boxes[:, 4:8] if cfg.CLASS_AGNOSTIC else boxes[:, j * 4:(j + 1) * 4]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    if cfg.TEST.SOFTNMS:
                        all_boxes[j][idx + delta] = nms(cls_dets)
                    else:
                        keep = nms(cls_dets)
                        all_boxes[j][idx + delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        post_processing_time.append(t3)
        all_inference_time.append(t1 + t2 + t3)
        inference_count += 1
        if inference_count % 200 == 0:
            valid_count = 500 if inference_count > 500 else inference_count
            print("--->> running-average inference time per batch: {}".format(
                float(sum(all_inference_time[-valid_count:])) / valid_count))
            print("--->> running-average post processing time per batch: {}".
                  format(
                      float(sum(post_processing_time[-valid_count:])) /
                      valid_count))
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, t1, t2, t3)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images, t1, t2, t3))

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    # np.save('class_lut.npy', class_lut)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
        num_valid_classes = [len(x) for x in class_lut]
        logger.info('valid class ratio:{}'.format(
            np.sum(num_valid_classes) / float(num_images)))
        logger.info('valid score ratio:{}'.format(
            float(valid_tally) / float(valid_sum + 0.01)))
Example #19
0
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
#def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=0.7, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    co_occur_matrix = np.load('/home/user/Deformable-ConvNets2/tmp/co_occur_matrix.npy')
    nor_co_occur_matrix = np.zeros((90,90))
    row_max = np.zeros(90)
    co_occur_matrix = co_occur_matrix.astype(int)
    for ind, val in enumerate(co_occur_matrix):        
        row_sum = np.sum(co_occur_matrix[:,ind])        
        if not row_sum == 0:
            nor_co_occur_matrix[:,ind] = co_occur_matrix[:,ind]/row_sum
        row_max[ind] = np.amax(nor_co_occur_matrix[:,ind])
        

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]    

    roidb = test_data.roidb

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)    
    
    soft_nms = py_softnms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images

    # all detections are collected into:    
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    #pl = Pool(8)

    annotation_file = '/home/user/Deformable-ConvNets-test/data/coco/annotations/kinstances_unlabeled2017.json'
    dataset = json.load(open(annotation_file, 'r'))    
    annotations = []    
    id_count = 1
    img_count = 1

    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, data_dict_all = im_detect(predictor, data_batch, data_names, scales, cfg)
        
        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):            
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))                
                keep = soft_nms(cls_dets)
                keep = keep.tolist()                
                all_boxes[j][idx+delta] = cls_dets[keep, :]                
            
            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][idx+delta][:, -1]
                                          for j in range(1, imdb.num_classes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(all_boxes[j][idx+delta][:, -1] >= image_thresh)[0]
                        all_boxes[j][idx+delta] = all_boxes[j][idx+delta][keep, :]

            if vis:                
                boxes_this_image = [[]] + [all_boxes[j][idx+delta] for j in range(1, imdb.num_classes)]
                im_name = roidb[idx]['image']
                im_name = im_name.rsplit("/", 1)
                im_name = im_name[-1]                                
                result = draw_all_detection(data_dict['data'].asnumpy(), boxes_this_image, imdb.classes, 
                                            scales[delta], cfg, im_name, annotations, id_count, 
                                            nor_co_occur_matrix, row_max)
                annotations = result['ann']
                id_count = result['id_count']                
        
        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size)
        if logger:
            logger.info('testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, imdb.num_images, data_time / idx * test_data.batch_size, net_time / idx * test_data.batch_size, post_time / idx * test_data.batch_size))
        
    dataset.update({'annotations':annotations})
    save_annotation_file = '/home/user/Deformable-ConvNets-test/data/coco/annotations/instances_unlabeled2017_ssl.json'
    with open(save_annotation_file, 'w') as f:
        json.dump(dataset, f)

    print "Finish generate pseudo ground truth!"
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    np.random.seed(0)
    mx.random.seed(0)
    logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)

    # load training data
    train_data = AnchorLoader(feat_sym, roidb, config, batch_size=input_batch_size, shuffle=config.TRAIN.SHUFFLE, ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape

    # max_dats_shape=['data', (1,3,600,1000)]
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES])))]
    # max_data_shape=[], max_lable_shape=[]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    logger.info('providing maximum shape'+str(max_data_shape)+"  "+str(max_label_shape))

    data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)

    # add by chaojie
    logger.info("data_sahpe_dict:\n{}".format(pprint.pformat(data_shape_dict)))

    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)
    pprint.pprint(sym_instance.arg_shape_dict)

    logger.info("sym_instance.arg_shape_dict\n")
    logging.info(pprint.pformat(sym_instance.arg_shape_dict))
    #dot = mx.viz.plot_network(sym, node_attrs={'shape': 'rect', 'fixedsize': 'false'})
    #dot.render(os.path.join('./output/rcnn/network_vis', config.symbol + '_rcnn'))

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(sym, data_names=data_names, label_names=label_names,
                        logger=logger, context=ctx, max_data_shapes=[max_data_shape for _ in range(batch_size)],
                        max_label_shapes=[max_label_shape for _ in range(batch_size)], fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)

    # decide training params
    # metric
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    if config.TRAIN.JOINT_TRAINING or (not config.TRAIN.LEARN_NMS):
        rpn_eval_metric = metric.RPNAccMetric()
        rpn_cls_metric = metric.RPNLogLossMetric()
        rpn_bbox_metric = metric.RPNL1LossMetric()
        for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric]:
            eval_metrics.add(child_metric)
    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if config.TRAIN.LEARN_NMS:
        eval_metrics.add(metric.NMSLossMetric(config, 'pos'))
        eval_metrics.add(metric.NMSLossMetric(config, 'neg'))
        eval_metrics.add(metric.NMSAccMetric(config))

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1, save_optimizer_states=True),
                          callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.default.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
Example #21
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            cache_res = cPickle.load(fid)
            all_boxes = cache_res['all_boxes']
            all_keypoints = cache_res.get('all_keypoints')
        info_str = imdb.evaluate_detections(all_boxes,
                                            all_keypoints=all_keypoints)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[np.array([]) for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    all_keypoints = None
    if cfg.network.PREDICT_KEYPOINTS:
        all_keypoints = [[np.array([]) for _ in range(num_images)]
                         for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        rets = im_detect(predictor, data_batch, data_names, scales, cfg)
        scores_all = rets[0]
        boxes_all = rets[1]
        data_dict_all = rets[-1]
        if cfg.network.PREDICT_KEYPOINTS:
            pred_kps_all = rets[2]

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes,
                                  4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                      indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))
                keep = nms(cls_dets)
                all_boxes[j][idx + delta] = cls_dets[keep, :]
                if cfg.network.PREDICT_KEYPOINTS:
                    all_keypoints[j][idx + delta] = pred_kps_all[delta][
                        indexes, :][keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]
                        if cfg.network.PREDICT_KEYPOINTS:
                            all_keypoints[j][idx + delta] = all_keypoints[j][
                                idx + delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images,
                    data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))

    with open(det_file, 'wb') as f:
        cPickle.dump({
            'all_boxes': all_boxes,
            'all_keypoints': all_keypoints
        },
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes, all_keypoints=all_keypoints)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
Example #22
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    mx.random.seed(3)
    np.random.seed(3)
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    config['final_output_path'] = final_output_path

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    feat_pyramid_level = np.log2(config.network.RPN_FEAT_STRIDE).astype(int)
    feat_sym = [
        sym.get_internals()['rpn_cls_score_p' + str(x) + '_output']
        for x in feat_pyramid_level
    ]

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    #leonid: adding semicolumn ";" support to allow several different datasets to be merged
    datasets = config.dataset.dataset.split(';')
    image_sets = config.dataset.image_set.split(';')
    data_paths = config.dataset.dataset_path.split(';')
    if type(config.dataset.per_category_epoch_max) is str:
        per_category_epoch_max = [
            float(x) for x in config.dataset.per_category_epoch_max.split(';')
        ]
    else:
        per_category_epoch_max = [float(config.dataset.per_category_epoch_max)]
    roidbs = []
    categ_index_offs = 0
    if 'classes_list_fname' not in config.dataset:
        classes_list_fname = ''
    else:
        classes_list_fname = config.dataset.classes_list_fname

    if 'num_ex_per_class' not in config.dataset:
        num_ex_per_class = ''
    else:
        num_ex_per_class = config.dataset.num_ex_per_class

    for iD, dataset in enumerate(datasets):
        # load dataset and prepare imdb for training
        image_sets_cur = [iset for iset in image_sets[iD].split('+')]
        for image_set in image_sets_cur:
            cur_roidb, cur_num_classes = load_gt_roidb(
                dataset,
                image_set,
                config.dataset.root_path,
                data_paths[iD],
                flip=config.TRAIN.FLIP,
                per_category_epoch_max=per_category_epoch_max[iD],
                return_num_classes=True,
                categ_index_offs=categ_index_offs,
                classes_list_fname=classes_list_fname,
                num_ex_per_class=num_ex_per_class)

            roidbs.append(cur_roidb)
        categ_index_offs += cur_num_classes
        # roidbs.extend([
        #     load_gt_roidb(
        #         dataset,
        #         image_set,
        #         config.dataset.root_path,
        #         data_paths[iD],
        #         flip=config.TRAIN.FLIP,
        #         per_category_epoch_max=per_category_epoch_max[iD])
        #     for image_set in image_sets])
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)

    # load training data

    train_data = PyramidAnchorIterator(
        feat_sym,
        roidb,
        config,
        batch_size=input_batch_size,
        shuffle=config.TRAIN.SHUFFLE,
        ctx=ctx,
        feat_strides=config.network.RPN_FEAT_STRIDE,
        anchor_scales=config.network.ANCHOR_SCALES,
        anchor_ratios=config.network.ANCHOR_RATIOS,
        aspect_grouping=config.TRAIN.ASPECT_GROUPING,
        allowed_border=np.inf)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape

    if not config.network.base_net_lock:
        data_shape_dict = dict(train_data.provide_data_single +
                               train_data.provide_label_single)
    else:
        data_shape_dict = dict(train_data.provide_data_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    if config.TRAIN.LOAD_EMBEDDING:
        import cPickle
        with open(config.TRAIN.EMBEDDING_FNAME, 'rb') as fid:
            model_data = cPickle.load(fid)
        for fcn in ['1', '2', '3']:
            layer = model_data['dense_' + fcn]
            weight = ListList2ndarray(layer[0])
            bias = mx.nd.array(layer[1])
            arg_params['embed_dense_' + fcn + '_weight'] = weight
            arg_params['embed_dense_' + fcn + '_bias'] = bias

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    alt_fixed_param_prefix = config.network.ALT_FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    if not config.network.base_net_lock:
        label_names = [k[0] for k in train_data.provide_label_single]
    else:
        label_names = []

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix,
        alt_fixed_param_prefix=alt_fixed_param_prefix)

    # Leonid: Comment out the following two lines if switching to smaller number of GPUs and resuming training, then after it starts running un-comment back
    # if config.TRAIN.RESUME:
    #     mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)
    #TODO: release this.
    # decide training params
    # metric
    if not config.network.base_net_lock:
        rpn_eval_metric = metric.RPNAccMetric()
        rpn_cls_metric = metric.RPNLogLossMetric()
        rpn_bbox_metric = metric.RPNL1LossMetric()
    rpn_fg_metric = metric.RPNFGFraction(config)
    eval_metric = metric.RCNNAccMetric(config)
    eval_fg_metric = metric.RCNNFGAccuracy(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    if not config.network.base_net_lock:
        all_child_metrics = [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, rpn_fg_metric,
            eval_fg_metric, eval_metric, cls_metric, bbox_metric
        ]
    else:
        all_child_metrics = [
            rpn_fg_metric, eval_fg_metric, eval_metric, cls_metric, bbox_metric
        ]
    # all_child_metrics = [rpn_eval_metric, rpn_bbox_metric, rpn_fg_metric, eval_fg_metric, eval_metric, cls_metric, bbox_metric]

    ################################################
    ### added / updated by Leonid to support oneshot
    ################################################
    if config.network.EMBEDDING_DIM != 0:
        if config.network.EMBED_LOSS_ENABLED:
            all_child_metrics += [
                metric.RepresentativesMetric(config, final_output_path)
            ]  # moved from above. JS.
            all_child_metrics += [metric.EmbedMetric(config)]
            if config.network.BG_REPS:
                all_child_metrics += [metric.BGModelMetric(config)]
        if config.network.REPS_CLS_LOSS:
            all_child_metrics += [metric.RepsCLSMetric(config)]
        if config.network.ADDITIONAL_LINEAR_CLS_LOSS:
            all_child_metrics += [metric.RCNNLinLogLossMetric(config)]
        if config.network.VAL_FILTER_REGRESS:
            all_child_metrics += [metric.ValRegMetric(config)]
        if config.network.SCORE_HIST_REGRESS:
            all_child_metrics += [metric.ScoreHistMetric(config)]
    ################################################

    for child_metric in all_child_metrics:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None
    }
    #
    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    if args.debug == 1:
        import copy
        arg_params_ = copy.deepcopy(arg_params)
        aux_params_ = copy.deepcopy(aux_params)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            config=config)

    if args.debug == 1:
        t = dictCompare(aux_params_, aux_params)
        t = dictCompare(arg_params_, arg_params)
Example #23
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    mx.random.seed(3)
    np.random.seed(3)

    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    dot = mx.viz.plot_network(sym)
    dot.render('graph/nn.gv', view=False)
    all_layers = sym.get_internals().list_outputs()
    node_file = 'graph/nodes.txt'
    if os.path.exists(node_file):
        os.remove(node_file)
    with open(node_file, 'a+') as f:
        for layer in all_layers:
            f.write(layer + '\n')

    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    sdsdbs = []
    for image_set in image_sets:
        if image_set == 'train2014':
            gt_sdsdb_file = os.path.join(config.dataset.root_path, 'cache',
                                         'COCOMask', 'train2014',
                                         'gt_sdsdb.pkl')
            if os.path.exists(gt_sdsdb_file):
                with open(gt_sdsdb_file, 'rb') as f:
                    sdsdbs.append(pkl.load(f))
            else:
                train2014_sdsdb = load_gt_sdsdb(
                    config.dataset.dataset,
                    image_set,
                    config.dataset.root_path,
                    config.dataset.dataset_path,
                    mask_size=config.MASK_SIZE,
                    binary_thresh=config.BINARY_THRESH,
                    result_path=final_output_path,
                    flip=config.TRAIN.FLIP)
                with open(gt_sdsdb_file, 'wb') as f:
                    pkl.dump(train2014_sdsdb, f, protocol=pkl.HIGHEST_PROTOCOL)
                sdsdbs.append(train2014_sdsdb)
        elif image_set == 'valminusminival2014':
            gt_sdsdb_file = os.path.join(config.dataset.root_path, 'cache',
                                         'COCOMask', 'val2014', 'gt_sdsdb.pkl')
            if os.path.exists(gt_sdsdb_file):
                with open(gt_sdsdb_file, 'rb') as f:
                    sdsdbs.append(pkl.load(f))
            else:
                val2014_sdsdb = load_gt_sdsdb(
                    config.dataset.dataset,
                    image_set,
                    config.dataset.root_path,
                    config.dataset.dataset_path,
                    mask_size=config.MASK_SIZE,
                    binary_thresh=config.BINARY_THRESH,
                    result_path=final_output_path,
                    flip=config.TRAIN.FLIP)
                with open(gt_sdsdb_file, 'wb') as f:
                    pkl.dump(val2014_sdsdb, f, protocol=pkl.HIGHEST_PROTOCOL)
                sdsdbs.append(val2014_sdsdb)

    sdsdb = merge_roidb(sdsdbs)
    sdsdb = filter_roidb(sdsdb, config)

    # load training data
    train_data = AnchorLoader(feat_sym,
                              sdsdb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              allowed_border=config.TRAIN.RPN_ALLOWED_BORDER)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max(v[1] for v in config.SCALES)))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    max_data_shape.append(('gt_masks', (config.TRAIN.BATCH_IMAGES, 100,
                                        max([v[0] for v in config.SCALES]),
                                        max(v[1] for v in config.SCALES))))
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    print 'data shape:'
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
        max_label_shapes=[max_label_shape for _ in xrange(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # decide training metric
    # RPN, classification accuracy/loss, regression loss
    rpn_acc = metric.RPNAccMetric()
    rpn_cls_loss = metric.RPNLogLossMetric()
    rpn_bbox_loss = metric.RPNL1LossMetric()

    fcis_acc = metric.FCISAccMetric(config)
    fcis_acc_fg = metric.FCISAccFGMetric(config)
    fcis_cls_loss = metric.FCISLogLossMetric(config)
    fcis_bbox_loss = metric.FCISL1LossMetric(config)
    fcis_mask_loss = metric.FCISMaskLossMetric(config)

    eval_metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [
            rpn_acc, rpn_cls_loss, rpn_bbox_loss, fcis_acc, fcis_acc_fg,
            fcis_cls_loss, fcis_bbox_loss, fcis_mask_loss
    ]:
        eval_metrics.add(child_metric)

    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = callback.do_checkpoint(prefix, means, stds)

    # print epoch, begin_epoch, end_epoch, lr_step
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(sdsdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # del sdsdb
    # a = mx.viz.plot_network(sym)
    # a.render('../example', view=True)
    # print 'prepare sds finished'

    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            batches_checkpoint=epoch_end_callback,
            num_batches_save_ckpt=2000)
Example #24
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    mx.random.seed(3)
    np.random.seed(3)
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    feat_pyramid_level = np.log2(config.network.RPN_FEAT_STRIDE).astype(int)
    feat_sym = [
        sym.get_internals()['rpn_cls_score_p' + str(x) + '_output']
        for x in feat_pyramid_level
    ]

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)

    # load training data

    train_data = PyramidAnchorIterator(
        feat_sym,
        roidb,
        config,
        batch_size=input_batch_size,
        shuffle=config.TRAIN.SHUFFLE,
        ctx=ctx,
        feat_strides=config.network.RPN_FEAT_STRIDE,
        anchor_scales=config.network.ANCHOR_SCALES,
        anchor_ratios=config.network.ANCHOR_RATIOS,
        aspect_grouping=config.TRAIN.ASPECT_GROUPING,
        allowed_border=np.inf)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        # sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    # sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    rpn_fg_metric = metric.RPNFGFraction(config)
    eval_metric = metric.RCNNAccMetric(config)
    eval_fg_metric = metric.RCNNFGAccuracy(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, rpn_fg_metric,
            eval_fg_metric, eval_metric, cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    # batch_end_callback = callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    # epoch_end_callback = [mx.callback.module_checkpoint(mod, prefix, period=1,
    # save_optimizer_states=True), callback.do_checkpoint(prefix, means, stds)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None
    }
    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    net = FPNNet(sym, args_pretrained=arg_params, auxes_pretrained=aux_params)

    # create multi-threaded DataParallel Model.
    net_parallel = DataParallelModel(net, ctx_list=ctx)

    # create trainer,
    # !Important: A trainer can be only created after the function `resnet_ctx` is called.
    # Please Note that DataParallelModel will call reset_ctx to initialize parameters on gpus.
    trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', optimizer_params)

    for epoch in range(begin_epoch, config.TRAIN.end_epoch):
        train_data.reset()
        net.hybridize(static_alloc=True, static_shape=False)
        progress_bar = tqdm.tqdm(total=len(roidb))
        for nbatch, data_batch in enumerate(train_data):
            inputs = [[
                x.astype('f').as_in_context(c) for x in d + l
            ] for c, d, l in zip(ctx, data_batch.data, data_batch.label)]
            with ag.record():
                outputs = net_parallel(*inputs)
                ag.backward(sum(outputs, ()))
            trainer.step(1)
            eval_metrics.update(data_batch.label[0], outputs[0])
            if nbatch % 100 == 0:
                msg = ','.join([
                    '{}={:.3f}'.format(w, v)
                    for w, v in zip(*eval_metrics.get())
                ])
                msg += ",lr={}".format(trainer.learning_rate)
                logger.info(msg)
                print(msg)
                eval_metrics.reset()
            progress_bar.update(len(inputs))
        progress_bar.close()
        net.hybridize(static_alloc=True, static_shape=False)
        re = ("mAP", 0.0)
        logger.info(re)
        save_path = "{}-{}-{}.params".format(prefix, epoch, re[1])
        net.collect_params().save(save_path)
        logger.info("Saved checkpoint to {}.".format(save_path))
def pred_double_eval(predictor,
                     test_data,
                     imdb,
                     cfg,
                     vis=False,
                     thresh=1e-3,
                     logger=None,
                     ignore_cache=True,
                     show_gt=False):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]
    num_images = test_data.size

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    #if cfg.TEST.SOFTNMS:
    #    nms = py_softnms_wrapper(cfg.TEST.NMS)
    #else:
    #    nms = py_nms_wrapper(cfg.TEST.NMS)

    if cfg.TEST.SOFTNMS:
        nms = py_softnms_wrapper(cfg.TEST.NMS)
    else:
        nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    ref_all_boxes = [[[] for _ in range(num_images)]
                     for _ in range(imdb.num_classes)]
    # class_lut = [[] for _ in range(imdb.num_classes)]
    valid_tally = 0
    valid_sum = 0

    idx = 0
    t = time.time()
    inference_count = 0
    all_inference_time = []
    post_processing_time = []
    nms_full_count = []
    nms_pos_count = []
    is_max_count = []
    all_count = []
    for im_info, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        scores_all, boxes_all, ref_scores_all, ref_boxes_all, data_dict_all = im_double_detect(
            predictor, data_batch, data_names, scales, cfg)

        t2 = time.time() - t
        t = time.time()
        # for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, boxes_all, data_dict_all)):
        nms_full_count_per_batch = 0
        nms_pos_count_per_batch = 0
        global num_of_is_full_max
        is_max_count_per_batch = num_of_is_full_max[0]
        all_count_per_batch = 0
        for delta, (scores, boxes, ref_scores, ref_boxes,
                    data_dict) in enumerate(
                        zip(scores_all, boxes_all, ref_scores_all,
                            ref_boxes_all, data_dict_all)):
            if cfg.TEST.LEARN_NMS:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j - 1, 0] > thresh)[0]
                    cls_scores = scores[indexes, j - 1, :]
                    cls_boxes = boxes[indexes, j - 1, :]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        # class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)

                    all_boxes[j][idx + delta] = cls_dets

                    if DEBUG:
                        keep = nms(cls_dets)
                        nms_cls_dets = cls_dets[keep, :]
                        target = data_dict['nms_multi_target']
                        target_indices = np.where(target[:, 4] == j - 1)
                        target = target[target_indices]
                        nms_full_count_per_batch += bbox_equal_count(
                            nms_cls_dets, target)

                        gt_boxes = data_dict['gt_boxes'][0].asnumpy()
                        gt_boxes = gt_boxes[np.where(gt_boxes[:,
                                                              4] == j)[0], :4]
                        gt_boxes /= scales[delta]

                        if len(cls_boxes) != 0 and len(gt_boxes) != 0:
                            overlap_mat = bbox_overlaps(
                                cls_boxes.astype(np.float),
                                gt_boxes.astype(np.float))
                            keep = nms(
                                cls_dets[np.where(overlap_mat > 0.5)[0]])
                            nms_cls_dets = cls_dets[np.where(
                                overlap_mat > 0.5)[0]][keep]
                            nms_pos_count_per_batch += bbox_equal_count(
                                nms_cls_dets, target)
                        all_count_per_batch += len(target)
            else:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(scores[:, j] > thresh)[0]
                    if cfg.TEST.FIRST_N > 0:
                        # todo: check whether the order affects the result
                        sort_indices = np.argsort(
                            scores[:, j])[-cfg.TEST.FIRST_N:]
                        # sort_indices = np.argsort(-scores[:, j])[0:cfg.TEST.FIRST_N]
                        indexes = np.intersect1d(sort_indices, indexes)

                    cls_scores = scores[indexes, j, np.newaxis]
                    cls_boxes = boxes[indexes,
                                      4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                          indexes, j * 4:(j + 1) * 4]
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        # class_lut[j].append(idx+delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(scores)
                        # print np.min(cls_scores), valid_tally, valid_sum
                        # cls_scores = scores[:, j, np.newaxis]
                        # cls_scores[cls_scores <= thresh] = thresh
                        # cls_boxes = boxes[:, 4:8] if cfg.CLASS_AGNOSTIC else boxes[:, j * 4:(j + 1) * 4]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    if cfg.TEST.SOFTNMS:
                        all_boxes[j][idx + delta] = nms(cls_dets)
                    else:
                        keep = nms(cls_dets)
                        all_boxes[j][idx + delta] = cls_dets[keep, :]
                        # all_boxes[j][idx + delta] = cls_dets

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                if show_gt:
                    gt_boxes = data_dict['gt_boxes'][0]
                    for gt_box in gt_boxes:
                        gt_box = gt_box.asnumpy()
                        gt_cls = int(gt_box[4])
                        gt_box = gt_box / scales[delta]
                        gt_box[4] = 1
                        if cfg.TEST.LEARN_NMS:
                            gt_box = np.append(gt_box, 1)
                        boxes_this_image[gt_cls] = np.vstack(
                            (boxes_this_image[gt_cls], gt_box))

                    if cfg.TEST.LEARN_NMS:
                        target_boxes = data_dict['nms_multi_target']
                        for target_box in target_boxes:
                            print("cur", target_box * scales[delta])
                            target_cls = int(target_box[4]) + 1
                            target_box[4] = 2 + target_box[5]
                            target_box[5] = target_box[6]
                            target_box = target_box[:6]
                            boxes_this_image[target_cls] = np.vstack(
                                (boxes_this_image[target_cls], target_box))
                # vis_all_detection(data_dict['ref_data'].asnumpy(), boxes_this_image, imdb.classes, scales[delta], cfg)
                # vis_double_all_detection(data_dict['data'].asnumpy(), boxes_this_image, data_dict['ref_data'].asnumpy(), ref_boxes_this_image, imdb.classes, scales[delta], cfg)
            if cfg.TEST.LEARN_NMS:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(ref_scores[:, j - 1, 0] > thresh)[0]
                    cls_scores = ref_scores[indexes, j - 1, :]
                    cls_boxes = ref_boxes[indexes, j - 1, :]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        # class_lut[j].append(idx + delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(ref_scores)
                    ref_all_boxes[j][idx + delta] = cls_dets

                    if DEBUG:
                        pass
                        keep = nms(cls_dets)
                        nms_cls_dets = cls_dets[keep, :]
                        target = data_dict['ref_nms_multi_target']
                        target_indices = np.where(target[:, 4] == j - 1)
                        target = target[target_indices]
                        nms_full_count_per_batch += bbox_equal_count(
                            nms_cls_dets, target)

                        gt_boxes = data_dict['ref_gt_boxes'][0].asnumpy()
                        gt_boxes = gt_boxes[np.where(gt_boxes[:,
                                                              4] == j)[0], :4]
                        gt_boxes /= scales[delta]

                        if len(cls_boxes) != 0 and len(gt_boxes) != 0:
                            overlap_mat = bbox_overlaps(
                                cls_boxes.astype(np.float),
                                gt_boxes.astype(np.float))
                            keep = nms(
                                cls_dets[np.where(overlap_mat > 0.5)[0]])
                            nms_cls_dets = cls_dets[np.where(
                                overlap_mat > 0.5)[0]][keep]
                            nms_pos_count_per_batch += bbox_equal_count(
                                nms_cls_dets, target)
                        all_count_per_batch += len(target)
            else:
                for j in range(1, imdb.num_classes):
                    indexes = np.where(ref_scores[:, j] > thresh)[0]
                    if cfg.TEST.FIRST_N > 0:
                        # todo: check whether the order affects the result
                        sort_indices = np.argsort(
                            ref_scores[:, j])[-cfg.TEST.FIRST_N:]
                        # sort_indices = np.argsort(-scores[:, j])[0:cfg.TEST.FIRST_N]
                        indexes = np.intersect1d(sort_indices, indexes)

                    cls_scores = ref_scores[indexes, j, np.newaxis]
                    cls_boxes = ref_boxes[
                        indexes,
                        4:8] if cfg.CLASS_AGNOSTIC else ref_boxes[indexes, j *
                                                                  4:(j + 1) *
                                                                  4]
                    # count the valid ground truth
                    if len(cls_scores) > 0:
                        # class_lut[j].append(idx+delta)
                        valid_tally += len(cls_scores)
                        valid_sum += len(ref_scores)
                        # print np.min(cls_scores), valid_tally, valid_sum
                        # cls_scores = scores[:, j, np.newaxis]
                        # cls_scores[cls_scores <= thresh] = thresh
                        # cls_boxes = boxes[:, 4:8] if cfg.CLASS_AGNOSTIC else boxes[:, j * 4:(j + 1) * 4]
                    cls_dets = np.hstack((cls_boxes, cls_scores))
                    if cfg.TEST.SOFTNMS:
                        ref_all_boxes[j][idx + delta] = nms(cls_dets)
                    else:
                        keep = nms(cls_dets)
                        ref_all_boxes[j][idx + delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    ref_all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            ref_all_boxes[j][idx +
                                             delta][:, -1] >= image_thresh)[0]
                        ref_all_boxes[j][idx + delta] = ref_all_boxes[j][
                            idx + delta][keep, :]

            if vis:
                ref_boxes_this_image = [[]] + [
                    ref_all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                if show_gt:
                    gt_boxes = data_dict['ref_gt_boxes'][0]
                    for gt_box in gt_boxes:
                        gt_box = gt_box.asnumpy()
                        gt_cls = int(gt_box[4])
                        gt_box = gt_box / scales[delta]
                        gt_box[4] = 1
                        if cfg.TEST.LEARN_NMS:
                            gt_box = np.append(gt_box, 1)
                        ref_boxes_this_image[gt_cls] = np.vstack(
                            (ref_boxes_this_image[gt_cls], gt_box))

                    if cfg.TEST.LEARN_NMS:
                        target_boxes = data_dict['ref_nms_multi_target']
                        for target_box in target_boxes:
                            print("ref", target_box * scales[delta])
                            target_cls = int(target_box[4]) + 1
                            target_box[4] = 2 + target_box[5]
                            target_box[5] = target_box[6]
                            target_box = target_box[:6]
                            ref_boxes_this_image[target_cls] = np.vstack(
                                (ref_boxes_this_image[target_cls], target_box))
                vis_double_all_detection(data_dict['data'][0:1].asnumpy(),
                                         boxes_this_image,
                                         data_dict['data'][1:2].asnumpy(),
                                         ref_boxes_this_image, imdb.classes,
                                         scales[delta], cfg)
                # vis_all_detection(data_dict['ref_data'].asnumpy(), ref_boxes_this_image, imdb.classes, scales[delta], cfg)

        if DEBUG:
            nms_full_count.append(nms_full_count_per_batch)
            nms_pos_count.append(nms_pos_count_per_batch)
            is_max_count.append(is_max_count_per_batch)
            all_count.append(all_count_per_batch)
            print("full:{} pos:{} max:{}".format(
                1.0 * sum(nms_full_count) / sum(all_count),
                1.0 * sum(nms_pos_count) / sum(all_count),
                1.0 * sum(is_max_count) / sum(all_count)))
        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        post_processing_time.append(t3)
        all_inference_time.append(t1 + t2 + t3)
        inference_count += 1
        if inference_count % 200 == 0:
            valid_count = 500 if inference_count > 500 else inference_count
            print("--->> running-average inference time per batch: {}".format(
                float(sum(all_inference_time[-valid_count:])) / valid_count))
            print("--->> running-average post processing time per batch: {}".
                  format(
                      float(sum(post_processing_time[-valid_count:])) /
                      valid_count))
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, num_images, t1, t2, t3)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, t1, t2, t3))
Example #26
0
def train_net(args, ctx, pretrained_dir, pretrained_resnet, epoch, prefix,
              begin_epoch, end_epoch, lr, lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    feat_sym = sym.get_internals()['rpn_cls_score_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    git_commit_id = commands.getoutput('git rev-parse HEAD')
    print("Git commit id:", git_commit_id)
    logger.info('Git commit id: {}'.format(git_commit_id))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      motion_iou_path=config.dataset.motion_iou_path,
                      flip=config.TRAIN.FLIP,
                      use_philly=args.usePhilly) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    train_data = AnchorLoader(feat_sym,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              normalize_target=config.network.NORMALIZE_RPN,
                              bbox_mean=config.network.ANCHOR_MEANS,
                              bbox_std=config.network.ANCHOR_STDS)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES])))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print('providing maximum shape', max_data_shape, max_label_shape)

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # load and initialize params
    params_loaded = False
    if config.TRAIN.RESUME:
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)
        print('continue training from ', begin_epoch)
        logger.info('continue training from ', begin_epoch)
        params_loaded = True
    elif config.TRAIN.AUTO_RESUME:
        for cur_epoch in range(end_epoch - 1, begin_epoch, -1):
            params_filename = '{}-{:04d}.params'.format(prefix, cur_epoch)
            states_filename = '{}-{:04d}.states'.format(prefix, cur_epoch)
            if os.path.exists(params_filename) and os.path.exists(
                    states_filename):
                begin_epoch = cur_epoch
                arg_params, aux_params = load_param(prefix,
                                                    cur_epoch,
                                                    convert=True)
                mod._preload_opt_states = states_filename
                print('auto continue training from {}, {}'.format(
                    params_filename, states_filename))
                logger.info('auto continue training from {}, {}'.format(
                    params_filename, states_filename))
                params_loaded = True
                break
    if not params_loaded:
        arg_params, aux_params = load_param(os.path.join(
            pretrained_dir, pretrained_resnet),
                                            epoch,
                                            convert=True)

    sym_instance.init_weight(config, arg_params, aux_params)
    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # decide training params
    # metric
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()

    for child_metric in [eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if config.TRAIN.JOINT_TRAINING or (not config.TRAIN.LEARN_NMS):
        rpn_eval_metric = metric.RPNAccMetric()
        rpn_cls_metric = metric.RPNLogLossMetric()
        rpn_bbox_metric = metric.RPNL1LossMetric()
        for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric]:
            eval_metrics.add(child_metric)
    if config.TRAIN.LEARN_NMS:
        eval_metrics.add(metric.NMSLossMetric(config, 'pos'))
        eval_metrics.add(metric.NMSLossMetric(config, 'neg'))
        eval_metrics.add(metric.NMSAccMetric(config))

    # callback
    batch_end_callback = [
        callback.Speedometer(train_data.batch_size, frequent=args.frequent)
    ]

    if config.USE_PHILLY:
        total_iter = (end_epoch - begin_epoch) * len(roidb) / input_batch_size
        progress_frequent = min(args.frequent * 10, 100)
        batch_end_callback.append(
            callback.PhillyProgressCallback(total_iter, progress_frequent))

    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    # base_lr = lr * len(ctx) * config.TRAIN.BATCH_IMAGES
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
Example #27
0
def pred_eval_impression_offline_seq_nms(gpu_id, first_predictor, cur_predictor, key_predictor, test_data, imdb, cfg, vis=False, thresh=1e-4, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """
    det_file = os.path.join(imdb.result_path, imdb.name + '_' + str(gpu_id))
    print 'det_file=', det_file
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes, frame_ids = cPickle.load(fid)
        return all_boxes, frame_ids

    assert vis or not test_data.shuffle
    vis = False
    data_names = [k[0] for k in test_data.provide_data[0]]
    num_images = test_data.size
    roidb_frame_ids = [x['frame_id'] for x in test_data.roidb]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    # limit detections to max_per_image over all classes
    # max_per_image = cfg.TEST.max_per_image
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    frame_ids = np.zeros(num_images, dtype=np.int)

    roidb_idx = -1
    roidb_offset = -1
    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, key_frame_flag, data_batch in test_data:
        t1 = time.time() - t
        t= time.time()
        scales = [iim_info[0, 2] for iim_info in im_info]
        if key_frame_flag == 0: # current frame is the first frame of the video
            key_scores_all, key_boxes_all, key_data_dict_all, conv_feat, _, _ = im_detect_impression_offline(first_predictor, data_batch, data_names, scales, cfg)
            feat_task = conv_feat
            impression = conv_feat
            data_batch.data[0][-1] = feat_task
            data_batch.provide_data[0][-1] = ('key_feat_task', feat_task.shape)
            scores_all, pred_boxes_all, data_dict_all, _, _, _ = im_detect_impression_offline(cur_predictor, data_batch, data_names, scales, cfg)
        elif key_frame_flag == 1: #current frame is the keyframe
            scores_all = key_scores_all
            pred_boxes_all = key_boxes_all
            data_dict_all = key_data_dict_all
        elif key_frame_flag == 2: #first frame of the new segment
            data_batch.data[0][-2] = impression
            data_batch.provide_data[0][-2] = ('impression', impression.shape)
            key_scores_all, key_boxes_all, key_data_dict_all, _, impression, feat_task = im_detect_impression_offline(key_predictor, data_batch, data_names, scales, cfg)
            data_batch.data[0][-1] = feat_task
            data_batch.provide_data[0][-1] = ('key_feat_task', feat_task.shape)
            scores_all, pred_boxes_all, data_dict_all, _, _, _ = im_detect_impression_offline(cur_predictor, data_batch, data_names, scales, cfg)
        else:
            data_batch.data[0][-1] = feat_task
            data_batch.provide_data[0][-1] = ('key_feat_task', feat_task.shape)
            scores_all, pred_boxes_all, data_dict_all, _, _, _ = im_detect_impression_offline(cur_predictor, data_batch, data_names, scales, cfg)

        if key_frame_flag == 0:
            roidb_idx += 1
            roidb_offset = 0
        else:
            roidb_offset += 1
        frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset
        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(zip(scores_all, pred_boxes_all, data_dict_all)):
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes, 4:8] if cfg.CLASS_AGNOSTIC else boxes[indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))

                all_boxes[j][idx + delta] = cls_dets

            if vis:
                show_boxes_with_nms(data_dict['data_cur'].asnumpy(), scores, boxes, imdb.classes, scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'seq nms testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, num_images,
                                                                           data_time / idx * test_data.batch_size,
                                                                           net_time / idx * test_data.batch_size,
                                                                           post_time / idx * test_data.batch_size)
        if logger:
            logger.info('seq nms testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(idx, num_images,
                                                                                     data_time / idx * test_data.batch_size,
                                                                                     net_time / idx * test_data.batch_size,
                                                                                     post_time / idx * test_data.batch_size))
    with open(det_file, 'wb') as f:
        cPickle.dump((all_boxes, frame_ids), f, protocol=cPickle.HIGHEST_PROTOCOL)
    return all_boxes, frame_ids
Example #28
0
def pred_eval(gpu_id,
              key_predictor,
              cur_predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-4,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path,
                            imdb.name + '_' + str(gpu_id) + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes, frame_ids = cPickle.load(fid)
        return all_boxes, frame_ids

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]
    num_images = test_data.size
    roidb_frame_ids = [x['frame_id'] for x in test_data.roidb]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    frame_ids = np.zeros(num_images, dtype=np.int)

    roidb_idx = -1
    roidb_offset = -1
    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for im_info, key_frame_flag, data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        scales = [iim_info[0, 2] for iim_info in im_info]
        if key_frame_flag != 2:
            scores_all, boxes_all, data_dict_all, feat = im_detect(
                key_predictor, data_batch, data_names, scales, cfg)
        else:
            data_batch.data[0][-1] = feat
            data_batch.provide_data[0][-1] = ('feat_key', feat.shape)
            scores_all, boxes_all, data_dict_all, _ = im_detect(
                cur_predictor, data_batch, data_names, scales, cfg)

        if key_frame_flag == 0:
            roidb_idx += 1
            roidb_offset = 0
        else:
            roidb_offset += 1

        frame_ids[idx] = roidb_frame_ids[roidb_idx] + roidb_offset

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes, data_dict) in enumerate(
                zip(scores_all, boxes_all, data_dict_all)):
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes,
                                  4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                      indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))
                keep = nms(cls_dets)
                all_boxes[j][idx + delta] = cls_dets[keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, num_images, data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))

    with open(det_file, 'wb') as f:
        cPickle.dump((all_boxes, frame_ids),
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)

    return all_boxes, frame_ids
Example #29
0
def pred_eval(predictor,
              test_data,
              imdb,
              vis=False,
              ignore_cache=None,
              logger=None):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param ignore_cache: ignore the saved cache file
    :param logger: the logger instance
    :return:
    """
    res_file = os.path.join(imdb.result_path, imdb.name + '_segmentations.pkl')
    if os.path.exists(res_file) and not ignore_cache:
        with open(res_file, 'rb') as fid:
            evaluation_results = cPickle.load(fid)
        print 'evaluate segmentation: \n'
        if logger:
            logger.info('evaluate segmentation: \n')

        meanIU = evaluation_results['meanIU']
        IU_array = evaluation_results['IU_array']
        print 'IU_array:\n'
        if logger:
            logger.info('IU_array:\n')
        for i in range(len(IU_array)):
            print '%.5f' % IU_array[i]
            if logger:
                logger.info('%.5f' % IU_array[i])
        print 'meanIU:%.5f' % meanIU
        if logger:
            logger.info('meanIU:%.5f' % meanIU)
        return

    assert vis or not test_data.shuffle
    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    num_images = imdb.num_images
    all_segmentation_result = [[] for _ in xrange(num_images)]
    idx = 0

    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for data_batch in test_data:
        t1 = time.time() - t
        t = time.time()
        output_all = predictor.predict(data_batch)
        output_all = [
            mx.ndarray.argmax(output['softmax_output'], axis=1).asnumpy()
            for output in output_all
        ]
        t2 = time.time() - t
        t = time.time()

        all_segmentation_result[idx:idx + test_data.batch_size] = [
            output.astype('int8') for output in output_all
        ]

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()

        data_time += t1
        net_time += t2
        post_time += t3
        print 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, data_time / idx * test_data.batch_size,
            net_time / idx * test_data.batch_size,
            post_time / idx * test_data.batch_size)
        if logger:
            logger.info(
                'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
                    idx, imdb.num_images,
                    data_time / idx * test_data.batch_size,
                    net_time / idx * test_data.batch_size,
                    post_time / idx * test_data.batch_size))

    evaluation_results = imdb.evaluate_segmentations(all_segmentation_result)

    if not os.path.exists(res_file) or ignore_cache:
        with open(res_file, 'wb') as f:
            cPickle.dump(evaluation_results,
                         f,
                         protocol=cPickle.HIGHEST_PROTOCOL)

    print 'evaluate segmentation: \n'
    if logger:
        logger.info('evaluate segmentation: \n')

    meanIU = evaluation_results['meanIU']
    IU_array = evaluation_results['IU_array']
    print 'IU_array:\n'
    if logger:
        logger.info('IU_array:\n')
    for i in range(len(IU_array)):
        print '%.5f' % IU_array[i]
        if logger:
            logger.info('%.5f' % IU_array[i])
    print 'meanIU:%.5f' % meanIU
    if logger:
        logger.info('meanIU:%.5f' % meanIU)
def train_net(args, ctx, pretrained, pretrained_flow, epoch, prefix,
              begin_epoch, end_epoch, lr, lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_train_symbol(config)
    feat_sym = sym.get_internals()['rpn_cls_score_output']
    feat_conv_3x3_relu = sym.get_internals()['feat_conv_3x3_relu_output']

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [
        load_gt_roidb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    train_data = AnchorLoader(feat_sym,
                              feat_conv_3x3_relu,
                              roidb,
                              config,
                              batch_size=input_batch_size,
                              shuffle=config.TRAIN.SHUFFLE,
                              ctx=ctx,
                              feat_stride=config.network.RPN_FEAT_STRIDE,
                              anchor_scales=config.network.ANCHOR_SCALES,
                              anchor_ratios=config.network.ANCHOR_RATIOS,
                              aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                              normalize_target=config.network.NORMALIZE_RPN,
                              bbox_mean=config.network.ANCHOR_MEANS,
                              bbox_std=config.network.ANCHOR_STDS)

    # infer max shape
    #max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES]))),
    #                  ('data_ref', (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]), max([v[1] for v in config.SCALES]))),
    #                  ('eq_flag', (1,))]
    data_shape1 = {
        'data_ref':
        (config.TRAIN.BATCH_IMAGES, 3, max([v[0] for v in config.SCALES]),
         max([v[1] for v in config.SCALES])),
    }
    _, feat_shape111, _ = feat_conv_3x3_relu.infer_shape(**data_shape1)

    max_data_shape = [('data_ref', (config.TRAIN.BATCH_IMAGES, 3,
                                    max([v[0] for v in config.SCALES]),
                                    max([v[1] for v in config.SCALES]))),
                      ('eq_flag', (1, )),
                      ('motion_vector', (config.TRAIN.BATCH_IMAGES, 2,
                                         int(feat_shape111[0][2]),
                                         int(feat_shape111[0][3])))]

    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    max_data_shape.append(('gt_boxes', (config.TRAIN.BATCH_IMAGES, 100, 5)))
    print 'providing maximum shape', max_data_shape, max_label_shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        #arg_params_flow, aux_params_flow = load_param(pretrained_flow, epoch, convert=True)
        #arg_params.update(arg_params_flow)
        #aux_params.update(aux_params_flow)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [
            rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric,
            cls_metric, bbox_metric
    ]:
        eval_metrics.add(child_metric)
    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS),
                    2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS),
                   2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [
        mx.callback.module_checkpoint(mod,
                                      prefix,
                                      period=1,
                                      save_optimizer_states=True),
        callback.do_checkpoint(prefix, means, stds)
    ]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    print('Start to train model')
    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def pred_eval(predictor, test_data, imdb, cfg, vis=False, thresh=1e-3, logger=None, ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            all_boxes = cPickle.load(fid)
        info_str = imdb.evaluate_detections(all_boxes)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data[0]]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image
    num_images = imdb.num_images

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(imdb.result_path, imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        # if os.path.exists(det_file_single_scale):
        #    continue
        cfg.SCALES = [test_scale]
        test_data.reset()

        # all detections are collected into:
        #    all_boxes[cls][image] = N x 5 array of detections in
        #    (x1, y1, x2, y2, score)
        all_boxes_single_scale = [[[] for _ in range(num_images)]
                                  for _ in range(imdb.num_classes)]

        detect_at_single_scale(predictor, data_names, imdb, test_data, cfg, thresh, vis, all_boxes_single_scale, logger)

        with open(det_file_single_scale, 'wb') as f:
            cPickle.dump(all_boxes_single_scale, f, protocol=cPickle.HIGHEST_PROTOCOL)

    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in range(num_images)] for _ in range(imdb.num_classes)]

    for test_scale_index, test_scale in enumerate(cfg.TEST_SCALES):
        det_file_single_scale = os.path.join(imdb.result_path, imdb.name + '_detections_' + str(test_scale_index) + '.pkl')
        if os.path.exists(det_file_single_scale):
            with open(det_file_single_scale, 'rb') as fid:
                all_boxes_single_scale = cPickle.load(fid)
            for idx_class in range(1, imdb.num_classes):
                for idx_im in range(0, num_images):
                    if len(all_boxes[idx_class][idx_im]) == 0:
                        all_boxes[idx_class][idx_im] = all_boxes_single_scale[idx_class][idx_im]
                    else:
                        all_boxes[idx_class][idx_im] = np.vstack((all_boxes[idx_class][idx_im], all_boxes_single_scale[idx_class][idx_im]))

    for idx_class in range(1, imdb.num_classes):
        for idx_im in range(0, num_images):
            if cfg.TEST.USE_SOFTNMS:
                soft_nms = py_softnms_wrapper(cfg.TEST.SOFTNMS_THRESH, max_dets=max_per_image)
                all_boxes[idx_class][idx_im] = soft_nms(all_boxes[idx_class][idx_im])
            else:
                nms = py_nms_wrapper(cfg.TEST.NMS)
                keep = nms(all_boxes[idx_class][idx_im])
                all_boxes[idx_class][idx_im] = all_boxes[idx_class][idx_im][keep, :]

    if max_per_image > 0:
        for idx_im in range(0, num_images):
            image_scores = np.hstack([all_boxes[j][idx_im][:, -1]
                                      for j in range(1, imdb.num_classes)])
            if len(image_scores) > max_per_image:
                image_thresh = np.sort(image_scores)[-max_per_image]
                for j in range(1, imdb.num_classes):
                    keep = np.where(all_boxes[j][idx_im][:, -1] >= image_thresh)[0]
                    all_boxes[j][idx_im] = all_boxes[j][idx_im][keep, :]

    with open(det_file, 'wb') as f:
        cPickle.dump(all_boxes, f, protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))