def build_learning_rate(params: base_configs.LearningRateConfig, batch_size: int = None, train_epochs: int = None, train_steps: int = None): """Build the learning rate given the provided configuration.""" decay_type = params.name base_lr = params.initial_lr decay_rate = params.decay_rate if params.decay_epochs is not None: decay_steps = params.decay_epochs * train_steps else: decay_steps = 0 if params.warmup_epochs is not None: warmup_steps = params.warmup_epochs * train_steps else: warmup_steps = 0 lr_multiplier = params.scale_by_batch_size if lr_multiplier and lr_multiplier > 0: # Scale the learning rate based on the batch size and a multiplier base_lr *= lr_multiplier * batch_size logging.info('Scaling the learning rate based on the batch size ' 'multiplier. New base_lr: %f', base_lr) if decay_type == 'exponential': logging.info('Using exponential learning rate with: ' 'initial_learning_rate: %f, decay_steps: %d, ' 'decay_rate: %f', base_lr, decay_steps, decay_rate) lr = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=base_lr, decay_steps=decay_steps, decay_rate=decay_rate, staircase=params.staircase) elif decay_type == 'stepwise': steps_per_epoch = params.examples_per_epoch // batch_size boundaries = [boundary * steps_per_epoch for boundary in params.boundaries] multipliers = [batch_size * multiplier for multiplier in params.multipliers] logging.info('Using stepwise learning rate. Parameters: ' 'boundaries: %s, values: %s', boundaries, multipliers) lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=boundaries, values=multipliers) elif decay_type == 'cosine_with_warmup': lr = learning_rate.CosineDecayWithWarmup( batch_size=batch_size, total_steps=train_epochs * train_steps, warmup_steps=warmup_steps) if warmup_steps > 0: if decay_type not in ['cosine_with_warmup']: logging.info('Applying %d warmup steps to the learning rate', warmup_steps) lr = learning_rate.WarmupDecaySchedule(lr, warmup_steps, warmup_lr=base_lr) return lr
def test_cosine_decay_with_warmup(self): """Basic computational test for cosine decay with warmup.""" expected_lrs = [0.0, 0.1, 0.05, 0.0] lr = learning_rate.CosineDecayWithWarmup(batch_size=256, total_steps=3, warmup_steps=1) for step in [0, 1, 2, 3]: self.assertAllClose(lr(step), expected_lrs[step])