Beispiel #1
0
def traceVsGlow(f_torch,
                f_glow,
                check_trace,
                atol,
                rtol,
                *inputs,
                expected_fused_ops=None,
                accept_all_ops=False,
                black_list=None):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)
        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
Beispiel #2
0
def ephemeral_torchglow_settings(
    fp16=False, backend=DEFAULT_BACKEND, fusion=False, blocklist=None
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16() if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16() if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16() if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass() if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
Beispiel #3
0
    def test_op_blacklist_allowlist(self):
        """Test Glow fuser allowlist overwrites blacklist mechanism."""
        def f(a, b):
            return (a + b) * (a - b)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(["aten::add", "aten::sub"])
        torch_glow.setFusionOverrideAllowlist(["aten::sub"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        fused_add = False
        fused_sub = False
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::add":
                        fused_add = True
                    if node.kind() == "aten::sub":
                        fused_sub = True

        assert not fused_add, "Expected aten::add to be blacklisted"
        assert fused_sub, "Expected aten::sub to not be blacklisted"

        torch_glow.clearFusionBlacklist()
        torch_glow.clearFusionOverrideAllowlist()
Beispiel #4
0
    def test_getattr(self):
        """Test fusion of the PyTorch prim::GetAttr Node into the Glow subgraph."""
        with torch.no_grad():

            class Model(torch.nn.Module):
                def __init__(self):
                    super(Model, self).__init__()
                    self.linear = torch.nn.Linear(2, 1)

                def forward(self, x):
                    return self.linear(x)

            x = torch.tensor([2.0, 3.0])

            torch_glow.enableFusionPass()

            m = Model()
            jit_m = torch.jit.trace(m, x)
            jit_m_graph = jit_m.graph_for(x)

            # Ensure all prim::GetAttrs were fused and none were left out
            found_getattrs = False
            for node in jit_m_graph.nodes():
                kind = node.kind()
                assert (
                    kind != "prim::GetAttr"
                ), "Expected all prim::GetAttrsGlow to be in Glow subgraph"
                if kind == GLOW_FUSION_GROUP:
                    glow_subgraph = node.g(SUBGRAPH_ATTR)
                    for node in glow_subgraph.nodes():
                        if node.kind() == "prim::GetAttr":
                            found_getattrs = True

            assert (found_getattrs
                    ), "Expected to find prim::GetAttrs in the Glow subgraph"
Beispiel #5
0
def jitVsGlow_(f_torch,
               f_glow,
               *inputs,
               expected_fused_ops=None,
               accept_all_ops=False):

    with torch.no_grad():
        torch_glow.disableFusionPass()
        torch_trace = torch.jit.trace(f_torch, inputs)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        glow_trace = torch.jit.trace(f_glow, inputs)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        expected_fused_ops_seen = set()

        # Check that ops that were *not* fused are *not* in expected_fused_ops
        for node in glow_graph.nodes():
            kind = node.kind()
            if kind != GLOW_NODE_NAME:
                # If the node is not a Glow fusion group, check that it is
                # *not* in expected_fused_ops
                assert accept_all_ops or kind not in expected_fused_ops, \
                    "Expected {} to be fused".format(kind)
            else:
                # If the node is a Glow fusion group, record which ops from
                # expected_fused_ops were in it

                # Get the definition of the fusion group
                glow_group = node.g(SUBGRAPH_ATTR)

                # Put all nodes that are in the group and in expected_fused_ops
                # into expected_fused_ops_seen
                for fused_node in glow_group.nodes():
                    fused_node_kind = fused_node.kind()

                    if accept_all_ops or fused_node_kind in expected_fused_ops:
                        expected_fused_ops_seen.add(fused_node_kind)

        # If the sizes of expected_fused_ops and expected_fused_ops_seen are
        # different, some ops in expected_fused_ops are not in the graph at all
        assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \
            "Expected all of expected_fused_ops to be in the graph"
        assert len(torch_res) == len(glow_res)
        for i in range(len(torch_res)):
            assert torch.allclose(torch_res[i], glow_res[i], atol=01e-6)
Beispiel #6
0
def run_model(model, image, use_glow, print_graph):
    if use_glow:
        torch_glow.enableFusionPass()
    with torch.no_grad():
        traced = torch.jit.trace(model, image)
        if print_graph:
            print(traced.graph_for(image))
        all_outputs = traced(image)
        topk = all_outputs.topk(5)
        return (topk[1], topk[0])
    def test_print_jit_indices(self):
        def test_f(a, b):
            c = a.add(b)
            return c.add(c)

        x = torch.randn(4)
        y = torch.randn(4)

        torch_glow.enableFusionPass()
        torch_glow.enable_printing_jit_node_indices()

        graph = torch.jit.trace(test_f, (x, y), check_trace=False)
        graph(x, y)
Beispiel #8
0
    def test_quantized_cut(self):
        """Test cut quantized chunk in the middle."""
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)

        def fun(a, b, c, d):
            q = torch.nn.quantized.Quantize(scale=1.0 / 128,
                                            zero_point=0,
                                            dtype=torch.quint8)
            dq = torch.nn.quantized.DeQuantize()
            a = q(a)
            b = q(b)
            c = q(c)
            d = q(d)
            adds = torch.ops.quantized.add(a, b, scale=1.0 / 121, zero_point=5)
            adds2 = torch.ops.quantized.add(c,
                                            d,
                                            scale=1.0 / 122,
                                            zero_point=4)
            res = torch.ops.quantized.add_relu(adds,
                                               adds2,
                                               scale=1.0 / 120,
                                               zero_point=6)
            res = torch.ops.quantized.add(res,
                                          res,
                                          scale=1.0 / 128,
                                          zero_point=7)
            res = dq(res)
            return res

        with torch.no_grad():
            a = torch.randn([5, 5])
            b = torch.randn([5, 5])
            c = torch.randn([5, 5])
            d = torch.randn([5, 5])
            res_torch = fun(a, b, c, d)
            torch_glow.enableFusionPass()
            # Cut using blacklist functionality
            blacklist = ["quantized::add_relu"]
            torch_glow.setFusionBlacklist(blacklist)
            traced_model = torch.jit.trace(fun, (a, b, c, d))
            for node in traced_model.graph_for(a, b, c, d).nodes():
                kind = node.kind()
                # Make sure the blacklist is working
                assert (kind == GLOW_NODE_NAME or kind in blacklist
                        or kind == "prim::Constant")
            res_glow = traced_model(a, b, c, d)
            print(res_torch)
            print(res_glow)
            assert torch.allclose(res_torch, res_glow)
Beispiel #9
0
    def test_shape_inference_unsupported_symbols_skip_fusion_group(self):
        """Test Glow shape inference unsupported symbols including skipping of
        symbols after a secondary fusion group."""

        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8 * torch.chain_matmul(x8, x8)

        torch_glow.enableFusionPass()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        args = (a, b)

        # Don't skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, and should be reported.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=False
        )
        expected = [
            "aten::chain_matmul",
        ]
        self.assertEqual(set(expected), set(actual))

        # DO skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, but is suppressed due to the skip_last_fusion_node flag.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=True
        )
        expected = []
        self.assertEqual(set(expected), set(actual))
Beispiel #10
0
def scriptVsGlow(
    f,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():

        torch_res = f(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.script(f)
        glow_res = glow_trace(*inputs)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
    def test_backend_specific_options(self):
        """Test loading backend specific options from YAML file."""
        def test_f(a, b):
            return a.add(b)

        x = torch.randn(4)
        y = torch.randn(4)

        # Create YAML file with backend options
        with tempfile.NamedTemporaryFile() as options_fd:
            options_fd.write(b'interpreter-memory: 4194304\n')
            options_fd.flush()

            # Run Glow
            torch_glow.loadBackendSpecificOptions(options_fd.name)
            torch_glow.enableFusionPass()
            glow_trace = torch.jit.trace(test_f, (x, y), check_trace=False)
            glow_trace(x, y)
Beispiel #12
0
    def test_op_index_blacklist_allowlist(self):
        """Test Glow fuser allowlist overwrites index blacklisting mechanism."""
        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8

        torch_glow.enableFusionPass()
        # Only one div is allowed by index
        torch_glow.setFusionStartIndex(5)
        torch_glow.setFusionEndIndex(6)
        # But all divs are allowed by allowlist
        torch_glow.setFusionOverrideAllowlist(["aten::div"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()
        torch_glow.clearFusionOverrideAllowlist()

        fused_muls = 0
        fused_divs = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_NODE_NAME:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::mul":
                        fused_muls += 1
                    if node.kind() == "aten::div":
                        fused_divs += 1

        assert fused_muls == 0, "Expected no aten::muls to be fused"
        assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
Beispiel #13
0
    def test_op_index_blacklist(self):
        """Test Glow fuser index blacklisting mechanism."""
        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8

        torch_glow.enableFusionPass()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        fused_muls = 0
        fused_divs = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::mul":
                        fused_muls += 1
                    if node.kind() == "aten::div":
                        fused_divs += 1

        assert fused_muls == 0, "Expected no aten::muls to be fused"
        assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
Beispiel #14
0
def scriptVsGlow(f,
                 atol,
                 rtol,
                 *inputs,
                 expected_fused_ops=None,
                 accept_all_ops=False,
                 black_list=None):
    if black_list is None:
        black_list = []
    with torch.no_grad():

        torch_res = f(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)
        glow_trace = torch.jit.script(f)
        glow_res = glow_trace(*inputs)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
Beispiel #15
0
def jitVsGlow(f, *inputs):
    torch_glow.disableFusionPass()
    torch_trace = torch.jit.trace(f, inputs)
    torch_res = torch_trace(*inputs)

    torch_glow.enableFusionPass()
    glow_trace = torch.jit.trace(f, inputs)
    glow_res = glow_trace(*inputs)

    # check that there are no Glow nodes in the torch graph
    torch_graph = torch_trace.graph_for(*inputs)
    print("torch_graph,", torch_graph)
    num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
    assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
        num_glow_nodes)

    # check that there is exactly 1 Glow node in the glow graph
    glow_graph = glow_trace.graph_for(*inputs)
    print("glow_graph,", glow_graph)
    num_glow_nodes = len(glow_graph.findAllNodes(GLOW_NODE_NAME))
    assert num_glow_nodes == 1, "Expected exactly 1 Glow node, found {}".format(
        num_glow_nodes)

    assert torch.allclose(torch_res, glow_res, atol=01e-6)
import torch
import torch.nn as nn
import torch_glow


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)


torch._C._jit_set_profiling_mode(True)
torch_glow.enableFusionPass()

m = Model()

m_jit = torch.jit.script(m)

x = torch.randn(10)

# No Glow fusion node
print("initial jit ir")
print(m_jit.graph_for(x))

m_jit(x)
m_jit(x)
m_jit(x)
Beispiel #17
0
def traceVsGlow(
    f_torch,
    f_glow,
    check_trace,
    atol,
    rtol,
    *inputs,
    expected_fused_ops=None,
    accept_all_ops=False,
    black_list=None,
    use_fp16=False,
    backend_name=None,
):
    if black_list is None:
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()

        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)

        if use_fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()

        if backend_name:
            torch_glow.setGlowBackend(backend_name)
        else:
            torch_glow.setGlowBackend("Interpreter")

        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        # need to explicitly clear settings to avoid carry-over static settings
        torch_glow.disableFusionPass()
        torch_glow.disable_convert_to_fp16()
        torch_glow.disable_convert_fused_to_fp16()
        torch_glow.disable_clip_fp16()
        torch_glow.setGlowBackend("Interpreter")

    checkExpectedOps(glow_graph, expected_fused_ops, accept_all_ops)
    checkResult(torch_res, glow_res, atol, rtol)
Beispiel #18
0
def jitVsGlow_(f_torch,
               f_glow,
               check_trace,
               atol,
               rtol,
               *inputs,
               expected_fused_ops=None,
               accept_all_ops=False,
               black_list=None):
    if (black_list is None):
        black_list = []
    with torch.no_grad():
        torch_glow.disableFusionPass()
        torch_trace = torch.jit.trace(f_torch, inputs, check_trace=check_trace)
        torch_res = torch_trace(*inputs)

        torch_glow.enableFusionPass()
        torch_glow.setFusionBlacklist(black_list)
        glow_trace = torch.jit.trace(f_glow, inputs, check_trace=check_trace)
        glow_res = glow_trace(*inputs)

        # check that there are no Glow nodes in the torch graph
        torch_graph = torch_trace.graph_for(*inputs)
        print("torch_graph,", torch_graph)

        num_glow_nodes = len(torch_graph.findAllNodes(GLOW_NODE_NAME))
        assert num_glow_nodes == 0, "Expected no Glow nodes, found {}".format(
            num_glow_nodes)

        glow_graph = glow_trace.graph_for(*inputs)
        print("glow_graph,", glow_graph)

        expected_fused_ops_seen = set()

        # Whether or not at least one node was fused to Glow.
        nodes_were_fused = False

        # Check that ops that were *not* fused are *not* in expected_fused_ops
        for node in glow_graph.nodes():
            kind = node.kind()
            if kind != GLOW_NODE_NAME:
                # If the node is not a Glow fusion group, check that it is
                # *not* in expected_fused_ops
                assert accept_all_ops or kind not in expected_fused_ops, \
                    "Expected {} to be fused".format(kind)
            else:
                # If the node is a Glow fusion group, record which ops from
                # expected_fused_ops were in it

                # Get the definition of the fusion group
                glow_group = node.g(SUBGRAPH_ATTR)

                # Put all nodes that are in the group and in expected_fused_ops
                # into expected_fused_ops_seen
                for fused_node in glow_group.nodes():
                    nodes_were_fused = True
                    fused_node_kind = fused_node.kind()

                    if accept_all_ops or fused_node_kind in expected_fused_ops:
                        expected_fused_ops_seen.add(fused_node_kind)

        assert nodes_were_fused, "Expected some nodes to be fused to Glow"

        # If the sizes of expected_fused_ops and expected_fused_ops_seen are
        # different, some ops in expected_fused_ops are not in the graph at all
        assert accept_all_ops or len(expected_fused_ops) == len(expected_fused_ops_seen), \
            "Expected all of expected_fused_ops to be in the graph"

        if isinstance(torch_res, tuple) or isinstance(glow_res, tuple):
            assert isinstance(torch_res, tuple) and isinstance(glow_res, tuple)
            assert len(torch_res) == len(glow_res)
            for i in range(len(torch_res)):
                print("torch shape: {}".format(torch_res[i].shape),
                      file=sys.stderr)
                print("glow shape: {}".format(glow_res[i].shape),
                      file=sys.stderr)
                assert torch.allclose(torch_res[i],
                                      glow_res[i],
                                      atol=atol,
                                      rtol=rtol)
        else:
            print("torch shape: {}".format(torch_res.shape), file=sys.stderr)
            print("glow shape: {}".format(glow_res.shape), file=sys.stderr)
            is_all_close = torch.allclose(torch_res,
                                          glow_res,
                                          atol=atol,
                                          rtol=rtol)
            if not is_all_close:
                print("torch_res\n", torch_res)
                print("glow_res\n", glow_res)
                print("diff\n", torch.abs(glow_res - torch_res))
            assert is_all_close