예제 #1
0
파일: ops_test.py 프로젝트: trantorznh/mesh
  def testMeshImpl(self):
    shape = mtf.Shape([mtf.Dimension("batch", 4),
                       mtf.Dimension("model", 8)])
    layout_rules = mtf.LayoutRules([("batch", "batch"),
                                    ("d_ff", "model"),
                                    ("heads", "model")])
    mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
    self.assertEqual(mesh_impl.shape, shape)
    self.assertLen(shape, mesh_impl.ndims)
    self.assertEqual(mesh_impl.layout_rules, layout_rules)
    self.assertEqual(mesh_impl.size, shape.size)
    self.assertTrue(mesh_impl.supports_control_dependencies)

    batch = mtf.Dimension("batch", 128)
    length = mtf.Dimension("length", 500)
    d_ff = mtf.Dimension("d_ff", 2048)
    heads = mtf.Dimension("heads", 8)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
    self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
                     mtf.TensorLayout([0, None, 1]))
예제 #2
0
파일: ops_test.py 프로젝트: tspannhw/mesh
 def testConvertToLayoutRules(self, inputs):
     layout_rules = mtf.convert_to_layout_rules(inputs)
     self.assertEqual(
         layout_rules._pairs,
         mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)
예제 #3
0
파일: ops_test.py 프로젝트: tspannhw/mesh
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertLen(graph.operations, 0)
        self.assertLen(graph.tensors, 0)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertLen(graph.tensors, 1)
        self.assertLen(graph.trainable_variables, 0)
        self.assertLen(graph.all_variables, 0)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.tensors, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.tensors, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    @tf.contrib.eager.run_test_in_graph_and_eager_modes()
    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        inputs_value, outputs_value = self.evaluate([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertEqual(mesh_impl.ndims, len(shape))
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))
예제 #4
0
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.parameters(
        (mtf.Dimension("x", 5), ),
        (("x", 5), ),
    )
    def testConvertToDimension(self, inputs):
        dimension = mtf.convert_to_dimension(inputs)
        self.assertEqual(dimension.name, "x")
        self.assertEqual(dimension.size, 5)

    def testConvertToDimensionGenericInputs(self):
        dimension = mtf.convert_to_dimension(None)
        self.assertEqual(dimension, None)
        with self.assertRaises(TypeError):
            mtf.convert_to_dimension(5)

    @parameterized.parameters(
        (mtf.Shape([mtf.Dimension("x", 4),
                    mtf.Dimension("y", 8)]), ),
        ("x:4;y:8", ),
        ("x:4.y:8", ),
        ("x:4 y:8", ),
        ("x:4,y:8", ),
    )
    def testConvertToShape(self, inputs):
        shape = mtf.convert_to_shape(inputs)
        self.assertEqual(
            shape, mtf.Shape([mtf.Dimension("x", 4),
                              mtf.Dimension("y", 8)]))

    def testConvertToShapeGenericInputs(self):
        shape = mtf.convert_to_shape([])
        self.assertEqual(shape.dims, [])
        shape = mtf.convert_to_shape(None)
        self.assertEqual(shape, None)
        with self.assertRaises(ValueError):
            mtf.convert_to_shape("x;4")

    @parameterized.parameters(
        (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ),
        ("d_ff:model;heads:model", ),
        ("d_ff:model.heads:model", ),
        ("d_ff:model heads:model", ),
        ("d_ff:model,heads:model", ),
        ([("d_ff", "model"), ("heads", "model")], ),
    )
    def testConvertToLayoutRules(self, inputs):
        layout_rules = mtf.convert_to_layout_rules(inputs)
        self.assertEqual(
            layout_rules._pairs,
            mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

    def testConvertToLayoutRulesGenericInputs(self):
        with self.assertRaises(ValueError):
            mtf.convert_to_layout_rules("d_ff;heads")

    def testTensorLayout(self):
        tensor_layout = mtf.TensorLayout([0, 2, 1])
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, ))
        self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
        tensor_layout = mtf.TensorLayout([None, 0])
        self.assertFalse(tensor_layout.is_fully_replicated)
        tensor_layout = mtf.TensorLayout([None, None, None])
        self.assertTrue(tensor_layout.is_fully_replicated)

    def testGraph(self):
        graph = mtf.Graph()
        self.assertEmpty(graph.operations)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        mesh = mtf.Mesh(graph, "mesh_test")
        _ = mtf.import_tf_tensor(mesh,
                                 tf_tensor=tf.constant(0.),
                                 shape=mtf.Shape([]))
        self.assertLen(graph.operations, 1)
        self.assertEmpty(graph.trainable_variables)
        self.assertEmpty(graph.all_variables)
        _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
        self.assertLen(graph.operations, 2)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 1)
        _ = mtf.get_variable(mesh,
                             "variable_1",
                             mtf.Shape([]),
                             trainable=False)
        self.assertLen(graph.operations, 3)
        self.assertLen(graph.trainable_variables, 1)
        self.assertLen(graph.all_variables, 2)

    def testGraphNames(self):
        # Standard Usage.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Edge cases, the user may choose the name "a_1".
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a"), "a_1")
        self.assertEqual(graph.unique_name("a_1"), "a_1_1")

        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("a_1"), "a_1")
        self.assertEqual(graph.unique_name("a"), "a_2")

        # Case insensitive.
        graph = mtf.Graph()
        self.assertEqual(graph.unique_name("a"), "a")
        self.assertEqual(graph.unique_name("A"), "A_1")

    @test_util.run_in_graph_and_eager_modes()
    def testLowering(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        inputs = tf.constant(0.)
        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          tf_tensor=inputs,
                                          shape=mtf.Shape([]))
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        outputs = lowering.export_to_tf_tensor(mtf_inputs)
        inputs_value, outputs_value = self.evaluate([inputs, outputs])
        self.assertEqual(inputs_value, outputs_value)

        # Check that methods run without error.
        _ = lowering.copy_masters_to_slices()
        _ = lowering.copy_slices_to_masters()

    def testMesh(self):
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        self.assertEqual(mesh.graph, graph)

    def testMeshImpl(self):
        shape = mtf.Shape(
            [mtf.Dimension("batch", 4),
             mtf.Dimension("model", 8)])
        layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"),
                                        ("heads", "model")])
        mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
        self.assertEqual(mesh_impl.shape, shape)
        self.assertLen(shape, mesh_impl.ndims)
        self.assertEqual(mesh_impl.layout_rules, layout_rules)
        self.assertEqual(mesh_impl.size, shape.size)
        self.assertTrue(mesh_impl.supports_control_dependencies)

        batch = mtf.Dimension("batch", 128)
        length = mtf.Dimension("length", 500)
        d_ff = mtf.Dimension("d_ff", 2048)
        heads = mtf.Dimension("heads", 8)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
        self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
        self.assertEqual(
            mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
            mtf.TensorLayout([0, None, 1]))

    @parameterized.parameters(
        {
            "pool_fn": np.mean,
            "pool_fn_mtf": mtf.reduce_mean
        }, {
            "pool_fn": np.max,
            "pool_fn_mtf": mtf.reduce_max
        }, {
            "pool_fn": np.min,
            "pool_fn_mtf": mtf.reduce_min
        })
    def testPoolTensor1d(self, pool_fn, pool_fn_mtf):
        converter = mtf_test_utils.NumpyConverter()
        pool_size = 2
        x = np.random.randn(2, 3, 4, 5)
        expected = np.empty(shape=[2, 3, 2, 5])
        expected[:, :, 0, :] = pool_fn(x[:, :, 0:2, :], axis=2)
        expected[:, :, 1, :] = pool_fn(x[:, :, 2:4, :], axis=2)

        x_mtf = converter.convert_np_array_to_mtf_tensor(x, dtype=tf.float32)
        pooled_mtf = mtf.pool_tensor_1d(x_mtf,
                                        pool_dim=x_mtf.shape.dims[2],
                                        reduce_fn=pool_fn_mtf,
                                        pool_size=pool_size)
        actual = converter.convert_mtf_tensor_to_np_array(pooled_mtf)
        self.assertAllClose(expected, actual)

    @parameterized.parameters({"pool_size": 2}, {"pool_size": 3})
    def testStrideTensor1d(self, pool_size):
        converter = mtf_test_utils.NumpyConverter()
        x = np.random.randint(0, 100, size=[2, 3, 6, 5])
        x_mtf = converter.convert_np_array_to_mtf_tensor(x)
        expected = x[:, :, range(0, x.shape[2], pool_size), :]
        strided_mtf = mtf.stride_tensor_1d(x_mtf,
                                           pool_dim=x_mtf.shape.dims[2],
                                           pool_size=pool_size)
        actual = converter.convert_mtf_tensor_to_np_array(strided_mtf)
        self.assertAllEqual(expected, actual)

    def testReduceFirst(self):
        converter = mtf_test_utils.NumpyConverter()
        x = np.random.randint(0, 100, size=[2, 3, 6, 5])
        x_mtf = converter.convert_np_array_to_mtf_tensor(x)
        expected = x[:, :, 0, :]
        reduced_mtf = mtf.reduce_first(x_mtf, reduced_dim=x_mtf.shape.dims[2])
        actual = converter.convert_mtf_tensor_to_np_array(reduced_mtf)
        self.assertAllEqual(expected, actual)