def generate_dequeue_op(self, tpu_device=0): """Generate TPU dequeue ops. Args: tpu_device: The TPU device ordinal where the infeed instruction should be placed. Returns: A list of Outputs corresponding to a partition of infeed dequeued into XLA, suitable for use within a replicated block. Raises: ValueError: if the types or shapes of the tuple elements have not been set; or if a dequeue op has already been generated. """ self.freeze() if self._generated_dequeue_op: raise ValueError("Can't generate two dequeue Ops from the same queue") self._generated_dequeue_op = True full_name = "%s/dequeue" % self._name sharded_shapes = [ policy.get_sharded_shape(shape) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] with ops.device(tpu.core(tpu_device)): values = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) return tag_sharding_attribute_for_dequeued_tensors( values, self._input_partition_dims)
def _OutfeedEnqueue(self, per_example_tensors): if not per_example_tensors: return tf.no_op() per_example_tensors = py_utils.NestedMap(per_example_tensors) device = tpu.core(0) if self.spmd else '' with tf.device(device): return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
def devices(self): distribute_lib.require_replica_context(self) ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. # TODO(cjfj): Return other devices when model parallelism is supported. return (tpu.core(0), ) else: return (ds.extended.worker_devices[replica_id], )
def _DecodeStep(): """Decode call to be compiled for TPU.""" input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def devices(self): distribute_lib.require_replica_context(self) ds = self._strategy replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. # TODO(cjfj): Return other devices when model parallelism is supported. return (tpu.core(0),) else: return (ds.extended.worker_devices[replica_id],)
def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. Args: tpu_device: The TPU device ordinal where the infeed instruction should be placed. If None, no explicit placement will be performed, and it is up to the user to call this API from within a proper TPU device scope. The XLA code will fail if the TPU dequeue instruction is not bound to any device. Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. Raises: ValueError: if the types or shapes of the tuple elements have not been set; or if a dequeue op has already been generated. """ self.freeze() if self._generated_dequeue_op: raise ValueError("Can't generate two dequeue Ops from the same queue") self._generated_dequeue_op = True full_name = "%s/dequeue" % self._name sharded_shapes = [ policy.get_unpartitioned_shape(policy.get_sharded_shape(shape)) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] if tpu_device is not None: with ops.device(tpu.core(tpu_device)): dequeue_op = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) else: dequeue_op = tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) if self._number_of_partitions <= 1: return dequeue_op partitions = [ policy.get_unpartitioned_shape([1] * shape.ndims).as_list() for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
def experimental_logical_device(self, logical_device_id): """Places variables and ops on the specified logical device.""" num_logical_devices_per_replica = self._tpu_devices.shape[1] if logical_device_id >= num_logical_devices_per_replica: raise ValueError( "`logical_device_id` not in range (was {}, but there are only {} " "logical devices per replica).".format( logical_device_id, num_logical_devices_per_replica)) self._logical_device_stack.append(logical_device_id) try: if values._enclosing_tpu_context() is None: # pylint: disable=protected-access yield else: with ops.device(tpu.core(logical_device_id)): yield finally: self._logical_device_stack.pop()
def generate_dequeue_op(self, tpu_device=0): """Generates the device-side Op to dequeue a tuple from the queue. Implicitly freezes the queue configuration if it is not already frozen, which will raise errors if the shapes and types have not been fully specified. Args: tpu_device: The TPU device ordinal where the infeed instruction should be placed. If None, no explicit placement will be performed, and it is up to the user to call this API from within a proper TPU device scope. The XLA code will fail if the TPU dequeue instruction is not bound to any device. Returns: A list of Outputs corresponding to a shard of infeed dequeued into XLA, suitable for use within a replicated block. Raises: ValueError: if the types or shapes of the tuple elements have not been set; or if a dequeue op has already been generated. """ self.freeze() if self._generated_dequeue_op: raise ValueError("Can't generate two dequeue Ops from the same queue") self._generated_dequeue_op = True full_name = "%s/dequeue" % self._name sharded_shapes = [ policy.get_sharded_shape(shape) for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) ] if tpu_device is not None: with ops.device(tpu.core(tpu_device)): return tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) else: return tpu_ops.infeed_dequeue_tuple( dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)