Ejemplo n.º 1
0
    def test_op_blacklist(self):
        """Test Glow fuser op kind blacklisting mechanism."""
        def f(a, b):
            return (a + b) * (a - b)

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionBlacklist(["aten::add"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        fused_add = False
        fused_sub = False
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::add":
                        fused_add = True
                    if node.kind() == "aten::sub":
                        fused_sub = True

        assert not fused_add, "Expected aten::add to be blacklisted"
        assert fused_sub, "Expected aten::sub to not be blacklisted"

        torch_glow.clearFusionBlacklist()
Ejemplo n.º 2
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
Ejemplo n.º 3
0
    def test_op_blacklist_allowlist(self):
        """Test Glow fuser allowlist overwrites blacklist mechanism."""
        def f(a, b):
            return (a + b) * (a - b)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(["aten::add", "aten::sub"])
        torch_glow.setFusionOverrideAllowlist(["aten::sub"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        fused_add = False
        fused_sub = False
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::add":
                        fused_add = True
                    if node.kind() == "aten::sub":
                        fused_sub = True

        assert not fused_add, "Expected aten::add to be blacklisted"
        assert fused_sub, "Expected aten::sub to not be blacklisted"

        torch_glow.clearFusionBlacklist()
        torch_glow.clearFusionOverrideAllowlist()
Ejemplo n.º 4
0
def test_min_graph_size():
    """Test Glow fuser minimum fusion group size mechanism."""

    torch_glow.disableFusionPass()

    # Disable aten::div so that each group of aten::mul nodes will be forced
    # into separate subgraphs
    torch_glow.setFusionBlacklist(["aten::div"])

    # Set minimum fusion group size to 3 nodes so that the smallest group which
    # contains only 2 nodes will not be created
    torch_glow.setMinFusionGroupSize(3)

    a = torch.randn(5, 5)
    b = torch.randn(5, 5)
    c = torch.randn(5, 5)

    jit_f = torch.jit.trace(f, (a, b, c))
    jit_f_graph = jit_f.graph_for(a, b, c)

    # print("before: ", jit_f_graph)

    torch_glow.glowCustomFuseDebug_(jit_f_graph)

    # print("after: ", jit_f_graph)

    fusion_nodes = 0
    for node in jit_f_graph.nodes():
        if node.kind() == GLOW_NODE_NAME:
            fusion_nodes += 1

    assert fusion_nodes == 2, "Expected smallest fusion group to not be created"

    torch_glow.clearFusionBlacklist()
    torch_glow.setMinFusionGroupSize(0)