Exemplo n.º 1
0
def train_net(args, ctx, pretrained, pretrained_base, pretrained_ec, epoch,
              prefix, begin_epoch, end_epoch, lr, lr_step):
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    # load symbol
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_train_symbol(config)

    # setup multi-gpu
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    # load training data
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 crop_height=config.TRAIN.CROP_HEIGHT,
                                 crop_width=config.TRAIN.CROP_WIDTH,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    # infer max shape
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in config.SCALES]),
                                max([v[1] for v in config.SCALES]))),
                      ('data_ref', (config.TRAIN.KEY_INTERVAL - 1, 3,
                                    max([v[0] for v in config.SCALES]),
                                    max([v[1] for v in config.SCALES]))),
                      ('eq_flag', (1, ))]
    max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape)
    print 'providing maximum shape', max_data_shape, max_label_shape

    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        print pretrained
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        arg_params_base, aux_params_base = load_param(pretrained_base,
                                                      epoch,
                                                      convert=True)
        arg_params.update(arg_params_base)
        aux_params.update(aux_params_base)
        arg_params_ec, aux_params_ec = load_param(
            pretrained_ec,
            epoch,
            convert=True,
            argprefix=config.TRAIN.arg_prefix)
        arg_params.update(arg_params_ec)
        aux_params.update(aux_params_ec)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in range(batch_size)],
        max_label_shapes=[max_label_shape for _ in range(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states' % (prefix, begin_epoch)

    # decide training params
    # metric
    fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent *
                                              batch_size)
    eval_metrics = mx.metric.CompositeEvalMetric()

    for child_metric in [fcn_loss_metric]:
        eval_metrics.add(child_metric)

    # callback
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)

    # decide learning rate
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print 'lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr,
              lr_step):
    """Main train function for segmentation

    Args:
        args:
            paramenter parser
        ctx:
            GPU context
        pretrained:
            pretrained file path
        epoch:
            pretrained checkpoint epoch
        prefix:
            model save name prefix
        begin_epoch:
            which epoch start to train
        end_epoch:
            eneded epoch of training phase
        lr:
            learning rate
        lr_step:
            list of epoch number to do learning rate decay

    """
    ##########################################
    # Step 1. Create logger and set up the save prefix
    ##########################################
    logger, final_output_path = create_logger(config.output_path, args.cfg,
                                              config.dataset.image_set)
    prefix = os.path.join(final_output_path, prefix)

    ##########################################
    # Step 2. Copy the symbols and load the symbol to build network
    ##########################################
    shutil.copy2(os.path.join(curr_path, 'symbols', config.symbol + '.py'),
                 final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)
    #
    #sym = eval('get_' + args.network + '_train')(num_classes=config.dataset.NUM_CLASSES)

    ##########################################
    # Step 3. Setup multi-gpu and batch size
    ##########################################
    batch_size = len(ctx)
    input_batch_size = config.TRAIN.BATCH_IMAGES * batch_size

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    ############################################
    # Step 4. load dataset and prepare imdb for training
    ############################################
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    segdbs = [
        load_gt_segdb(config.dataset.dataset,
                      image_set,
                      config.dataset.root_path,
                      config.dataset.dataset_path,
                      result_path=final_output_path,
                      flip=config.TRAIN.FLIP) for image_set in image_sets
    ]
    segdb = merge_segdb(segdbs)

    ############################################
    # Step 5. Set dataloader and set the data shape
    ############################################
    train_data = TrainDataLoader(sym,
                                 segdb,
                                 config,
                                 batch_size=input_batch_size,
                                 crop_height=config.TRAIN.CROP_HEIGHT,
                                 crop_width=config.TRAIN.CROP_WIDTH,
                                 shuffle=config.TRAIN.SHUFFLE,
                                 ctx=ctx)

    # infer max shape
    max_scale = [(config.TRAIN.CROP_HEIGHT, config.TRAIN.CROP_WIDTH)]
    max_data_shape = [('data', (config.TRAIN.BATCH_IMAGES, 3,
                                max([v[0] for v in max_scale]),
                                max([v[1] for v in max_scale])))]
    max_label_shape = [('label', (config.TRAIN.BATCH_IMAGES, 1,
                                  max([v[0] for v in max_scale]),
                                  max([v[1] for v in max_scale])))]
    # max_data_shape, max_label_shape = train_data.infer_shape(max_data_shape, max_label_shape)
    print('providing maximum shape', max_data_shape, max_label_shape)

    # infer shape
    data_shape_dict = dict(train_data.provide_data_single +
                           train_data.provide_label_single)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    ##############################################
    # Step 6. load and initialize params
    ##############################################
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        print(pretrained)
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weights(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params,
                                        data_shape_dict)

    ##############################################
    # Step 6 Create solver and set metrics
    ##############################################
    fixed_param_prefix = config.network.FIXED_PARAMS
    data_names = [k[0] for k in train_data.provide_data_single]
    label_names = [k[0] for k in train_data.provide_label_single]

    mod = MutableModule(
        sym,
        data_names=data_names,
        label_names=label_names,
        logger=logger,
        context=ctx,
        max_data_shapes=[max_data_shape for _ in xrange(batch_size)],
        max_label_shapes=[max_label_shape for _ in xrange(batch_size)],
        fixed_param_prefix=fixed_param_prefix)

    # decide training params
    # metric
    fcn_loss_metric = metric.FCNLogLossMetric(config.default.frequent *
                                              batch_size)
    eval_metrics = mx.metric.CompositeEvalMetric()

    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [fcn_loss_metric]:
        eval_metrics.add(child_metric)

    ##############################################
    # Step 7. Set callback for training process
    ##############################################
    batch_end_callback = callback.Speedometer(train_data.batch_size,
                                              frequent=args.frequent)
    epoch_end_callback = mx.callback.module_checkpoint(
        mod, prefix, period=1, save_optimizer_states=True)

    ##############################################
    # Step 8. Decide learning rate and optimizers
    ##############################################
    base_lr = lr
    lr_factor = 0.1
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [
        epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * len(segdb) / batch_size) for epoch in lr_epoch_diff
    ]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)

    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor,
                                              config.TRAIN.warmup,
                                              config.TRAIN.warmup_lr,
                                              config.TRAIN.warmup_step)

    # optimizer
    optimizer_params = {
        'momentum': config.TRAIN.momentum,
        'wd': config.TRAIN.wd,
        'learning_rate': lr,
        'lr_scheduler': lr_scheduler,
        'rescale_grad': 1.0,
        'clip_gradient': None
    }

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    ##############################################
    # Step 9 Start to train
    ##############################################
    mod.fit(train_data,
            eval_metric=eval_metrics,
            epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback,
            kvstore=config.default.kvstore,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            arg_params=arg_params,
            aux_params=aux_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch)