def testDefaultInitValueWithExponentialDecay(self): """Test the decay_over_time function with default init value.""" decay_step = 10 global_step = 5 decay_rate = 0.96 expected_value = decay_rate**(global_step / decay_step) config = configs.DecayConfig(decay_step, decay_rate) decayed_value = decay_over_time_wrapper(config)(global_step) self.assertAllClose(decayed_value, expected_value, 1e-6)
def testExponentialDecay(self): """Test the decay_over_time function with exponential decay applied.""" init_value = 0.1 decay_step = 10 global_step = 5 decay_rate = 0.96 expected_value = init_value * decay_rate**(global_step / decay_step) config = configs.DecayConfig(decay_step, decay_rate) decayed_value = decay_over_time_wrapper(config)(global_step, init_value) self.assertAllClose(decayed_value, expected_value, 1e-6)
def testBoundedDecay(self): """Test the decay_over_time function with bounded decay value.""" init_value = 0.1 min_value = 0.99 decay_step = 10 global_step = 5 decay_rate = 0.96 bounded_config = configs.DecayConfig(decay_step, decay_rate, min_value) bounded_value = decay_over_time_wrapper(bounded_config)(global_step, init_value) self.assertAllClose(bounded_value, min_value, 1e-6)
def testInverseTimeDecay(self): """Test the decay_over_time function with inverse time decay applied.""" init_value = 0.1 decay_step = 10 global_step = 5 decay_rate = 0.9 expected_value = init_value / (1 + decay_rate * global_step / decay_step) config = configs.DecayConfig( decay_step, decay_rate, decay_type=configs.DecayType.INVERSE_TIME_DECAY) decayed_value = decay_over_time_wrapper(config)(global_step, init_value) self.assertAllClose(decayed_value, expected_value, 1e-6)
def testNaturalExpDecay(self): """Test the decay_over_time function with natural exp decay applied.""" init_value = 0.1 decay_step = 10 global_step = 5 decay_rate = 0.9 expected_value = init_value * math.exp( -decay_rate * global_step / decay_step) config = configs.DecayConfig( decay_step, decay_rate, decay_type=configs.DecayType.NATURAL_EXP_DECAY) decayed_value = decay_over_time_wrapper(config)(global_step, init_value) self.assertAllClose(decayed_value, expected_value, 1e-6)