def test_exp_schedule(backend): """ Test exponential learning rate schedule """ lr_init = 0.1 decay = 0.01 sch = ExpSchedule(decay) for epoch in range(10): lr = sch.get_learning_rate(learning_rate=lr_init, epoch=epoch) assert np.allclose(lr, lr_init / (1. + decay * epoch))
def get_learning_rate(self, learning_rate, epoch): return ExpSchedule.get_learning_rate(self, learning_rate, round_to(epoch, self.epoch_freq))