Esempio n. 1
0
def test_serialization():
    optimizer = RectifiedAdam(
        lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5,
    )
    config = tf.keras.optimizers.serialize(optimizer)
    new_optimizer = tf.keras.optimizers.deserialize(config)
    assert new_optimizer.get_config() == optimizer.get_config()
Esempio n. 2
0
def test_scheduler_serialization():
    lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
        1e-3, 50, 0.5)
    wd_scheduler = tf.keras.optimizers.schedules.InverseTimeDecay(
        2e-3, 25, 0.25)

    optimizer = RectifiedAdam(learning_rate=lr_scheduler,
                              weight_decay=wd_scheduler)
    config = tf.keras.optimizers.serialize(optimizer)
    new_optimizer = tf.keras.optimizers.deserialize(config)
    assert new_optimizer.get_config() == optimizer.get_config()

    assert new_optimizer.get_config()["learning_rate"] == {
        "class_name": "ExponentialDecay",
        "config": lr_scheduler.get_config(),
    }

    assert new_optimizer.get_config()["weight_decay"] == {
        "class_name": "InverseTimeDecay",
        "config": wd_scheduler.get_config(),
    }
Esempio n. 3
0
 def test_get_config(self):
     opt = RectifiedAdam(lr=1e-4)
     config = opt.get_config()
     self.assertEqual(config["learning_rate"], 1e-4)
     self.assertEqual(config["total_steps"], 0)
Esempio n. 4
0
def test_get_config():
    opt = RectifiedAdam(lr=1e-4)
    config = opt.get_config()
    assert config["learning_rate"] == 1e-4
    assert config["total_steps"] == 0