def testConcatOperation(self): concat_dim1 = mtf.Dimension("concat", 5) concat_dim2 = mtf.Dimension("concat", 7) x1 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim1])) x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim2])) concat_operation = mtf.ConcatOperation([x1, x2], "concat") self.assertEqual(concat_operation.splittable_dims, frozenset(["a", "b"])) self.assertEqual(concat_operation.unsplittable_dims, frozenset(["concat"]))
def setUp(self): super(LayoutValidatorTest, self).setUp() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 5) b_dim = mtf.Dimension("b", 10) concat_dim1 = mtf.Dimension("concat", 15) concat_dim2 = mtf.Dimension("concat", 20) x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1])) x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2])) mtf.ConcatOperation([x1, x2], "concat") # We add a tensor with anonymous shape, which is supposed to be # unsplittable (i.e. none of its dimensions show up during # test_SplittableMtfDimensionNames). _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim]))) mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)]) self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)