def testHalfWayWithEnd(self): step = 5 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = (lr + end_lr) * 0.5 self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testBeyondEnd(self): step = 15 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = end_lr self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testEnd(self): with self.test_session(): step = 10 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = end_lr self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testHalfWay(self): with self.test_session(): step = 5 lr = 0.05 end_lr = 0.0 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = lr * 0.5 self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testBeginWithCycle(self): lr = 0.001 decay_steps = 10 step = 0 decayed_lr = learning_rate_decay.polynomial_decay( lr, step, decay_steps, cycle=True) expected = lr self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testBeyondEndWithCycle(self): step = 15 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay( lr, step, 10, end_lr, cycle=True) expected = (lr - end_lr) * 0.25 + end_lr self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testBeyondEnd(self): with self.test_session(): step = 15 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = end_lr self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testHalfWayWithEnd(self): with self.test_session(): step = 5 lr = 0.05 end_lr = 0.001 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr) expected = (lr + end_lr) * 0.5 self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testHalfWay(self): step = 5 lr = 0.05 end_lr = 0.0 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay( lr, step, 10, end_lr, power=power) expected = lr * 0.5**power self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testBeyondEnd(self): step = 15 lr = 0.05 end_lr = 0.001 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay( lr, step, 10, end_lr, power=power) expected = end_lr self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testHalfWayWithEnd(self): step = 5 lr = 0.05 end_lr = 0.001 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay( lr, step, 10, end_lr, power=power) expected = (lr - end_lr) * 0.5**power + end_lr self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testBeyondEndWithCycle(self): with self.test_session(): step = 15 lr = 0.05 end_lr = 0.001 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, power=power, cycle=True) expected = (lr - end_lr) * 0.25 ** power + end_lr self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testHalfWay(self): with self.test_session(): step = 5 lr = 0.05 end_lr = 0.0 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, power=power) expected = lr * 0.5 ** power self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def testHalfWayWithEnd(self): with self.test_session(): step = 5 lr = 0.05 end_lr = 0.001 power = 0.5 decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr, power=power) expected = (lr - end_lr) * 0.5 ** power + end_lr self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
def apply_lr_decay(cfg, global_step): # Learning rate schedule if cfg.lr_decay is None: lr = cfg.lr elif cfg.lr_decay == 'exp': lr = exponential_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'piecewise': lr = piecewise_constant(global_step, cfg.lr_boundaries, cfg.lr_values) elif cfg.lr_decay == 'polynomial': lr = polynomial_decay(cfg.lr, global_step, cfg.decay_steps, end_learning_rate=cfg.end_lr, power=cfg.power, cycle=cfg.staircase) elif cfg.lr_decay == 'natural_exp': lr = natural_exp_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'inverse_time': lr = inverse_time_decay(cfg.lr, global_step, cfg.decay_steps, cfg.decay_rate, staircase=cfg.staircase) elif cfg.lr_decay == 'STN': epoch = tf.cast(global_step / cfg.decay_steps, tf.int32) lr = cfg.lr * tf.pow(0.5, tf.cast(epoch / 50, cfg._FLOATX)) else: raise NotImplementedError() return lr