Exemplo n.º 1
0
def MeshSplit(x,
              device_mesh,
              tensor_split_dims_mapping,
              use_sharding_op=True,
              unspecified_dims=None):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    # Apply the prefix in the context.
    tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack +
                                 tensor_split_dims_mapping)
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    if _MANUAL_MESH_DIMS.stack or unspecified_dims:
        return xla_sharding.mesh_split(
            x,
            device_mesh,
            tensor_split_dims_mapping,
            use_sharding_op=use_sharding_op,
            manual_mesh_dims=_MANUAL_MESH_DIMS.stack,
            unspecified_dims=unspecified_dims)
    # Do not include manual_mesh_dims or unspecified_dims to support legacy TF
    # versions.
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
Exemplo n.º 2
0
 def ApplyToVariable(self, variable: tf.Variable) -> tf.Variable:
     if self.is_replicated:
         return xla_sharding.replicate(variable, use_sharding_op=False)
     return xla_sharding.mesh_split(variable,
                                    self.device_mesh,
                                    self.split_dims_mapping,
                                    use_sharding_op=False)
Exemplo n.º 3
0
    def testCreateSlotWithCustomSplitXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom split XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.split(slot,
                                          split_dimension=0,
                                          num_devices=4,
                                          use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER,
                                        tile_assignment_dimensions=[4],
                                        tile_assignment_devices=range(4)))
Exemplo n.º 4
0
    def testCreateSlotWithCustomReplicatedXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom replicated XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.replicate(slot, use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(
                    type=xla_data_pb2.OpSharding.REPLICATED))
Exemplo n.º 5
0
 def ApplyToTensor(self,
                   tensor: tf.Tensor,
                   use_sharding_op: bool = True) -> tf.Tensor:
     if self.is_replicated:
         return xla_sharding.replicate(tensor,
                                       use_sharding_op=use_sharding_op)
     return xla_sharding.mesh_split(tensor,
                                    self.device_mesh,
                                    self.split_dims_mapping,
                                    use_sharding_op=use_sharding_op)
Exemplo n.º 6
0
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
Exemplo n.º 7
0
 def testCopyXlaSharding(self):
     ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
     v = variables.Variable(_Repeat(10.0, 2), name="v")
     self.assertIsNone(xla_sharding.get_tensor_sharding(v))
     v = xla_sharding.mesh_split(v,
                                 np.array([0, 1]), [0],
                                 use_sharding_op=False)
     self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
     self.evaluate(variables.global_variables_initializer())
     ema.apply([v])
     avg = ema.average(v)
     self.assertEqual(xla_sharding.get_tensor_sharding(v),
                      xla_sharding.get_tensor_sharding(avg))
Exemplo n.º 8
0
 def testCreateSlotFromVariableCopyXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         slot = slot_creator.create_slot(v,
                                         v.initialized_value(),
                                         name="slot",
                                         copy_xla_sharding=True)
         self.assertEqual(xla_sharding.get_tensor_sharding(v),
                          xla_sharding.get_tensor_sharding(slot))
Exemplo n.º 9
0
 def testCreateZerosSlotFromVariableCopyXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         with ops.control_dependencies(None):
             slot = slot_creator.create_zeros_slot(v,
                                                   name="slot",
                                                   dtype=dtypes.float64,
                                                   copy_xla_sharding=True)
         self.assertEqual(xla_sharding.get_tensor_sharding(v),
                          xla_sharding.get_tensor_sharding(slot))
Exemplo n.º 10
0
def MeshSplit(x, device_mesh, tensor_split_dims_mapping, use_sharding_op=True):
    """Wrapper of xla_sharding.mesh_split()."""
    if (not py_utils_flags.use_tpu() or tensor_split_dims_mapping is None
            or device_mesh is None or device_mesh.size <= 1):
        return x
    # Apply the prefix in the context.
    tensor_split_dims_mapping = (_MESH_SPLIT_DIM_PREFIXES.stack +
                                 tensor_split_dims_mapping)
    num_tiles = np.prod(
        [device_mesh.shape[i] for i in tensor_split_dims_mapping if i >= 0])
    if num_tiles <= 1:
        return x
    return xla_sharding.mesh_split(x,
                                   device_mesh,
                                   tensor_split_dims_mapping,
                                   use_sharding_op=use_sharding_op)
Exemplo n.º 11
0
 def testCreateSlotWithoutXlaSharding(self):
     # slot_creator is used only in optimizer V1.
     # The SPMD sharding annotations should not be copied since the primary
     # variable and slot variable have different ranks.
     with ops.Graph().as_default(), self.cached_session():
         v = variables.Variable([1.0, 2.5], name="var")
         v = xla_sharding.mesh_split(v,
                                     np.array([0, 1]), [0],
                                     use_sharding_op=False)
         with ops.control_dependencies(None):
             slot = slot_creator.create_slot(v,
                                             constant_op.constant(
                                                 10, name="const"),
                                             name="slot",
                                             copy_xla_sharding=True)
         self.assertIsNone(xla_sharding.get_tensor_sharding(slot))
         self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                             xla_sharding.get_tensor_sharding(slot))
Exemplo n.º 12
0
    def testXlaSharding(self):
        dtype = dtypes.float32
        with self.session(graph=ops.Graph()):
            # Initialize variables for numpy implementation.
            var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
            grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
            var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
            grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)

            var0 = resource_variable_ops.ResourceVariable(var0_np, name="var0")
            var1 = resource_variable_ops.ResourceVariable(var1_np, name="var1")
            var0, var1 = [
                xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
                for v in (var0, var1)
            ]
            grads0 = constant_op.constant(grads0_np)
            grads1 = constant_op.constant(grads1_np)

            learning_rate = lambda: 0.001

            opt = adam.AdamOptimizer(learning_rate=learning_rate)
            update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))

            self.evaluate(variables.global_variables_initializer())
            self.evaluate(update)
            # The beta accumulators are not sharded.
            beta1_power, beta2_power = opt._get_beta_accumulators()
            self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power))
            self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power))

            # Variables and slots are sharded.
            for v in (var0, var1):
                self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
                for slot_name in ("m", "v"):
                    slot = opt.get_slot(v, slot_name)
                    self.assertIsNotNone(
                        xla_sharding.get_tensor_sharding(slot))