コード例 #1
0
ファイル: schedule_test.py プロジェクト: stjordanis/optax
    def test_exponential(self, staircase, transition_begin):
        """Checks non-linear (quadratic) schedule."""
        # Get schedule function.
        init_value = 1.
        num_steps = 15
        transition_steps = 2
        decay_rate = 2.
        schedule_fn = self.variant(
            schedule.exponential_decay(init_value=init_value,
                                       transition_steps=transition_steps,
                                       decay_rate=decay_rate,
                                       transition_begin=transition_begin,
                                       staircase=staircase))

        # Test that generated values equal the expected schedule values.
        def _staircased(count):
            p = count / transition_steps
            if staircase:
                p = np.floor(p)
            return p

        generated_vals = []
        for count in range(num_steps + transition_begin):
            generated_vals.append(schedule_fn(count))
        expected_vals = np.array([init_value] * transition_begin + [
            init_value * np.power(decay_rate, _staircased(count))
            for count in range(num_steps)
        ],
                                 dtype=np.float32)
        np.testing.assert_allclose(expected_vals,
                                   np.array(generated_vals),
                                   atol=1e-3)
コード例 #2
0
ファイル: schedule_test.py プロジェクト: stjordanis/optax
 def test_nonvalid_decay_rate(self, staircase):
     """Checks nonvalid decay steps results in a constant schedule."""
     init_value = 1.
     schedule_fn = self.variant(
         schedule.exponential_decay(init_value=init_value,
                                    transition_steps=2,
                                    decay_rate=0.,
                                    staircase=staircase))
     for count in range(15):
         np.testing.assert_allclose(schedule_fn(count), init_value)
コード例 #3
0
ファイル: schedule_test.py プロジェクト: stjordanis/optax
 def test_constant_schedule(self, staircase):
     """Checks constant schedule for exponential decay schedule."""
     num_steps = 15
     # Get schedule function.
     init_value = 1.
     schedule_fn = self.variant(
         schedule.exponential_decay(init_value=init_value,
                                    transition_steps=num_steps,
                                    decay_rate=1.,
                                    staircase=staircase))
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(num_steps):
         generated_vals.append(schedule_fn(count))
     expected_vals = np.array([init_value] * num_steps, dtype=np.float32)
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)
コード例 #4
0
 def test_immutable_count(self):
     """Checks constant schedule for exponential decay schedule."""
     num_steps = 5
     # Get schedule function.
     init_value = 32.
     schedule_fn = self.variant(
         schedule.exponential_decay(init_value=init_value,
                                    transition_steps=1,
                                    decay_rate=0.5))
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(num_steps):
         # Jax arrays are read-only in ChexVariantType.WITHOUT_DEVICE.
         immutable_count = jnp.array(count, dtype=jnp.float32)
         generated_vals.append(schedule_fn(immutable_count))
     expected_vals = np.array([32, 16, 8, 4, 2], dtype=np.float32)
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)
コード例 #5
0
    def test_end_value_with_staircase(self, decay_rate, end_value, staircase):
        # Get schedule function.
        init_value = 1.
        num_steps = 11
        transition_steps = 2
        transition_begin = 3
        schedule_fn = self.variant(
            schedule.exponential_decay(init_value=init_value,
                                       transition_steps=transition_steps,
                                       decay_rate=decay_rate,
                                       transition_begin=transition_begin,
                                       staircase=staircase,
                                       end_value=end_value))

        # Test that generated values equal the expected schedule values.
        def _staircased(count):
            p = count / transition_steps
            if staircase:
                p = np.floor(p)
            return p

        generated_vals = []
        for count in range(num_steps + transition_begin):
            generated_vals.append(schedule_fn(count))
        expected_vals = np.array([init_value] * transition_begin + [
            init_value * np.power(decay_rate, _staircased(count))
            for count in range(num_steps)
        ],
                                 dtype=np.float32)

        if decay_rate < 1.0:
            expected_vals = np.maximum(expected_vals, end_value)
        else:
            expected_vals = np.minimum(expected_vals, end_value)

        np.testing.assert_allclose(expected_vals,
                                   np.array(generated_vals),
                                   atol=1e-3)