Exemple #1
0
  def testWeightDecay(self):
    grads, var1, var2, var3 = tf.zeros(
        ()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0)
    optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_1.apply_gradients(zip([grads], [var1]))

    optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_2.exclude_from_weight_decay([var2])
    optimizer_2.apply_gradients(zip([grads], [var2]))

    optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004)
    optimizer_3.build([var3], exclude_from_weight_decay=[var3])
    optimizer_3.apply_gradients(zip([grads], [var3]))

    self.assertEqual(var1, 1.992)
    self.assertEqual(var2, 2.0)
    self.assertEqual(var3, 2.0)
Exemple #2
0
    def testWeightDecay(self):
        grads, var1, var2, var3 = (
            tf.zeros(()),
            tf.Variable(2.0),
            tf.Variable(2.0, name="exclude"),
            tf.Variable(2.0),
        )
        optimizer_1 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004)
        optimizer_1.apply_gradients(zip([grads], [var1]))

        optimizer_2 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004)
        optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
        optimizer_2.apply_gradients(zip([grads], [var2]))

        optimizer_3 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004)
        optimizer_3.exclude_from_weight_decay(var_list=[var3])
        optimizer_3.apply_gradients(zip([grads], [var3]))

        self.assertEqual(var1, 1.992)
        self.assertEqual(var2, 2.0)
        self.assertEqual(var3, 2.0)
Exemple #3
0
    def testPassingMissingWDError(self):
        with self.assertRaises(ValueError):
            _ = adamw_new.AdamW(0.01, weight_decay=None)

        with self.assertRaisesRegex(ValueError, "Missing value of"):
            _ = adamw_new.AdamW(0.01, weight_decay=None)
Exemple #4
0
    ds_combinations.multi_worker_mirrored_2x2_gpu,
    ds_combinations.central_storage_strategy_with_two_gpus,
]

adadelta_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentaladadelta",
    lambda: adadelta_new.Adadelta(  # pylint: disable=g-long-lambda
        0.002,
        use_ema=True,
        ema_overwrite_frequency=None))
adagrad_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentaladagrad", lambda: adagrad_new.Adagrad(0.002))
adam_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentaladam", lambda: adam_new.Adam(0.002))
adamw_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentaladamw", lambda: adamw_new.AdamW(0.002, weight_decay=0.004))
rmsprop_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentalrmsprop", lambda: rmsprop_new.RMSprop(0.002))
sgd_new_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentalsgdaverage",
    lambda: sgd_new.SGD(  # pylint: disable=g-long-lambda
        0.002,
        use_ema=True,
        ema_overwrite_frequency=1))

OPTIMIZER_FN = [
    adadelta_new_fn,
    adagrad_new_fn,
    adam_new_fn,
    adamw_new_fn,
    rmsprop_new_fn,
adadelta_fn = tf.__internal__.test.combinations.NamedObject(
    "adadelta",
    lambda: adadelta.Adadelta(  # pylint: disable=g-long-lambda
        0.002,
        use_ema=True,
        ema_overwrite_frequency=None),
)
adagrad_fn = tf.__internal__.test.combinations.NamedObject(
    "adagrad", lambda: adagrad.Adagrad(0.002))
adam_fn = tf.__internal__.test.combinations.NamedObject(
    "adam", lambda: adam.Adam(0.002))
adamax_fn = tf.__internal__.test.combinations.NamedObject(
    "adamax", lambda: adamax.Adamax(0.002))
adamw_fn = tf.__internal__.test.combinations.NamedObject(
    "adamw", lambda: adamw.AdamW(0.002, weight_decay=0.004))
ftrl_fn = tf.__internal__.test.combinations.NamedObject(
    "ftrl", lambda: ftrl.Ftrl(0.002))
nadam_fn = tf.__internal__.test.combinations.NamedObject(
    "experimentnadam", lambda: nadam.Nadam(0.002))
rmsprop_fn = tf.__internal__.test.combinations.NamedObject(
    "rmsprop", lambda: rmsprop.RMSprop(0.002))
sgd_fn = tf.__internal__.test.combinations.NamedObject(
    "sgdaverage",
    lambda: sgd.SGD(  # pylint: disable=g-long-lambda
        0.002,
        use_ema=True,
        ema_overwrite_frequency=1),
)

OPTIMIZER_FN = [