def test_serialization(): optimizer = RectifiedAdam( lr=1e-3, total_steps=10000, warmup_proportion=0.1, min_lr=1e-5, ) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config()
def test_scheduler_serialization(): lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay( 1e-3, 50, 0.5) wd_scheduler = tf.keras.optimizers.schedules.InverseTimeDecay( 2e-3, 25, 0.25) optimizer = RectifiedAdam(learning_rate=lr_scheduler, weight_decay=wd_scheduler) config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() assert new_optimizer.get_config()["learning_rate"] == { "class_name": "ExponentialDecay", "config": lr_scheduler.get_config(), } assert new_optimizer.get_config()["weight_decay"] == { "class_name": "InverseTimeDecay", "config": wd_scheduler.get_config(), }
def test_get_config(self): opt = RectifiedAdam(lr=1e-4) config = opt.get_config() self.assertEqual(config["learning_rate"], 1e-4) self.assertEqual(config["total_steps"], 0)
def test_get_config(): opt = RectifiedAdam(lr=1e-4) config = opt.get_config() assert config["learning_rate"] == 1e-4 assert config["total_steps"] == 0