def testMeshImpl(self): shape = mtf.Shape([mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertLen(shape, mesh_impl.ndims) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension("x", 5), ), (("x", 5), ), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]), ), ("x:4;y:8", ), ("x:4.y:8", ), ("x:4 y:8", ), ("x:4,y:8", ), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape([]) self.assertEqual(shape.dims, []) shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ), ("d_ff:model;heads:model", ), ("d_ff:model.heads:model", ), ("d_ff:model heads:model", ), ("d_ff:model,heads:model", ), ([("d_ff", "model"), ("heads", "model")], ), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads") def testTensorLayout(self): tensor_layout = mtf.TensorLayout([0, 2, 1]) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ()) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, )) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2)) tensor_layout = mtf.TensorLayout([None, 0]) self.assertFalse(tensor_layout.is_fully_replicated) tensor_layout = mtf.TensorLayout([None, None, None]) self.assertTrue(tensor_layout.is_fully_replicated) def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2) @tf.contrib.eager.run_test_in_graph_and_eager_modes() def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) inputs_value, outputs_value = self.evaluate([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters() def testMesh(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") self.assertEqual(mesh.graph, graph) def testMeshImpl(self): shape = mtf.Shape( [mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertEqual(mesh_impl.ndims, len(shape)) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual( mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension("x", 5), ), (("x", 5), ), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]), ), ("x:4;y:8", ), ("x:4.y:8", ), ("x:4 y:8", ), ("x:4,y:8", ), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape([]) self.assertEqual(shape.dims, []) shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ), ("d_ff:model;heads:model", ), ("d_ff:model.heads:model", ), ("d_ff:model heads:model", ), ("d_ff:model,heads:model", ), ([("d_ff", "model"), ("heads", "model")], ), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads") def testTensorLayout(self): tensor_layout = mtf.TensorLayout([0, 2, 1]) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ()) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, )) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2)) tensor_layout = mtf.TensorLayout([None, 0]) self.assertFalse(tensor_layout.is_fully_replicated) tensor_layout = mtf.TensorLayout([None, None, None]) self.assertTrue(tensor_layout.is_fully_replicated) def testGraph(self): graph = mtf.Graph() self.assertEmpty(graph.operations) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertEmpty(graph.trainable_variables) self.assertEmpty(graph.all_variables) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2) def testGraphNames(self): # Standard Usage. graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a"), "a_1") self.assertEqual(graph.unique_name("a"), "a_2") # Edge cases, the user may choose the name "a_1". graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a"), "a_1") self.assertEqual(graph.unique_name("a_1"), "a_1_1") graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("a_1"), "a_1") self.assertEqual(graph.unique_name("a"), "a_2") # Case insensitive. graph = mtf.Graph() self.assertEqual(graph.unique_name("a"), "a") self.assertEqual(graph.unique_name("A"), "A_1") @test_util.run_in_graph_and_eager_modes() def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) inputs_value, outputs_value = self.evaluate([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters() def testMesh(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") self.assertEqual(mesh.graph, graph) def testMeshImpl(self): shape = mtf.Shape( [mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertLen(shape, mesh_impl.ndims) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual( mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1])) @parameterized.parameters( { "pool_fn": np.mean, "pool_fn_mtf": mtf.reduce_mean }, { "pool_fn": np.max, "pool_fn_mtf": mtf.reduce_max }, { "pool_fn": np.min, "pool_fn_mtf": mtf.reduce_min }) def testPoolTensor1d(self, pool_fn, pool_fn_mtf): converter = mtf_test_utils.NumpyConverter() pool_size = 2 x = np.random.randn(2, 3, 4, 5) expected = np.empty(shape=[2, 3, 2, 5]) expected[:, :, 0, :] = pool_fn(x[:, :, 0:2, :], axis=2) expected[:, :, 1, :] = pool_fn(x[:, :, 2:4, :], axis=2) x_mtf = converter.convert_np_array_to_mtf_tensor(x, dtype=tf.float32) pooled_mtf = mtf.pool_tensor_1d(x_mtf, pool_dim=x_mtf.shape.dims[2], reduce_fn=pool_fn_mtf, pool_size=pool_size) actual = converter.convert_mtf_tensor_to_np_array(pooled_mtf) self.assertAllClose(expected, actual) @parameterized.parameters({"pool_size": 2}, {"pool_size": 3}) def testStrideTensor1d(self, pool_size): converter = mtf_test_utils.NumpyConverter() x = np.random.randint(0, 100, size=[2, 3, 6, 5]) x_mtf = converter.convert_np_array_to_mtf_tensor(x) expected = x[:, :, range(0, x.shape[2], pool_size), :] strided_mtf = mtf.stride_tensor_1d(x_mtf, pool_dim=x_mtf.shape.dims[2], pool_size=pool_size) actual = converter.convert_mtf_tensor_to_np_array(strided_mtf) self.assertAllEqual(expected, actual) def testReduceFirst(self): converter = mtf_test_utils.NumpyConverter() x = np.random.randint(0, 100, size=[2, 3, 6, 5]) x_mtf = converter.convert_np_array_to_mtf_tensor(x) expected = x[:, :, 0, :] reduced_mtf = mtf.reduce_first(x_mtf, reduced_dim=x_mtf.shape.dims[2]) actual = converter.convert_mtf_tensor_to_np_array(reduced_mtf) self.assertAllEqual(expected, actual)