Example #1
0
    def test_getattr(self):
        """Test fusion of the PyTorch prim::GetAttr Node into the Glow subgraph."""
        with torch.no_grad():

            class Model(torch.nn.Module):
                def __init__(self):
                    super(Model, self).__init__()
                    self.linear = torch.nn.Linear(2, 1)

                def forward(self, x):
                    return self.linear(x)

            x = torch.tensor([2.0, 3.0])

            torch_glow.enableFusionPass_DO_NOT_USE_THIS()

            m = Model()
            jit_m = torch.jit.trace(m, x)
            jit_m_graph = jit_m.graph_for(x)

            # Ensure all prim::GetAttrs were fused and none were left out
            found_getattrs = False
            for node in jit_m_graph.nodes():
                kind = node.kind()
                assert (
                    kind != "prim::GetAttr"
                ), "Expected all prim::GetAttrsGlow to be in Glow subgraph"
                if kind == GLOW_FUSION_GROUP:
                    glow_subgraph = node.g(SUBGRAPH_ATTR)
                    for node in glow_subgraph.nodes():
                        if node.kind() == "prim::GetAttr":
                            found_getattrs = True

            assert (found_getattrs
                    ), "Expected to find prim::GetAttrs in the Glow subgraph"
Example #2
0
    def test_op_blacklist(self):
        """Test Glow fuser op kind blacklisting mechanism."""
        def f(a, b):
            return (a + b) * (a - b)

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionBlacklist(["aten::add"])

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        fused_add = False
        fused_sub = False
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::add":
                        fused_add = True
                    if node.kind() == "aten::sub":
                        fused_sub = True

        assert not fused_add, "Expected aten::add to be blacklisted"
        assert fused_sub, "Expected aten::sub to not be blacklisted"

        torch_glow.clearFusionBlacklist()
Example #3
0
def run_model(model, image, use_glow, backend, print_graph):
    if use_glow:
        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        if backend:
            torch_glow.setGlowBackend(backend)

    with torch.no_grad():
        traced = torch.jit.trace(model, image)
        if print_graph:
            print(traced.graph_for(image))
        all_outputs = traced(image)
        topk = all_outputs.topk(5)
        return (topk[1], topk[0])
Example #4
0
    def test_print_jit_indices(self):
        def test_f(a, b):
            c = a.add(b)
            return c.add(c)

        x = torch.randn(4)
        y = torch.randn(4)

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.enable_printing_jit_node_indices()

        graph = torch.jit.trace(test_f, (x, y), check_trace=False)
        graph(x, y)
    def test_shape_inference_unsupported_symbols_skip_fusion_group(self):
        """Test Glow shape inference unsupported symbols including skipping of
        symbols after a secondary fusion group."""

        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8 * torch.chain_matmul(x8, x8)

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        args = (a, b)

        # Don't skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, and should be reported.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=False
        )
        expected = [
            "aten::chain_matmul",
        ]
        self.assertEqual(set(expected), set(actual))

        # DO skip nodes after the last fusion node.
        # in this case, one of the nodes (chain_matmul) following the last fusion node
        # is not supported, but is suppressed due to the skip_last_fusion_node flag.
        actual = torch_glow.glow_shape_inference_find_unsupported_symbols(
            jit_f_graph, args, skip_last_fusion_node=True
        )
        expected = []
        self.assertEqual(set(expected), set(actual))
    def test_quantized_cut(self):
        """Test cut quantized chunk in the middle."""
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)

        def fun(a, b, c, d):
            q = torch.nn.quantized.Quantize(scale=1.0 / 21,
                                            zero_point=0,
                                            dtype=torch.quint8)
            dq = torch.nn.quantized.DeQuantize()
            a = q(a)
            b = q(b)
            c = q(c)
            d = q(d)
            adds = torch.ops.quantized.add(a, b, scale=1.0 / 17, zero_point=5)
            adds2 = torch.ops.quantized.add(c, d, scale=1.0 / 14, zero_point=4)
            res = torch.ops.quantized.add_relu(adds,
                                               adds2,
                                               scale=1.0 / 18,
                                               zero_point=6)
            res = torch.ops.quantized.add(res,
                                          res,
                                          scale=1.0 / 13,
                                          zero_point=7)
            res = dq(res)
            return res

        with torch.no_grad():
            a = torch.randn([5, 5])
            b = torch.randn([5, 5])
            c = torch.randn([5, 5])
            d = torch.randn([5, 5])
            res_torch = fun(a, b, c, d)
            torch_glow.enableFusionPass_DO_NOT_USE_THIS()
            # Cut using blacklist functionality
            blacklist = ["quantized::add_relu"]
            torch_glow.setFusionBlacklist(blacklist)
            torch_glow.setGlowBackend("Interpreter")
            traced_model = torch.jit.trace(fun, (a, b, c, d))
            for node in traced_model.graph_for(a, b, c, d).nodes():
                kind = node.kind()
                # Make sure the blacklist is working
                assert (kind == GLOW_FUSION_GROUP or kind in blacklist
                        or kind == "prim::Constant")
            res_glow = traced_model(a, b, c, d)
            print(res_torch)
            print(res_glow)
            assert torch.allclose(res_torch, res_glow)
Example #7
0
def ephemeral_torchglow_settings(
    fp16=False,
    backend=DEFAULT_BACKEND,
    fusion=False,
    blocklist=None,
    accept_all_layouts=False,
):
    old_fp16 = torch_glow.get_convert_to_fp16()
    old_clip = torch_glow.get_clip_fp16()
    old_convert_fused = torch_glow.get_convert_fused_to_fp16()
    old_backend = torch_glow.getGlowBackendName()
    old_blocklist = torch_glow.getFusionBlacklist()
    old_fusion = torch_glow.getFusionPassEnabled()
    try:
        if fusion:
            torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        else:
            torch_glow.disableFusionPass()
        if fp16:
            torch_glow.enable_convert_to_fp16()
            torch_glow.enable_convert_fused_to_fp16()
            torch_glow.enable_clip_fp16()
        else:
            torch_glow.disable_convert_to_fp16()
            torch_glow.disable_convert_fused_to_fp16()
            torch_glow.disable_clip_fp16()
        if blocklist is None:
            torch_glow.clearFusionBlacklist()
        else:
            torch_glow.setFusionBlacklist(list(blocklist))
        if accept_all_layouts:
            torch_glow.enable_accept_all_layout()
        else:
            torch_glow.disable_accept_all_layout()
        torch_glow.setGlowBackend(backend)
        yield
    finally:
        torch_glow.enable_convert_to_fp16(
        ) if old_fp16 else torch_glow.disable_convert_to_fp16()
        torch_glow.enable_clip_fp16(
        ) if old_clip else torch_glow.disable_clip_fp16()
        torch_glow.enable_convert_fused_to_fp16(
        ) if old_convert_fused else torch_glow.disable_convert_fused_to_fp16()
        torch_glow.enableFusionPass_DO_NOT_USE_THIS(
        ) if old_fusion else torch_glow.disableFusionPass()
        torch_glow.setGlowBackend(old_backend)
        torch_glow.setFusionBlacklist(old_blocklist)
    def test_backend_specific_options(self):
        """Test loading backend specific options from YAML file."""
        def test_f(a, b):
            return a.add(b)

        x = torch.randn(4)
        y = torch.randn(4)

        # Create YAML file with backend options
        with tempfile.NamedTemporaryFile() as options_fd:
            options_fd.write(b"interpreter-memory: 4194304\n")
            options_fd.flush()

            # Run Glow
            torch_glow.loadBackendSpecificOptions(options_fd.name)
            torch_glow.enableFusionPass_DO_NOT_USE_THIS()
            glow_trace = torch.jit.trace(test_f, (x, y), check_trace=False)
            glow_trace(x, y)
Example #9
0
    def test_op_index_blacklist(self):
        """Test Glow fuser index blacklisting mechanism."""
        def f(a, b):
            x1 = a * b
            x2 = x1 * b
            x3 = x2 * a
            x4 = x3 / b
            x5 = x4 / a
            x6 = x5 / b
            x7 = x6 * a
            x8 = x7 * b
            return x8

        torch_glow.enableFusionPass_DO_NOT_USE_THIS()
        torch_glow.setFusionStartIndex(3)
        torch_glow.setFusionEndIndex(6)

        a = torch.randn(5, 5)
        b = torch.randn(5, 5)

        jit_f = torch.jit.trace(f, (a, b))

        jit_f_graph = jit_f.graph_for(a, b)

        torch_glow.clearFusionIndices()

        fused_muls = 0
        fused_divs = 0
        for node in jit_f_graph.nodes():
            if node.kind() == GLOW_FUSION_GROUP:
                glow_subgraph = node.g(SUBGRAPH_ATTR)
                for node in glow_subgraph.nodes():
                    if node.kind() == "aten::mul":
                        fused_muls += 1
                    if node.kind() == "aten::div":
                        fused_divs += 1

        assert fused_muls == 0, "Expected no aten::muls to be fused"
        assert fused_divs == 3, "Expected all 3 aten::divs to be fused"
Example #10
0
@torch.jit.script
def foo(a, b):
    c = a.mul(b)
    a = c.mul(c)
    a = c.mul(a)
    d = c.div(a)
    return d


print("original jit ir")
print(foo.graph_for(x, y))

jit_res = foo(x, y)

torch_glow.enableFusionPass_DO_NOT_USE_THIS()


@torch.jit.script
def foo_glow(a, b):
    return foo(a, b)


print("glow jit ir")
print(foo_glow.graph_for(x, y))

jit_glow_res = foo_glow(x, y)

print("jit_res")
print(jit_res)
print("jit_glow_res")