Ejemplo n.º 1
0
 def test_get_config(self):
     opt = Lookahead("adam", sync_period=10, slow_step_size=0.4)
     opt = tf.keras.optimizers.deserialize(
         tf.keras.optimizers.serialize(opt))
     config = opt.get_config()
     self.assertEqual(config["sync_period"], 10)
     self.assertEqual(config["slow_step_size"], 0.4)
Ejemplo n.º 2
0
 def test_get_config(self):
     self.skipTest('Wait #33614 to be fixed')
     opt = Lookahead('adam', sync_period=10, slow_step_size=0.4)
     opt = tf.keras.optimizers.deserialize(
         tf.keras.optimizers.serialize(opt))
     config = opt.get_config()
     self.assertEqual(config['sync_period'], 10)
     self.assertEqual(config['slow_step_size'], 0.4)
Ejemplo n.º 3
0
def test_get_config():
    opt = Lookahead("adam", sync_period=10, slow_step_size=0.4)
    opt = tf.keras.optimizers.deserialize(tf.keras.optimizers.serialize(opt))
    config = opt.get_config()
    assert config["sync_period"] == 10
    assert config["slow_step_size"] == 0.4
Ejemplo n.º 4
0
def test_serialization():
    optimizer = Lookahead("adam", sync_period=10, slow_step_size=0.4)
    config = tf.keras.optimizers.serialize(optimizer)
    new_optimizer = tf.keras.optimizers.deserialize(config)
    assert new_optimizer.get_config() == optimizer.get_config()