def build_training_op(self, loss): """Get training operation. Args: loss: a loss function for training. Define the optimization operation and perform gradient calculation for both TPU/Non-TPU training. Returns: Computed gradient. """ adam_optimizer = tf.train.AdamOptimizer( learning_rate=self._decayed_learning_rate, epsilon=1e-5) if self._use_tpu: # Notes from: learning/brain/research/dune/examples/v2018_09/train.py # If we use TPUs, reduce_mean runs on each chip separately and by default # only the loss of the first chip is reported. # # You can either: # - execute this if, which synchronizes the losses # across the chips to obtain the full loss on all samples. # - or remove this section, gaining some performance and getting the # loss only from the first chip. # compute gradients perform averaging of the loss adam_optimizer = tf.tpu.CrossShardOptimizer(adam_optimizer) tpu_sum_loss = contrib_tpu.cross_replica_sum(loss / self._tpu_num_shards) grads_and_vars = adam_optimizer.compute_gradients( tpu_sum_loss, self.total_params) grads, var = zip(*grads_and_vars) sum_grads = [] sum_vars = [] for (grad, var) in grads_and_vars: if grad is None: sum_grads.append(grad) sum_vars.append(var) else: sum_grads.append( contrib_tpu.cross_replica_sum(grad) / self._tpu_num_shards) sum_vars.append(var) # calculate sum of grads norm_grads, _ = tf.clip_by_global_norm(sum_grads, 0.5) grads_and_vars = list(zip(norm_grads, sum_vars)) else: grads_and_vars = adam_optimizer.compute_gradients( loss, self.total_params) grads, var = zip(*grads_and_vars) norm_grads, _ = tf.clip_by_global_norm(grads, 0.5) grads_and_vars = list(zip(norm_grads, var)) return adam_optimizer.apply_gradients( grads_and_vars, global_step=tf.train.get_global_step())
def tpu_all_sum(tensor): return contrib_tpu.cross_replica_sum(tensor, name=name)