def test_zero_steps_schedule(self, variant): # Get schedule function. num_steps = 0 initial_value = 10. final_value = 20. for num_steps in [-1, 0]: schedule_fn = schedules.polynomial_schedule( initial_value, final_value, 1, num_steps) schedule_fn = variant(schedule_fn) for count in range(15): np.testing.assert_allclose(schedule_fn(count), initial_value)
def test_linear(self, variant): """Check linear schedule.""" # Get schedule function. schedule_fn = schedules.polynomial_schedule(10., 20., 1, 10) schedule_fn = variant(schedule_fn) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(15): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_vals = np.array(list(range(10, 20)) + [20] * 5, dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_nonlinear(self, variant): """Check non-linear (quadratic) schedule.""" # Get schedule function. schedule_fn = schedules.polynomial_schedule(25., 10., 2, 10) schedule_fn = variant(schedule_fn) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(15): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_vals = np.array( [10. + 15. * (1. - n / 10)**2 for n in range(10)] + [10] * 5, dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_with_decay_begin(self, variant): """Check quadratic schedule with non-zero schedule begin.""" # Get schedule function. schedule_fn = schedules.polynomial_schedule(30., 10., 2, 10, transition_begin=4) schedule_fn = variant(schedule_fn) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(20): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_vals = np.array( [30.] * 4 + [10. + 20. * (1. - n / 10)**2 for n in range(10)] + [10] * 6, dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)