def testWeightDecay(self): grads, var1, var2, var3 = tf.zeros( ()), tf.Variable(2.0), tf.Variable(2.0), tf.Variable(2.0) optimizer_1 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004) optimizer_1.apply_gradients(zip([grads], [var1])) optimizer_2 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004) optimizer_2.exclude_from_weight_decay([var2]) optimizer_2.apply_gradients(zip([grads], [var2])) optimizer_3 = adamw_new.AdamW(learning_rate=0.001, weight_decay=0.004) optimizer_3.build([var3], exclude_from_weight_decay=[var3]) optimizer_3.apply_gradients(zip([grads], [var3])) self.assertEqual(var1, 1.992) self.assertEqual(var2, 2.0) self.assertEqual(var3, 2.0)
def testWeightDecay(self): grads, var1, var2, var3 = ( tf.zeros(()), tf.Variable(2.0), tf.Variable(2.0, name="exclude"), tf.Variable(2.0), ) optimizer_1 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004) optimizer_1.apply_gradients(zip([grads], [var1])) optimizer_2 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004) optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) optimizer_2.apply_gradients(zip([grads], [var2])) optimizer_3 = adamw_new.AdamW(learning_rate=1, weight_decay=0.004) optimizer_3.exclude_from_weight_decay(var_list=[var3]) optimizer_3.apply_gradients(zip([grads], [var3])) self.assertEqual(var1, 1.992) self.assertEqual(var2, 2.0) self.assertEqual(var3, 2.0)
def testPassingMissingWDError(self): with self.assertRaises(ValueError): _ = adamw_new.AdamW(0.01, weight_decay=None) with self.assertRaisesRegex(ValueError, "Missing value of"): _ = adamw_new.AdamW(0.01, weight_decay=None)
ds_combinations.multi_worker_mirrored_2x2_gpu, ds_combinations.central_storage_strategy_with_two_gpus, ] adadelta_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladadelta", lambda: adadelta_new.Adadelta( # pylint: disable=g-long-lambda 0.002, use_ema=True, ema_overwrite_frequency=None)) adagrad_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladagrad", lambda: adagrad_new.Adagrad(0.002)) adam_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladam", lambda: adam_new.Adam(0.002)) adamw_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentaladamw", lambda: adamw_new.AdamW(0.002, weight_decay=0.004)) rmsprop_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentalrmsprop", lambda: rmsprop_new.RMSprop(0.002)) sgd_new_fn = tf.__internal__.test.combinations.NamedObject( "experimentalsgdaverage", lambda: sgd_new.SGD( # pylint: disable=g-long-lambda 0.002, use_ema=True, ema_overwrite_frequency=1)) OPTIMIZER_FN = [ adadelta_new_fn, adagrad_new_fn, adam_new_fn, adamw_new_fn, rmsprop_new_fn,
adadelta_fn = tf.__internal__.test.combinations.NamedObject( "adadelta", lambda: adadelta.Adadelta( # pylint: disable=g-long-lambda 0.002, use_ema=True, ema_overwrite_frequency=None), ) adagrad_fn = tf.__internal__.test.combinations.NamedObject( "adagrad", lambda: adagrad.Adagrad(0.002)) adam_fn = tf.__internal__.test.combinations.NamedObject( "adam", lambda: adam.Adam(0.002)) adamax_fn = tf.__internal__.test.combinations.NamedObject( "adamax", lambda: adamax.Adamax(0.002)) adamw_fn = tf.__internal__.test.combinations.NamedObject( "adamw", lambda: adamw.AdamW(0.002, weight_decay=0.004)) ftrl_fn = tf.__internal__.test.combinations.NamedObject( "ftrl", lambda: ftrl.Ftrl(0.002)) nadam_fn = tf.__internal__.test.combinations.NamedObject( "experimentnadam", lambda: nadam.Nadam(0.002)) rmsprop_fn = tf.__internal__.test.combinations.NamedObject( "rmsprop", lambda: rmsprop.RMSprop(0.002)) sgd_fn = tf.__internal__.test.combinations.NamedObject( "sgdaverage", lambda: sgd.SGD( # pylint: disable=g-long-lambda 0.002, use_ema=True, ema_overwrite_frequency=1), ) OPTIMIZER_FN = [