Esempio n. 1
0
def get_replica_id():
  """Returns an id number for the current replica, counting from 0.

  If not operating in a supported replicated context this function will return
  0.
  """

  tf_replicator = get_tf_replicator()

  if tf_replicator:
    return tf_replicator.current_replica_id
  elif tf.distribute.has_strategy() and tf.distribute.get_replica_context():
    return tf.distribute.get_replica_context().replica_id_in_sync_group

  # This code below this point is based on
  # TensorTracer._add_replica_id_to_graph().
  num_replicas = get_num_replicas()

  if num_replicas <= 1:
    return 0

  with tf.control_dependencies(None):
    # Uses None as dependency to run outside of TPU graph rewrites.
    return tpu_ops.tpu_replicated_input(list(range(num_replicas)),
                                        name="replica_id")
Esempio n. 2
0
 def pnum_tensor(self):
     if self._pnum_tensor is not None:
         return self._pnum_tensor
     with utils.outside_all_rewrites():
         tf.logging.info("Create pnum_tensor")
         self._pnum_tensor = tpu_ops.tpu_replicated_input(
             self._physical_to_logical, name="pnum_constants")
         return self._pnum_tensor
Esempio n. 3
0
    def pack(self, tensors):
        """Create a tensor on the parallel device from a sequence of tensors.

    Args:
      tensors: A flat list of tensors, one per device in `self.components`.

    Returns:
      A single tensor placed on `self.name`.
    """
        with ops.device(self.name):
            return tpu_ops.tpu_replicated_input(inputs=tensors)
Esempio n. 4
0
def local_tpu_replica_id():
  """Returns the index of the current TPU replica."""
  num_tpu_replicas = tpu_function.get_tpu_context().number_of_shards
  if num_tpu_replicas is not None:
    # Need tf.control_dependencies(None) in order to make sure this is run
    # on CPU (not TPU)
    with tf.control_dependencies(None):
      return tpu_ops.tpu_replicated_input(
          list(range(num_tpu_replicas)), name='local_replica_id')
  else:
    # The non-TPU case.
    return 0
Esempio n. 5
0
 def _pack_tensor(self, *tensors):
     """Helper to pack plain-old-tensors, not structures or composites."""
     for tensor in tensors:
         if not isinstance(tensor,
                           (ops.Tensor, composite_tensor.CompositeTensor,
                            variables.Variable)):
             raise ValueError((
                 "Every component to pack onto the ParallelDevice must already be "
                 "a tensor, got {}. Consider running `tf.constant` or "
                 "`tf.convert_to_tensor` first on literal values."
             ).format(tensors))
     with ops.device(self._name):
         return tpu_ops.tpu_replicated_input(inputs=tensors)
Esempio n. 6
0
 def _pack_tensor(self, *tensors):
   """Helper to pack plain-old-tensors, not structures or composites."""
   for tensor in tensors:
     if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor,
                                variables.Variable)):
       raise ValueError(
           ("Every component to pack onto the ParallelDevice must already be "
            "a tensor, got {}. Consider running `tf.constant` or "
            "`tf.convert_to_tensor` first on literal values.")
           .format(tensors))
   with ops.device(None):
     # Explicitly read variable values. This can not be done on the parallel
     # device since the tensors are to be packed.
     tensors = [t.read_value() if isinstance(t, variables.Variable)
                else t for t in tensors]
   with ops.device(self._name):
     return tpu_ops.tpu_replicated_input(inputs=tensors)
Esempio n. 7
0
    def pack(self, tensors):
        """Create a tensor on the parallel device from a sequence of tensors.

    Args:
      tensors: A list of tensors, one per device in `self.components`. Composite
        tensors with the same structure on each device are accepted.

    Returns:
      A single tensor placed on the ParallelDevice.
    """
        self._assert_eager()
        if len(tensors) != len(self.components):
            raise ValueError((
                "Creating a parallel tensor requires one tensor per component. "
                "Got {} but was expecting {}.").format(len(tensors),
                                                       len(self.components)))
        for tensor in tensors:
            if not isinstance(tensor,
                              (ops.Tensor, composite_tensor.CompositeTensor,
                               variables.Variable)):
                raise ValueError((
                    "Every component must already be a tensor, got {}. Consider "
                    "running `tf.constant` or `tf.convert_to_tensor` first on literal "
                    "values.").format(tensors))
        first_structure = tensors[0]
        flat_tensors = []
        for tensor in tensors:
            nest.assert_same_structure(first_structure,
                                       tensor,
                                       expand_composites=True)
            flat_tensors.append(nest.flatten(tensor, expand_composites=True))
        parallel_tensors = []
        with ops.device(self._name):
            for tensors in zip(*flat_tensors):
                parallel_tensors.append(
                    tpu_ops.tpu_replicated_input(inputs=tensors))
        return nest.pack_sequence_as(first_structure,
                                     parallel_tensors,
                                     expand_composites=True)
Esempio n. 8
0
def shard_id():
    """Get an integer scalar Tensor indicating the index of the current shard."""
    # Prevent the TPU compiler from rewriting this part of the graph.
    with tf.control_dependencies(None):
        return tpu_ops.tpu_replicated_input(list(range(num_tpu_shards())))