def test_tensor_attribute(self): class TensorAttribute(torch.nn.Module): def __init__(self): super().__init__() self.tensor = torch.rand(3, 4) def forward(self, x): return torch.nn.functional.linear(x, self.tensor) ta = TensorAttribute() traced = symbolic_trace(ta) traced(torch.rand(4, 4)) class WrapperForQualname(torch.nn.Module): def __init__(self): super().__init__() self.ta = TensorAttribute() def forward(self, x): return torch.nn.functional.linear(x, self.ta.tensor) wfq = WrapperForQualname() traced2 = symbolic_trace(wfq) traced2.graph.lint(traced2) traced2(torch.rand(4, 4))
def test_pretty_print(self): st = SimpleTest() traced = symbolic_trace(st) traced.graph.lint(traced) printed = str(traced) assert 'GraphModuleImpl()' in printed assert 'torch.relu' in printed
def test_resnet(self): resnet = resnet18() resnet.train() res_graph = symbolic_trace(resnet) res_script = torch.jit.script(res_graph) ip = torch.rand(1, 3, 224, 224) a = resnet(ip) b = res_graph(ip) c = res_script(ip) self.assertEqual(a, b) self.assertEqual(a, c) quantizer = Quantizer(res_graph) for i in range(10): quantizer.observe((torch.rand(1, 3, 224, 224), )) qgraph = quantizer.quantize() qgraph.graph.lint(qgraph) qgraph_script = torch.jit.script(qgraph) d = qgraph(ip) e = qgraph_script(ip) assert (a - d).abs().max() < 2 self.assertEqual(d, e)
def test_unpack_dict_better_error(self): class SomeKwargs(torch.nn.Module): def forward(self, x=3, y=4): return torch.rand(3, 4) class UnpacksDict(torch.nn.Module): def __init__(self): super().__init__() self.sk = SomeKwargs() def forward(self, x: dict): return self.sk(**x) ud = UnpacksDict() with self.assertRaisesRegex( TraceError, 'Proxy object cannot be unpacked as function argument'): symbolic_trace(ud)
def test_tensor_constant(self): class ConstTensor(torch.nn.Module): def forward(self, x): return torch.nn.functional.linear(x, torch.zeros(3, 4)) ct = ConstTensor() traced = symbolic_trace(ct) traced.graph.lint(traced) traced(torch.rand(4, 4))
def test_unpack_list_better_error(self): class SomeArgs(torch.nn.Module): def forward(self, a, b): return torch.rand(3, 4) class UnpacksList(torch.nn.Module): def __init__(self): super().__init__() self.sa = SomeArgs() def forward(self, x: list): return self.sa(*x) ul = UnpacksList() with self.assertRaisesRegex( TraceError, 'Proxy object cannot be unpacked as function argument'): symbolic_trace(ul)
def test_copy_no_remap(self): traced = symbolic_trace(SimpleTest()) g = traced.graph copied = torch._fx.Graph() for node in g.nodes: copied.node_copy(node) with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): copied.lint()
def test_symbolic_trace_sequential(self): class Simple(torch.nn.Module): def forward(self, x): return torch.neg(x) seq = torch.nn.Sequential(Simple(), Simple(), Simple()) traced = symbolic_trace(seq) traced.graph.lint(traced) x = torch.rand(3, 4) self.assertEqual(traced(x), seq(x))
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): """Check that an nn.Module's results match the GraphModule version for a given set of args/kwargs. """ kwargs = kwargs if kwargs else {} ref_outs = m(*args, **kwargs) gm = symbolic_trace(m) gm.graph.lint(gm) test_outs = gm(*args, **kwargs) self.assertEqual(ref_outs, test_outs)
def test_reserved_getattr(self): """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" class M(torch.nn.Module): def forward(self, a): return a.foo.bar.baz m = M() m_g = symbolic_trace(m) m_g.graph.lint(m_g) for node in m_g.graph.nodes: self.assertTrue(node.name != "getattr")
def test_pretty_print_graph(self): class KwargPrintTest(torch.nn.Module): def forward(self, x): return torch.squeeze(x + 3.0, dim=2) st = KwargPrintTest() traced = symbolic_trace(st) traced.graph.lint(traced) stringed = str(traced.graph) for s in ['args', 'kwargs', 'uses']: assert s in stringed
def test_graph_module(self): class MySub(torch.nn.Module): def __init__(self): super().__init__() self.w = torch.nn.Parameter(torch.rand(4, 3)) def forward(self, x): return self.w + x class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(4, 3) self.sub_mod = MySub() self.w = torch.nn.Parameter(torch.rand(3)) def forward(self, A, B, c): t = torch.sigmoid(A) + self.lin(c) return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) m = MyModule() gm = symbolic_trace(m) ms = torch.jit.script(gm) class M2(torch.nn.Module): def forward(self, A): m, idx = torch.max(A, 0) return m + 1, idx + 1 m2 = M2() gm2 = symbolic_trace(m2) class T(torch.nn.Module): def forward(self, A, b=4, *args, c=5, **kwargs): x = A + 1 + args[0] + kwargs['3'] return x t = T() symbolic_trace(t)
def test_torch_custom_ops(self): class M(torch.nn.Module): def forward(self, a): b = torch.ops.aten.sigmoid(a) c = torch.ops.aten.cat([a, b]) return torch.ops.aten.cat((c, c)) m = M() input = torch.randn(3) ref_out = m(input) gm = symbolic_trace(m) gm.graph.lint(gm) out = gm(input) self.assertEqual(out, ref_out)
def test_graph_edit_with_proxy(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch._fx.Graph() new_g.graph_copy(g) t = Proxy(new_g.nodes[-1]) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) gm.graph.lint(gm) self.assertEqual(gm(3, 4), 14)
def test_symbolic_trace_assert(self): message = "assert_foobar" class AssertsTensorShape(torch.nn.Module): def forward(self, x): torch.Assert(x.shape[1] > 4, message) return x m = AssertsTensorShape() # verify traceability traced = symbolic_trace(m) # verify assertion on traced model works correctly at runtime traced(torch.rand(4, 5)) with self.assertRaisesRegex(AssertionError, message): traced(torch.rand(4, 3))
def test_pickle_graphmodule(self): class Nested(torch.nn.Module): def __init__(self): super().__init__() self.st = torch.nn.Linear(4, 4) def forward(self, x): return self.st(x) n = Nested() traced = symbolic_trace(n) traced.graph.lint(traced) pickled = pickle.dumps(traced) loaded = pickle.loads(pickled) loaded.graph.lint(loaded) x = torch.rand(3, 4) self.assertEqual(loaded(x), traced(x))
def test_graph_unique_names(self): class M(torch.nn.Module): def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph new_g = torch._fx.Graph() new_g.graph_copy(g) t = Proxy(new_g.nodes[-1]) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. new_g.output((t + t).node) gm = GraphModule(m, new_g) seen_names: Set[str] = set() for node in gm.graph.nodes: assert node.name not in seen_names seen_names.add(node.name)
def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) GraphManipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)
def test_deepcopy_graphmodule_with_transform(self): st = SimpleTest() traced = symbolic_trace(st) traced.graph.lint(traced) def transform(traced): new_graph = torch._fx.Graph() new_graph.graph_copy(traced.graph) relu_out = new_graph.create_node(op='call_method', target='neg', args=(new_graph.nodes[-1], ), kwargs={}) new_graph.output(relu_out) return GraphModule(traced, new_graph) transformed = transform(traced) transformed.graph.lint(transformed) copied = copy.deepcopy(transformed) self.assertNotEqual(id(type(transformed)), id(type(copied))) x = torch.randn(3, 4) self.assertEqual(copied(x), transformed(x))
def test_deepcopy_with_submods_params(self): class Bar(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) def forward(self, x): return torch.relu(x) + self.param class Baz(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.bar = Bar() def forward(self, x): return self.bar(x) - self.param baz = Baz() traced = symbolic_trace(baz) traced.graph.lint(traced) copied = copy.deepcopy(traced) copied.graph.lint(copied)
def lower_to_elementwise_interpreter( orig_mod: torch.nn.Module) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) # ===== Stage 2: Lower GraphModule representation to the C++ # interpreter's instruction format ====== instructions = [] constant_idx = 0 constants = {} fn_input_names = [] target_to_name = {operator.add: "add", operator.mul: "mul"} # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter for n in mod.graph.nodes: target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" if n.op == 'placeholder': # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) elif n.op == 'call_function': assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() arg_name = f'constant_{constant_idx}' constants[arg_name] = torch.Tensor( [arg] if isinstance(arg, numbers.Number ) else arg) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append( (target_to_name[target], arg_names, out_name)) else: raise RuntimeError('Unsupported opcode' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter( ) # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) # Specify names for positional input arguments interpreter.set_input_names(fn_input_names) # Load instructions interpreter.set_instructions(instructions) # Specify name for single output interpreter.set_output_name(mod.graph.result.name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): def __init__(self, interpreter): super().__init__() self.interpreter = interpreter wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter # 3) Returns the speficied return value # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root # module, however, I can't get `__call__` exposed on TorchBind classes # without it messing up Python `hasattr` for some reason. More digging # into CPython's implementation of hasattr is probably in order... graph = torch._fx.Graph() # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: placeholder_nodes.append(graph.create_node( 'placeholder', name)) # Get the interpreter object interpreter_node = graph.create_node('get_attr', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node(op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) # Register output graph.output(output_node) graph.lint(wrapper) # Return final GraphModule!!! return GraphModule(wrapper, graph)