Example #1
0
    def build_training_op(self, loss):
        """Get training operation.

    Args:
      loss: a loss function for training.

    Define the optimization operation and perform gradient calculation for both
      TPU/Non-TPU training.

    Returns:
      Computed gradient.
    """
        adam_optimizer = tf.train.AdamOptimizer(
            learning_rate=self._decayed_learning_rate, epsilon=1e-5)
        if self._use_tpu:
            # Notes from: learning/brain/research/dune/examples/v2018_09/train.py
            # If we use TPUs, reduce_mean runs on each chip separately and by default
            # only the loss of the first chip is reported.
            #
            # You can either:
            # - execute this if, which synchronizes the losses
            #   across the chips to obtain the full loss on all samples.
            # - or remove this section, gaining some performance and getting the
            #   loss only from the first chip.
            # compute gradients perform averaging of the loss
            adam_optimizer = tf.tpu.CrossShardOptimizer(adam_optimizer)

            tpu_sum_loss = contrib_tpu.cross_replica_sum(loss /
                                                         self._tpu_num_shards)

            grads_and_vars = adam_optimizer.compute_gradients(
                tpu_sum_loss, self.total_params)
            grads, var = zip(*grads_and_vars)
            sum_grads = []
            sum_vars = []
            for (grad, var) in grads_and_vars:
                if grad is None:
                    sum_grads.append(grad)
                    sum_vars.append(var)
                else:
                    sum_grads.append(
                        contrib_tpu.cross_replica_sum(grad) /
                        self._tpu_num_shards)
                    sum_vars.append(var)
            # calculate sum of grads
            norm_grads, _ = tf.clip_by_global_norm(sum_grads, 0.5)
            grads_and_vars = list(zip(norm_grads, sum_vars))
        else:
            grads_and_vars = adam_optimizer.compute_gradients(
                loss, self.total_params)
            grads, var = zip(*grads_and_vars)
            norm_grads, _ = tf.clip_by_global_norm(grads, 0.5)
            grads_and_vars = list(zip(norm_grads, var))

        return adam_optimizer.apply_gradients(
            grads_and_vars, global_step=tf.train.get_global_step())
Example #2
0
 def tpu_all_sum(tensor):
     return contrib_tpu.cross_replica_sum(tensor, name=name)