def test_cos_piecewise(self): schedule_fn = self.variant( schedule.piecewise_interpolate_schedule('cosine', 400., { 5: 1.2, 3: 0.6, 7: 1. })) generated_vals = [schedule_fn(step) for step in range(9)] expected_vals = [400., 360., 280., 240., 264., 288., 288., 288., 288.] np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3)
def test_linear_piecewise(self): schedule_fn = self.variant( schedule.piecewise_interpolate_schedule('linear', 200., { 5: 1.5, 10: 0.25 })) generated_vals = [schedule_fn(step) for step in range(13)] expected_vals = [ 200., 220., 240., 260., 280., 300., 255., 210., 165., 120., 75., 75., 75. ] np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3)
def test_invalid_type(self): with self.assertRaises(ValueError): schedule.piecewise_interpolate_schedule('linar', 13.) with self.assertRaises(ValueError): schedule.piecewise_interpolate_schedule('', 13., {5: 3.}) with self.assertRaises(ValueError): schedule.piecewise_interpolate_schedule(None, 13., {}) # pytype: disable=wrong-arg-types
def test_invalid_scale(self): with self.assertRaises(ValueError): schedule.piecewise_interpolate_schedule('linear', 13., {5: -3})
def test_no_dict(self): schedule_fn = self.variant( schedule.piecewise_interpolate_schedule('cosine', 17.)) generated_vals = [schedule_fn(step) for step in range(3)] expected_vals = [17., 17., 17.] np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3)
def test_empty_dict(self): schedule_fn = self.variant( schedule.piecewise_interpolate_schedule('linear', 13., {})) generated_vals = [schedule_fn(step) for step in range(5)] expected_vals = [13., 13., 13., 13., 13.] np.testing.assert_allclose(generated_vals, expected_vals, atol=1e-3)