示例#1
0
  def generate_dequeue_op(self, tpu_device=0):
    """Generate TPU dequeue ops.

    Args:
      tpu_device: The TPU device ordinal where the infeed instruction should be
        placed.

    Returns:
      A list of Outputs corresponding to a partition of infeed dequeued
      into XLA, suitable for use within a replicated block.

    Raises:
      ValueError: if the types or shapes of the tuple elements have not been
      set; or if a dequeue op has already been generated.
    """
    self.freeze()
    if self._generated_dequeue_op:
      raise ValueError("Can't generate two dequeue Ops from the same queue")
    self._generated_dequeue_op = True
    full_name = "%s/dequeue" % self._name
    sharded_shapes = [
        policy.get_sharded_shape(shape)
        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
    ]
    with ops.device(tpu.core(tpu_device)):
      values = tpu_ops.infeed_dequeue_tuple(
          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
    return tag_sharding_attribute_for_dequeued_tensors(
        values, self._input_partition_dims)
示例#2
0
 def _OutfeedEnqueue(self, per_example_tensors):
     if not per_example_tensors:
         return tf.no_op()
     per_example_tensors = py_utils.NestedMap(per_example_tensors)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
示例#3
0
  def generate_dequeue_op(self, tpu_device=0):
    """Generate TPU dequeue ops.

    Args:
      tpu_device: The TPU device ordinal where the infeed instruction should be
        placed.

    Returns:
      A list of Outputs corresponding to a partition of infeed dequeued
      into XLA, suitable for use within a replicated block.

    Raises:
      ValueError: if the types or shapes of the tuple elements have not been
      set; or if a dequeue op has already been generated.
    """
    self.freeze()
    if self._generated_dequeue_op:
      raise ValueError("Can't generate two dequeue Ops from the same queue")
    self._generated_dequeue_op = True
    full_name = "%s/dequeue" % self._name
    sharded_shapes = [
        policy.get_sharded_shape(shape)
        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
    ]
    with ops.device(tpu.core(tpu_device)):
      values = tpu_ops.infeed_dequeue_tuple(
          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
    return tag_sharding_attribute_for_dequeued_tensors(
        values, self._input_partition_dims)
示例#4
0
    def devices(self):
        distribute_lib.require_replica_context(self)
        ds = self._strategy
        replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)

        if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
            # TODO(cjfj): Return other devices when model parallelism is supported.
            return (tpu.core(0), )
        else:
            return (ds.extended.worker_devices[replica_id], )
示例#5
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     input_batch = self._task.input.TpuDequeueBatch()
     metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
示例#6
0
  def devices(self):
    distribute_lib.require_replica_context(self)
    ds = self._strategy
    replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)

    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
      # TODO(cjfj): Return other devices when model parallelism is supported.
      return (tpu.core(0),)
    else:
      return (ds.extended.worker_devices[replica_id],)
示例#7
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     with py_utils.OpportunisticVariableReuseScope(True):
         self._model.InstantiateVariables()
         input_batch = self._task.input.TpuDequeueBatch()
         metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
示例#8
0
  def generate_dequeue_op(self, tpu_device=0):
    """Generates the device-side Op to dequeue a tuple from the queue.

    Implicitly freezes the queue configuration if it is not already
    frozen, which will raise errors if the shapes and types have not
    been fully specified.

    Args:
      tpu_device: The TPU device ordinal where the infeed instruction should be
        placed. If None, no explicit placement will be performed, and it is up
        to the user to call this API from within a proper TPU device scope.
        The XLA code will fail if the TPU dequeue instruction is not bound to
        any device.

    Returns:
      A list of Outputs corresponding to a shard of infeed dequeued
      into XLA, suitable for use within a replicated block.

    Raises:
      ValueError: if the types or shapes of the tuple elements have not been
      set; or if a dequeue op has already been generated.
    """
    self.freeze()
    if self._generated_dequeue_op:
      raise ValueError("Can't generate two dequeue Ops from the same queue")
    self._generated_dequeue_op = True
    full_name = "%s/dequeue" % self._name
    sharded_shapes = [
        policy.get_unpartitioned_shape(policy.get_sharded_shape(shape))
        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
    ]
    if tpu_device is not None:
      with ops.device(tpu.core(tpu_device)):
        dequeue_op = tpu_ops.infeed_dequeue_tuple(
            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
    else:
      dequeue_op = tpu_ops.infeed_dequeue_tuple(
          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
    if self._number_of_partitions <= 1:
      return dequeue_op
    partitions = [
        policy.get_unpartitioned_shape([1] * shape.ndims).as_list()
        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
    ]
    return tag_sharding_attribute_for_dequeued_tensors(dequeue_op, partitions)
示例#9
0
    def experimental_logical_device(self, logical_device_id):
        """Places variables and ops on the specified logical device."""
        num_logical_devices_per_replica = self._tpu_devices.shape[1]
        if logical_device_id >= num_logical_devices_per_replica:
            raise ValueError(
                "`logical_device_id` not in range (was {}, but there are only {} "
                "logical devices per replica).".format(
                    logical_device_id, num_logical_devices_per_replica))

        self._logical_device_stack.append(logical_device_id)
        try:
            if values._enclosing_tpu_context() is None:  # pylint: disable=protected-access
                yield
            else:
                with ops.device(tpu.core(logical_device_id)):
                    yield
        finally:
            self._logical_device_stack.pop()
示例#10
0
  def generate_dequeue_op(self, tpu_device=0):
    """Generates the device-side Op to dequeue a tuple from the queue.

    Implicitly freezes the queue configuration if it is not already
    frozen, which will raise errors if the shapes and types have not
    been fully specified.

    Args:
      tpu_device: The TPU device ordinal where the infeed instruction should be
        placed. If None, no explicit placement will be performed, and it is up
        to the user to call this API from within a proper TPU device scope.
        The XLA code will fail if the TPU dequeue instruction is not bound to
        any device.

    Returns:
      A list of Outputs corresponding to a shard of infeed dequeued
      into XLA, suitable for use within a replicated block.

    Raises:
      ValueError: if the types or shapes of the tuple elements have not been
      set; or if a dequeue op has already been generated.
    """
    self.freeze()
    if self._generated_dequeue_op:
      raise ValueError("Can't generate two dequeue Ops from the same queue")
    self._generated_dequeue_op = True
    full_name = "%s/dequeue" % self._name
    sharded_shapes = [
        policy.get_sharded_shape(shape)
        for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies)
    ]
    if tpu_device is not None:
      with ops.device(tpu.core(tpu_device)):
        return tpu_ops.infeed_dequeue_tuple(
            dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)
    else:
      return tpu_ops.infeed_dequeue_tuple(
          dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name)