Example #1
0
    def setUp(self):
        super(LayoutValidatorTest, self).setUp()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        a_dim = mtf.Dimension("a", 5)
        b_dim = mtf.Dimension("b", 10)
        concat_dim1 = mtf.Dimension("concat", 15)
        concat_dim2 = mtf.Dimension("concat", 20)

        x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1]))
        x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2]))
        mtf.ConcatOperation([x1, x2], "concat")

        # We add a tensor with anonymous shape, which is supposed to be
        # unsplittable (i.e. none of its dimensions show up during
        # test_SplittableMtfDimensionNames).
        _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim])))

        mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)])
        self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)
Example #2
0
 def _compute_layout_validator(self):
     """Computes self._layout_validator."""
     self._layout_validator = valid_layouts.LayoutValidator(
         self.mtf_graph, self.mesh_shape)