Example #1
0
def _create_learning_rate_scheduler(lrs_config, optimizer):
    """Create optimizer learning rate scheduler based on config.
  
    Args:
      lrs_config: A learning rate scheduler configparser object.
      optimizer: An associated optimizer
  
    Returns:
      A learning rate scheduler.
  
    Raises:
      ValueError: when using an unsupported input data type.
    """
    lrs_type = lrs_config['type']
    total_step = int(lrs_config['total_step'])
    lr_scheduler = None

    if lrs_type == 'one_cycle':
        lr_max = float(lrs_config['lr_max'])
        moms = eval(lrs_config['moms'])
        div_factor = float(lrs_config['div_factor'])
        pct_start = float(lrs_config['pct_start'])
        lr_scheduler = lsf.OneCycle(optimizer, total_step, lr_max, moms,
                                    div_factor, pct_start)

    if lr_scheduler is None:
        raise ValueError('Learning_rate %s not supported.' % lrs_type)

    return lr_scheduler
def _create_learning_rate_scheduler(learning_rate_config, optimizer,
                                    total_step):
    """Create optimizer learning rate scheduler based on config.

    Args:
      learning_rate_config: A LearningRate proto message.

    Returns:
      A learning rate.

    Raises:
      ValueError: when using an unsupported input data type.
    """
    lr_scheduler = None
    learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
    if learning_rate_type == 'multi_phase':
        config = learning_rate_config.multi_phase
        lr_phases = []
        mom_phases = []
        for phase_cfg in config.phases:
            lr_phases.append((phase_cfg.start, phase_cfg.lambda_func))
            mom_phases.append(
                (phase_cfg.start, phase_cfg.momentum_lambda_func))
        lr_scheduler = lsf.LRSchedulerStep(optimizer, total_step, lr_phases,
                                           mom_phases)

    if learning_rate_type == 'one_cycle':
        config = learning_rate_config.one_cycle
        lr_scheduler = lsf.OneCycle(optimizer, total_step, config.lr_max,
                                    list(config.moms), config.div_factor,
                                    config.pct_start)
    if learning_rate_type == 'exponential_decay':
        config = learning_rate_config.exponential_decay
        lr_scheduler = lsf.ExponentialDecay(optimizer, total_step,
                                            config.initial_learning_rate,
                                            config.decay_length,
                                            config.decay_factor,
                                            config.staircase)
    if learning_rate_type == 'manual_stepping':
        config = learning_rate_config.manual_stepping
        lr_scheduler = lsf.ManualStepping(optimizer, total_step,
                                          list(config.boundaries),
                                          list(config.rates))

    if lr_scheduler is None:
        raise ValueError('Learning_rate %s not supported.' %
                         learning_rate_type)

    return lr_scheduler