Beispiel #1
0
  def testConcatOperation(self):
    concat_dim1 = mtf.Dimension("concat", 5)
    concat_dim2 = mtf.Dimension("concat", 7)

    x1 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim1]))
    x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim2]))

    concat_operation = mtf.ConcatOperation([x1, x2], "concat")
    self.assertEqual(concat_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(concat_operation.unsplittable_dims, frozenset(["concat"]))
Beispiel #2
0
    def setUp(self):
        super(LayoutValidatorTest, self).setUp()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        a_dim = mtf.Dimension("a", 5)
        b_dim = mtf.Dimension("b", 10)
        concat_dim1 = mtf.Dimension("concat", 15)
        concat_dim2 = mtf.Dimension("concat", 20)

        x1 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim1]))
        x2 = mtf.zeros(mesh, mtf.Shape([a_dim, b_dim, concat_dim2]))
        mtf.ConcatOperation([x1, x2], "concat")

        # We add a tensor with anonymous shape, which is supposed to be
        # unsplittable (i.e. none of its dimensions show up during
        # test_SplittableMtfDimensionNames).
        _ = mtf.zeros(mesh, mtf.anonymous_shape(mtf.Shape([a_dim, b_dim])))

        mesh_shape = mtf.Shape([("m1", 4), ("m2", 2)])
        self.valid_layouts = valid_layouts.LayoutValidator(graph, mesh_shape)