def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True, unspecified_dims=None): """Wrapper of xla_sharding.mesh_split().""" if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None or device_mesh is None or device_mesh.size <= 1): return x # Apply the prefix in the context. tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack + tensor_split_dims_mapping) num_tiles = np.prod( [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0]) if num_tiles <= 1: return x if _MANUAL_MESH_DIMS.stack or unspecified_dims: return xla_sharding.mesh_split( x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op, manual_mesh_dims=_MANUAL_MESH_DIMS.stack, unspecified_dims=unspecified_dims) # Do not include manual_mesh_dims or unspecified_dims to support legacy TF # versions. return xla_sharding.mesh_split(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op)
def ApplyToVariable(self, variable: tf.Variable) -> tf.Variable: if self.is_replicated: return xla_sharding.replicate(variable, use_sharding_op=False) return xla_sharding.mesh_split(variable, self.device_mesh, self.split_dims_mapping, use_sharding_op=False)
def testCreateSlotWithCustomSplitXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom split XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.split(slot, split_dimension=0, num_devices=4, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=[4], tile_assignment_devices=range(4)))
def testCreateSlotWithCustomReplicatedXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom replicated XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.replicate(slot, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.REPLICATED))
def ApplyToTensor(self, tensor: tf.Tensor, use_sharding_op: bool = True) -> tf.Tensor: if self.is_replicated: return xla_sharding.replicate(tensor, use_sharding_op=use_sharding_op) return xla_sharding.mesh_split(tensor, self.device_mesh, self.split_dims_mapping, use_sharding_op=use_sharding_op)
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True): """Wrapper of xla_sharding.mesh_split().""" if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None or device_mesh is None or device_mesh.size <= 1): return x num_tiles = np.prod( [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0]) if num_tiles <= 1: return x return xla_sharding.mesh_split(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op)
def testCopyXlaSharding(self): ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg") v = variables.Variable(_Repeat(10.0, 2), name="v") self.assertIsNone(xla_sharding.get_tensor_sharding(v)) v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) self.evaluate(variables.global_variables_initializer()) ema.apply([v]) avg = ema.average(v) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(avg))
def testCreateSlotFromVariableCopyXlaSharding(self): # slot_creator is used only in optimizer V1. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) slot = slot_creator.create_slot(v, v.initialized_value(), name="slot", copy_xla_sharding=True) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def testCreateZerosSlotFromVariableCopyXlaSharding(self): # slot_creator is used only in optimizer V1. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True): """Wrapper of xla_sharding.mesh_split().""" if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None or device_mesh is None or device_mesh.size <= 1): return x # Apply the prefix in the context. tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack + tensor_split_dims_mapping) num_tiles = np.prod( [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0]) if num_tiles <= 1: return x return xla_sharding.mesh_split(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=use_sharding_op)
def testCreateSlotWithoutXlaSharding(self): # slot_creator is used only in optimizer V1. # The SPMD sharding annotations should not be copied since the primary # variable and slot variable have different ranks. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_slot(v, constant_op.constant( 10, name="const"), name="slot", copy_xla_sharding=True) self.assertIsNone(xla_sharding.get_tensor_sharding(slot)) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def testXlaSharding(self): dtype = dtypes.float32 with self.session(graph=ops.Graph()): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0") var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1") var0, var1 = [ xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) for v in (var0, var1) ] grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) learning_rate = lambda: 0.001 opt = adam.AdamOptimizer(learning_rate=learning_rate) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) self.evaluate(update) # The beta accumulators are not sharded. beta1_power, beta2_power = opt._get_beta_accumulators() self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power)) self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power)) # Variables and slots are sharded. for v in (var0, var1): self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) for slot_name in ("m", "v"): slot = opt.get_slot(v, slot_name) self.assertIsNotNone( xla_sharding.get_tensor_sharding(slot))