예제 #1
0
    def unpack(self, parallel_tensor):
        """Unpack a parallel tensor into its components.

    Args:
      parallel_tensor: A tensor or composite tensor placed on the
        ParallelDevice.

    Returns:
      A flat list of tensors, one per `self.components`.
    """
        self._assert_eager()
        if not isinstance(parallel_tensor,
                          (ops.Tensor, composite_tensor.CompositeTensor,
                           variables.Variable)):
            raise ValueError(
                "Expected a tensor, got {}.".format(parallel_tensor))
        with ops.device(self._name):
            flattened = nest.flatten(parallel_tensor, expand_composites=True)
            unpacked = []
            for packed in flattened:
                unpacked.append(
                    tpu_ops.tpu_replicated_output(packed,
                                                  num_replicas=len(
                                                      self.components)))
        return [
            nest.pack_sequence_as(parallel_tensor,
                                  component_value,
                                  expand_composites=True)
            for component_value in zip(*unpacked)
        ]
예제 #2
0
 def _unpack_tensor(self, parallel_tensor):
   """Helper to unpack a single tensor."""
   if not isinstance(parallel_tensor, (
       ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)):
     raise ValueError(
         "Expected a tensor, got {}.".format(parallel_tensor))
   with ops.device(self._name):
     return tpu_ops.tpu_replicated_output(
         parallel_tensor, num_replicas=len(self.components))
예제 #3
0
  def unpack(self, parallel_tensor):
    """Unpack a parallel tensor into its components.

    Args:
      parallel_tensor: A tensor placed on the ParallelDevice.

    Returns:
      A flat list of tensors, one per `self.components`.
    """
    with ops.device(self._name):
      return tpu_ops.tpu_replicated_output(
          parallel_tensor, num_replicas=len(self.components))