def test_to_glow_selective(self):
        inputs = (torch.zeros(4) + 8, torch.zeros(4) + 7)
        torch_res = model(*inputs)

        bar_inputs = torch_glow.get_submod_inputs(model, "foo.bar", inputs)
        qux_inputs = torch_glow.get_submod_inputs(model, "qux", inputs)

        glow_mod = torch_glow.to_glow_selective(
            model,
            {
                "foo.bar": (get_compilation_spec(bar_inputs), bar_inputs),
                "qux": (get_compilation_spec(qux_inputs), qux_inputs),
            },
            inplace=False,
        )

        glow_mod = torch.jit.trace(glow_mod, inputs)
        glow_res = glow_mod(*inputs)

        assert torch.allclose(torch_res, glow_res)
    def test_to_glow_selective_already_scripted(self):
        inputs = (torch.zeros(4) + 8, torch.zeros(4) + 7)
        torch_res = model(*inputs)

        bar_inputs = torch_glow.get_submod_inputs(model, "foo.bar", inputs)
        qux_inputs = torch_glow.get_submod_inputs(model, "qux", inputs)

        with torch.no_grad():
            traced_model = torch.jit.trace(model, inputs)

        glow_mod = torch_glow.to_glow_selective(
            traced_model,
            {
                "foo.bar": get_compilation_spec(bar_inputs),
                "qux": get_compilation_spec(qux_inputs),
            },
            inplace=False,
        )
        glow_res = glow_mod(*inputs)
        assert torch.allclose(torch_res, glow_res)
示例#3
0
    def test_input_spec(self):
        """Test setting quantized and non-quantized input specs."""
        with torch.no_grad():
            a = torch.tensor([[0.1]])
            b = torch.tensor([[0.1]])

            mod = TestModule()
            traced_model = torch.jit.trace(mod, (a, b))
            ref_result = traced_model(a, b)

            # test non-quantized input
            glow_mod = torch_glow.to_glow(traced_model, get_compilation_spec((a, b)))
            glow_result = glow_mod(a, b)
            self.assertTrue(torch.allclose(ref_result, glow_result))

            # test quantized input
            add_inputs = torch_glow.get_submod_inputs(mod, "add", (a, b))
            glow_mod = torch_glow.to_glow_selective(
                traced_model, {"add": get_compilation_spec(add_inputs)}
            )
            glow_result = glow_mod(a, b)
            self.assertTrue(torch.allclose(ref_result, glow_result))