def test_only_tensor_outputs(self):
        """Test that Glow fuser only produces tensor outputs."""
        def f(a, b):
            x = (a + b).size(0)
            c = a.reshape(x, -1)
            return a + c

        torch_glow.disableFusionPass()

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))
        jit_f_graph = jit_f.graph_for(a, b)

        # By creating a graph with an aten::size (supported) feeding into an
        # unsupported op (prim::ListConstruct), we see that even if an op is
        # supported, if it produces a non-tensor output to the fusion group it
        # would not be fused.
        torch_glow.glowCustomFuseDebug_(
            jit_f_graph,
            ["prim::Constant", "aten::add", "aten::size", "aten::reshape"])

        fusion_nodes = 0
        aten_sizes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                fusion_nodes += 1
            if node.kind() == "aten::size":
                aten_sizes += 1

        assert (
            fusion_nodes == 2
        ), "Expected two fusion nodes to be split up with aten::size between them"
        assert aten_sizes == 1, "Expected aten::size not to be fused"
Exemple #2
0
    def test_max_fusion_merge_size_zero(self):
        """Test Glow fuser maximum fusion merge size mechanism set to zero."""
        def f(a):
            return a * a * a * a * a * a

        torch_glow.disableFusionPass()

        # Set maximum fusion merge size to 0 so that there is
        # no limit to fusion
        torch_glow.setMaxFusionMergeSize(0)

        a = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        # print("before: ", jit_f_graph)

        torch_glow.glowCustomFuseDebug_(jit_f_graph)

        # print("after: ", jit_f_graph)

        fusion_nodes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                fusion_nodes += 1

        assert fusion_nodes == 1, "Expected just one fusion group to be created"

        torch_glow.setMaxFusionMergeSize(0)
    def test_max_fusion_merge_size(self):
        """Test Glow fuser maximum fusion merge size mechanism."""
        def f(a):
            return a * a * a * a * a * a

        torch_glow.disableFusionPass()

        # Set maximum fusion merge size to 3 nodes so that the
        # graph will not fit into 1 node
        torch_glow.setMaxFusionMergeSize(3)

        a = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a))
        jit_f_graph = jit_f.graph_for(a)

        # print("before: ", jit_f_graph)

        torch_glow.glowCustomFuseDebug_(jit_f_graph)

        # print("after: ", jit_f_graph)

        fusion_nodes = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                fusion_nodes += 1

        assert fusion_nodes > 1, "Expected more than one fusion group to be created"

        torch_glow.setMaxFusionMergeSize(0)
    def test_fuse_parallel_branches_without_fusible_root(self):
        r"""Test GlowFuser fusing parallel branches without a common fusible root

                x = add(x, x)       y = add(y, y)
                        |                  |
                b1 = add(x, x)      b2 = add(y, y)
                        \                 /
                res = TupleConstruct(b1, b2)

        This should be fused as

                        glow::FusionGroup_0
                                    |
                            TupleConstruct

        """
        def test_fuser(x, y):
            x = x + x
            y = y + y
            branch1 = x + x
            branch2 = y + y
            res = (branch1, branch2)
            return res

        inputs = (torch.randn(2, 4), torch.randn(2, 4))
        traced = torch.jit.trace(test_fuser, inputs)
        torch_glow.glowCustomFuseDebug_(traced.graph)

        count = 0
        for node in traced.graph.nodes():
            if node.kind() == "glow::FusionGroup":
                count += 1
        assert count == 1, f"Expect 1 glow::FusionGroup, found {count}."
Exemple #5
0
def test_min_graph_size():
    """Test Glow fuser minimum fusion group size mechanism."""

    torch_glow.disableFusionPass()

    # Disable aten::div so that each group of aten::mul nodes will be forced
    # into separate subgraphs
    torch_glow.setFusionBlacklist(["aten::div"])

    # Set minimum fusion group size to 3 nodes so that the smallest group which
    # contains only 2 nodes will not be created
    torch_glow.setMinFusionGroupSize(3)

    a = torch.randn(5, 5)
    b = torch.randn(5, 5)
    c = torch.randn(5, 5)

    jit_f = torch.jit.trace(f, (a, b, c))
    jit_f_graph = jit_f.graph_for(a, b, c)

    # print("before: ", jit_f_graph)

    torch_glow.glowCustomFuseDebug_(jit_f_graph)

    # print("after: ", jit_f_graph)

    fusion_nodes = 0
    for node in jit_f_graph.nodes():
        if node.kind() == GLOW_NODE_NAME:
            fusion_nodes += 1

    assert fusion_nodes == 2, "Expected smallest fusion group to not be created"

    torch_glow.clearFusionBlacklist()
    torch_glow.setMinFusionGroupSize(0)
def test_fuse_necessary_getattrs_only():
    m = Model()
    x = torch.randn(1, 3, 5, 5)

    torch_glow.disableFusionPass()

    jit_m = torch.jit.trace(m, x)
    jit_m_graph = jit_m.graph_for(x)

    # don't fuse aten::_convolutions
    torch_glow.glowCustomFuseDebug_(
        jit_m_graph,
        ["prim::Constant", "prim::GetAttr", "aten::t", "aten::matmul", "aten::add_"],
    )

    return m(x)