def test_decay_count_smaller_count(self): """Check cosine schedule decay for the entire training schedule.""" initial_value = 0.1 schedule_fn = self.variant( schedule.cosine_decay_schedule(initial_value, 10, 0.0)) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(10): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_multipliers = np.array(0.5 + 0.5 * np.cos(np.pi * np.array( [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]))) np.testing.assert_allclose(initial_value * expected_multipliers, np.array(generated_vals), atol=1e-3)
def test_decay_count_greater_count_with_alpha(self): """Check cosine schedule decay for a part of the training schedule.""" # Get schedule function. initial_value = 0.1 schedule_fn = self.variant( schedule.cosine_decay_schedule(initial_value, 5, 0.1)) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(12): # Compute next value. generated_vals.append(schedule_fn(count)) # Test output. expected_multipliers = np.array(0.5 + 0.5 * np.cos(np.pi * np.array( [0.0, 0.2, 0.4, 0.6, 0.8, 1., 1., 1., 1., 1., 1., 1.]))) expected_multipliers = 0.9 * expected_multipliers + 0.1 np.testing.assert_allclose(initial_value * expected_multipliers, np.array(generated_vals), atol=1e-3)