def get_replicated_var_handle(self, name, vars_): """Returns a variable handle for replicated TPU variable 'var'. This is a method used by an experimental replicated variable implementation and is not intended as a public API. Args: name: The common name of the variable. vars_: The replicated TPU variables. Returns: The handle of the TPU replicated input node. """ handle = self._replicated_vars.get(name) if handle is not None: return handle # Builds a TPUReplicatedInput node for the variable, if one does not already # exist. The TPUReplicatedInput node must belong to the enclosing # control-flow scope of the TPUReplicateContext. # TODO(phawkins): consider changing the contract of the TPU encapsulation # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope # instead. # pylint: disable=protected-access graph = ops.get_default_graph() saved_context = graph._get_control_flow_context() graph._set_control_flow_context(self.outer_context) handle = tpu_ops.tpu_replicated_input([v.handle for v in vars_], name=name + "/handle") graph._set_control_flow_context(saved_context) # pylint: enable=protected-access self._replicated_vars[name] = handle return handle
def get_replicated_var_handle(self, var): """Returns a variable handle for replicated TPU variable 'var'. This is a method used by an experimental replicated variable implementation and is not intended as a public API. Args: var: The replicated TPU variable. Returns: The handle of the TPU replicated input node. """ handle = self._replicated_vars.get(var) if handle is not None: return handle # Builds a TPUReplicatedInput node for the variable, if one does not already # exist. The TPUReplicatedInput node must belong to the enclosing # control-flow scope of the TPUReplicateContext. # TODO(phawkins): consider changing the contract of the TPU encapsulation # so the TPUReplicatedInput nodes go inside the TPUReplicateContext scope # instead. # pylint: disable=protected-access graph = ops.get_default_graph() saved_context = graph._get_control_flow_context() graph._set_control_flow_context(self.outer_context) handle = tpu_ops.tpu_replicated_input( [v.handle for v in var._vars], name=var.name + "/handle") graph._set_control_flow_context(saved_context) # pylint: enable=protected-access self._replicated_vars[var] = handle return handle
def get_replica_id(): """Returns an id number for the current replica, counting from 0. If not operating in a supported replicated context this function will return 0. """ tf_replicator = get_tf_replicator() if tf_replicator: return tf_replicator.current_replica_id elif tf.distribute.has_strategy() and tf.distribute.get_replica_context(): return tf.distribute.get_replica_context().replica_id_in_sync_group # This code below this point is based on # TensorTracer._add_replica_id_to_graph(). num_replicas = get_num_replicas() if num_replicas <= 1: return 0 with tf.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. return tpu_ops.tpu_replicated_input(list(range(num_replicas)), name="replica_id")
def pnum_tensor(self): if self._pnum_tensor is not None: return self._pnum_tensor with mtf_utils.outside_all_rewrites(): tf.logging.info("Create pnum_tensor") self._pnum_tensor = tpu_ops.tpu_replicated_input( list(range(self.size)), name="pnum_constants") return self._pnum_tensor
def pnum_tensor(self): if self._pnum_tensor is not None: return self._pnum_tensor with utils.outside_all_rewrites(): tf.logging.info("Create pnum_tensor") self._pnum_tensor = tpu_ops.tpu_replicated_input( self._physical_to_logical, name="pnum_constants") return self._pnum_tensor
def get_replica_id(): """Returns an id number for the current replica, counting from 0.""" # This code is based on TensorTracer._add_replica_id_to_graph(). # I'm assuming replicas and shards are always equal until someone tells me # different. num_replicas = get_num_tpu_shards() if not num_replicas: return None with tf.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. return tpu_ops.tpu_replicated_input(list(range(num_replicas)), name="replica_id")
def _add_replica_id_to_graph(self, num_replicas, result_tensor): """Adds nodes for computing the replica ID to the graph.""" if not num_replicas: self._replica_id = 'unknown' return result_tensor self._num_replicas = num_replicas with ops.control_dependencies(None): # Uses None as dependency to run outside of TPU graph rewrites. self._replica_id = tpu_ops.tpu_replicated_input( list(range(self._num_replicas)), name='tt_replica_id') use_replica_id = array_ops.identity(self._replica_id).op with ops.control_dependencies([use_replica_id]): # Adds a control dependency from the result_tensor to # the replica_id to ensure that replica_id will be added to the graph. return array_ops.identity(result_tensor)
def replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: The name of the operator. Returns: A list of lists of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ if name is None: name = "TPUReplicate" inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist(), "computation_shape": device_assignment.computation_shape.tolist() } if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError( "tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() with ops.name_scope(name, "replicate"): # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( num_replicas=num_replicas, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators # tagged with an attribute _tpu_replicated_input to identify the # replicated inputs. # pylint: disable=protected-access with graph._attr_scope({ "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True) }): computation_inputs = [ array_ops.identity( x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [ o for o in outputs if isinstance(o, ops.Operation) ] output_tensors = [ o for o in outputs if not isinstance(o, ops.Operation) ] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors finally: context.report_unsupported_operations() context.Exit() # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [ tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity) ] with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ control_flow_ops.no_op(name="%s_shard_%d" % (name, i)) for i in range(num_replicas) ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [[ array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity) ] for replica in xrange(num_replicas)]
def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None, use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile and execute output tensor. In the generated graph the compile op feeds into the execute op and no additional compilation is incurred when running the compile op before the execute op. The compile op returns additional information about the compilation but does not return the compiled program. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist() } # TODO(phawkins): remove this case after the forward compatibility window # expires on 2018-10-5. if api_compat.forward_compatible(2018, 10, 5): metadata_kwargs["num_cores_per_replica"] = ( device_assignment.num_cores_per_replica) else: metadata_kwargs["computation_shape"] = [ device_assignment.num_cores_per_replica ] if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError( "tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = xla.check_function_argument_count(computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str([i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata(num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # Add identity ops so even unused inputs are "consumed" by the # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. # Partitioned variables is not supported (b/112311320). vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource saved_custom_getter = vscope.custom_getter def custom_getter(getter, name, *args, **kwargs): """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] if partitioner is not None: kwargs["partitioner"] = None logging.warning( "Partitioned variables are not supported on TPU. Got " "`partitioner` that is {} for variable {}. " "Setting `partitioner` to `None`.".format( partitioner, name)) if saved_custom_getter is None: return getter(name, *args, **kwargs) else: return saved_custom_getter(getter, name, *args, **kwargs) vscope.set_use_resource(True) vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) # Append `no_op` here so that fetching any return value of this function # will trigger TPUExecute node. outputs += (control_flow_ops.no_op(), ) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [ o for o in outputs if isinstance(o, ops.Operation) ] output_tensors = [ o for o in outputs if not isinstance(o, ops.Operation) ] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() host_compute_core = context.HostComputeCore() if host_compute_core: attr_value = attr_value_pb2.AttrValue() attr_value.list.s.extend( [compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [ tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity) ] with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() op = compile_status.op attr_value = attr_value_pb2.AttrValue( s=compat.as_bytes(cluster_name)) op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access else: compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ compile_status, [ control_flow_ops.no_op(name="shard_%d" % i) for i in range(num_replicas) ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ compile_status, [[ array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity) ] for replica in xrange(num_replicas)] ]
def shard_id(): """Get an integer scalar Tensor indicating the index of the current shard.""" # Prevent the TPU compiler from rewriting this part of the graph. with tf.control_dependencies(None): return tpu_ops.tpu_replicated_input(list(range(num_tpu_shards())))
def split_compile_and_replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None, use_tpu=True): """Builds graph operators that runs compilation and replicated computation. This is a lower level interface than replicate that returns a separate compile and execute output tensor. In the generated graph the compile op feeds into the execute op and no additional compilation is incurred when running the compile op before the execute op. The compile op returns additional information about the compilation but does not return the compiled program. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU backends. Currently, only supports a default placement (computation is placed on GPU if one is available, and on CPU if not). Returns: A list of lists with the first list corresponding to the compile op and the second a list of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist(), "computation_shape": device_assignment.computation_shape.tolist() } if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % ( input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str( [i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") context = TPUReplicateContext( name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # Add identity ops so even unused inputs are "consumed" by the # computation. This is to avoid orphaned TPUReplicatedInput nodes. # TODO(phawkins): consider instead pruning unused TPUReplicatedInput # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs) ] # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. # Partitioned variables is not supported (b/112311320). def custom_getter(getter, name, *args, **kwargs): """Variables on TPU have a few restrictions.""" partitioner = kwargs["partitioner"] if partitioner is not None: kwargs["partitioner"] = None logging.warning( "Partitioned variables are not supported on TPU. Got " "`partitioner` that is {} for variable {}. " "Setting `partitioner` to `None`." .format(partitioner, name)) return getter(name, *args, **kwargs) vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource saved_custom_getter = vscope.custom_getter vscope.set_use_resource(True) vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) # Append `no_op` here so that fetching any return value of this function # will trigger TPUExecute node. outputs += (control_flow_ops.no_op(),) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() host_compute_core = context.HostComputeCore() if host_compute_core: attr_value = attr_value_pb2.AttrValue() attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity)] with ops.control_dependencies([metadata]): if use_tpu: compile_status = tpu_ops.tpu_compilation_result() op = compile_status.op attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)) op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access else: compile_status = control_flow_ops.no_op(name="compilation_status") with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ compile_status, [ control_flow_ops.no_op(name="shard_%d" % i) for i in range(num_replicas) ] ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ compile_status, [[ array_ops.identity( outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity) ] for replica in xrange(num_replicas)] ]
def replicate(computation, inputs=None, infeed_queue=None, global_tpu_id=None, name=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: a Python function that builds the computation to replicate. inputs: a list of lists of input tensors or None (equivalent to [[]]), indexed by [replica_num][input_num]. All replicas must have the same number of inputs. infeed_queue: if not None, the InfeedQueue from which to append a tuple of arguments as inputs to computation. global_tpu_id: if not None, a Numpy 2D array indicating the global id of each TPU device in the system. The outer dimension of the array is host task id, and the inner dimension is device ordinal, so e.g., global_tpu_id[x][y] indicates the global id of device /task:x/device:TPU_NODE:y. name: name of the operator. Returns: A list of lists of output tensors, indexed by [replica_num][output_num]. Raises: ValueError: if all replicas do not have equal numbers of input tensors. ValueError: if the number of inputs per replica does not match the number of formal parameters to `computation`. """ if name is None: name = "TPUReplicate" inputs = [[]] if inputs is None else inputs if global_tpu_id is not None: # Turn the Numpy array into a flattened list. global_tpu_id = global_tpu_id.flatten().tolist() if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % ( input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str( [i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() with ops.name_scope(name, "replicate"): # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( num_replicas=num_replicas, global_tpu_id=global_tpu_id) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators # tagged with an attribute _tpu_replicated_input to identify the # replicated inputs. # pylint: disable=protected-access with graph._attr_scope({"_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)}): computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs)] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors finally: context.Exit() # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity)] with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ control_flow_ops.no_op(name="%s_shard_%d" % (name, i)) for i in range(num_replicas) ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ [array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity)] for replica in xrange(num_replicas) ]
def get_tpu_replica_id(): with tf.control_dependencies(None): return tpu_ops.tpu_replicated_input(list(range(num_tpu_replicas())))
def replicate(computation, inputs=None, infeed_queue=None, device_assignment=None, name=None): """Builds a graph operator that runs a replicated TPU computation. Args: computation: A Python function that builds the computation to replicate. inputs: A list of lists of input tensors or `None` (equivalent to `[[]]`), indexed by `[replica_num][input_num]`. All replicas must have the same number of inputs. infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple of arguments as inputs to computation. device_assignment: If not `None`, a `DeviceAssignment` describing the mapping between logical cores in the computation with physical cores in the TPU topology. Uses a default device assignment if `None`. The `DeviceAssignment` may be omitted if each replica of the computation uses only one core, and there is either only one replica, or the number of replicas is equal to the number of cores in the TPU system. name: (Deprecated) Does nothing. Returns: A list of lists of output tensors, indexed by `[replica_num][output_num]`. Raises: ValueError: If all replicas do not have equal numbers of input tensors. ValueError: If the number of inputs per replica does not match the number of formal parameters to `computation`. """ del name inputs = [[]] if inputs is None else inputs metadata_kwargs = {} if device_assignment is not None: # Turn the Numpy array into a flattened list so we can pass it as an # operator attribute. metadata_kwargs = { "topology": device_assignment.topology.serialized(), "device_assignment": device_assignment.core_assignment.flatten().tolist(), "computation_shape": device_assignment.computation_shape.tolist() } if ((not isinstance(inputs, list)) or any(not isinstance(inp, (list, tuple)) for inp in inputs)): raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") num_replicas = len(inputs) # No replicas? Nothing to do. if num_replicas == 0: return [] # Converts inputs to Tensors. inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] # Verifies that all replicas have matching numbers and types of inputs input_types = [x.dtype for x in inputs[0]] input_arity = len(input_types) for i in range(num_replicas): if len(inputs[i]) != input_arity: raise ValueError("Replicas must have the same number of inputs. " "Replica 0 had {} inputs, replica {} had {} " "inputs.".format(input_arity, i, len(inputs[i]))) types = [x.dtype for x in inputs[i]] if types != input_types: raise ValueError( "Replicas must have matching input types. Replica 0 had " "input types {}, replica {} had input types {}".format( input_types, i, types)) arg_error = tpu_function.check_function_argument_count( computation, input_arity, infeed_queue) if arg_error is not None: if infeed_queue is None: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s, but the computation needs %s" % ( input_arity, str([i.name for i in inputs[0]]), arg_error)) else: raise TypeError( "Supplied computation cannot be called with the specified inputs. " "You specified %d inputs: %s and %d additional inputs from infeed," " but the computation needs %s" % (input_arity, str( [i.name for i in inputs[0]]), infeed_queue.number_of_tuple_elements, arg_error)) graph = ops.get_default_graph() # Fan-in: Builds a TPUReplicatedInput node for each input. computation_inputs = [] for i in range(0, input_arity): replicas = [inputs[replica][i] for replica in xrange(num_replicas)] computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) context = TPUReplicateContext( name=graph.unique_name("cluster"), num_replicas=num_replicas) try: context.Enter() metadata = tpu_ops.tpu_replicate_metadata( num_replicas=num_replicas, **metadata_kwargs) with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators # tagged with an attribute _tpu_replicated_input to identify the # replicated inputs. # pylint: disable=protected-access with graph._attr_scope({"_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)}): computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) for i, x in enumerate(computation_inputs)] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the # computation's inputs. if infeed_queue is not None: infeed_queue.set_number_of_shards(num_replicas) for t in infeed_queue.generate_dequeue_op(): computation_inputs.append(t) # Only resource variables work inside a TPU computation, so turn on # resource variables for the computation. # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. vscope = variable_scope.get_variable_scope() saved_use_resource = vscope.use_resource vscope.set_use_resource(True) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) try: with ops.device(core(0)): outputs = [ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) for o in outputs ] except Exception as e: raise ValueError( "TPU function return values must all either be Operations or " "convertible to Tensors. Got '%s'" % str(e)) # Separates the returned Operations and Tensors. output_operations = [o for o in outputs if isinstance(o, ops.Operation)] output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] if outputs != output_tensors + output_operations: raise ValueError( "TPU functions must return zero-or more Tensor values followed by " "zero or more Operations.") output_arity = len(output_tensors) # Wraps outputs in Identity ops. Otherwise a replicated input copied # straight to an output would bypass the replicate(). This would be bad # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. new_output_tensors = [] for t in output_tensors: with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors finally: context.report_unsupported_operations() context.Exit() host_compute_core = context.HostComputeCore() if host_compute_core: attr_value = attr_value_pb2.AttrValue() attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, name="output{}".format(i)) for i in xrange(output_arity)] with ops.control_dependencies(output_operations): if output_arity == 0: # Returns a list of NoOps dependent on the replication Op, indexed by # [replica_num]. return [ control_flow_ops.no_op(name="shard_%d" % i) for i in range(num_replicas) ] else: # Wraps the outputs in identity operators so the names of any possible # `fetch` nodes are preserved by the replication rewrite. return [ [array_ops.identity(outputs[out][replica], name="output_%d_shard_%d" % (out, replica)) for out in xrange(output_arity)] for replica in xrange(num_replicas) ]