Exemple #1
0
    def testOptimizeLayoutUnsplittable(self):
        x1 = mtf.zeros(self.mesh, "a:10,b:5")
        x2 = mtf.zeros(self.mesh, "b:5,c:20")
        mtf.UnstackOperation(x1, mtf.Dimension("a", 10))
        mtf.UnstackOperation(x2, mtf.Dimension("c", 20))
        optimizer = self.get_layout_optimizer()

        # No dimensions can be split, because a and c are unstack dimensions and
        # b has size 5 (so there are divisiblity issues).
        self.assertEqual(optimizer.solve(), "")
Exemple #2
0
 def testUnstackOperation(self):
     unstack_operation = mtf.UnstackOperation(self.x, self.b_dim)
     self.assertEqual(unstack_operation.splittable_dims, frozenset(["a"]))
     self.assertEqual(unstack_operation.unsplittable_dims, frozenset(["b"]))