def test_cosine_decay_multiplier_end_cycle(self): learning_rate = cosine_decay(10, cycle_length=10, init_lr=0.01, min_lr=0.0, cycle_multiplier=2) self.assertEqual(learning_rate, 0.0)
def test_lambda_simple(self): tables = {} ret_ref = Flag() epochs = 8 resp = _parse_lambda_fallback( lambda step: cosine_decay(step, cycle_length=3750, init_lr=1e-3 + 1 if epochs > 2 else 1e-4), tables, ret_ref) self.assertIsInstance( resp, dict, "_parse_lambda_fallback should return a dictionary") self.assertEqual( {}, tables, "_parse_lambda_fallback should not have generated any tables for this lambda" ) self.assertIn('function', resp, "response should contain a function summary") self.assertEqual( r"cosine\_decay(step, cycle\_length=3750, init\_lr=1e{-}3 + 1 if epochs > 2 else 1e{-}4)", resp['function']) self.assertIn('kwargs', resp, "response should contain kwargs") self.assertIsInstance(resp['kwargs'], dict, "kwargs should be a dictionary") self.assertDictEqual({NoEscape('epochs'): NoEscape(r'\seqsplit{8}')}, resp['kwargs'])
def test_cosine_decay(self): learning_rate = cosine_decay(time=5, cycle_length=10, init_lr=0.01, min_lr=0.0) self.assertEqual(learning_rate, 0.005)
def test_cosine_decay_cycle(self): learning_rate = cosine_decay(time=1001, cycle_length=1000, init_lr=0.01, min_lr=0.0) self.assertTrue(math.isclose(learning_rate, 0.01, rel_tol=1e-3))