Exemplo n.º 1
0
    def test_shape_inference_unsupported_symbols_skip_fusion_group(self):
        """Test Glow shape inference unsupported symbols including skipping of
        symbols after a secondary fusion group."""

        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8 * torch.chain_matmul(x8, x8)

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        args = (a, b)

        # Don't skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, and should be reported.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=False
        )
        expected = [
            "aten::chain_matmul",
        ]
        self.assertEqual(set(expected), set(actual))

        # DO skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, but is suppressed due to the skip_last_fusion_node flag.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=True
        )
        expected = []
        self.assertEqual(set(expected), set(actual))
Exemplo n.º 2
0
    def test_op_index_blacklist_allowlist(self):
        """Test Glow fuser allowlist overwrites index blacklisting mechanism."""
        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8

        torch_glow.enableFusionPass()
        # Only one div is allowed by index
        torch_glow.setFusionStartIndex(5)
        torch_glow.setFusionEndIndex(6)
        # But all divs are allowed by allowlist
        torch_glow.setFusionOverrideAllowlist(["aten::div"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()
        torch_glow.clearFusionOverrideAllowlist()

        fused_muls = 0
        fused_divs = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::mul":
                        fused_muls += 1
                    if node.kind() == "aten::div":
                        fused_divs += 1

        assert fused_muls == 0, "Expected no aten::muls to be fused"
        assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
Exemplo n.º 3
0
    def test_op_index_blacklist(self):
        """Test Glow fuser index blacklisting mechanism."""
        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        fused_muls = 0
        fused_divs = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::mul":
                        fused_muls += 1
                    if node.kind() == "aten::div":
                        fused_divs += 1

        assert fused_muls == 0, "Expected no aten::muls to be fused"
        assert fused_divs == 3, "Expected all 3 aten::divs to be fused"