Ejemplo n.º 1
0
def test_serialization():
    optimizer = weight_decay_optimizers.AdamW(
        learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
    )
    config = tf.keras.optimizers.serialize(optimizer)
    new_optimizer = tf.keras.optimizers.deserialize(config)
    assert new_optimizer.get_config() == optimizer.get_config()
Ejemplo n.º 2
0
def test_exclude_weight_decay_adamw():
    optimizer = weight_decay_optimizers.AdamW(
        learning_rate=1e-4, weight_decay=1e-4, exclude_from_weight_decay=["var1"]
    )
    assert optimizer._do_use_weight_decay(tf.Variable([], name="var0"))
    assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1"))
    assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
def test_serialization_with_wd_schedule():
    wd_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        WEIGHT_DECAY, decay_steps=10, decay_rate=0.9)
    optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4,
                                              weight_decay=wd_schedule)
    config = tf.keras.optimizers.serialize(optimizer)
    new_optimizer = tf.keras.optimizers.deserialize(config)
    assert new_optimizer.get_config() == optimizer.get_config()
Ejemplo n.º 4
0
def test_keras_fit():
    """Check if calling model.fit works."""
    model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)])
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-4)
    model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
    x, y = np.random.uniform(size=(2, 4, 1))
    model.fit(x, y, epochs=1)
def test_keras_fit_with_schedule():
    """Check if calling model.fit works with wd schedule."""
    model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)])
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    wd_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        WEIGHT_DECAY, decay_steps=10, decay_rate=0.9)
    optimizer = weight_decay_optimizers.AdamW(learning_rate=1e-4,
                                              weight_decay=wd_schedule)
    model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
    x, y = np.random.uniform(size=(2, 4, 1))
    model.fit(x, y, epochs=1)
def test_exclude_weight_decay_adamw():
    optimizer = weight_decay_optimizers.AdamW(
        learning_rate=1e-4,
        weight_decay=1e-4,
        exclude_from_weight_decay=["var1"])
    var0 = tf.Variable([], name="var0")
    var1 = tf.Variable([], name="var1")
    var1_weight = tf.Variable([], name="var1_weight")

    optimizer._set_decay_var_list([var0, var1, var1_weight])
    assert optimizer._do_use_weight_decay(var0)
    assert not optimizer._do_use_weight_decay(var1)
    assert not optimizer._do_use_weight_decay(var1_weight)