def _create_learning_rate_scheduler(lrs_config, optimizer): """Create optimizer learning rate scheduler based on config. Args: lrs_config: A learning rate scheduler configparser object. optimizer: An associated optimizer Returns: A learning rate scheduler. Raises: ValueError: when using an unsupported input data type. """ lrs_type = lrs_config['type'] total_step = int(lrs_config['total_step']) lr_scheduler = None if lrs_type == 'one_cycle': lr_max = float(lrs_config['lr_max']) moms = eval(lrs_config['moms']) div_factor = float(lrs_config['div_factor']) pct_start = float(lrs_config['pct_start']) lr_scheduler = lsf.OneCycle(optimizer, total_step, lr_max, moms, div_factor, pct_start) if lr_scheduler is None: raise ValueError('Learning_rate %s not supported.' % lrs_type) return lr_scheduler
def _create_learning_rate_scheduler(learning_rate_config, optimizer, total_step): """Create optimizer learning rate scheduler based on config. Args: learning_rate_config: A LearningRate proto message. Returns: A learning rate. Raises: ValueError: when using an unsupported input data type. """ lr_scheduler = None learning_rate_type = learning_rate_config.WhichOneof('learning_rate') if learning_rate_type == 'multi_phase': config = learning_rate_config.multi_phase lr_phases = [] mom_phases = [] for phase_cfg in config.phases: lr_phases.append((phase_cfg.start, phase_cfg.lambda_func)) mom_phases.append( (phase_cfg.start, phase_cfg.momentum_lambda_func)) lr_scheduler = lsf.LRSchedulerStep(optimizer, total_step, lr_phases, mom_phases) if learning_rate_type == 'one_cycle': config = learning_rate_config.one_cycle lr_scheduler = lsf.OneCycle(optimizer, total_step, config.lr_max, list(config.moms), config.div_factor, config.pct_start) if learning_rate_type == 'exponential_decay': config = learning_rate_config.exponential_decay lr_scheduler = lsf.ExponentialDecay(optimizer, total_step, config.initial_learning_rate, config.decay_length, config.decay_factor, config.staircase) if learning_rate_type == 'manual_stepping': config = learning_rate_config.manual_stepping lr_scheduler = lsf.ManualStepping(optimizer, total_step, list(config.boundaries), list(config.rates)) if lr_scheduler is None: raise ValueError('Learning_rate %s not supported.' % learning_rate_type) return lr_scheduler