def testDefaultDecay(self):
   num_training_steps = 1000
   initial_lr = 1.0
   for step in range(0, 1500, 250):
     decayed_lr = learning_rate_decay.linear_cosine_decay(
         initial_lr, step, num_training_steps)
     expected = self.np_linear_cosine_decay(step, num_training_steps)
     self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
Пример #2
0
 def testDefaultDecay(self):
     num_training_steps = 1000
     initial_lr = 1.0
     for step in range(0, 1500, 250):
         decayed_lr = learning_rate_decay.linear_cosine_decay(
             initial_lr, step, num_training_steps)
         expected = self.np_linear_cosine_decay(step, num_training_steps)
         self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
 def testNonDefaultDecay(self):
     num_training_steps = 1000
     initial_lr = 1.0
     for step in range(0, 1500, 250):
         with self.test_session():
             decayed_lr = learning_rate_decay.linear_cosine_decay(
                 initial_lr,
                 step,
                 num_training_steps,
                 alpha=0.1,
                 beta=1e-4,
                 num_periods=5)
             expected = self.np_linear_cosine_decay(step,
                                                    num_training_steps,
                                                    alpha=0.1,
                                                    beta=1e-4,
                                                    num_periods=5)
             self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def testNonDefaultDecay(self):
   num_training_steps = 1000
   initial_lr = 1.0
   for step in range(0, 1500, 250):
     with self.test_session():
       decayed_lr = learning_rate_decay.linear_cosine_decay(
           initial_lr,
           step,
           num_training_steps,
           alpha=0.1,
           beta=1e-4,
           num_periods=5)
       expected = self.np_linear_cosine_decay(
           step,
           num_training_steps,
           alpha=0.1,
           beta=1e-4,
           num_periods=5)
       self.assertAllClose(decayed_lr.eval(), expected, 1e-6)