def unpack(self, parallel_tensor): """Unpack a parallel tensor into its components. Args: parallel_tensor: A tensor or composite tensor placed on the ParallelDevice. Returns: A flat list of tensors, one per `self.components`. """ self._assert_eager() if not isinstance(parallel_tensor, (ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): raise ValueError( "Expected a tensor, got {}.".format(parallel_tensor)) with ops.device(self._name): flattened = nest.flatten(parallel_tensor, expand_composites=True) unpacked = [] for packed in flattened: unpacked.append( tpu_ops.tpu_replicated_output(packed, num_replicas=len( self.components))) return [ nest.pack_sequence_as(parallel_tensor, component_value, expand_composites=True) for component_value in zip(*unpacked) ]
def _unpack_tensor(self, parallel_tensor): """Helper to unpack a single tensor.""" if not isinstance(parallel_tensor, ( ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): raise ValueError( "Expected a tensor, got {}.".format(parallel_tensor)) with ops.device(self._name): return tpu_ops.tpu_replicated_output( parallel_tensor, num_replicas=len(self.components))
def unpack(self, parallel_tensor): """Unpack a parallel tensor into its components. Args: parallel_tensor: A tensor placed on the ParallelDevice. Returns: A flat list of tensors, one per `self.components`. """ with ops.device(self._name): return tpu_ops.tpu_replicated_output( parallel_tensor, num_replicas=len(self.components))