コード例 #1
0
    def test_zero_steps_schedule(self, variant):
        # Get schedule function.
        num_steps = 0
        initial_value = 10.
        final_value = 20.

        for num_steps in [-1, 0]:
            schedule_fn = schedules.polynomial_schedule(
                initial_value, final_value, 1, num_steps)
            schedule_fn = variant(schedule_fn)
            for count in range(15):
                np.testing.assert_allclose(schedule_fn(count), initial_value)
コード例 #2
0
 def test_linear(self, variant):
     """Check linear schedule."""
     # Get schedule function.
     schedule_fn = schedules.polynomial_schedule(10., 20., 1, 10)
     schedule_fn = variant(schedule_fn)
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(15):
         # Compute next value.
         generated_vals.append(schedule_fn(count))
     # Test output.
     expected_vals = np.array(list(range(10, 20)) + [20] * 5,
                              dtype=np.float32)
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)
コード例 #3
0
 def test_nonlinear(self, variant):
     """Check non-linear (quadratic) schedule."""
     # Get schedule function.
     schedule_fn = schedules.polynomial_schedule(25., 10., 2, 10)
     schedule_fn = variant(schedule_fn)
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(15):
         # Compute next value.
         generated_vals.append(schedule_fn(count))
     # Test output.
     expected_vals = np.array(
         [10. + 15. * (1. - n / 10)**2 for n in range(10)] + [10] * 5,
         dtype=np.float32)
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)
コード例 #4
0
 def test_with_decay_begin(self, variant):
     """Check quadratic schedule with non-zero schedule begin."""
     # Get schedule function.
     schedule_fn = schedules.polynomial_schedule(30.,
                                                 10.,
                                                 2,
                                                 10,
                                                 transition_begin=4)
     schedule_fn = variant(schedule_fn)
     # Test that generated values equal the expected schedule values.
     generated_vals = []
     for count in range(20):
         # Compute next value.
         generated_vals.append(schedule_fn(count))
     # Test output.
     expected_vals = np.array(
         [30.] * 4 + [10. + 20. * (1. - n / 10)**2
                      for n in range(10)] + [10] * 6,
         dtype=np.float32)
     np.testing.assert_allclose(expected_vals,
                                np.array(generated_vals),
                                atol=1e-3)