Example #1
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if self._num_micro_batches == 1:
            return self._opt.apply_gradients(grads_and_vars, global_step)
        global_step = global_step or py_utils.GetOrCreateGlobalStepVar()
        with tf.init_scope():
            self._create_slots([v for (_, v) in grads_and_vars])

        accums = []
        variables = []

        for g, v in grads_and_vars:
            accum = self.get_slot(v, 'grad_accum')
            variables.append(v)
            # pytype: disable=attribute-error
            if isinstance(g, tf.IndexedSlices):
                scaled_grad = tf.IndexedSlices(g.values /
                                               self._num_micro_batches,
                                               g.indices,
                                               dense_shape=g.dense_shape)
            else:
                scaled_grad = g / self._num_micro_batches
            accum_tensor = accum.read_value()
            accums.append(accum.assign(accum_tensor + scaled_grad))
            # pytype: enable=attribute-error

        def _ApplyAndReset():
            normalized_accums = accums
            if self._apply_crs_to_grad:
                normalized_accums = [
                    tf.tpu.cross_replica_sum(accum.read_value())
                    for accum in accums
                ]
            apply_op = self._opt.apply_gradients(
                list(zip(normalized_accums, variables)))
            with tf.control_dependencies([apply_op]):
                zero_op = [
                    tf.assign(accum, tf.zeros_like(accum)) for accum in accums
                ]
            return tf.group(zero_op, tf.assign_add(global_step, 1))

        def _Accum():
            return tf.no_op()

        accum_step = tf.cond(
            tf.equal(
                tf.math.floormod(self._counter + 1, self._num_micro_batches),
                0),
            _ApplyAndReset,  # Apply the accumulated gradients and reset.
            _Accum)  # Accumulate gradients.

        with tf.control_dependencies([tf.group(accums)]):
            return tf.group(accum_step, tf.assign_add(self._counter, 1))
Example #2
0
 def _resource_apply_sparse(self, grad, handle, indices):
     return self._resource_apply_dense(
         tf.convert_to_tensor(
             tf.IndexedSlices(grad, indices, tf.shape(handle))), handle)