def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
    replicas, and then applies the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      An `Operation` that applies the gradients. If `global_step` was not None,
      that operation also increments `global_step`.

    Raises:
      ValueError: If the grads_and_vars is malformed.
    """
    summed_grads_and_vars = []
    for (grad, var) in grads_and_vars:
      if grad is None:
        summed_grads_and_vars.append((grad, var))
      else:
        with ops.colocate_with(grad):
          summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
              grad, self._group_assignment), var))
    return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
    def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and tpu_values.enclosing_tpu_context() is not None:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        value_list = value.values
        # pylint: disable=protected-access
        if isinstance(value, values.DistributedVariable
                      ) and value._packed_variable is not None:
            value_list = tuple(
                value._packed_variable.on_device(d)
                for d in value._packed_variable.devices)
        # pylint: enable=protected-access

        # Currently XLA op by op mode has a limit for the number of inputs for a
        # single op, thus we break one `add_n` op into a group of `add_n` ops to
        # work around the constraint.
        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        if len(value.values) <= _XLA_OP_BY_OP_INPUTS_LIMIT:
            output = math_ops.add_n(value_list)
        else:
            output = array_ops.zeros_like(value_list[0],
                                          dtype=value_list[0].dtype)
            for i in range(0, len(value_list), _XLA_OP_BY_OP_INPUTS_LIMIT):
                output += math_ops.add_n(
                    value_list[i:i + _XLA_OP_BY_OP_INPUTS_LIMIT])

        if reduce_op == reduce_util.ReduceOp.MEAN:
            output *= (1. / len(value_list))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
Exemple #3
0
def cross_replica_average(inputs=None,
                          num_shards=None,
                          distributed_group_size=None):
    """Calculates the average value of inputs tensor across TPU replicas."""
    __group_assignment = None
    if is63(num_shards) and not (distributed_group_size == num_shards):
        group_size = distributed_group_size
        __group_assignment = []
        for g in range(num_shards // group_size):
            __replica_ids = [g * group_size + i for i in range(group_size)]
            add(__group_assignment, __replica_ids)
    return tpu_ops.cross_replica_sum(inputs, __group_assignment) / tf.cast(
        distributed_group_size, inputs.dtype)
Exemple #4
0
 def _cross_replica_average(self, t, num_shards_per_group):
     """Calculates the average value of input tensor across TPU replicas."""
     num_shards = tpu_function.get_tpu_context().number_of_shards
     group_assignment = None
     if num_shards_per_group > 1:
         if num_shards % num_shards_per_group != 0:
             raise ValueError(
                 'num_shards: %d mod shards_per_group: %d, should be 0' %
                 (num_shards, num_shards_per_group))
         num_groups = num_shards // num_shards_per_group
         group_assignment = [[
             x for x in range(num_shards) if x // num_shards_per_group == y
         ] for y in range(num_groups)]
     return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast(
         num_shards_per_group, t.dtype)
Exemple #5
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
    replicas, and then applies the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      An `Operation` that applies the gradients. If `global_step` was not None,
      that operation also increments `global_step`.

    Raises:
      ValueError: If the grads_and_vars is malformed.
    """
        summed_grads = []
        variables = []
        for (grad, var) in grads_and_vars:
            variables.append(var)
            if grad is None:
                summed_grads.append(grad)
            else:
                with ops.colocate_with(grad):
                    summed_grad = tpu_ops.cross_replica_sum(
                        grad, self._group_assignment)
                    if self._skip_nan_grad:
                        summed_grad = self.convert_nan_or_inf_to_zero(
                            summed_grad)
                    summed_grads.append(summed_grad)
        if self._clip is not None and self._clip > 0:
            tf.logging.info("Clip global gradient with norm %.3f.", self._clip)
            clipped_grads, _ = tf.clip_by_global_norm(summed_grads, self._clip)
        else:
            tf.logging.info("Do not clip global gradient.")
            clipped_grads = summed_grads

        train_op = self._opt.apply_gradients(
            list(zip(clipped_grads, variables)), global_step, name)

        return train_op
Exemple #6
0
    def _reduce_to(self, reduce_op, value, destinations):
        if (isinstance(value, values.DistributedValues)
                or tensor_util.is_tensor(value)
            ) and values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, value, destinations, self._num_replicas_in_sync)

        # TODO(cjfj): Detect when it is possible to use `cross_replica_sum`.
        # Always performs the reduction on the TPU host.
        with ops.device(self._host_device):
            output = math_ops.add_n(value.values)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                output *= (1. / len(value.values))

        devices = cross_device_ops_lib.get_devices_from(destinations)

        if len(devices) == 1:
            # If necessary, copy to requested destination.
            dest_canonical = device_util.canonicalize(devices[0])
            host_canonical = device_util.canonicalize(self._host_device)

            if dest_canonical != host_canonical:
                with ops.device(dest_canonical):
                    output = array_ops.identity(output)
        else:
            output = cross_device_ops_lib.simple_broadcast(
                output, destinations)

        return output
Exemple #7
0
    def allreduce(self, x, mesh_axes, reduction_fn_string):
        """Grouped allreduce, (summed across the given dimensions).

    Args:
      x: a LaidOutTensor
      mesh_axes: a list of integers
      reduction_fn_string: "SUM"
    Returns:
      a LaidOutTensor
    Raises:
      ValueError: if the reduction is not yet implemented.
    """
        if not mesh_axes:
            return x
        x = x.to_laid_out_tensor()
        if reduction_fn_string == "SUM":
            group_assignment = self._create_group_assignment(mesh_axes)
            group_size = len(group_assignment[0])
            tf_in = x.one_slice
            dtype = tf_in.dtype
            if dtype == tf.float32:
                cast_to_float32 = False
            elif dtype == tf.bfloat16:
                cast_to_float32 = (group_size >
                                   self._allreduce_in_bfloat16_max_group_size)
            else:
                tf.logging.info("Casting %s to float32 for allreduce" %
                                tf_in.dtype)
                cast_to_float32 = True
            if cast_to_float32:
                tf_in = tf.cast(tf_in, tf.float32)
            tf_out = tpu_ops.cross_replica_sum(tf_in, group_assignment)
            if cast_to_float32:
                tf_out = tf.cast(tf_out, dtype)
            return self.LaidOutTensor([tf_out])
        else:
            for axis in mesh_axes:
                x = self.allconcat(x, axis, 0, stack=True)
                x = self.LaidOutTensor(
                    [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)])
            return x
    def tpu_train_step(loss):
      """Generate the TPU graph."""
      del loss
      values = self.infeed_queue[0].generate_dequeue_op(tpu_device=0)
      unflattened_inputs = data_nest.pack_sequence_as(self.feature_structure,
                                                      values)
      features = unflattened_inputs["features"]
      core_id = unflattened_inputs["core_id"]
      new_features = {}
      for k in features:
        s = features[k].shape.as_list()
        s = [self.hparams.num_shards, s[0] // self.hparams.num_shards] + s[1:]
        new_features[k] = tf.squeeze(
            tf.gather(
                tf.reshape(tpu_ops.cross_replica_sum(features[k]), s), core_id),
            [0])

      estimator_spec = model_fn(new_features, None, tf.estimator.ModeKeys.TRAIN,
                                params)
      loss, train_op = estimator_spec.loss, estimator_spec.train_op
      with tf.control_dependencies([train_op]):
        return tf.identity(loss)
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) != 1:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    # Always performs the reduction on the TPU host.
    with ops.device(self._host_device):
      output = math_ops.add_n(value.values)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        output *= (1. / len(value.values))

    # If necessary, copy to requested destination.
    dest_canonical = device_util.canonicalize(devices[0])
    host_canonical = device_util.canonicalize(self._host_device)

    if dest_canonical != host_canonical:
      with ops.device(dest_canonical):
        output = array_ops.identity(output)

    return output
Exemple #10
0
 def set_masked_grads(self, grads, weights):
     if self._use_tpu:
         grads = [tpu_ops.cross_replica_sum(g) for g in grads]
     self._masked_grads = grads
     # Using names since better to hash.
     self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)}
Exemple #11
0
 def tpu_all_sum(tensor):
   return tpu_ops.cross_replica_sum(tensor, name=name)