def get_learning_rate(lr_config, global_step): """ Instantiate a learning rate operation given a configuration with learning rate, name, and parameters. :param lr_config: learning rate configuration :param global_step: global step `Tensor` :return: learning rate operation """ lr = lr_config.rate name = lr_config.name if "exponential_decay" == name: decay = exponential_decay(lr, global_step, **lr_config.params) elif "inverse_time_decay" == name: decay = inverse_time_decay(lr, global_step, **lr_config.params) elif "vaswani" == name: decay = _transformer_learning_rate(lr_config, global_step) elif "bert" == name: decay = _bert_learning_rate(lr_config, global_step) elif "clr" == name: decay = cyclic_learning_rate(global_step, learning_rate=lr_config.rate, max_lr=lr_config.params.get( 'max_lr', 0.1), step_size=lr_config.steps_per_epoch * lr_config.params.get('step_size', 4)) else: raise ValueError("Unknown learning rate schedule: {}".format(name)) return decay
def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 step = resource_variable_ops.ResourceVariable(0) decayed_lr = learning_rate_decay.inverse_time_decay( initial_lr, step, k, decay_rate, staircase=True) self.evaluate(variables.global_variables_initializer()) for i in range(k + 1): expected = initial_lr / (1 + decay_rate * (i // k)) self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) self.evaluate(step.assign_add(1))
def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 step = resource_variable_ops.ResourceVariable(0) decayed_lr = learning_rate_decay.inverse_time_decay( initial_lr, step, k, decay_rate) self.evaluate(variables.global_variables_initializer()) for i in range(k + 1): expected = initial_lr / (1 + i / k * decay_rate) self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6) self.evaluate(step.assign_add(1))
def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 step = state_ops.variable_op([], dtypes.int32) assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.inverse_time_decay( initial_lr, step, k, decay_rate) with self.test_session(): assign_step.op.run() for i in range(k + 1): expected = initial_lr / (1 + i / k * decay_rate) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) increment_step.op.run()
def testDecay(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 step = state_ops.variable_op([], dtypes.int32) assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k, decay_rate) with self.test_session(): assign_step.op.run() for i in range(k+1): expected = initial_lr / (1 + i / k * decay_rate) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) increment_step.op.run()
def testStaircase(self): initial_lr = 0.1 k = 10 decay_rate = 0.96 step = gen_state_ops._variable(shape=[], dtype=dtypes.int32, name="step", container="", shared_name="") assign_step = state_ops.assign(step, 0) increment_step = state_ops.assign_add(step, 1) decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k, decay_rate, staircase=True) with self.test_session(): assign_step.op.run() for i in range(k+1): expected = initial_lr / (1 + decay_rate * (i // k)) self.assertAllClose(decayed_lr.eval(), expected, 1e-6) increment_step.op.run()
def apply_lr_decay(cfg, global_step): # Learning rate schedule if cfg.lr_decay is None: lr = cfg.lr elif cfg.lr_decay == 'exp': lr = exponential_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'piecewise': lr = piecewise_constant(global_step, cfg.lr_boundaries, cfg.lr_values) elif cfg.lr_decay == 'polynomial': lr = polynomial_decay(cfg.lr, global_step, cfg.decay_steps, end_learning_rate=cfg.end_lr, power=cfg.power, cycle=cfg.staircase) elif cfg.lr_decay == 'natural_exp': lr = natural_exp_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'inverse_time': lr = inverse_time_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'STN': epoch = tf.cast(global_step / cfg.decay_steps, tf.int32) lr = cfg.lr * tf.pow(0.5, tf.cast(epoch / 50, cfg._FLOATX)) else: raise NotImplementedError() return lr