示例#1
0
    def testCreateSlotWithCustomSplitXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom split XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.split(slot,
                                          split_dimension=0,
                                          num_devices=4,
                                          use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER,
                                        tile_assignment_dimensions=[4],
                                        tile_assignment_devices=range(4)))
示例#2
0
    def testCreateSlotWithCustomReplicatedXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom replicated XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.replicate(slot, use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(
                    type=xla_data_pb2.OpSharding.REPLICATED))
 def split_helper(tensor):
     self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
     split_tensor = xla_sharding.split(tensor, 2, 3)
     self.assertIsInstance(split_tensor, ops.Tensor)
     split_sharding = xla_sharding.get_tensor_sharding(split_tensor)
     split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
     expected_shape = [1, 1, 3]
     self.assertEqual(expected_shape, split_shape)
     return split_tensor
 def replicate_helper(tensor):
   replicated_tensor = xla_sharding.replicate(
       array_ops.ones([4, 5, 6], dtype=dtypes.float32))
   self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
   replicated_sharding = xla_sharding.get_tensor_sharding(replicated_tensor)
   self.assertIsNotNone(replicated_sharding)
   self.assertIsNone(
       xla_sharding.get_sharding_tile_shape(replicated_sharding))
   return replicated_tensor
 def tile_helper(tensor):
     self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
     tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
     self.assertIsInstance(tiled_tensor, ops.Tensor)
     tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
     tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
     # This is the shape of the tile assignment [2, 1, 6]
     expected_shape = [3]
     self.assertEqual(expected_shape, tile_shape)
     return tiled_tensor
 def testCopyXlaSharding(self):
     ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
     v = variables.Variable(_Repeat(10.0, 2), name="v")
     self.assertIsNone(xla_sharding.get_tensor_sharding(v))
     v = xla_sharding.mesh_split(v,
                                 np.array([0, 1]), [0],
                                 use_sharding_op=False)
     self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
     self.evaluate(variables.global_variables_initializer())
     ema.apply([v])
     avg = ema.average(v)
     self.assertEqual(xla_sharding.get_tensor_sharding(v),
                      xla_sharding.get_tensor_sharding(avg))
示例#7
0
 def testCreateSlotFromVariableCopyXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         slot = slot_creator.create_slot(v,
                                         v.initialized_value(),
                                         name="slot",
                                         copy_xla_sharding=True)
         self.assertEqual(xla_sharding.get_tensor_sharding(v),
                          xla_sharding.get_tensor_sharding(slot))
示例#8
0
 def testCreateZerosSlotFromVariableCopyXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         with ops.control_dependencies(None):
             slot = slot_creator.create_zeros_slot(v,
                                                   name="slot",
                                                   dtype=dtypes.float64,
                                                   copy_xla_sharding=True)
         self.assertEqual(xla_sharding.get_tensor_sharding(v),
                          xla_sharding.get_tensor_sharding(slot))
        def copy_helper(tensor):
            tensor_src = array_ops.identity(tensor)
            tensor_src = xla_sharding.split(tensor, 2, 3)
            sharding_src = xla_sharding.get_tensor_sharding(tensor_src)
            shape_src = xla_sharding.get_sharding_tile_shape(sharding_src)
            self.assertEqual([1, 1, 3], shape_src)

            tensor_dest = array_ops.identity(tensor)
            self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest))

            xla_sharding.copy_sharding(tensor_src, tensor_dest)
            sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest)
            shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest)
            self.assertEqual([1, 1, 3], shape_dest)
            return tensor_dest
示例#10
0
 def testCreateSlotWithoutXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     # The SPMD sharding annotations should not be copied since the primary
     # variable and slot variable have different ranks.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         with ops.control_dependencies(None):
             slot = slot_creator.create_slot(v,
                                             constant_op.constant(
                                                 10, name="const"),
                                             name="slot",
                                             copy_xla_sharding=True)
         self.assertIsNone(xla_sharding.get_tensor_sharding(slot))
         self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                             xla_sharding.get_tensor_sharding(slot))
示例#11
0
    def testXlaSharding(self):
        dtype = dtypes.float32
        with self.session(graph=ops.Graph()):
            # Initialize variables for numpy implementation.
            var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
            grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
            grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

            var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0")
            var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1")
            var0, var1 = [
                xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
                for v in (var0, var1)
            ]
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)

            learning_rate = lambda: 0.001

            opt = adam.AdamOptimizer(learning_rate=learning_rate)
            update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))

            self.evaluate(variables.global_variables_initializer())
            self.evaluate(update)
            # The beta accumulators are not sharded.
            beta1_power, beta2_power = opt._get_beta_accumulators()
            self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power))
            self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power))

            # Variables and slots are sharded.
            for v in (var0, var1):
                self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
                for slot_name in ("m", "v"):
                    slot = opt.get_slot(v, slot_name)
                    self.assertIsNotNone(
                        xla_sharding.get_tensor_sharding(slot))