def testCreateSlotWithCustomSplitXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom split XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.split(slot, split_dimension=0, num_devices=4, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=[4], tile_assignment_devices=range(4)))
def testCreateSlotWithCustomReplicatedXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom replicated XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.replicate(slot, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.REPLICATED))
def split_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) split_tensor = xla_sharding.split(tensor, 2, 3) self.assertIsInstance(split_tensor, ops.Tensor) split_sharding = xla_sharding.get_tensor_sharding(split_tensor) split_shape = xla_sharding.get_sharding_tile_shape(split_sharding) expected_shape = [1, 1, 3] self.assertEqual(expected_shape, split_shape) return split_tensor
def replicate_helper(tensor): replicated_tensor = xla_sharding.replicate( array_ops.ones([4, 5, 6], dtype=dtypes.float32)) self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor) self.assertIsNotNone(replicated_sharding) self.assertIsNone( xla_sharding.get_sharding_tile_shape(replicated_sharding)) return replicated_tensor
def tile_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6])) self.assertIsInstance(tiled_tensor, ops.Tensor) tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor) tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding) # This is the shape of the tile assignment [2, 1, 6] expected_shape = [3] self.assertEqual(expected_shape, tile_shape) return tiled_tensor
def testCopyXlaSharding(self): ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg") v = variables.Variable(_Repeat(10.0, 2), name="v") self.assertIsNone(xla_sharding.get_tensor_sharding(v)) v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) self.evaluate(variables.global_variables_initializer()) ema.apply([v]) avg = ema.average(v) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(avg))
def testCreateSlotFromVariableCopyXlaSharding(self): # slot_creator is used only in optimizer V1. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) slot = slot_creator.create_slot(v, v.initialized_value(), name="slot", copy_xla_sharding=True) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def testCreateZerosSlotFromVariableCopyXlaSharding(self): # slot_creator is used only in optimizer V1. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) self.assertEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def copy_helper(tensor): tensor_src = array_ops.identity(tensor) tensor_src = xla_sharding.split(tensor, 2, 3) sharding_src = xla_sharding.get_tensor_sharding(tensor_src) shape_src = xla_sharding.get_sharding_tile_shape(sharding_src) self.assertEqual([1, 1, 3], shape_src) tensor_dest = array_ops.identity(tensor) self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest)) xla_sharding.copy_sharding(tensor_src, tensor_dest) sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest) shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest) self.assertEqual([1, 1, 3], shape_dest) return tensor_dest
def testCreateSlotWithoutXlaSharding(self): # slot_creator is used only in optimizer V1. # The SPMD sharding annotations should not be copied since the primary # variable and slot variable have different ranks. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_slot(v, constant_op.constant( 10, name="const"), name="slot", copy_xla_sharding=True) self.assertIsNone(xla_sharding.get_tensor_sharding(slot)) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot))
def testXlaSharding(self): dtype = dtypes.float32 with self.session(graph=ops.Graph()): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0") var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1") var0, var1 = [ xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) for v in (var0, var1) ] grads0 = constant_op.constant(grads0_np) grads1 = constant_op.constant(grads1_np) learning_rate = lambda: 0.001 opt = adam.AdamOptimizer(learning_rate=learning_rate) update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) self.evaluate(variables.global_variables_initializer()) self.evaluate(update) # The beta accumulators are not sharded. beta1_power, beta2_power = opt._get_beta_accumulators() self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power)) self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power)) # Variables and slots are sharded. for v in (var0, var1): self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) for slot_name in ("m", "v"): slot = opt.get_slot(v, slot_name) self.assertIsNotNone( xla_sharding.get_tensor_sharding(slot))