Ejemplo n.º 1
0
def get_data(name, meta_dir, gpu_nums):
    isTrain = name == 'train'
    ds = Cityscapes(meta_dir, name, shuffle=True)
    if isTrain:  #special augmentation
        shape_aug = [
            RandomResize(xrange=(0.7, 1.5),
                         yrange=(0.7, 1.5),
                         aspect_ratio_thres=0.15),
            RandomCropWithPadding(args.crop_size, IGNORE_LABEL),
            Flip(horiz=True),
        ]
    else:
        shape_aug = []

    ds = AugmentImageComponents(ds, shape_aug, (0, 1), copy=False)

    def f(ds):
        image, label = ds
        m = np.array([104, 116, 122])
        const_arr = np.resize(m, (1, 1, 3))  # NCHW
        image = image - const_arr
        return image, label

    ds = MapData(ds, f)
    if isTrain:
        ds = BatchData(ds, args.batch_size * gpu_nums)
        ds = PrefetchDataZMQ(ds, 3)
    else:
        ds = BatchData(ds, 1)
    return ds
Ejemplo n.º 2
0
def train_net(args, ctx):
    logger.auto_set_dir()

    from symbols.symbol_resnet_deeplabv2 import resnet101_deeplab_new
    sym_instance = resnet101_deeplab_new()

    sym = sym_instance.get_symbol(NUM_CLASSES,
                                  is_train=True,
                                  use_global_stats=True)

    eval_sym_instance = resnet101_deeplab_new()
    eval_sym = eval_sym_instance.get_symbol(NUM_CLASSES,
                                            is_train=False,
                                            use_global_stats=True)

    # setup multi-gpu
    gpu_nums = len(ctx)
    input_batch_size = args.batch_size * gpu_nums

    train_data = get_data("train", LIST_DIR, len(ctx))
    test_data = get_data("val", LIST_DIR, len(ctx))

    # infer shape
    data_shape_dict = {
        'data': (args.batch_size, 3, args.crop_size[0], args.crop_size[1]),
        'label': (args.batch_size, 1, args.crop_size[0], args.crop_size[1])
    }

    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    epoch_string = args.load.rsplit("-", 2)[1]
    begin_epoch = 1
    if not args.scratch:
        begin_epoch = int(epoch_string)
        logger.info('continue training from {}'.format(begin_epoch))
        arg_params, aux_params = load_init_param(args.load, convert=True)
    else:
        logger.info(args.load)
        arg_params, aux_params = load_init_param(args.load, convert=True)
        sym_instance.init_weights(arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    data_names = ['data']
    label_names = ['label']

    mod = MutableModule(sym,
                        data_names=data_names,
                        label_names=label_names,
                        context=ctx,
                        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    fcn_loss_metric = metric.FCNLogLossMetric(args.frequent,
                                              Cityscapes.class_num())
    eval_metrics = mx.metric.CompositeEvalMetric()

    # callback
    batch_end_callbacks = [
        callback.Speedometer(input_batch_size, frequent=args.frequent)
    ]
    #batch_end_callbacks = [mx.callback.ProgressBar(total=train_data.size/train_data.batch_size)]
    epoch_end_callbacks = \
        [mx.callback.module_checkpoint(mod, os.path.join(logger.get_logger_dir(),"mxnetgo"), period=1, save_optimizer_states=True),
         ]

    lr_scheduler = StepScheduler(train_data.size() * EPOCH_SCALE, lr_step_list)

    # optimizer
    optimizer_params = {
        'momentum': 0.9,
        'wd': 0.0005,
        'learning_rate': 2.5e-4,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    logger.info("epoch scale = {}".format(EPOCH_SCALE))
    mod.fit(train_data=train_data,
            args=args,
            eval_sym=eval_sym,
            eval_sym_instance=eval_sym_instance,
            eval_data=test_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callbacks,
            batch_end_callback=batch_end_callbacks,
            kvstore=kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            epoch_scale=EPOCH_SCALE,
            validation_on_last=validation_on_last)
Ejemplo n.º 3
0
os.environ['PYTHONUNBUFFERED'] = '1'
os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'
os.environ['MXNET_ENABLE_GPU_P2P'] = '0'

IGNORE_LABEL = 255

CROP_HEIGHT = 672
CROP_WIDTH = 672
tile_height = 1024
tile_width = 1024
batch_size = 7

EPOCH_SCALE = 4
end_epoch = 10
lr_step_list = [(6, 1e-3), (10, 1e-4)]
NUM_CLASSES = Cityscapes.class_num()
validation_on_last = 2
kvstore = "device"
fixed_param_prefix = ["conv1", "bn_conv1", "res2", "bn2", "gamma", "beta"]
symbol_str = "resnet_v1_101_deeplab"


def parse_args():
    parser = argparse.ArgumentParser(description='Train deeplab network')
    parser.add_argument("--gpu", default="5")
    parser.add_argument('--frequent',
                        help='frequency of logging',
                        default=800,
                        type=int)
    parser.add_argument('--view', action='store_true')
    parser.add_argument("--validation", action="store_true")
Ejemplo n.º 4
0
def get_data(name, data_dir, meta_dir, config):
    ds = Cityscapes(data_dir, meta_dir, name, shuffle=False)
    ds = BatchData(ds, 1)
    return ds