def test_exclude_weight_decay_sgdw(): optimizer = weight_decay_optimizers.SGDW( learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"] ) assert optimizer._do_use_weight_decay(tf.Variable([], name="var0")) assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1")) assert not optimizer._do_use_weight_decay(tf.Variable([], name="var1_weight"))
def test_weight_decay_with_piecewise_constant_decay_schedule(): model = tf.keras.models.Sequential([tf.keras.layers.Dense(2)]) loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) wd_schedule = tf.optimizers.schedules.PiecewiseConstantDecay([2], [1e-4, 1e-5]) optimizer = weight_decay_optimizers.SGDW(learning_rate=1e-2, weight_decay=wd_schedule) model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"]) x, y = np.random.uniform(size=(2, 4, 1)) model.fit(x, y, batch_size=1, epochs=1)
def test_exclude_weight_decay_sgdw(): optimizer = weight_decay_optimizers.SGDW( learning_rate=0.01, weight_decay=1e-4, exclude_from_weight_decay=["var1"]) var0 = tf.Variable([], name="var0") var1 = tf.Variable([], name="var1") var1_weight = tf.Variable([], name="var1_weight") optimizer._set_decay_var_list([var0, var1, var1_weight]) assert optimizer._do_use_weight_decay(var0) assert not optimizer._do_use_weight_decay(var1) assert not optimizer._do_use_weight_decay(var1_weight)