def test_apply_gradients(self):

    x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
    dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
    itr = dataset.make_one_shot_iterator()

    lr = 1
    opt = gd.GradientDescentOptimizer(lr)
    lsm = lsm_lib.FixedLossScaleManager(1.e4)
    opt = lso.LossScaleOptimizer(opt, lsm)
    train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
    if not context.executing_eagerly():
      train_op = train_fn()

    expected_output = [1, 1, 1 - 0.1]
    actual_output = []

    self.evaluate(variables.global_variables_initializer())
    for _ in range(3):
      # nan or inf is not applied.
      if context.executing_eagerly():
        train_fn()
      else:
        self.evaluate(train_op)
      actual_output.append(self.evaluate(x))
    self.assertAllClose(expected_output, actual_output)
    def test_apply_gradients_loss_scale_is_updated(self):
        class SimpleLossScaleManager(lsm_lib.LossScaleManager):
            """A simple loss scale manager for easier testing.

      It increments loss scale by 1 if grads are finite, and decreases loss
      scale by 1 if otherwise.
      """
            def __init__(self, loss_scale):
                self._loss_scale = variable_scope.variable(
                    name="loss_scale",
                    initial_value=loss_scale,
                    dtype=dtypes.float32,
                    trainable=False)

            def get_loss_scale(self):
                return self._loss_scale

            def update_loss_scale(self, if_finite_grads):
                return control_flow_ops.cond(
                    if_finite_grads,
                    lambda: state_ops.assign_add(self._loss_scale, 1),
                    lambda: state_ops.assign_sub(self._loss_scale, 1))

        x = variable_scope.get_variable("x",
                                        initializer=1.,
                                        dtype=dtypes.float32)
        dataset = dataset_ops.Dataset.from_tensor_slices([np.nan, np.inf, 0.1])
        itr = dataset_ops.make_one_shot_iterator(dataset)

        lr = 1
        init_loss_scale = 8
        opt = gd.GradientDescentOptimizer(lr)
        lsm = SimpleLossScaleManager(init_loss_scale)
        opt = lso.LossScaleOptimizer(opt, lsm)
        train_fn = lambda: opt.apply_gradients([(itr.get_next(), x)])
        if not context.executing_eagerly():
            train_op = train_fn()

        self.evaluate(variables.global_variables_initializer())

        expected_loss_scale = [
            init_loss_scale - 1, init_loss_scale - 2, init_loss_scale - 2 + 1
        ]
        expected_output = [1, 1, 1 - 0.1]
        actual_output = []
        for i in range(3):
            # nan or inf is not applied.
            if context.executing_eagerly():
                train_fn()
            else:
                self.evaluate(train_op)
            actual_output.append(self.evaluate(x))
            self.assertAllClose(expected_loss_scale[i],
                                self.evaluate(lsm._loss_scale))
        self.assertAllClose(expected_output, actual_output)
 def loss_scale_opt_fn(opt):
   return lso.LossScaleOptimizer(opt, lsm_lib.FixedLossScaleManager(1e4))