Exemple #1
0
    def _compute_loss(self, input_dict):
        """CTC loss graph construction.

    Expects the following inputs::

      input_dict = {
        "decoder_output": {
          "logits": tensor of shape [batch_size, time length, num features]
          "src_length": tensor of shape [batch_size]
        }
        "tgt_sequence": tensor of shape [batch_size, time length]
        "tgt_length": tensor of shape [batch_size]
      }
    """
        logits = input_dict['decoder_output']['logits']
        tgt_sequence = input_dict['tgt_sequence']
        tgt_length = input_dict['tgt_length']
        # this loss needs an access to src_length since they
        # might get changed in the encoder
        src_length = input_dict['decoder_output']['src_length']

        # Compute the CTC loss
        total_loss = tf.nn.ctc_loss(
            labels=dense_to_sparse(tgt_sequence, tgt_length),
            inputs=logits,
            sequence_length=src_length,
            ignore_longer_outputs_than_inputs=True,
        )

        if self._mask_nan:
            total_loss = mask_nans(total_loss)

        # Calculate the average loss across the batch
        avg_loss = tf.reduce_mean(total_loss)
        return avg_loss
Exemple #2
0
    def _compute_loss(self, input_dict):
        """CTC loss graph construction.

    Expects the following inputs::

      input_dict = {

      }

    Args:
      input_dict (dict): input dictionary that has to contain
          the following fields::
            input_dict = {
              "decoder_output": {
                "logits": tensor, shape [batch_size, time length, tgt_vocab_size]
                "src_length": tensor, shape [batch_size]
              },
              "target_tensors": [
                tgt_sequence (shape=[batch_size, time length, num features]),
                tgt_length (shape=[batch_size])
              ]
            }

    Returns:
      averaged CTC loss.
    """
        CSI = "\x1B["
        print(CSI + "32;40m" + "open_seq2seq/losses/ctc_loss.py 71" + CSI +
              "0m")
        print('input_dict')
        print(input_dict)
        #logits = input_dict['decoder_output']['ctc_outputs']['logits']
        logits = input_dict['decoder_output']['logits']
        tgt_sequence, tgt_length = input_dict['target_tensors']
        # this loss needs an access to src_length since they
        # might get changed in the encoder
        #src_length = input_dict['decoder_output']['ctc_outputs']['src_length']
        src_length = input_dict['decoder_output']['src_length']
        CSI = "\x1B["
        print(CSI + "32;40m" + "open_seq2seq/losses/ctc_loss.py 81" + CSI +
              "0m")
        print('src_length')
        print(src_length)

        # Compute the CTC loss
        total_loss = tf.nn.ctc_loss(
            labels=dense_to_sparse(tgt_sequence, tgt_length),
            inputs=logits,
            sequence_length=src_length,
            ignore_longer_outputs_than_inputs=True,
        )

        if self._mask_nan:
            total_loss = mask_nans(total_loss)

        # Calculate the average loss across the batch
        avg_loss = tf.reduce_mean(total_loss)
        return avg_loss
Exemple #3
0
  def _compute_loss(self, input_dict):
    """CTC loss graph construction.

    Expects the following inputs::

      input_dict = {

      }

    Args:
      input_dict (dict): input dictionary that has to contain
          the following fields::
            input_dict = {
              "decoder_output": {
                "logits": tensor, shape [batch_size, time length, tgt_vocab_size]
                "src_length": tensor, shape [batch_size]
              },
              "target_tensors": [
                tgt_sequence (shape=[batch_size, time length, num features]),
                tgt_length (shape=[batch_size])
              ]
            }

    Returns:
      averaged CTC loss.
    """
    logits = input_dict['decoder_output']['logits']
    tgt_sequence, tgt_length = input_dict['target_tensors']
    # this loss needs an access to src_length since they
    # might get changed in the encoder
    src_length = input_dict['decoder_output']['src_length']

    # Compute the CTC loss
    total_loss = tf.nn.ctc_loss(
      labels=dense_to_sparse(tgt_sequence, tgt_length),
      inputs=logits,
      sequence_length=src_length,
      ignore_longer_outputs_than_inputs=True,
    )

    if self._mask_nan:
      total_loss = mask_nans(total_loss)

    # Calculate the average loss across the batch
    avg_loss = tf.reduce_mean(total_loss)
    return avg_loss
Exemple #4
0
def post_process_gradients(grads_and_vars, summaries, lr, clip_gradients,
                           larc_params):
    """Applies post processing to gradients, i.e. clipping, LARC, summaries."""
    if "global_gradient_norm" in summaries:
        tf.summary.scalar(
            "global_gradient_norm",
            _global_norm_with_cast(grads_and_vars),
        )

    # Optionally clip gradients by global norm.
    if clip_gradients is not None:
        grads_and_vars = _clip_gradients_by_norm(grads_and_vars,
                                                 clip_gradients)

    # Add histograms for variables, gradients and gradient norms.
    for gradient, variable in grads_and_vars:
        if isinstance(gradient, tf.IndexedSlices):
            grad_values = gradient.values
        else:
            grad_values = gradient

        if isinstance(variable, tf.IndexedSlices):
            var_values = variable.values
        else:
            var_values = variable

        if grad_values is not None:
            var_name = variable.name.replace(":", "_")
            if "gradients" in summaries:
                # need to mask nans for automatic loss scaling
                tf.summary.histogram("gradients/%s" % var_name,
                                     mask_nans(grad_values))
            if "gradient_norm" in summaries:
                tf.summary.scalar("gradient_norm/%s" % var_name,
                                  tf.norm(grad_values))
            if "variables" in summaries:
                tf.summary.histogram("variables/%s" % var_name, var_values)
            if "variable_norm" in summaries:
                tf.summary.scalar("variable_norm/%s" % var_name,
                                  tf.norm(var_values))

    if clip_gradients is not None and "global_gradient_norm" in summaries:
        tf.summary.scalar(
            "global_clipped_gradient_norm",
            _global_norm_with_cast(grads_and_vars),
        )

    # LARC gradient re-scaling
    if larc_params is not None:
        check_params(
            config=larc_params,
            required_dict={'larc_eta': float},
            optional_dict={
                'larc_mode': ['clip', 'scale'],
                'min_update': float,
                'epsilon': float
            },
        )
        larc_eta = larc_params['larc_eta']
        larc_mode = larc_params.get('larc_mode', 'clip')
        min_update = larc_params.get('min_update', 1e-7)
        eps = larc_params.get('epsilon', 1e-7)

        grads_and_vars_larc = [None] * len(grads_and_vars)
        for idx, (g, v) in enumerate(grads_and_vars):
            var_dtype = v.dtype
            v_norm = tf.norm(tensor=tf.cast(v, tf.float32), ord=2)
            g_norm = tf.norm(tensor=tf.cast(g, tf.float32), ord=2)

            if larc_mode == 'clip':
                larc_grad_update = tf.maximum(
                    larc_eta * v_norm / (lr * (g_norm + eps)),
                    min_update,
                )
                if "larc_summaries" in summaries:
                    tf.summary.scalar(
                        'larc_clip_on/{}'.format(v.name),
                        tf.cast(tf.less(larc_grad_update, 1.0), tf.int32))
                larc_grad_update = tf.minimum(larc_grad_update, 1.0)
            else:
                larc_grad_update = tf.maximum(
                    larc_eta * v_norm / (g_norm + eps),
                    min_update,
                )
            larc_grad_update = tf.saturate_cast(larc_grad_update, var_dtype)
            grads_and_vars_larc[idx] = (larc_grad_update * g, v)

            # adding additional summary
            if "larc_summaries" in summaries:
                tf.summary.scalar('larc_grad_update/{}'.format(v.name),
                                  larc_grad_update)
                tf.summary.scalar("larc_final_lr/{}".format(v.name),
                                  tf.cast(lr, var_dtype) * larc_grad_update)
        grads_and_vars = grads_and_vars_larc
    return grads_and_vars
Exemple #5
0
def optimize_loss(loss,
                  learning_rate,
                  optimizer,
                  optimizer_params,
                  global_step=None,
                  dtype=tf.float32,
                  gradient_noise_scale=None,
                  gradient_multipliers=None,
                  clip_gradients=None,
                  learning_rate_decay_fn=None,
                  update_ops=None,
                  variables=None,
                  name=None,
                  summaries=None,
                  colocate_gradients_with_ops=False,
                  increment_global_step=True,
                  LARC_nu=None,
                  LARC_mode='clip',
                  loss_scale=1.0,
                  automatic_loss_scaling=None,
                  on_horovod=False):
    """Given loss and parameters for optimizer, returns a training op.

  Various ways of passing optimizers include:

  - by string specifying the name of the optimizer. See OPTIMIZER_CLS_NAMES
      for full list. E.g. `optimize_loss(..., optimizer='Adam')`.
  - by function taking learning rate `Tensor` as argument and returning an
      `Optimizer` instance. E.g. `optimize_loss(...,
      optimizer=lambda lr: tf.train.MomentumOptimizer(lr, momentum=0.5))`.
    Alternatively, if `learning_rate` is `None`, the function takes no
    arguments. E.g. `optimize_loss(..., learning_rate=None,
      optimizer=lambda: tf.train.MomentumOptimizer(0.5, momentum=0.5))`.
  - by a subclass of `Optimizer` having a single-argument constructor
      (the argument is the learning rate), such as AdamOptimizer or
      AdagradOptimizer. E.g. `optimize_loss(...,
      optimizer=tf.train.AdagradOptimizer)`.
  - by an instance of a subclass of `Optimizer`.
      E.g., `optimize_loss(..., optimizer=tf.train.AdagradOptimizer(0.5))`.

  Args:
    loss: Scalar `Tensor`.
    global_step: Scalar int `Tensor`, step counter to update on each step
                 unless `increment_global_step` is `False`. If not supplied,
                 it will be fetched from the default graph (see
                 `tf.train.get_global_step` for details). If it has
                 not been created, no step will be incremented with each weight
                 update. `learning_rate_decay_fn` requires `global_step`.
    learning_rate: float or `Tensor`, magnitude of update per each training
                   step. Can be `None`.
    optimizer: string, class or optimizer instance, used as trainer.
               string should be name of optimizer, like 'SGD',
                 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant.
               class should be sub-class of `tf.Optimizer` that implements
                 `compute_gradients` and `apply_gradients` functions.
               optimizer instance should be instantiation of `tf.Optimizer`
                 sub-class and have `compute_gradients` and `apply_gradients`
                 functions.
    gradient_noise_scale: float or None, adds 0-mean normal noise scaled by this
                          value.
    gradient_multipliers: dict of variables or variable names to floats.
                          If present, gradients for specified
                          variables will be multiplied by given constant.
    clip_gradients: float, callable or `None`. If float, is provided, a global
      clipping is applied to prevent the norm of the gradient to exceed this
      value. Alternatively, a callable can be provided e.g.: adaptive_clipping.
      This callable takes a `list` of `(gradients, variables)` `tuple`s and
      returns the same thing with the gradients modified.
    learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
                            `Tensor`s, returns `Tensor`.
                            Can be used to implement any learning rate decay
                            functions.
                            For example: `tf.train.exponential_decay`.
                            Ignored if `learning_rate` is not supplied.
    update_ops: list of update `Operation`s to execute at each step. If `None`,
                uses elements of UPDATE_OPS collection. The order of execution
                between `update_ops` and `loss` is non-deterministic.
    variables: list of variables to optimize or
               `None` to use all trainable variables.
    name: The name for this operation is used to scope operations and summaries.
    summaries: List of internal quantities to visualize on tensorboard. If not
               set only the loss and the learning rate will be reported. The
               complete list is in OPTIMIZER_SUMMARIES.
    colocate_gradients_with_ops: If True, try colocating gradients with the
                                 corresponding op.
    increment_global_step: Whether to increment `global_step`. If your model
      calls `optimize_loss` multiple times per training step (e.g. to optimize
      different parts of the model), use this arg to avoid incrementing
      `global_step` more times than necessary.
    LARC_mode: 'scale' or 'clip'
    LARC_nu: If not None, LARC re-scaling will be
             applied https://arxiv.org/pdf/1708.03888.pdf with nu=LARC_nu
    automatic_loss_scaling: if not None, use the corresponding automatic
                            loss scaling algorithm. Must be one of 'Backoff'
                            of 'LogMax'. `dtype` must be "mixed" to use ALS.
  Returns:
    Training op.

  Raises:
    ValueError: if:
        * `loss` is an invalid type or shape.
        * `global_step` is an invalid type or shape.
        * `learning_rate` is an invalid type or value.
        * `optimizer` has the wrong type.
        * `clip_gradients` is neither float nor callable.
        * `learning_rate` and `learning_rate_decay_fn` are supplied, but no
          `global_step` is available.
        * `gradients` is empty.
  """
    loss = ops.convert_to_tensor(loss)
    contrib_framework.assert_scalar(loss)
    if global_step is None:
        global_step = tf.train.get_or_create_global_step()
    else:
        tf.train.assert_global_step(global_step)
    with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]):
        # Update ops take UPDATE_OPS collection if not provided.
        if update_ops is None:
            update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
        # Make sure update ops are ran before computing loss.
        if update_ops:
            loss = control_flow_ops.with_dependencies(list(update_ops), loss)

        # Learning rate variable, with possible decay.
        lr = None
        if learning_rate is not None:
            if isinstance(learning_rate, ops.Tensor) and \
               learning_rate.get_shape().ndims == 0:
                lr = learning_rate
            elif isinstance(learning_rate, float):
                if learning_rate < 0.0:
                    raise ValueError("Invalid learning_rate %s.",
                                     learning_rate)
                lr = vs.get_variable(
                    "learning_rate", [],
                    trainable=False,
                    initializer=init_ops.constant_initializer(learning_rate))
            else:
                raise ValueError(
                    "Learning rate should be 0d Tensor or float. "
                    "Got %s of type %s" %
                    (str(learning_rate), str(type(learning_rate))))

        if summaries is None:
            summaries = ["learning_rate", "global_gradient_norm"]
        else:
            for summ in summaries:
                if summ not in OPTIMIZER_SUMMARIES:
                    raise ValueError(
                        "Summaries should be one of [%s], you provided %s." %
                        (", ".join(OPTIMIZER_SUMMARIES), summ))
        if learning_rate is not None and learning_rate_decay_fn is not None:
            if global_step is None:
                raise ValueError(
                    "global_step is required for learning_rate_decay_fn.")
            lr = learning_rate_decay_fn(lr, global_step)
            if "learning_rate" in summaries:
                summary.scalar("learning_rate", lr)

        # Create optimizer, given specified parameters.
        if isinstance(optimizer, six.string_types):
            if lr is None:
                raise ValueError(
                    "Learning rate is None, but should be specified if "
                    "optimizer is string (%s)." % optimizer)
            if optimizer not in OPTIMIZER_CLS_NAMES:
                raise ValueError(
                    "Optimizer name should be one of [%s], you provided %s." %
                    (", ".join(OPTIMIZER_CLS_NAMES), optimizer))
            opt = OPTIMIZER_CLS_NAMES[optimizer](learning_rate=lr,
                                                 **optimizer_params)
        elif (isinstance(optimizer, type)
              and issubclass(optimizer, optimizer_.Optimizer)):
            if lr is None:
                raise ValueError(
                    "Learning rate is None, but should be specified if "
                    "optimizer is class (%s)." % optimizer)
            opt = optimizer(learning_rate=lr, **optimizer_params)
        elif isinstance(optimizer, optimizer_.Optimizer):
            opt = optimizer
        elif callable(optimizer):
            if learning_rate is not None:
                opt = optimizer(lr, **optimizer_params)
            else:
                opt = optimizer(**optimizer_params)
            if not isinstance(opt, optimizer_.Optimizer):
                raise ValueError(
                    "Unrecognized optimizer: function should return "
                    "subclass of Optimizer. Got %s." % str(opt))
        else:
            raise ValueError(
                "Unrecognized optimizer: should be string, "
                "subclass of Optimizer, instance of "
                "subclass of Optimizer or function with one argument. "
                "Got %s." % str(optimizer))
        if on_horovod:
            import horovod.tensorflow as hvd
            opt = hvd.DistributedOptimizer(opt)
        # All trainable variables, if specific variables are not specified.
        if variables is None:
            variables = vars_.trainable_variables()

        if automatic_loss_scaling is not None:
            if not automatic_loss_scaling in AutomaticLossScaler.SUPPORTED_ALGOS:
                raise ValueError(
                    "Unknown automatic loss scaling algorithm: %s." %
                    automatic_loss_sclaing)
            if dtype != "mixed":
                raise ValueError(
                    "Automatic loss scaling can be used only with "
                    "dtype=mixed.")
            loss_scaler = AutomaticLossScaler(algorithm=automatic_loss_scaling)
        else:
            loss_scaler = None

        if dtype == 'mixed':
            opt = MixedPrecisionOptimizerWrapper(
                opt,
                automatic_loss_scaler=loss_scaler,
            )

        # Compute gradients.
        gradients = opt.compute_gradients(
            loss if loss_scale == 1.0 else loss * loss_scale,
            variables,
            colocate_gradients_with_ops=colocate_gradients_with_ops)

        if loss_scale != 1.0:
            gradients = _multiply_gradients_const(gradients, 1.0 / loss_scale)

        # Optionally add gradient noise.
        if gradient_noise_scale is not None:
            gradients = _add_scaled_noise_to_gradients(gradients,
                                                       gradient_noise_scale)

        # Multiply some gradients.
        if gradient_multipliers is not None:
            gradients = _multiply_gradients(gradients, gradient_multipliers)
            if not gradients:
                raise ValueError(
                    "Empty list of (gradient, var) pairs encountered. This is most "
                    "likely to be caused by an improper value of gradient_multipliers."
                )

        if "global_gradient_norm" in summaries or "gradient_norm" in summaries:
            summary.scalar(
                "global_norm/gradient_norm",
                clip_ops.global_norm(
                    list(
                        map(lambda x: tf.cast(x, tf.float32),
                            list(zip(*gradients))[0]))),
            )

        # Optionally clip gradients by global norm.
        if clip_gradients is not None and LARC_nu is not None:
            raise AttributeError(
                "LARC and gradient norm clipping should not be used together")
        if isinstance(clip_gradients, float):
            gradients = _clip_gradients_by_norm(gradients, clip_gradients)
        elif callable(clip_gradients):
            gradients = clip_gradients(gradients)
        elif clip_gradients is not None:
            raise ValueError("Unknown type %s for clip_gradients" %
                             type(clip_gradients))

        # Add histograms for variables, gradients and gradient norms.
        for gradient, variable in gradients:
            if isinstance(gradient, ops.IndexedSlices):
                grad_values = gradient.values
            else:
                grad_values = gradient

            if isinstance(variable, ops.IndexedSlices):
                var_values = variable.values
            else:
                var_values = variable

            if grad_values is not None:
                var_name = variable.name.replace(":", "_")
                if "gradients" in summaries:
                    summary.histogram("gradients/%s" % var_name,
                                      mask_nans(grad_values))
                if "gradient_norm" in summaries:
                    summary.scalar("gradient_norm/%s" % var_name,
                                   clip_ops.global_norm([grad_values]))
                if "variables" in summaries:
                    summary.histogram("variables/%s" % var_name, var_values)
                if "variable_norm" in summaries:
                    summary.scalar("variable_norm/%s" % var_name,
                                   clip_ops.global_norm([var_values]))

        if clip_gradients is not None and ("global_gradient_norm" in summaries
                                           or "gradient_norm" in summaries):
            summary.scalar(
                "global_norm/clipped_gradient_norm",
                clip_ops.global_norm(
                    list(
                        map(lambda x: tf.cast(x, tf.float32),
                            list(zip(*gradients))[0]))),
            )

        # LARC gradient re-scaling
        if LARC_nu is not None and isinstance(LARC_nu, float):
            for idx, (g, v) in enumerate(gradients):
                var_dtype = v.dtype
                v_norm = tf.norm(tensor=tf.cast(v, tf.float32), ord=2)
                g_norm = tf.norm(tensor=tf.cast(g, tf.float32), ord=2)

                zero_const = tf.constant(0.0, dtype=tf.float32)
                # this condition is necessary for two reasons.
                # First, g_norm can be zero. In that case we can not use the usual
                # formula, but it does not matter what larc_local_lr is, since it will
                # be multiplied by g = 0.
                # Second, v_norm can be zero. In that case, we want to switch to the
                # usual gradient descent since the usual LARC update will always be 0.
                # Thus, we make larc_local_lr = 1e-5 so that
                # we move a little bit away from zero and can use LARC again
                larc_local_lr = tf.cond(
                    pred=tf.logical_and(
                        tf.greater(v_norm, zero_const),
                        tf.greater(g_norm, zero_const),
                    ),
                    true_fn=lambda: LARC_nu * v_norm / g_norm,
                    false_fn=lambda: tf.Print(
                        tf.constant(1e-5, dtype=tf.float32), [],
                        "g_norm = 0 or v_norm = 0 for {}".format(v.name)),
                )
                if LARC_mode == 'clip':
                    summary.scalar(
                        'switched_to_global/{}'.format(v.name),
                        tf.cast(tf.greater(larc_local_lr, lr), tf.int32))
                    larc_local_lr = tf.cond(
                        pred=tf.less(larc_local_lr, lr),
                        true_fn=lambda: larc_local_lr / lr,
                        false_fn=lambda: tf.constant(1.0, dtype=tf.float32),
                    )
                larc_local_lr = tf.saturate_cast(larc_local_lr, var_dtype)
                summary.scalar('larc_local_lr/{}'.format(v.name),
                               larc_local_lr)
                summary.scalar('larc_effective_lr/{}'.format(v.name),
                               larc_local_lr * tf.cast(lr, var_dtype))
                gradients[idx] = (larc_local_lr * g, v)

        # Create gradient updates.
        grad_updates = opt.apply_gradients(
            gradients,
            global_step=global_step if increment_global_step else None,
            name="train")

        # # Ensure the train_tensor computes grad_updates.
        train_tensor = control_flow_ops.with_dependencies([grad_updates], loss)

        return train_tensor
Exemple #6
0
def optimize_loss(loss,
                  optimizer,
                  optimizer_params,
                  learning_rate_decay_fn,
                  global_step=None,
                  dtype=tf.float32,
                  gradient_noise_scale=None,
                  gradient_multipliers=None,
                  clip_gradients=None,
                  update_ops=None,
                  variables=None,
                  name=None,
                  summaries=None,
                  colocate_gradients_with_ops=False,
                  increment_global_step=True,
                  larc_params=None,
                  loss_scale=1.0,
                  automatic_loss_scaling=None,
                  on_horovod=False):
  """Given loss and parameters for optimizer, returns a training op.

  Various ways of passing optimizers include:

  - by string specifying the name of the optimizer. See OPTIMIZER_CLS_NAMES
      for full list. E.g. `optimize_loss(..., optimizer='Adam')`.
  - by function taking learning rate `Tensor` as argument and returning an
      `Optimizer` instance. E.g. `optimize_loss(...,
      optimizer=lambda lr: tf.train.MomentumOptimizer(lr, momentum=0.5))`.
    Alternatively, if `learning_rate` is `None`, the function takes no
    arguments. E.g. `optimize_loss(..., learning_rate=None,
      optimizer=lambda: tf.train.MomentumOptimizer(0.5, momentum=0.5))`.
  - by a subclass of `Optimizer` having a single-argument constructor
      (the argument is the learning rate), such as AdamOptimizer or
      AdagradOptimizer. E.g. `optimize_loss(...,
      optimizer=tf.train.AdagradOptimizer)`.
  - by an instance of a subclass of `Optimizer`.
      E.g., `optimize_loss(..., optimizer=tf.train.AdagradOptimizer(0.5))`.

  Args:
    loss: Scalar `Tensor`.
    global_step: Scalar int `Tensor`, step counter to update on each step
                 unless `increment_global_step` is `False`. If not supplied,
                 it will be fetched from the default graph (see
                 `tf.train.get_global_step` for details). If it has
                 not been created, no step will be incremented with each weight
                 update. `learning_rate_decay_fn` requires `global_step`.
    learning_rate: float or `Tensor`, magnitude of update per each training
                   step. Can be `None`.
    optimizer: string, class or optimizer instance, used as trainer.
               string should be name of optimizer, like 'SGD',
                 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant.
               class should be sub-class of `tf.Optimizer` that implements
                 `compute_gradients` and `apply_gradients` functions.
               optimizer instance should be instantiation of `tf.Optimizer`
                 sub-class and have `compute_gradients` and `apply_gradients`
                 functions.
    gradient_noise_scale: float or None, adds 0-mean normal noise scaled by this
                          value.
    gradient_multipliers: dict of variables or variable names to floats.
                          If present, gradients for specified
                          variables will be multiplied by given constant.
    clip_gradients: float, callable or `None`. If float, is provided, a global
      clipping is applied to prevent the norm of the gradient to exceed this
      value. Alternatively, a callable can be provided e.g.: adaptive_clipping.
      This callable takes a `list` of `(gradients, variables)` `tuple`s and
      returns the same thing with the gradients modified.
    learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
                            `Tensor`s, returns `Tensor`.
                            Can be used to implement any learning rate decay
                            functions.
                            For example: `tf.train.exponential_decay`.
                            Ignored if `learning_rate` is not supplied.
    update_ops: list of update `Operation`s to execute at each step. If `None`,
                uses elements of UPDATE_OPS collection. The order of execution
                between `update_ops` and `loss` is non-deterministic.
    variables: list of variables to optimize or
               `None` to use all trainable variables.
    name: The name for this operation is used to scope operations and summaries.
    summaries: List of internal quantities to visualize on tensorboard. If not
               set only the loss and the learning rate will be reported. The
               complete list is in OPTIMIZER_SUMMARIES.
    colocate_gradients_with_ops: If True, try colocating gradients with the
                                 corresponding op.
    increment_global_step: Whether to increment `global_step`. If your model
      calls `optimize_loss` multiple times per training step (e.g. to optimize
      different parts of the model), use this arg to avoid incrementing
      `global_step` more times than necessary.
    LARC_mode: 'scale' or 'clip'
    LARC_nu: If not None, LARC re-scaling will be
             applied https://arxiv.org/pdf/1708.03888.pdf with nu=LARC_nu
    automatic_loss_scaling: if not None, use the corresponding automatic
                            loss scaling algorithm. Must be one of 'Backoff'
                            of 'LogMax'. `dtype` must be "mixed" to use ALS.
  Returns:
    Training op.

  Raises:
    ValueError: if:
        * `loss` is an invalid type or shape.
        * `global_step` is an invalid type or shape.
        * `learning_rate` is an invalid type or value.
        * `optimizer` has the wrong type.
        * `clip_gradients` is neither float nor callable.
        * `learning_rate` and `learning_rate_decay_fn` are supplied, but no
          `global_step` is available.
        * `gradients` is empty.
  """
  loss = ops.convert_to_tensor(loss)
  contrib_framework.assert_scalar(loss)
  if global_step is None:
    global_step = tf.train.get_or_create_global_step()
  else:
    tf.train.assert_global_step(global_step)
  with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]):
    # Update ops take UPDATE_OPS collection if not provided.
    if update_ops is None:
      update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    # Make sure update ops are ran before computing loss.
    if update_ops:
      loss = control_flow_ops.with_dependencies(list(update_ops), loss)

    if summaries is None:
      summaries = ["learning_rate", "global_gradient_norm"]
    else:
      for summ in summaries:
        if summ not in OPTIMIZER_SUMMARIES:
          raise ValueError("Summaries should be one of [%s], you provided %s." %
                           (", ".join(OPTIMIZER_SUMMARIES), summ))
    if global_step is None:
      raise ValueError("global_step is required for learning_rate_decay_fn.")
    lr = learning_rate_decay_fn(global_step)

    if "learning_rate" in summaries:
      summary.scalar("learning_rate", lr)

    # Create optimizer, given specified parameters.
    if isinstance(optimizer, six.string_types):
      if lr is None:
        raise ValueError("Learning rate is None, but should be specified if "
                         "optimizer is string (%s)." % optimizer)
      if optimizer not in OPTIMIZER_CLS_NAMES:
        raise ValueError(
            "Optimizer name should be one of [%s], you provided %s." %
            (", ".join(OPTIMIZER_CLS_NAMES), optimizer))
      opt = OPTIMIZER_CLS_NAMES[optimizer](learning_rate=lr, **optimizer_params)
    elif (isinstance(optimizer, type) and
          issubclass(optimizer, optimizer_.Optimizer)):
      if lr is None:
        raise ValueError("Learning rate is None, but should be specified if "
                         "optimizer is class (%s)." % optimizer)
      opt = optimizer(learning_rate=lr, **optimizer_params)
    elif isinstance(optimizer, optimizer_.Optimizer):
      opt = optimizer
    elif callable(optimizer):
      if lr is not None:
        opt = optimizer(lr, **optimizer_params)
      else:
        opt = optimizer(**optimizer_params)
      if not isinstance(opt, optimizer_.Optimizer):
        raise ValueError("Unrecognized optimizer: function should return "
                         "subclass of Optimizer. Got %s." % str(opt))
    else:
      raise ValueError("Unrecognized optimizer: should be string, "
                       "subclass of Optimizer, instance of "
                       "subclass of Optimizer or function with one argument. "
                       "Got %s." % str(optimizer))
    # All trainable variables, if specific variables are not specified.
    if variables is None:
      variables = vars_.trainable_variables()

    if automatic_loss_scaling is not None:
      if automatic_loss_scaling not in AutomaticLossScaler.SUPPORTED_ALGOS:
        raise ValueError("Unknown automatic loss scaling algorithm: %s."
                         % automatic_loss_sclaing)
      if dtype != "mixed":
        raise ValueError("Automatic loss scaling can be used only with "
                         "dtype=mixed.")
      loss_scale = AutomaticLossScaler(algorithm=automatic_loss_scaling)

    if dtype == 'mixed':
      opt = MixedPrecisionOptimizerWrapper(opt, loss_scale=loss_scale)
    if on_horovod:
      opt = DistributedOptimizer(opt)

    # Compute gradients.
    gradients = opt.compute_gradients(
      loss, variables,
      colocate_gradients_with_ops=colocate_gradients_with_ops,
    )

    # Optionally add gradient noise.
    if gradient_noise_scale is not None:
      gradients = _add_scaled_noise_to_gradients(gradients,
                                                 gradient_noise_scale)

    # Multiply some gradients.
    if gradient_multipliers is not None:
      gradients = _multiply_gradients(gradients, gradient_multipliers)
      if not gradients:
        raise ValueError(
            "Empty list of (gradient, var) pairs encountered. This is most "
            "likely to be caused by an improper value of gradient_multipliers.")

    if "global_gradient_norm" in summaries or "gradient_norm" in summaries:
      summary.scalar(
        "global_norm/gradient_norm",
        clip_ops.global_norm(list(map(
          lambda x: tf.cast(x, tf.float32),
          list(zip(*gradients))[0])
        )),
      )

    # Optionally clip gradients by global norm.
    if clip_gradients is not None and larc_params is not None:
      raise AttributeError(
        "LARC and gradient norm clipping should not be used together"
      )
    if isinstance(clip_gradients, float):
      gradients = _clip_gradients_by_norm(gradients, clip_gradients)
    elif callable(clip_gradients):
      gradients = clip_gradients(gradients)
    elif clip_gradients is not None:
      raise ValueError(
          "Unknown type %s for clip_gradients" % type(clip_gradients))

    # Add histograms for variables, gradients and gradient norms.
    for gradient, variable in gradients:
      if isinstance(gradient, ops.IndexedSlices):
        grad_values = gradient.values
      else:
        grad_values = gradient

      if isinstance(variable, ops.IndexedSlices):
        var_values = variable.values
      else:
        var_values = variable

      if grad_values is not None:
        var_name = variable.name.replace(":", "_")
        if "gradients" in summaries:
          summary.histogram("gradients/%s" % var_name, mask_nans(grad_values))
        if "gradient_norm" in summaries:
          summary.scalar("gradient_norm/%s" % var_name,
                         clip_ops.global_norm([grad_values]))
        if "variables" in summaries:
          summary.histogram("variables/%s" % var_name, var_values)
        if "variable_norm" in summaries:
          summary.scalar("variable_norm/%s" % var_name,
                         clip_ops.global_norm([var_values]))

    if clip_gradients is not None and ("global_gradient_norm" in summaries or
                                       "gradient_norm" in summaries):
      summary.scalar(
        "global_norm/clipped_gradient_norm",
        clip_ops.global_norm(list(map(
          lambda x: tf.cast(x, tf.float32),
          list(zip(*gradients))[0])
        )),
      )

    # LARC gradient re-scaling
    if larc_params is not None:
      check_params(
        config=larc_params,
        required_dict={'larc_eta': float},
        optional_dict={
          'larc_mode': ['clip', 'scale'],
          'min_update': float,
          'epsilon': float
        },
      )
      larc_eta = larc_params['larc_eta']
      larc_mode = larc_params.get('larc_mode', 'clip')
      min_update = larc_params.get('min_update', 1e-7)
      eps = larc_params.get('epsilon', 1e-7)

      for idx, (g, v) in enumerate(gradients):
        var_dtype = v.dtype
        v_norm = tf.norm(tensor=tf.cast(v, tf.float32), ord=2)
        g_norm = tf.norm(tensor=tf.cast(g, tf.float32), ord=2)

        if larc_mode == 'clip':
          larc_grad_update = tf.maximum(
            larc_eta * v_norm / (lr * (g_norm + eps)),
            min_update,
          )
          if "larc_summaries" in summaries:
            summary.scalar('larc_clip_on/{}'.format(v.name),
                           tf.cast(tf.less(larc_grad_update, 1.0), tf.int32))
          larc_grad_update = tf.minimum(larc_grad_update, 1.0)
        else:
          larc_grad_update = tf.maximum(
            larc_eta * v_norm / (g_norm + eps),
            min_update,
          )
        larc_grad_update = tf.saturate_cast(larc_grad_update, var_dtype)
        gradients[idx] = (larc_grad_update * g, v)

        # adding additional summary
        if "larc_summaries" in summaries:
          summary.scalar('larc_grad_update/{}'.format(v.name), larc_grad_update)
          summary.scalar("larc_final_lr/{}".format(v.name),
                         tf.cast(lr, var_dtype) * larc_grad_update)

    # Create gradient updates.
    grad_updates = opt.apply_gradients(
        gradients,
        global_step=global_step if increment_global_step else None,
        name="train")

    # # Ensure the train_tensor computes grad_updates.
    train_tensor = control_flow_ops.with_dependencies([grad_updates], loss)

    return train_tensor