def test_exponential(self, staircase, transition_begin): """Checks non-linear (quadratic) schedule.""" # Get schedule function. init_value = 1. num_steps = 15 transition_steps = 2 decay_rate = 2. schedule_fn = self.variant( schedule.exponential_decay(init_value=init_value, transition_steps=transition_steps, decay_rate=decay_rate, transition_begin=transition_begin, staircase=staircase)) # Test that generated values equal the expected schedule values. def _staircased(count): p = count / transition_steps if staircase: p = np.floor(p) return p generated_vals = [] for count in range(num_steps + transition_begin): generated_vals.append(schedule_fn(count)) expected_vals = np.array([init_value] * transition_begin + [ init_value * np.power(decay_rate, _staircased(count)) for count in range(num_steps) ], dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_nonvalid_decay_rate(self, staircase): """Checks nonvalid decay steps results in a constant schedule.""" init_value = 1. schedule_fn = self.variant( schedule.exponential_decay(init_value=init_value, transition_steps=2, decay_rate=0., staircase=staircase)) for count in range(15): np.testing.assert_allclose(schedule_fn(count), init_value)
def test_constant_schedule(self, staircase): """Checks constant schedule for exponential decay schedule.""" num_steps = 15 # Get schedule function. init_value = 1. schedule_fn = self.variant( schedule.exponential_decay(init_value=init_value, transition_steps=num_steps, decay_rate=1., staircase=staircase)) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(num_steps): generated_vals.append(schedule_fn(count)) expected_vals = np.array([init_value] * num_steps, dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_immutable_count(self): """Checks constant schedule for exponential decay schedule.""" num_steps = 5 # Get schedule function. init_value = 32. schedule_fn = self.variant( schedule.exponential_decay(init_value=init_value, transition_steps=1, decay_rate=0.5)) # Test that generated values equal the expected schedule values. generated_vals = [] for count in range(num_steps): # Jax arrays are read-only in ChexVariantType.WITHOUT_DEVICE. immutable_count = jnp.array(count, dtype=jnp.float32) generated_vals.append(schedule_fn(immutable_count)) expected_vals = np.array([32, 16, 8, 4, 2], dtype=np.float32) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)
def test_end_value_with_staircase(self, decay_rate, end_value, staircase): # Get schedule function. init_value = 1. num_steps = 11 transition_steps = 2 transition_begin = 3 schedule_fn = self.variant( schedule.exponential_decay(init_value=init_value, transition_steps=transition_steps, decay_rate=decay_rate, transition_begin=transition_begin, staircase=staircase, end_value=end_value)) # Test that generated values equal the expected schedule values. def _staircased(count): p = count / transition_steps if staircase: p = np.floor(p) return p generated_vals = [] for count in range(num_steps + transition_begin): generated_vals.append(schedule_fn(count)) expected_vals = np.array([init_value] * transition_begin + [ init_value * np.power(decay_rate, _staircased(count)) for count in range(num_steps) ], dtype=np.float32) if decay_rate < 1.0: expected_vals = np.maximum(expected_vals, end_value) else: expected_vals = np.minimum(expected_vals, end_value) np.testing.assert_allclose(expected_vals, np.array(generated_vals), atol=1e-3)