def test_remove_exceptions(self):
        """Test Glow's removeExceptions JIT pass"""

        foo_jit = torch.jit.script(foo)
        graph = foo_jit.graph
        assert graph_contains_str(graph, "prim::RaiseException")
        torch_glow.removeExceptions_(graph)
        assert not graph_contains_str(graph, "prim::RaiseException")
Beispiel #2
0
 def test_fuse_linear(self):
     """Test Glow's fuseBranchedLinearPattern JIT pass"""
     graph = torch._C.parse_ir(graph_str)
     assert not graph_contains_str(graph, "glow::fused_linear")
     torch_glow.fuseBranchedLinearPattern_(graph)
     assert graph_contains_str(graph, "glow::fused_linear")
 def test_remove_exceptions(self):
     """Test Glow's removeExceptions JIT pass"""
     graph = torch._C.parse_ir(graph_str)
     assert(graph_contains_str(graph, "prim::RaiseException"))
     torch_glow.removeExceptions_(graph)
     assert(not graph_contains_str(graph, "prim::RaiseException"))