def test_to_glow_selective(self): inputs = (torch.zeros(4) + 8, torch.zeros(4) + 7) torch_res = model(*inputs) bar_inputs = torch_glow.get_submod_inputs(model, "foo.bar", inputs) qux_inputs = torch_glow.get_submod_inputs(model, "qux", inputs) glow_mod = torch_glow.to_glow_selective( model, { "foo.bar": (get_compilation_spec(bar_inputs), bar_inputs), "qux": (get_compilation_spec(qux_inputs), qux_inputs), }, inplace=False, ) glow_mod = torch.jit.trace(glow_mod, inputs) glow_res = glow_mod(*inputs) assert torch.allclose(torch_res, glow_res)
def test_to_glow_selective_already_scripted(self): inputs = (torch.zeros(4) + 8, torch.zeros(4) + 7) torch_res = model(*inputs) bar_inputs = torch_glow.get_submod_inputs(model, "foo.bar", inputs) qux_inputs = torch_glow.get_submod_inputs(model, "qux", inputs) with torch.no_grad(): traced_model = torch.jit.trace(model, inputs) glow_mod = torch_glow.to_glow_selective( traced_model, { "foo.bar": get_compilation_spec(bar_inputs), "qux": get_compilation_spec(qux_inputs), }, inplace=False, ) glow_res = glow_mod(*inputs) assert torch.allclose(torch_res, glow_res)
def test_input_spec(self): """Test setting quantized and non-quantized input specs.""" with torch.no_grad(): a = torch.tensor([[0.1]]) b = torch.tensor([[0.1]]) mod = TestModule() traced_model = torch.jit.trace(mod, (a, b)) ref_result = traced_model(a, b) # test non-quantized input glow_mod = torch_glow.to_glow(traced_model, get_compilation_spec((a, b))) glow_result = glow_mod(a, b) self.assertTrue(torch.allclose(ref_result, glow_result)) # test quantized input add_inputs = torch_glow.get_submod_inputs(mod, "add", (a, b)) glow_mod = torch_glow.to_glow_selective( traced_model, {"add": get_compilation_spec(add_inputs)} ) glow_result = glow_mod(a, b) self.assertTrue(torch.allclose(ref_result, glow_result))