Ejemplo n.º 1
0
    def testCreateSlotWithCustomReplicatedXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom replicated XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.replicate(slot, use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(
                    type=xla_data_pb2.OpSharding.REPLICATED))
Ejemplo n.º 2
0
 def ApplyToVariable(self, variable: tf.Variable) -> tf.Variable:
     if self.is_replicated:
         return xla_sharding.replicate(variable, use_sharding_op=False)
     return xla_sharding.mesh_split(variable,
                                    self.device_mesh,
                                    self.split_dims_mapping,
                                    use_sharding_op=False)
 def replicate_helper(tensor):
   replicated_tensor = xla_sharding.replicate(
       array_ops.ones([4, 5, 6], dtype=dtypes.float32))
   self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
   replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
   self.assertIsNotNone(replicated_sharding)
   self.assertIsNone(
       xla_sharding.get_sharding_tile_shape(replicated_sharding))
   return replicated_tensor
Ejemplo n.º 4
0
 def ApplyToTensor(self,
                   tensor: tf.Tensor,
                   use_sharding_op: bool = True) -> tf.Tensor:
     if self.is_replicated:
         return xla_sharding.replicate(tensor,
                                       use_sharding_op=use_sharding_op)
     return xla_sharding.mesh_split(tensor,
                                    self.device_mesh,
                                    self.split_dims_mapping,
                                    use_sharding_op=use_sharding_op)
  def handle(self):
    if save_context.in_save_context() or context.executing_eagerly():
      return self._vars[0].handle

    if tpu_util.enclosing_tpu_context() is None:
      raise NotImplementedError('TPUReplicatedVariable.handle is not available '
                                'outside tpu context or save context')
    else:
      with tpu_util.outside_or_skip_tpu_context():
        return xla_sharding.replicate(
            tpu_partition_ops.tpu_partitioned_input(
                [v.handle for v in self._vars], partition_dim=-1))
Ejemplo n.º 6
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
Ejemplo n.º 7
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
Ejemplo n.º 8
0
    def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
        """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
        if dims is None:
            return xla_sharding.replicate(tensor)
        elif np.prod(dims) == 1:
            return xla_sharding.assign_device(tensor, 0)
        else:
            tile_shape = np.array(tensor.shape.as_list()) // dims
            tile_assignment = np.arange(np.prod(dims)).reshape(dims)
            return xla_sharding.tile(
                tensor=tensor,
                tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
                    dtype=np.dtype(tensor.dtype.as_numpy_dtype),
                    shape_tuple=tile_shape),
                tile_assignment=tile_assignment)
Ejemplo n.º 9
0
  def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
    """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
    if dims is None:
      return xla_sharding.replicate(tensor)
    elif np.prod(dims) == 1:
      return xla_sharding.assign_device(tensor, 0)
    else:
      tile_shape = np.array(tensor.shape.as_list()) // dims
      tile_assignment = np.arange(np.prod(dims)).reshape(dims)
      return xla_sharding.tile(
          tensor=tensor,
          tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
              dtype=np.dtype(tensor.dtype.as_numpy_dtype),
              shape_tuple=tile_shape),
          tile_assignment=tile_assignment)
Ejemplo n.º 10
0
 def _experimental_replicate_to_logical_devices(self, tensor):
     """See `DistributionStrategy.experimental_replicate_to_logical_devices`."""
     return xla_sharding.replicate(tensor, use_sharding_op=True)
Ejemplo n.º 11
0
def Replicate(x, use_sharding_op=True):
    """Wrapper of xla_sharding.replicate."""
    if not py_utils_flags.use_tpu():
        return x
    return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)
Ejemplo n.º 12
0
def Replicate(x, use_sharding_op=True):
    """Wrapper of xla_sharding.replicate."""
    return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)