def get_symbol_train(network, data_shape, **kwargs):
    """Wrapper for get symbol for train

    Parameters
    ----------
    network : str
        name for the base network symbol
    data_shape : int
        input shape
    kwargs : dict
        see symbol_builder.get_symbol_train for more details
    """
    if network.startswith('legacy'):
        logging.warn('Using legacy model.')
        return symbol_builder.import_module(network).get_symbol_train(**kwargs)
    config = get_config(network, data_shape, **kwargs).copy()
    config.update(kwargs)
    return symbol_builder.get_symbol_train(**config)
Пример #2
0
def get_symbol_train(network, data_shape, **kwargs):
    """Wrapper for get symbol for train

    Parameters
    ----------
    network : str
        name for the base network symbol
    data_shape : int
        input shape
    kwargs : dict
        see symbol_builder.get_symbol_train for more details
    """
    if network.startswith('legacy'):
        logging.warn('Using legacy model.')
        return symbol_builder.import_module(network).get_symbol_train(**kwargs)
    config = get_config(network, data_shape, **kwargs).copy()
    config.update(kwargs)
    return symbol_builder.get_symbol_train(**config)
Пример #3
0
def train_net(net, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              optimizer_name='nadam',
              freeze_layer_pattern='',
              shape_range=(320, 512), random_shape_step=0, random_shape_epoch=10,
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    shape_range : tuple of (min, max)
        random data shape range
    random_shape_step : int
        step size for random data shape, defined by network, 0 to disable
    random_step_epoch : int
        number of epoch before next random shape
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net.strip('_yolo') + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    # load symbol
    sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
    # symbol_module = importlib.import_module("symbol_" + net)
    # net = symbol_module.get_symbol(num_classes, nms_thresh=nms_thresh,
    #     force_suppress=force_suppress, nms_topk=nms_topk)
    net = get_symbol_train(net, num_classes, nms_thresh, force_suppress, nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    allow_missing = True
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
        allow_missing = False
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)

    # init training module
    mod = mx.mod.Module(net, label_names=('yolo_output_label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    random_shape_step = int(random_shape_step)
    if random_shape_step > 0:
        fit_begins = range(begin_epoch, end_epoch, random_shape_epoch)
        fit_ends = fit_begins[1:] + [end_epoch]
        assert (len(shape_range) == 2)
        data_shapes = [(3, x * random_shape_step, x * random_shape_step) \
            for x in range(shape_range[0] // random_shape_step,
            shape_range[1] // random_shape_step + 1)]
        logger.info("Candidate random shapes:" + str(data_shapes))
    else:
        fit_begins = [begin_epoch]
        fit_ends = [end_epoch]
        data_shapes = [data_shape]

    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
        lr_refactor_ratio, num_example, batch_size, begin_epoch)
    optimizer_params={'learning_rate': learning_rate,
                      'wd': weight_decay,
                      'lr_scheduler': lr_scheduler,
                      'clip_gradient': 4.0,
                      'rescale_grad': 1.0 }
    if optimizer_name in ('sgd', 'nag'):
        optimizer_params['momentum'] = momentum
    # optimizer_params={'learning_rate':learning_rate,
    #                   'momentum':momentum,
    #                   'wd':weight_decay,
    #                   'lr_scheduler':lr_scheduler,
    #                   'clip_gradient':10,
    #                   'rescale_grad': 1.0 }
    for begin, end in zip(fit_begins, fit_ends):
        if len(data_shapes) == 1:
            data_shape = data_shapes[0]
        else:
            data_shape = data_shapes[random.randint(0, len(data_shapes)-1)]
            logger.info("Setting random data shape: " + str(data_shape))

        train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
            label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

        if val_path:
            val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
                label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
        else:
            val_iter = None


        # more informatic parameter setting
        if not mod.binded:
            mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
            mod = set_mod_params(mod, args, auxs, logger)

        mod.fit(train_iter,
                val_iter,
                eval_metric=MultiBoxMetric(),
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin,
                num_epoch=end,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=allow_missing,
                monitor=monitor,
                force_rebind=True,
                force_init=True)

        args, auxs = mod.get_params()
        allow_missing = False