def get_replica_id(): """Returns an id number for the current replica, counting from 0. If not operating in a supported replicated context this function will return 0. """ tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.current_replica_id elif tf.distribute.has_strategy() and tf.distribute.get_replica_context(): return tf.distribute.get_replica_context().replica_id_in_sync_group # This code below this point is based on # TensorTracer._add_replica_id_to_graph(). num_replicas = get_num_replicas() if num_replicas <= 1: return 0 with tf.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. return tpu_ops.tpu_replicated_input(list(range(num_replicas)), name="replica_id")
def pnum_tensor(self): if self._pnum_tensor is not None: return self._pnum_tensor with utils.outside_all_rewrites(): tf.logging.info("Create pnum_tensor") self._pnum_tensor = tpu_ops.tpu_replicated_input( self._physical_to_logical, name="pnum_constants") return self._pnum_tensor
def pack(self, tensors): """Create a tensor on the parallel device from a sequence of tensors. Args: tensors: A flat list of tensors, one per device in `self.components`. Returns: A single tensor placed on `self.name`. """ with ops.device(self.name): return tpu_ops.tpu_replicated_input(inputs=tensors)
def local_tpu_replica_id(): """Returns the index of the current TPU replica.""" num_tpu_replicas = tpu_function.get_tpu_context().number_of_shards if num_tpu_replicas is not None: # Need tf.control_dependencies(None) in order to make sure this is run # on CPU (not TPU) with tf.control_dependencies(None): return tpu_ops.tpu_replicated_input( list(range(num_tpu_replicas)), name='local_replica_id') else: # The non-TPU case. return 0
def _pack_tensor(self, *tensors): """Helper to pack plain-old-tensors, not structures or composites.""" for tensor in tensors: if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): raise ValueError(( "Every component to pack onto the ParallelDevice must already be " "a tensor, got {}. Consider running `tf.constant` or " "`tf.convert_to_tensor` first on literal values." ).format(tensors)) with ops.device(self._name): return tpu_ops.tpu_replicated_input(inputs=tensors)
def _pack_tensor(self, *tensors): """Helper to pack plain-old-tensors, not structures or composites.""" for tensor in tensors: if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): raise ValueError( ("Every component to pack onto the ParallelDevice must already be " "a tensor, got {}. Consider running `tf.constant` or " "`tf.convert_to_tensor` first on literal values.") .format(tensors)) with ops.device(None): # Explicitly read variable values. This can not be done on the parallel # device since the tensors are to be packed. tensors = [t.read_value() if isinstance(t, variables.Variable) else t for t in tensors] with ops.device(self._name): return tpu_ops.tpu_replicated_input(inputs=tensors)
def pack(self, tensors): """Create a tensor on the parallel device from a sequence of tensors. Args: tensors: A list of tensors, one per device in `self.components`. Composite tensors with the same structure on each device are accepted. Returns: A single tensor placed on the ParallelDevice. """ self._assert_eager() if len(tensors) != len(self.components): raise ValueError(( "Creating a parallel tensor requires one tensor per component. " "Got {} but was expecting {}.").format(len(tensors), len(self.components))) for tensor in tensors: if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): raise ValueError(( "Every component must already be a tensor, got {}. Consider " "running `tf.constant` or `tf.convert_to_tensor` first on literal " "values.").format(tensors)) first_structure = tensors[0] flat_tensors = [] for tensor in tensors: nest.assert_same_structure(first_structure, tensor, expand_composites=True) flat_tensors.append(nest.flatten(tensor, expand_composites=True)) parallel_tensors = [] with ops.device(self._name): for tensors in zip(*flat_tensors): parallel_tensors.append( tpu_ops.tpu_replicated_input(inputs=tensors)) return nest.pack_sequence_as(first_structure, parallel_tensors, expand_composites=True)
def shard_id(): """Get an integer scalar Tensor indicating the index of the current shard.""" # Prevent the TPU compiler from rewriting this part of the graph. with tf.control_dependencies(None): return tpu_ops.tpu_replicated_input(list(range(num_tpu_shards())))