def testCreateSlotWithCustomReplicatedXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom replicated XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.replicate(slot, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.REPLICATED))
def ApplyToVariable(self, variable: tf.Variable) -> tf.Variable: if self.is_replicated: return xla_sharding.replicate(variable, use_sharding_op=False) return xla_sharding.mesh_split(variable, self.device_mesh, self.split_dims_mapping, use_sharding_op=False)
def replicate_helper(tensor): replicated_tensor = xla_sharding.replicate( array_ops.ones([4, 5, 6], dtype=dtypes.float32)) self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor) self.assertIsNotNone(replicated_sharding) self.assertIsNone( xla_sharding.get_sharding_tile_shape(replicated_sharding)) return replicated_tensor
def ApplyToTensor(self, tensor: tf.Tensor, use_sharding_op: bool = True) -> tf.Tensor: if self.is_replicated: return xla_sharding.replicate(tensor, use_sharding_op=use_sharding_op) return xla_sharding.mesh_split(tensor, self.device_mesh, self.split_dims_mapping, use_sharding_op=use_sharding_op)
def handle(self): if save_context.in_save_context() or context.executing_eagerly(): return self._vars[0].handle if tpu_util.enclosing_tpu_context() is None: raise NotImplementedError('TPUReplicatedVariable.handle is not available ' 'outside tpu context or save context') else: with tpu_util.outside_or_skip_tpu_context(): return xla_sharding.replicate( tpu_partition_ops.tpu_partitioned_input( [v.handle for v in self._vars], partition_dim=-1))
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape_tuple=tile_shape), tile_assignment=tile_assignment)
def _experimental_replicate_to_logical_devices(self, tensor): """See `DistributionStrategy.experimental_replicate_to_logical_devices`.""" return xla_sharding.replicate(tensor, use_sharding_op=True)
def Replicate(x, use_sharding_op=True): """Wrapper of xla_sharding.replicate.""" if not py_utils_flags.use_tpu(): return x return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)
def Replicate(x, use_sharding_op=True): """Wrapper of xla_sharding.replicate.""" return xla_sharding.replicate(x, use_sharding_op=use_sharding_op)