예제 #1
0
def parfetch(config, crop_width, crop_height, isegdb):
    # get testing data for multigpu

    if config.dataset.dataset == "PascalVOC" or config.dataset.dataset == "ADE20K":
        datas = {}
        labels = {}
        datas['data'], labels['label'] = get_segmentation_image_voc(
            isegdb, config)
        if config.network.use_metric:
            labels['metric_label'] = generate_metric_label(labels['label'])
        if config.TRAIN.use_mult_metric:
            for i in [1, 2, 4]:
                labels['metric_label_' + str(i)] = generate_metric_label(
                    labels['label'], skip_step=i)

        return {'data': datas, 'label': labels}
    else:
        datas, labels = get_segmentation_train_batch(isegdb, config)
        feature_stride = config.network.LABEL_STRIDE
        network_ratio = config.network.ratio
        if config.TRAIN.enable_crop:
            datas_internal = datas['data']
            labels_internal = labels['label']
            sx = math.floor(random.random() *
                            (datas_internal.shape[3] - crop_width + 1))
            sy = math.floor(random.random() *
                            (datas_internal.shape[2] - crop_height + 1))
            sx = (int)(sx)
            sy = (int)(sy)

            assert (sx >= 0 and sx < datas_internal.shape[3] - crop_width + 1)
            assert (sy >= 0 and sy < datas_internal.shape[2] - crop_height + 1)

            ex = (int)(sx + crop_width - 1)
            ey = (int)(sy + crop_height - 1)

            datas_internal = datas_internal[:, :, sy:ey + 1, sx:ex + 1]
            labels_internal = labels_internal[:, :, sy:ey + 1, sx:ex + 1]

            if config.network.use_crop_context:
                crop_context_scale = config.network.crop_context_scale

                scale_width = make_divisible(
                    int(float(crop_width) / crop_context_scale),
                    feature_stride)
                scale_height = make_divisible(
                    int(float(crop_height) / crop_context_scale),
                    feature_stride)
                pad_width = int(scale_width - crop_width) / 2
                pad_height = int(scale_height - crop_height) / 2

                datas['origin_data'] = np.zeros(
                    (datas['data'].shape[0], datas['data'].shape[1],
                     datas['data'].shape[2] + 2 * int(pad_height),
                     datas['data'].shape[3] + 2 * int(pad_width)))
                datas['origin_data'][:, :,
                                     int(pad_height):datas['data'].shape[2] +
                                     int(pad_height),
                                     int(pad_width):datas['data'].shape[3] +
                                     int(pad_width)] = datas['data']

                labels['origin_label'] = np.full(
                    (labels['label'].shape[0], labels['label'].shape[1],
                     labels['label'].shape[2] + 2 * int(pad_height),
                     labels['label'].shape[3] + 2 * int(pad_width)), 255)
                labels[
                    'origin_label'][:, :,
                                    int(pad_height):labels['label'].shape[2] +
                                    int(pad_height),
                                    int(pad_width):labels['label'].shape[3] +
                                    int(pad_width)] = labels['label']

                datas_origin = datas['origin_data'][:, :, sy:sy + scale_height,
                                                    sx:sx + scale_width]

                labels_origin = labels['origin_label'][:, :,
                                                       sy:sy + scale_height,
                                                       sx:sx + scale_width]

                datas['origin_data'] = datas_origin
                labels['origin_label'] = labels_origin

                # labels_origin_in = np.zeros((labels['origin_label'].shape[0],labels['origin_label'].shape[1],
                #                   labels['origin_label'].shape[2]//feature_stride,labels['origin_label'].shape[3]//feature_stride))
                # for i, label in enumerate(labels['origin_label']):
                #     label_im = Image.fromarray(np.squeeze(label.astype(np.uint8, copy=False))).resize(
                #         (labels['origin_label'].shape[3] // feature_stride,
                #          labels['origin_label'].shape[2] // feature_stride), Image.NEAREST)
                #     label = np.array(label_im)
                #     labels_origin_in[i, 0, :, :] = label
                #
                # labels['origin_label']=labels_origin_in

                rois = []
                for i, im_info in zip(xrange(datas_internal.shape[0]),
                                      datas['im_info']):
                    rois.append(
                        np.array([
                            i, pad_width, pad_height, pad_width + crop_width,
                            pad_height + crop_height
                        ]).reshape((1, 5)))
                datas['rois'] = tensor_vstack(rois)
                # print rois

                datas['data'] = datas_internal
                labels['label'] = labels_internal

            else:
                rois = []
                for i, im_info in zip(xrange(datas_internal.shape[0]),
                                      datas['im_info']):
                    scale = im_info[2]
                    rois.append(
                        np.array([
                            i, sx * network_ratio / scale,
                            sy * network_ratio / scale,
                            (ex + 1) * network_ratio / scale,
                            (ey + 1) * network_ratio / scale
                        ]).reshape((1, 5)))
                datas['rois'] = tensor_vstack(rois)

                datas['data'] = datas_internal
                labels['label'] = labels_internal
                assert (datas['data'].shape[2]
                        == crop_height) and (datas['data'].shape[3]
                                             == crop_width)
        else:
            datas_internal = datas['data']
            rois = []
            for i, im_info in zip(xrange(datas_internal.shape[0]),
                                  datas['im_info']):
                im_size = im_info[:2]
                rois.append(
                    np.array([
                        i, 0, 0, im_size[1] * network_ratio,
                        im_size[0] * network_ratio
                    ]).reshape((1, 5)))
            datas['rois'] = tensor_vstack(rois)

        # if feature_stride == 1:
        #     assert (labels['label'].shape[2] == crop_height) and (labels['label'].shape[3] == crop_width)
        # else:

        labels_in = dict()
        labels_in['origin_label'] = labels['origin_label']
        labels_in['label'] = np.zeros(
            (labels['label'].shape[0], labels['label'].shape[1],
             labels['label'].shape[2] // feature_stride,
             labels['label'].shape[3] // feature_stride))

        # to reshape the label to the network label
        for i, label in enumerate(labels['label']):
            label_im = Image.fromarray(
                np.squeeze(label.astype(np.uint8, copy=False))).resize(
                    (labels['label'].shape[3] // feature_stride,
                     labels['label'].shape[2] // feature_stride),
                    Image.NEAREST)
            label = np.array(label_im)
            labels_in['label'][i, 0, :, :] = label

        labels = labels_in

        if config.TRAIN.enable_ignore_border:
            labels['label'] = border_ignore_label(
                labels['label'], config.TRAIN.ignore_border_size, 255.0)

        if config.network.use_metric:
            labels['metric_label'] = generate_metric_label(labels['label'])

        if config.TRAIN.use_mult_metric:
            scale_name = ['a', 'b', 'c']
            if config.network.scale_list == [1, 2, 4]:
                scale_name = ['', '', '']
            for idx, i in enumerate(config.network.scale_list):
                labels['metric_label_' + str(i) +
                       scale_name[idx]] = generate_metric_label(
                           labels['label'], skip_step=i)

        return {'data': datas, 'label': labels}
예제 #2
0
파일: train_v1.py 프로젝트: zhuangyqin/DRN
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):

    logger, final_output_path, _, tensorboard_path = create_env(
        config.output_path, args.cfg, config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    print "config.symbol", config.symbol
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    if config.TRAIN.use_dynamic:
        sym_gen = sym_instance.sym_gen(config, is_train=True)
    else:
        sym = sym_instance.get_symbol(config, is_train=True)

    # infer max shape
    scales = [(config.TRAIN.crop_size[0], config.TRAIN.crop_size[1])
              ] if config.TRAIN.enable_crop else config.SCALES
    label_stride = config.network.LABEL_STRIDE
    network_ratio = config.network.ratio
    if config.network.use_context:
        if config.network.use_crop_context:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3, 736, 736)),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3, config.TRAIN.crop_size[0],
                  config.TRAIN.crop_size[1])),
                ('origin_data', (config.TRAIN.BATCH_IMAGES, 3,
                                 int(config.SCALES[0][0] * network_ratio),
                                 int(config.SCALES[0][1] * network_ratio))),
                ('rois', (config.TRAIN.BATCH_IMAGES, 5))
            ]
    else:
        if config.TRAIN.enable_crop:
            max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                        config.TRAIN.crop_size[0],
                                        config.TRAIN.crop_size[1]))]
        else:
            max_data_shape = [
                ('data',
                 (config.TRAIN.BATCH_IMAGES, 3,
                  max([make_divisible(v[0], label_stride) for v in scales]),
                  max([make_divisible(v[1], label_stride) for v in scales])))
            ]

    if config.network.use_mult_label:

        if config.network.use_crop_context:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1, 736, 736))
            ]

        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  make_divisible(config.TRAIN.crop_size[0], label_stride) //
                  config.network.LABEL_STRIDE,
                  make_divisible(config.TRAIN.crop_size[1], label_stride) //
                  config.network.LABEL_STRIDE)),
                ('origin_label', (config.TRAIN.BATCH_IMAGES, 1,
                                  int(config.SCALES[0][0] * network_ratio),
                                  int(config.SCALES[0][1] * network_ratio)))
            ]
    elif config.network.use_metric:
        scale_list = config.network.scale_list
        scale_name = ['a', 'b', 'c']
        if config.network.scale_list == [1, 2, 4]:
            scale_name = ['', '', '']

        if config.TRAIN.enable_crop:
            if config.TRAIN.use_mult_metric:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]

            else:
                max_label_shape = [
                    ('label', (config.TRAIN.BATCH_IMAGES, 1,
                               config.TRAIN.crop_size[0] // label_stride,
                               config.TRAIN.crop_size[1] // label_stride)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      config.TRAIN.crop_size[0] // label_stride,
                      config.TRAIN.crop_size[1] // label_stride))
                ]
        else:
            if config.TRAIN.use_mult_metric:

                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[0]) + scale_name[0],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[1]) + scale_name[1],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label_' + str(scale_list[2]) + scale_name[2],
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

            else:
                max_label_shape = [
                    ('label',
                     (config.TRAIN.BATCH_IMAGES, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE)),
                    ('metric_label',
                     (config.TRAIN.BATCH_IMAGES, 9, 1,
                      max([make_divisible(v[0], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE,
                      max([make_divisible(v[1], label_stride)
                           for v in scales]) // config.network.LABEL_STRIDE))
                ]

    else:
        if config.TRAIN.enable_crop:
            max_label_shape = [('label',
                                (config.TRAIN.BATCH_IMAGES, 1,
                                 config.TRAIN.crop_size[0] // label_stride,
                                 config.TRAIN.crop_size[1] // label_stride))]
        else:
            max_label_shape = [
                ('label',
                 (config.TRAIN.BATCH_IMAGES, 1,
                  max([make_divisible(v[0], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE,
                  max([make_divisible(v[1], label_stride)
                       for v in scales]) // config.network.LABEL_STRIDE))
            ]

    print "max_label_shapes", max_label_shape
    if config.TRAIN.use_dynamic:
        sym = sym_gen([max_data_shape])

        # setup multi-gpu
    input_batch_size = config.TRAIN.BATCH_IMAGES * len(ctx)
    NUM_GPUS = len(ctx)

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=True) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx,
                                 use_context=config.network.use_context,
                                 use_mult_label=config.network.use_mult_label,
                                 use_metric=config.network.use_metric)

    # loading val data
    if config.TRAIN.eval_data_frequency > 0:
        val_image_set = config.dataset.test_image_set
        val_root_path = config.dataset.root_path
        val_dataset = config.dataset.dataset
        val_dataset_path = config.dataset.dataset_path
        val_imdb = eval(val_dataset)(val_image_set,
                                     val_root_path,
                                     val_dataset_path,
                                     result_path=final_output_path)
        val_segdb = val_imdb.gt_segdb()

        val_data = TrainDataLoader(
            sym,
            val_segdb,
            config,
            batch_size=input_batch_size,
            shuffle=config.TRAIN.SHUFFLE,
            ctx=ctx,
            use_context=config.network.use_context,
            use_mult_label=config.network.use_mult_label,
            use_metric=config.network.use_metric)
    else:
        val_data = None

    # print sym.list_arguments()
    print 'providing maximum shape', max_data_shape, max_label_shape
    max_data_shape, max_label_shape = train_data.infer_shape(
        max_data_shape, max_label_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    if config.TRAIN.use_dynamic:
        sym = sym_gen([train_data.provide_data_single])
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    nset = set()
    for nm in sym.list_arguments():
        if nm in nset:
            raise ValueError('Duplicate names detected, %s' % str(nm))
        nset.add(nm)

    # load and initialize params
    if config.TRAIN.RESUME:
        print 'continue training from ', begin_epoch
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)
        preload_opt_states = load_preload_opt_states(prefix, begin_epoch)
        # preload_opt_states = None
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        preload_opt_states = None
        if not config.TRAIN.FINTUNE:
            fixed_param_names = sym_instance.init_weights(
                config, arg_params, aux_params)
        sym_instance.check_parameter_shapes(arg_params,
                                            aux_params,
                                            data_shape_dict,
                                            is_train=True)

    # check parameter shapes
    # sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(NUM_GPUS)],
        max_label_shapes=[max_label_shape for _ in xrange(NUM_GPUS)],
        fixed_param_prefix=fixed_param_prefix)

    # metric
    imagecrossentropylossmetric = metric.ImageCrossEntropyLossMetric()
    localmetric = metric.LocalImageCrossEntropyLossMetric()
    globalmetric = metric.GlobalImageCrossEntropyLossMetric()
    pixcelAccMetric = metric.PixcelAccMetric()
    eval_metrics = mx.metric.CompositeEvalMetric()

    if config.network.use_mult_label:
        metric_list = [
            imagecrossentropylossmetric, localmetric, globalmetric,
            pixcelAccMetric
        ]
    elif config.network.use_metric:
        if config.TRAIN.use_crl_ses:
            metric_list = [
                imagecrossentropylossmetric,
                metric.SigmoidPixcelAccMetric(1),
                metric.SigmoidPixcelAccMetric(2),
                metric.SigmoidPixcelAccMetric(3),
                metric.CenterLossMetric(4),
                metric.CenterLossMetric(5),
                metric.CenterLossMetric(6), pixcelAccMetric
            ]

        elif config.network.use_sigmoid_metric:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(1),
                    metric.SigmoidPixcelAccMetric(2),
                    metric.SigmoidPixcelAccMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.SigmoidPixcelAccMetric(), pixcelAccMetric
                ]
        else:
            if config.TRAIN.use_mult_metric:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1),
                    metric.MetricLossMetric(2),
                    metric.MetricLossMetric(3), pixcelAccMetric
                ]
            else:
                metric_list = [
                    imagecrossentropylossmetric,
                    metric.MetricLossMetric(1), pixcelAccMetric
                ]
    elif config.network.mult_loss:
        metric_list = [
            imagecrossentropylossmetric,
            metric.MImageCrossEntropyLossMetric(1),
            metric.MImageCrossEntropyLossMetric(2), pixcelAccMetric
        ]
    elif config.TRAIN.use_center:
        if config.TRAIN.use_one_center:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1)
            ]
        else:
            metric_list = [
                imagecrossentropylossmetric, pixcelAccMetric,
                metric.CenterLossMetric(1),
                metric.CenterLossMetric(2),
                metric.CenterLossMetric(3)
            ]
    else:
        metric_list = [imagecrossentropylossmetric, pixcelAccMetric]
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in metric_list:
        eval_metrics.add(child_metric)

    # callback
    if False:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size,
                                 frequent=args.frequent),
            callback.TensorboardCallback(tensorboard_path,
                                         prefix="train/batch")
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        shared_tensorboard = batch_end_callback[1]

        epoch_end_metric_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="train/epoch")
        eval_end_callback = callback.TensorboardCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix="val/epoch")
        lr_callback = callback.LrCallback(
            tensorboard_path,
            shared_tensorboard=shared_tensorboard,
            prefix='train/batch')
    else:
        batch_end_callback = [
            callback.Speedometer(train_data.batch_size, frequent=args.frequent)
        ]
        epoch_end_callback = mx.callback.module_checkpoint(
            mod, prefix, period=1, save_optimizer_states=True)
        epoch_end_metric_callback = None
        eval_end_callback = None

    #decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / input_batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    if config.TRAIN.lr_type == "MultiStage":
        lr_scheduler = LinearWarmupMultiStageScheduler(
            lr_iters,
            lr_factor,
            config.TRAIN.warmup,
            config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step,
            args.frequent,
            stop_lr=lr * 0.01)
    elif config.TRAIN.lr_type == "MultiFactor":
        lr_scheduler = LinearWarmupMultiFactorScheduler(
            lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr,
            config.TRAIN.warmup_step, args.frequent)

    if config.TRAIN.optimizer == "sgd":
        optimizer_params = {
            'momentum': config.TRAIN.momentum,
            'wd': config.TRAIN.wd,
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = SGD(**optimizer_params)
    elif config.TRAIN.optimizer == "adam":
        optimizer_params = {
            'learning_rate': lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0,
            'clip_gradient': None
        }
        optimizer = Adam(**optimizer_params)
        print "optimizer adam"

    freeze_layer_pattern = config.TRAIN.FIXED_PARAMS_PATTERN

    if freeze_layer_pattern.strip():
        args_lr_mult = {}
        re_prog = re.compile(freeze_layer_pattern)
        if freeze_layer_pattern:
            fixed_param_names = [
                name for name in sym.list_arguments() if re_prog.match(name)
            ]
        print "============================"
        print "fixed_params_names:"
        print(fixed_param_names)
        for name in fixed_param_names:
            args_lr_mult[name] = config.TRAIN.FIXED_PARAMS_PATTERN_LR_MULT
        print "============================"
    else:
        args_lr_mult = {}
    optimizer.set_lr_mult(args_lr_mult)

    # data_shape_dict = dict(train_data.provide_data_single + train_data.provide_label_single)
    # if config.TRAIN.use_dynamic:
    #     sym = sym_gen([train_data.provide_data_single])
    # pprint.pprint(data_shape_dict)
    # sym_instance.infer_shape(data_shape_dict)

    if not isinstance(train_data, PrefetchingIter) and config.TRAIN.use_thread:
        train_data = PrefetchingIter(train_data)

    if val_data:
        if not isinstance(val_data, PrefetchingIter):
            val_data = PrefetchingIter(val_data)

    if Debug:
        monitor = mx.monitor.Monitor(1)
    else:
        monitor = None

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            eval_end_callback=eval_end_callback,
            epoch_end_metric_callback=epoch_end_metric_callback,
            optimizer=optimizer,
            eval_data=val_data,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            allow_missing=begin_epoch == 0,
            allow_extra=True,
            monitor=monitor,
            preload_opt_states=preload_opt_states,
            eval_data_frequency=config.TRAIN.eval_data_frequency)
예제 #3
0
def get_segmentation_test_batch(segdb,
                                config,
                                is_train=False,
                                has_label=True,
                                scale=1.0):
    """
    return a dict of train batch
    :param segdb: ['image', 'flipped']
    :param config: the config setting
    :return: data, label, im_info
    """
    imgs, seg_cls_gts, segdb, origin_ims, origin_labels = get_segmentation_image(
        segdb, config, is_train=is_train, has_label=has_label, scale=scale)

    im_array = tensor_vstack(imgs)
    if has_label:
        seg_cls_gt = tensor_vstack(seg_cls_gts)
    else:
        seg_cls_gt = []
    im_info = tensor_vstack(
        [np.array([isegdb['im_info']], dtype=np.float32) for isegdb in segdb])
    origin_im = tensor_vstack(origin_ims)
    rois = []

    if config.network.use_crop_context:
        crop_context_scale = config.network.crop_context_scale
        crop_height, crop_width = config.TRAIN.crop_size
        feature_stride = config.network.LABEL_STRIDE
        scale_width = make_divisible(
            int(float(crop_width) / crop_context_scale), feature_stride)
        scale_height = make_divisible(
            int(float(crop_height) / crop_context_scale), feature_stride)
        pad_width = int(scale_width - crop_width) / 2
        pad_height = int(scale_height - crop_height) / 2

        origin_data = np.zeros((im_array.shape[0], im_array.shape[1],
                                im_array.shape[2] + 2 * int(pad_height),
                                im_array.shape[3] + 2 * int(pad_width)))
        origin_data[:, :,
                    int(pad_height):im_array.shape[2] + int(pad_height),
                    int(pad_width):im_array.shape[3] +
                    int(pad_width)] = im_array

        for i, im_info in enumerate(im_info):
            im_size = im_info[:2]
            rois.append(
                np.array([
                    i, pad_width, pad_height, pad_width + im_size[1],
                    pad_width + im_size[0]
                ]).reshape((1, 5)))
        rois = tensor_vstack(rois)
        # print rois

    else:
        network_ratio = config.network.ratio
        for i, im_info in enumerate(im_info):
            im_size = im_info[:2]
            rois.append(
                np.array([
                    i, 0, 0, im_size[1] * network_ratio,
                    im_size[0] * network_ratio
                ]).reshape((1, 5)))
        rois = tensor_vstack(rois)
        print rois

    data = {
        'data': im_array,
        'im_info': im_info,
        'origin_data': origin_im,
        'rois': rois
    }

    label = {'label': seg_cls_gt}

    return {'data': data, 'label': label}
예제 #4
0
def test_deeplab():
    epoch = config.TEST.test_epoch
    ctx = [mx.gpu(int(i)) for i in config.gpus.split(',')]
    image_set = config.dataset.test_image_set
    root_path = config.dataset.root_path
    dataset = config.dataset.dataset
    dataset_path = config.dataset.dataset_path

    logger, final_output_path, experiments_path, _ = create_env(
        config.output_path, args.cfg, image_set)
    prefix = os.path.join(
        final_output_path, '..',
        '_'.join([iset for iset in config.dataset.image_set.split('+')]),
        config.TRAIN.model_prefix)

    # print config
    logger.info('testing config:{}\n'.format(pprint.pformat(config)))

    imdb = eval(dataset)(image_set,
                         root_path,
                         dataset_path,
                         result_path=experiments_path)
    segdb = imdb.gt_segdb()

    #get test data iter
    batch_size = (config.TEST.BATCH_IMAGES) * len(ctx)
    mctx = ctx

    test_data = TestDataLoader(segdb,
                               config=config,
                               batch_size=batch_size,
                               shuffle=False,
                               ctx=mctx,
                               has_label=imdb.has_label)

    # infer shape
    data_shapes = [(data_shape[1][0], data_shape[1][1], data_shape[1][2],
                    data_shape[1][3])
                   for data_shape in test_data.provide_data_single]

    provide_data_single = [('data', (1, 3,
                                     make_divisible(data_shapes[0][2], 32),
                                     make_divisible(data_shapes[0][3], 32)))]
    data_shape_dict = dict(provide_data_single)
    config.SCALES = [(provide_data_single[0][1][2],
                      provide_data_single[0][1][3])]
    # load symbol and testing data
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=False)
    config.SCALES = [(data_shape_dict['data'][2], data_shape_dict['data'][3])]
    sym_instance.infer_shape(data_shape_dict)

    arg_params, aux_params = load_param(prefix, epoch, process=True)
    # sym_instance.init_weights(config,arg_params,aux_params)
    print "arg_params keys", arg_params.keys()
    sym_instance.check_parameter_shapes(arg_params,
                                        aux_params,
                                        data_shape_dict,
                                        is_train=False)

    # decide maximum shape
    data_names = [k[0] for k in test_data.provide_data_single]
    label_names = ['label']
    max_data_shape = [[('data', (config.TEST.BATCH_IMAGES, 3,
                                 max([v[0] for v in config.SCALES]),
                                 max([v[1] for v in config.SCALES])))]]

    # create predictor
    sym_gen = sym_instance.sym_gen(config)

    predictor = Predictor(sym_gen,
                          data_names,
                          label_names,
                          context=ctx,
                          max_data_shapes=max_data_shape,
                          provide_data=[provide_data_single],
                          provide_label=test_data.provide_label,
                          arg_params=arg_params,
                          aux_params=aux_params)

    # start detection
    args.ignore_cache = True
    pred_eval(predictor,
              test_data,
              imdb,
              vis=args.vis,
              ignore_cache=args.ignore_cache,
              logger=logger)