def test_step_learning_rate_with_linear_warmup(self):
     params = params_dict.ParamsDict({
         'type': 'step',
         'init_learning_rate': 0.2,
         'warmup_learning_rate': 0.1,
         'warmup_steps': 100,
         'learning_rate_levels': [0.02, 0.002],
         'learning_rate_steps': [200, 400],
     })
     learning_rate_fn = learning_rates.learning_rate_generator(params)
     lr = learning_rate_fn(0).numpy()
     self.assertAlmostEqual(0.1, lr)
     lr = learning_rate_fn(50).numpy()
     self.assertAlmostEqual(0.15, lr)
     lr = learning_rate_fn(100).numpy()
     self.assertAlmostEqual(0.2, lr)
     lr = learning_rate_fn(150).numpy()
     self.assertAlmostEqual(0.2, lr)
     lr = learning_rate_fn(200).numpy()
     self.assertAlmostEqual(0.02, lr)
     lr = learning_rate_fn(300).numpy()
     self.assertAlmostEqual(0.02, lr)
     lr = learning_rate_fn(400).numpy()
     self.assertAlmostEqual(0.002, lr)
     lr = learning_rate_fn(500).numpy()
     self.assertAlmostEqual(0.002, lr)
     lr = learning_rate_fn(600).numpy()
     self.assertAlmostEqual(0.002, lr)
    def __init__(self, params):
        self._use_bfloat16 = params.architecture.use_bfloat16

        # Optimization.
        self._optimizer_fn = OptimizerFactory(params.train.optimizer)
        self._learning_rate = learning_rates.learning_rate_generator(
            params.train.learning_rate)

        self._frozen_variable_prefix = params.train.frozen_variable_prefix

        # Checkpoint restoration.
        self._checkpoint = params.train.checkpoint.as_dict()

        # Summary.
        self._enable_summary = params.enable_summary
        self._model_dir = params.model_dir
  def __init__(self, params):
    self._use_bfloat16 = params.architecture.use_bfloat16

    if params.architecture.use_bfloat16:
      policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
          'mixed_bfloat16')
      tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)

    # Optimization.
    self._optimizer_fn = OptimizerFactory(params.train.optimizer)
    self._learning_rate = learning_rates.learning_rate_generator(
        params.train.learning_rate)

    self._frozen_variable_prefix = params.train.frozen_variable_prefix
    self._l2_weight_decay = params.train.l2_weight_decay

    # Checkpoint restoration.
    self._checkpoint = params.train.checkpoint.as_dict()

    # Summary.
    self._enable_summary = params.enable_summary
    self._model_dir = params.model_dir
 def test_cosine_learning_rate_with_linear_warmup(self):
     params = params_dict.ParamsDict({
         'type': 'cosine',
         'init_learning_rate': 0.2,
         'warmup_learning_rate': 0.1,
         'warmup_steps': 100,
         'total_steps': 1100,
     })
     learning_rate_fn = learning_rates.learning_rate_generator(params)
     lr = learning_rate_fn(0).numpy()
     self.assertAlmostEqual(0.1, lr)
     lr = learning_rate_fn(50).numpy()
     self.assertAlmostEqual(0.15, lr)
     lr = learning_rate_fn(100).numpy()
     self.assertAlmostEqual(0.2, lr)
     lr = learning_rate_fn(350).numpy()
     self.assertAlmostEqual(0.17071067811865476, lr)
     lr = learning_rate_fn(600).numpy()
     self.assertAlmostEqual(0.1, lr)
     lr = learning_rate_fn(850).numpy()
     self.assertAlmostEqual(0.029289321881345254, lr)
     lr = learning_rate_fn(1100).numpy()
     self.assertAlmostEqual(0.0, lr)
Exemple #5
0
    def __init__(self, params):
        self._use_bfloat16 = params.architecture.use_bfloat16

        if params.architecture.use_bfloat16:
            tf.compat.v2.keras.mixed_precision.set_global_policy(
                'mixed_bfloat16')

        # Optimization.
        self._optimizer_fn = optimizers.OptimizerFactory(
            params.train.optimizer)
        self._learning_rate = learning_rates.learning_rate_generator(
            params.train.total_steps, params.train.learning_rate)

        self._frozen_variable_prefix = params.train.frozen_variable_prefix
        self._regularization_var_regex = params.train.regularization_variable_regex
        self._l2_weight_decay = params.train.l2_weight_decay

        # Checkpoint restoration.
        self._checkpoint = params.train.checkpoint.as_dict()

        # Summary.
        self._enable_summary = params.enable_summary
        self._model_dir = params.model_dir