Exemplo n.º 1
0
def prepare_trainer(net,
                    optimizer_name,
                    wd,
                    momentum,
                    lr_mode,
                    lr,
                    lr_decay_period,
                    lr_decay_epoch,
                    lr_decay,
                    target_lr,
                    poly_power,
                    warmup_epochs,
                    warmup_lr,
                    warmup_mode,
                    batch_size,
                    num_epochs,
                    num_training_samples,
                    dtype,
                    state_file_path=None):

    if lr_decay_period > 0:
        lr_decay_epoch = list(range(lr_decay_period, num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in lr_decay_epoch.split(',')]
    num_batches = num_training_samples // batch_size
    lr_scheduler = LRScheduler(
        mode=lr_mode,
        base_lr=lr,
        n_iters=num_batches,
        n_epochs=num_epochs,
        step=lr_decay_epoch,
        step_factor=lr_decay,
        target_lr=target_lr,
        power=poly_power,
        warmup_epochs=warmup_epochs,
        warmup_lr=warmup_lr,
        warmup_mode=warmup_mode)

    optimizer_params = {'learning_rate': lr,
                        'wd': wd,
                        'momentum': momentum,
                        'lr_scheduler': lr_scheduler}
    if dtype != 'float32':
        optimizer_params['multi_precision'] = True

    trainer = gluon.Trainer(
        params=net.collect_params(),
        optimizer=optimizer_name,
        optimizer_params=optimizer_params)

    if (state_file_path is not None) and state_file_path and os.path.exists(state_file_path):
        logging.info('Loading trainer states: {}'.format(state_file_path))
        trainer.load_states(state_file_path)
        if trainer._optimizer.wd != wd:
            trainer._optimizer.wd = wd
            logging.info('Reset the weight decay: {}'.format(wd))
        # lr_scheduler = trainer._optimizer.lr_scheduler
        trainer._optimizer.lr_scheduler = lr_scheduler

    return trainer, lr_scheduler
Exemplo n.º 2
0
def prepare_trainer(net,
                    optimizer_name,
                    wd,
                    momentum,
                    lr_mode,
                    lr,
                    lr_decay_period,
                    lr_decay_epoch,
                    lr_decay,
                    target_lr,
                    poly_power,
                    warmup_epochs,
                    warmup_lr,
                    warmup_mode,
                    batch_size,
                    num_epochs,
                    num_training_samples,
                    dtype,
                    gamma_wd_mult=1.0,
                    beta_wd_mult=1.0,
                    bias_wd_mult=1.0,
                    state_file_path=None):
    """
    Prepare trainer.

    Parameters:
    ----------
    net : HybridBlock
        Model.
    optimizer_name : str
        Name of optimizer.
    wd : float
        Weight decay rate.
    momentum : float
        Momentum value.
    lr_mode : str
        Learning rate scheduler mode.
    lr : float
        Learning rate.
    lr_decay_period : int
        Interval for periodic learning rate decays.
    lr_decay_epoch : str
        Epoches at which learning rate decays.
    lr_decay : float
        Decay rate of learning rate.
    target_lr : float
        Final learning rate.
    poly_power : float
        Power value for poly LR scheduler.
    warmup_epochs : int
        Number of warmup epochs.
    warmup_lr : float
        Starting warmup learning rate.
    warmup_mode : str
        Learning rate scheduler warmup mode.
    batch_size : int
        Training batch size.
    num_epochs : int
        Number of training epochs.
    num_training_samples : int
        Number of training samples in dataset.
    dtype : str
        Base data type for tensors.
    gamma_wd_mult : float
        Weight decay multiplier for batchnorm gamma.
    beta_wd_mult : float
        Weight decay multiplier for batchnorm beta.
    bias_wd_mult : float
        Weight decay multiplier for bias.
    state_file_path : str, default None
        Path for file with trainer state.

    Returns:
    -------
    Trainer
        Trainer.
    LRScheduler
        Learning rate scheduler.
    """
    if gamma_wd_mult != 1.0:
        for k, v in net.collect_params(".*gamma").items():
            v.wd_mult = gamma_wd_mult

    if beta_wd_mult != 1.0:
        for k, v in net.collect_params(".*beta").items():
            v.wd_mult = beta_wd_mult

    if bias_wd_mult != 1.0:
        for k, v in net.collect_params(".*bias").items():
            v.wd_mult = bias_wd_mult

    if lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in lr_decay_epoch.split(",")]
    num_batches = num_training_samples // batch_size
    lr_scheduler = LRScheduler(mode=lr_mode,
                               base_lr=lr,
                               n_iters=num_batches,
                               n_epochs=num_epochs,
                               step=lr_decay_epoch,
                               step_factor=lr_decay,
                               target_lr=target_lr,
                               power=poly_power,
                               warmup_epochs=warmup_epochs,
                               warmup_lr=warmup_lr,
                               warmup_mode=warmup_mode)

    optimizer_params = {
        "learning_rate": lr,
        "wd": wd,
        "momentum": momentum,
        "lr_scheduler": lr_scheduler
    }
    if dtype != "float32":
        optimizer_params["multi_precision"] = True

    trainer = gluon.Trainer(params=net.collect_params(),
                            optimizer=optimizer_name,
                            optimizer_params=optimizer_params)

    if (state_file_path is not None
        ) and state_file_path and os.path.exists(state_file_path):
        logging.info("Loading trainer states: {}".format(state_file_path))
        trainer.load_states(state_file_path)
        if trainer._optimizer.wd != wd:
            trainer._optimizer.wd = wd
            logging.info("Reset the weight decay: {}".format(wd))
        # lr_scheduler = trainer._optimizer.lr_scheduler
        trainer._optimizer.lr_scheduler = lr_scheduler

    return trainer, lr_scheduler
Exemplo n.º 3
0
def prepare_trainer(net,
                    optimizer_name,
                    wd,
                    momentum,
                    lr_mode,
                    lr,
                    lr_decay_period,
                    lr_decay_epoch,
                    lr_decay,
                    target_lr,
                    poly_power,
                    warmup_epochs,
                    warmup_lr,
                    warmup_mode,
                    batch_size,
                    num_epochs,
                    num_training_samples,
                    dtype,
                    gamma_wd_mult=1.0,
                    beta_wd_mult=1.0,
                    bias_wd_mult=1.0,
                    state_file_path=None):

    if gamma_wd_mult != 1.0:
        for k, v in net.collect_params(".*gamma").items():
            v.wd_mult = gamma_wd_mult

    if beta_wd_mult != 1.0:
        for k, v in net.collect_params(".*beta").items():
            v.wd_mult = beta_wd_mult

    if bias_wd_mult != 1.0:
        for k, v in net.collect_params(".*bias").items():
            v.wd_mult = bias_wd_mult

    if lr_decay_period > 0:
        lr_decay_epoch = list(
            range(lr_decay_period, num_epochs, lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in lr_decay_epoch.split(",")]
    num_batches = num_training_samples // batch_size
    lr_scheduler = LRScheduler(mode=lr_mode,
                               base_lr=lr,
                               n_iters=num_batches,
                               n_epochs=num_epochs,
                               step=lr_decay_epoch,
                               step_factor=lr_decay,
                               target_lr=target_lr,
                               power=poly_power,
                               warmup_epochs=warmup_epochs,
                               warmup_lr=warmup_lr,
                               warmup_mode=warmup_mode)

    optimizer_params = {
        "learning_rate": lr,
        "wd": wd,
        "momentum": momentum,
        "lr_scheduler": lr_scheduler
    }
    if dtype != "float32":
        optimizer_params["multi_precision"] = True

    trainer = gluon.Trainer(params=net.collect_params(),
                            optimizer=optimizer_name,
                            optimizer_params=optimizer_params)

    if (state_file_path is not None
        ) and state_file_path and os.path.exists(state_file_path):
        logging.info("Loading trainer states: {}".format(state_file_path))
        trainer.load_states(state_file_path)
        if trainer._optimizer.wd != wd:
            trainer._optimizer.wd = wd
            logging.info("Reset the weight decay: {}".format(wd))
        # lr_scheduler = trainer._optimizer.lr_scheduler
        trainer._optimizer.lr_scheduler = lr_scheduler

    return trainer, lr_scheduler