def test_pretty_print(self): st = SimpleTest() traced = symbolic_trace(st) printed = str(traced) assert 'GraphModuleImpl()' in printed assert 'torch.relu' in printed
by creating a custom Tracer and overriding `is_leaf_module`. In this case, we'll keep the default behavior for all `torch.nn` Modules except for `ReLU`. """ class M1(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(x) default_traced: GraphModule = symbolic_trace(M1()) """ Tracing with the default tracer and calling `print_tabular` produces: opcode name target args kwargs ----------- ------ -------- --------- -------- placeholder x x () {} call_module relu_1 relu (x,) {} output output output (relu_1,) {} """ default_traced.graph.print_tabular() class LowerReluTracer(Tracer): def is_leaf_module(self, m: torch.nn.Module, qualname: str):
compiled_f = nnc_compile(fx_model, args, get_loopnest=True) return compiled_f return wrapped ################################ # Example usage and Benchmarking ################################ def bench(f, warmup=3, iters=1000): for _ in range(warmup): f() begin = time.time() for _ in range(iters): f() print(time.time() - begin) if __name__ == '__main__': def f(a, b): return (torch.cos(a) * torch.sin(b))[:2000] mod = fx.symbolic_trace(f) inps = (torch.randn(5000), torch.randn(5000)) ShapeProp(mod).propagate(*inps) cg = nnc_compile(mod, inps) bench(lambda: cg(*inps)) bench(lambda: f(*inps))
def lower_to_elementwise_interpreter( orig_mod: torch.nn.Module) -> torch.nn.Module: # ===== Stage 1: Symbolic trace the module ===== mod = symbolic_trace(orig_mod) # ===== Stage 2: Lower GraphModule representation to the C++ # interpreter's instruction format ====== instructions = [] constant_idx = 0 constants = {} fn_input_names = [] target_to_name = {operator.add: "add", operator.mul: "mul"} output_node: Optional[Node] = None # For each instruction, create a triple # (instruction_name : str, inputs : List[str], output : str) # to feed into the C++ interpreter for n in mod.graph.nodes: target, args, out_name = n.target, n.args, n.name assert len(n.kwargs) == 0, "kwargs currently not supported" if n.op == 'placeholder': # Placeholders specify function argument names. Save these # for later when we generate the wrapper GraphModule fn_input_names.append(target) elif n.op == 'call_function': assert target in target_to_name, "Unsupported call target " + target arg_names = [] for arg in args: if not isinstance(arg, Node): # Pull out constants. These constants will later be # fed to the interpreter C++ object via add_constant() arg_name = f'constant_{constant_idx}' constants[arg_name] = torch.Tensor( [arg] if isinstance(arg, numbers.Number ) else arg) arg_names.append(arg_name) constant_idx += 1 else: arg_names.append(arg.name) instructions.append( (target_to_name[target], arg_names, out_name)) elif n.op == 'output': if output_node is not None: raise RuntimeError('Multiple output nodes!') output_node = n else: raise RuntimeError('Unsupported opcode ' + n.op) interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter( ) # Load constants for k, v in constants.items(): interpreter.add_constant(k, v) # Specify names for positional input arguments interpreter.set_input_names(fn_input_names) # Load instructions interpreter.set_instructions(instructions) # Specify name for single output assert isinstance(output_node.args[0], torch.fx.Node) interpreter.set_output_name(output_node.args[0].name) # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== class WrapperModule(torch.nn.Module): def __init__(self, interpreter): super().__init__() self.interpreter = interpreter wrapper = WrapperModule(interpreter) # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter # 3) Returns the speficied return value # FIXME: The following code could be greatly simplified by symbolic_trace'ing # the wrapper with a Tracer that considers the Wrapper instance a root # module, however, I can't get `__call__` exposed on TorchBind classes # without it messing up Python `hasattr` for some reason. More digging # into CPython's implementation of hasattr is probably in order... graph = torch.fx.Graph() # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: placeholder_nodes.append(graph.create_node( 'placeholder', name)) # Get the interpreter object interpreter_node = graph.create_node('get_attr', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node(op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) # Register output graph.output(output_node) graph.lint(wrapper) # Return final GraphModule!!! return GraphModule(wrapper, graph)
def test_nonetype_annotation(self): eb = torch.nn.EmbeddingBag(3, 4) symbolic_trace(eb)
def test_fn_type_annotation_empty(self): def forward(a: List[torch.Tensor]): return a[0] torch.jit.script(symbolic_trace(forward))
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inputs) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_shapes(node): return [te.ExprHandle.int(i) for i in node.shape] def get_nnc_type(dtype): if dtype == torch.float: return te.Dtype.Float elif dtype == torch.long: return te.Dtype.Long else: raise RuntimeError("nyi") def get_te_type(node): return get_nnc_type(node.dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): return fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. shapes = get_te_shapes(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) inputs.append(env[node.name]) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. result = lower_function(node, node.target, lookup_env(node.args), node.args) env[node.name] = result elif node.op == 'output': outs = list(lookup_env(node.args)) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes) else: raise RuntimeError("not yet implemented") loopnest = te.LoopNest(outs) loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs ]) def f(inps): module_stuff = [fetch_attr(i.target) for i in module_attrs] cg.call(module_stuff + list(inps)) return f
def test_general_shape_ops(self): """ A test that checks dequantize will be swapped for all supported general shape ops like aten::flatten without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) self.dropout = torch.nn.Dropout() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) self.relu = torch.nn.ReLU() def forward(self, x): x = self.conv1(x) # add_scalar x = x + 3 # mul_scalar x = x * 3 # add_scalar_out x += 3 # mul_scalar_out x *= 3 # add_scalar_relu x = x + 3 x = F.relu(x) # add_scalar_relu_out x += 3 x = F.relu(x) # mul_scalar_relu x = x * 3 x = F.relu(x) # mul_scalar_relu_out x *= 3 x = F.relu(x) x = self.maxpool1d(x) x = self.maxpool2d(x) x = self.maxpool3d(x) x = torch.flatten(x) x = torch.max(x) x = torch.min(x) x = x.reshape([-1]) x = x.resize_(1, 1, x.numel()) x = x.view(-1) # prim::ListConstruct xs = [x, x] # prim::ListUnpack x, y = xs # prim::TupleConstruct xs = (x, x) # prim::TupleUnpack x, y = xs x = x.transpose(1, 2) x = x.contiguous() x, y = torch.chunk(x, 2) x = F.dropout(x) x = self.dropout(x) x, _ = torch.sort(x) x = x.permute(0, 2, 3, 1) x = x.repeat_interleave(3, 1) x = torch.repeat_interleave(x, 3, 1) x = self.relu(x) x = F.relu(x) x = F.relu(x, inplace=True) x = x.relu() x.relu_() x = x.squeeze(0) x.squeeze_(0) x = torch.squeeze(x, 0) x = x.unsqueeze(0) x.unsqueeze_(0) x = torch.unsqueeze(x, 0) x = x.detach() x.detach_() x = x.repeat(4, 2) y = [] y.append(x) z = torch.stack(y, 0) z = [z, z] x, _ = z x = self.conv2(x) return x data = torch.rand(1, 3, 10, 10) # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step quantizer = Quantizer() qconfig_dict = {'': default_qconfig} prepared = quantizer.prepare(original, qconfig_dict) # not runnable quantized = quantizer.convert(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers and also successfully fused two quantized::conv2d # patterns # one quantize_per_tensor for input # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def test_general_value_ops(self): """ A test that checks correct patterns are produced for all supported general value ops like aten::avg_pool2d \ without actually checking for execution of these ops """ class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.conv = torch.nn.Conv2d(3, 3, 3) self.avg_pool1d = torch.nn.AvgPool1d(3) self.avg_pool2d = torch.nn.AvgPool2d(3) self.avg_pool3d = torch.nn.AvgPool3d(3) self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d( (1, 1, 1)) self.leaky_relu = torch.nn.LeakyReLU() self.hardsigmoid = torch.nn.Hardsigmoid() self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x): x = self.conv(x) x = self.avg_pool1d(x) x = self.avg_pool2d(x) x = self.avg_pool3d(x) x = self.adaptive_avg_pool1d(x) x = self.adaptive_avg_pool2d(x) x = self.adaptive_avg_pool3d(x) x = F.avg_pool1d(x, 3) x = F.avg_pool2d(x, 3) x = F.avg_pool3d(x, 3) x = F.adaptive_avg_pool1d(x, (1)) x = F.adaptive_avg_pool2d(x, (1, 1)) x = F.adaptive_avg_pool3d(x, (1, 1, 1)) x = torch.mean(x) x = torch.mean(x, [2, 3], False) x = x.mean() x = x.mean([2, 3], True) x = F.interpolate(x, 4, mode='nearest') x = F.interpolate(x, 4, mode='linear') x = self.leaky_relu(x) x = F.leaky_relu(x) x = F.leaky_relu(x, inplace=True) x = x.leaky_relu() x.leaky_relu_() x = self.hardsigmoid(x) x = F.hardsigmoid(x) x = F.hardsigmoid(x, inplace=True) x = x.hardsigmoid() x.hardsigmoid_() x = self.sigmoid(x) x = torch.sigmoid(x) # F.sigmoid is deprecated x = x.sigmoid() x.sigmoid_() x = self.tanh(x) # F.tanh is deprecated x = torch.tanh(x) x = x.tanh() x.tanh_() x = self.conv(x) return x # This model is not executable since we just put all ops # in the same forward m = M() original = symbolic_trace(m) # nothing to fuse so skipping the fuse step quantizer = Quantizer() qconfig_dict = {'': default_qconfig} prepared = quantizer.prepare(original, qconfig_dict) # not runnable quantized = quantizer.convert(prepared) # This checks that the dequantize from the output of first conv # is being propagated to the end, so that we don't insert extra # observers # check exact counts of quantize and dequantize count_check = { ns.call_function(torch.quantize_per_tensor): 1, ns.call_method('dequantize'): 1 } order_check = [ ns.call_function(torch.quantize_per_tensor), ns.call_module(nnq.Conv2d), ns.call_module(nnq.Conv2d), ns.call_method('dequantize'), ] self.checkGraphModuleNodes(quantized, expected_node_occurrence=count_check, expected_node_list=order_check)
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> None: """ Matches all possible non-overlapping sets of operators and their data dependencies (``pattern``) in the Graph of a GraphModule (``gm``), then replaces each of these matched subgraphs with another subgraph (``replacement``). Args: ``gm``: The GraphModule that wraps the Graph to operate on ``pattern``: The subgraph to match in ``gm`` for replacement ``replacement``: The subgraph to replace ``pattern`` with Examples: .. code-block:: python import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) The above code will first match ``pattern`` in the ``forward`` method of ``traced_module``. Pattern-matching is done based on use-def relationships, not node names. For example, if you had ``p = torch.cat([a, b])`` in ``pattern``, you could match ``m = torch.cat([a, b])`` in the original ``forward`` function, despite the variable names being different (``p`` vs ``m``). The ``return`` statement in ``pattern`` is matched based on its value only; it may or may not match to the ``return`` statement in the larger graph. In other words, the pattern doesn't have to extend to the end of the larger graph. When the pattern is matched, it will be removed from the larger function and replaced by ``replacement``. If there are multiple matches for ``pattern`` in the larger function, each non-overlapping match will be replaced. In the case of a match overlap, the first found match in the set of overlapping matches will be replaced. ("First" here being defined as the first in a topological ordering of the Nodes' use-def relationships. In most cases, the first Node is the parameter that appears directly after ``self``, while the last Node is whatever the function returns.) One important thing to note is that the parameters of the ``pattern`` Callable must be used in the Callable itself, and the parameters of the ``replacement`` Callable must match the pattern. The first rule is why, in the above code block, the ``forward`` function has parameters ``x, w1, w2``, but the ``pattern`` function only has parameters ``w1, w2``. ``pattern`` doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. As an example of the second rule, consider replacing .. code-block:: python def pattern(x, y): return torch.neg(x) + torch.relu(y) with .. code-block:: python def replacement(x, y): return torch.relu(x) In this case, ``replacement`` needs the same number of parameters as ``pattern`` (both ``x`` and ``y``), even though the parameter ``y`` isn't used in ``replacement``. After calling ``subgraph_rewriter.replace_pattern``, the generated Python code looks like this: .. code-block:: python def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2 """ # Get the graphs for `gm`, `pattern`, `replacement` original_graph = gm.graph pattern_graph = symbolic_trace(pattern).graph replacement_graph = symbolic_trace(replacement).graph # Find all possible pattern matches in original_graph. Note that # pattern matches may overlap with each other. matcher = SubgraphMatcher(pattern_graph) matches: List[Match] = [] # Consider each node as an "anchor" (deepest matching graph node) for anchor in original_graph.nodes: if matcher.matches_subgraph_from_anchor(anchor): def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool: # `lookup` represents all the nodes in `original_graph` # that are part of `pattern` lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()} for n in lookup.keys(): if n.op == "placeholder" or lookup[n].op == "output": continue for user in n.users: # If this node has users that were not in # `lookup`, then it must leak out of the # pattern subgraph if user not in lookup: return False return True # It's not a match if the pattern leaks out into the rest # of the graph if pattern_is_contained(matcher.nodes_map): for k, v in matcher.nodes_map.items(): # Shallow copy nodes_map matches.append( Match(anchor=anchor, nodes_map=copy.copy(matcher.nodes_map))) # The set of all nodes in `original_graph` that we've seen thus far # as part of a pattern match replaced_nodes: Set[Node] = set() # Return TRUE if one of the nodes in the current match has already # been used as part of another match def overlaps_with_prev_match(match: Match) -> bool: for n in match.nodes_map.values(): if n in replaced_nodes and n.op != "placeholder": return True return False for match in matches: # Skip overlapping matches if overlaps_with_prev_match(match): continue # Map replacement graph nodes to their copy in `original_graph` val_map: Dict[Node, Node] = {} pattern_placeholders = [ n for n in pattern_graph.nodes if n.op == "placeholder" ] assert len(pattern_placeholders) replacement_placeholders = [ n for n in replacement_graph.nodes if n.op == "placeholder" ] assert len(pattern_placeholders) == len(replacement_placeholders) placeholder_map = { r: p for r, p in zip(replacement_placeholders, pattern_placeholders) } # node from `original_graph` that matched with the output node # in `pattern` subgraph_output: Node = match.anchor def mark_node_as_replaced(n: Node) -> None: if n not in match.nodes_map.values(): return for n_ in n.all_input_nodes: mark_node_as_replaced(n_) replaced_nodes.add(n) mark_node_as_replaced(subgraph_output) # Intialize `val_map` with mappings from placeholder nodes in # `replacement` to their corresponding node in `original_graph` for replacement_node in replacement_placeholders: # Get the `original_graph` placeholder node # corresponding to the current `replacement_node` pattern_node = placeholder_map[replacement_node] original_graph_node = match.nodes_map[pattern_node] # Populate `val_map` val_map[replacement_node] = original_graph_node # Copy the replacement graph over with original_graph.inserting_before(subgraph_output): copied_output = original_graph.graph_copy(replacement_graph, val_map) assert isinstance(copied_output, Node) # We only want to copy in the output node from `pattern` if we # have an output-output match. Otherwise, we leave out the # `pattern` output node so we don't have two outputs in the # resultant graph if subgraph_output.op != "output": subgraph_output = subgraph_output.args[0] # type: ignore subgraph_output.replace_all_uses_with(copied_output) # Erase the `pattern` nodes for node in reversed(original_graph.nodes): if len(node.users) == 0 and node.op != "output": original_graph.erase_node(node) # Update the passed-in GraphModule to reflect the new state of # `original_graph` gm.recompile()
def checkGraphModeFxOp(self, model, inputs, check_spec, quant_type=QuantType.STATIC, debug=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model check_spec: either: quntized_node: a tuple of 2 elements, first element is the op type for GraphModule node, second element is the target function for call_function and type of the module for call_module e.g. ('call_function', torch.ops.quantized.conv2d) node map: a dict from node(tuple of 2 elements) to number of occurences (int) e.g. {('call_function', torch.quantize_per_tensor) : 1)} ordered node list: a list of node(tuple of 2 elements) e.g. [('call_function', torch.quantize_per_tensor), ('call_function', torch.dequantize)] """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse(original) quantizer = Quantizer() # TODO: uncommon after we make per channel observer work in the flow # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} qconfig_dict = {'': default_qconfig} if quant_type == QuantType.DYNAMIC: prepared = quantizer.prepare_dynamic(fused, qconfig_dict) else: prepared = quantizer.prepare(fused, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if debug: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() self.checkGraphModule(qgraph, check_spec)
# Non-torch annotation with no internal forward references class M3(torch.nn.Module): def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: return a(x[0]) # Non-torch annotation with internal forward references class M4(torch.nn.Module): def forward(self, x: typing.List['torch.Tensor'], a: A) -> 'torch.Tensor': return a(x[0]) x = torch.rand(2, 3) ref = torch.add(x, x) traced1 = symbolic_trace(M1()) res1 = traced1(x, A()) assert torch.all(torch.eq(ref, res1)) traced2 = symbolic_trace(M2()) res2 = traced2(x, A()) assert torch.all(torch.eq(ref, res2)) traced3 = symbolic_trace(M3()) res3 = traced3([x], A()) assert torch.all(torch.eq(ref, res3)) traced4 = symbolic_trace(M4()) res4 = traced4([x], A()) assert torch.all(torch.eq(ref, res4))
def test_functional(self): """ Test quantizing functional conv and linear """ class Conv(torch.nn.Module): def __init__(self): super().__init__() self.stride = (1, 1) self.padding = (0, 0) self.dilation = (1, 1) self.groups = 1 def forward(self, x, weight): return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups) conv_input = torch.rand(1, 3, 224, 224) conv_weight = torch.rand(3, 3, 3, 3) class Linear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, weight): return F.linear(x, weight) linear_input = torch.rand(8, 5) linear_weight = torch.rand(10, 5) class LinearModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 10) def forward(self, x): return self.linear(x) linear_module_input = torch.rand(8, 5) tests = [ (False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)), (True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)), (False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)), (True, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.dynamic.Linear)), (False, LinearModule, (linear_module_input, ), ('call_module', torch.nn.quantized.Linear)), ] for is_dynamic, M, inputs, quantized_node in tests: m = M().eval() qconfig = default_qconfig graph = symbolic_trace(m) script = torch.jit.script(graph) a = m(*inputs) b = graph(*inputs) c = script(*inputs) assert (a - b).abs().max() == 0 assert (a - c).abs().max() == 0 assert torch.allclose(a, b) assert torch.allclose(a, c) graph = fuse(graph) quantizer = Quantizer() qconfig_dict = {'': qconfig} if is_dynamic: prepared = quantizer.prepare_dynamic(graph, qconfig_dict) else: prepared = quantizer.prepare(graph, qconfig_dict) prepared(*inputs) qgraph = quantizer.convert(prepared) qgraph_debug = quantizer.convert(prepared, debug=True) qgraph.eval() qgraph_debug.eval() qgraph_script = torch.jit.script(qgraph) d = qgraph(*inputs) d_debug = qgraph_debug(*inputs) e = qgraph_script(*inputs) e_debug = qgraph_debug(*inputs) found = False modules = dict(qgraph.root.named_modules()) for node in qgraph.graph.nodes: if node.op == 'call_function': found = found or node.op == quantized_node[ 0] and node.target == quantized_node[1] elif node.op == 'call_module': found = found or node.op == quantized_node[0] and type( modules[node.target]) == quantized_node[1] assert found, 'Expected to find quantized node:' + str( quantized_op) # assert (a-d).abs().max() < 2 assert torch.allclose(d, e) assert (d - d_debug).abs().max() == 0 assert (e - e_debug).abs().max() == 0
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module: """ nnc_compile(model, example_inputs) returns a function with the same args as `model.forward`, with an extra argument corresponding to where the output is stored. This function takes the inputs (which must be PyTorch tensors with the same shapes as example_inputs), and passes them to an NNC executor. """ fx_model = fx.symbolic_trace(model) ShapeProp(fx_model).propagate(*example_inputs) # This env maps from nodes to `te.ExprHandle`, which represent the output # of an NNC computation. env = {} def get_te_type(node): return get_nnc_type(node.meta['tensor_meta'].dtype) def gen_compute(args): te_args = [env[arg.name] for arg in args] def lookup_env(l): res = fx.node.map_aggregate( l, lambda x: env[x.name] if isinstance(x, fx.Node) else x) return res def fetch_attr(target: str): target_atoms = target.split('.') attr_itr = fx_model for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr outs = None inputs = [] module_attrs = [] compute_stmts = [] for node in fx_model.graph.nodes: if node.op == 'placeholder': # We simply map the input placeholder to a `te.Placeholder`, which # also represents an input to the NNC computation. shapes = get_te_shapes(node.meta['tensor_meta'].shape) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() inputs.append(placeholder) elif node.op == 'call_function': # This does the bulk of the work - we call `lower_function`, which # returns a `te.ExprHandle` (the output of a NNC computation), and # put it in our environment. if 'tensor_meta' in node.meta: # todo: fix kwargs handling if node.kwargs: raise RuntimeError("kwargs nyi") buf, stmt = lower_function(node, node.target, lookup_env(node.args), node.args) # if isinstance(stmt, list) compute_stmts.extend(stmt) env[node.name] = buf elif node.target == getattr or node.target == operator.getitem: # todo: handle non-tensor computations correctly continue elif node.op == 'output': args = node.args if not isinstance(args, tuple): args = (args, ) if isinstance(args[0], tuple): args = args[0] te_args = lookup_env(args) outs = (list(te_args), [(i.meta['tensor_meta'].shape, i.meta['tensor_meta'].dtype) for i in args]) elif node.op == 'get_attr': # As NNC doesn't have any concept of state, we pull out the module # attributes and pass them in as inputs to NNC. module_attrs.append(node) shapes = get_te_shapes( process_shape(node.meta['tensor_meta'].shape)) placeholder = te.Placeholder(node.name, get_te_type(node), shapes) env[node.name] = placeholder.data() else: print(node.op, node.target) raise RuntimeError("not yet implemented") loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0]) # loopnest.inline_intermediate_bufs(True) loopnest.simplify() loopnest.prepare_for_codegen() stmt = te.simplify(loopnest.root_stmt()) cg = te.construct_codegen('llvm', stmt, [ te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs[0] ]) alloc_results = [ torch.empty(shape, dtype=dtype) for shape, dtype in outs[1] ] if module_attrs: module_stuff = [ fetch_attr(i.target).contiguous().data for i in module_attrs ] else: module_stuff = [] def f(*inps, out_tensors=None): # begin = time.time() if out_tensors is None: results = alloc_results else: results = out_tensors cg.call(module_stuff + list(inps) + results) if out_tensors is None: if len(results) == 1: return results[0] return results return f
def test_subgraph_rewriter_placeholder_matching(self): """ This tests that a placeholder Node can be matched to a Node with a different number of input Nodes. In the example below, the original traced Module looks like this: opcode target args kwargs ------------- ---------------------------------------------------------- ------------------------ -------- placeholder x () {} call_function <built-in function add> (x, 3) {} call_method dequantize (add,) {} call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} call_method to (sigmoid, torch.float16) {} output output (to,) {} while the pattern we want to match looks like this: opcode target args kwargs ------------- ---------------------------------------------------------- ------------------------ -------- placeholder x () {} call_method dequantize (x,) {} call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} call_method to (sigmoid, torch.float16) {} output output (to,) {} Here, we want to be able to match the original graph's `call_function.add` Node with the pattern graph's `plaeholder.x` Node. Credit to Jerry Zhang (GitHub: jerryzh168) for this test case """ class M(torch.nn.Module): def __init__(self): super().__init__() self.dtype = torch.float16 def forward(self, x): x += 3 x = x.dequantize() x = torch.sigmoid(x) dtype = self.dtype x = x.to(dtype) return x def pattern(x): x = x.dequantize() x = torch.sigmoid(x) x = x.to(torch.float16) return x def replacement(x): return x def comparison(x): return x + 3 traced = symbolic_trace(M()) comparison_fn = symbolic_trace(comparison) x = torch.randn(3, 4) subgraph_rewriter.replace_pattern(traced, pattern, replacement) traced.graph.lint() ref_outs = comparison_fn(x) test_outs = traced.forward(x) self.assertEqual(ref_outs, test_outs)
def _test_model_impl(self, mode, name, model, eager_quantizable_model, check_with_eager=True, diff_of_quant=None, diff_from_eager=None): if diff_of_quant is None or diff_from_eager is None: diff_of_quant = {} diff_from_eager = {} if mode not in diff_of_quant or mode not in diff_from_eager: diff_of_quant[mode] = {} diff_from_eager[mode] = {} input_tensor = torch.rand(1, 3, 224, 224) input_tensor_inception = torch.rand(1, 3, 299, 299) output_value = torch.randint(0, 1, (1, )) # print('quantizing:', name, ' mode:', mode) if name == 'inception_v3': input_value = input_tensor_inception else: input_value = input_tensor qconfig = default_qconfig if mode == 'static' else default_qat_qconfig qconfig_dict = {'': qconfig} graph_module = symbolic_trace(model) # print('graph module:', graph_module.src) script = torch.jit.script(graph_module) # make sure graph module and script module are both runanble original_out = graph_module(input_value) is_not_tuple_out = not isinstance(original_out, tuple) script_out = script(input_value) self.assertEqual( (original_out - script_out).abs().max(), 0, 'Reslut of original graph module and script module does not match') # set to train just before quantization if mode != 'static': model.train() graph_module = fuse(graph_module) quantizer = Quantizer() prepared = quantizer.prepare(graph_module, qconfig_dict) if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, prepared), nprocs=world_size, join=True) elif mode == 'qat': assert prepared.training, 'prepared must be in training mode for qat' optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): prepared(input_value) # print('after observation root:', prepared.root) qgraph = quantizer.convert(prepared) # print('after quantization root:', qgraph.root) # print('after quantization code:', qgraph.src) qgraph.eval() qgraph_script = torch.jit.script(qgraph) # print('quantized and scripted:', qgraph_script.graph) qgraph_out = qgraph(input_value) qgraph_script = qgraph_script(input_value) if is_not_tuple_out: diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' else: print('tuple output') if eager_quantizable_model is not None: # comparing to eager mode quantization qeager = eager_quantizable_model ref_out = qeager(input_value) qeager.qconfig = qconfig if mode == 'static': qeager.fuse_model() prepare(qeager, inplace=True) else: qeager.train() qeager.fuse_model() prepare_qat(qeager, inplace=True) # calibration if mode == 'ddp': mp.spawn(run_ddp, args=(world_size, qeager), nprocs=world_size, join=True) elif mode == 'qat': assert qeager.training, 'qeager should be in training mode for qat' optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: for i in range(10): qeager(input_value) # print('ref after observation:', qeager) convert(qeager, inplace=True) qeager.eval() # print('ref after quantization:', qeager) qeager_out = qeager(input_value) qeager_script = torch.jit.script(qeager) qscript_out = qeager_script(input_value) if is_not_tuple_out: diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() if check_with_eager: self.assertEqual( diff_from_eager[mode][name], 0, 'Result of graph mode quantization and ' + 'eager mode quantization on model: ' + name + ' should match. Mode: ' + mode + ' diff:' + str(diff_from_eager[mode][name]))
return (q, p) q_i = torch.tensor(onp.array(q_i['z'])) p_i = torch.tensor(onp.array(p_i(seed)['z'])) inv_mass_matrix = torch.eye(D) # inv_mass_matrix = torch.ones(D) step_size = 0.001 num_steps = 10000 import torch.fx as fx # is_u_turning = nnc_compile(wrap_key(is_u_turning, (q_i,p_i,q_i)), (q_i, p_i, q_i)) # is_u_turning = nnc_compile(is_u_turning, example_inputs=(q_i, p_i, q_i)) inps = (q_i, p_i, potential_fn, inv_mass_matrix, step_size) leapfrog = fx.symbolic_trace(wrap_key(leapfrog, inps)) leapfrog = remove_args(leapfrog) # leapfrog = truncate(leapfrog, 7) # leapfrog = fx.symbolic_trace(leapfrog, concrete_args={'potential_fn': potential_fn, 'step_size': step_size, 'inverse_mass_matrix': inv_mass_matrix}) # if VERSION == 'nnc': leapfrog = nnc_compile(leapfrog, example_inputs=(q_i, p_i, inv_mass_matrix)) # elif VERSION == 'ts': # leapfrog = torch.jit.script(leapfrog) # elif VERSION == 'pt': # pass out = get_final_state(potential_fn, inv_mass_matrix, step_size, q_i, p_i) begin = time.time() out = get_final_state(potential_fn, inv_mass_matrix, step_size, q_i, p_i) print(out[0][0])
for node in list(model.graph.nodes): new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node cnt += 1 if cnt == k: new_graph.output(env[node.name]) break return fx.GraphModule(model, new_graph) torch.manual_seed(0) inputs = (torch.randn(1, 3, 256, 256),) model = mobilenet_v2() model.eval() fx_model = fx.symbolic_trace(model) fx_model = fuse(fx_model) fx_model = decompose(fx_model, example_inputs=inputs) # f = torch.jit.freeze(model) # torch._C._fancy_compile(f.graph, list(inputs[0].shape)) # f(*inputs) # fx_model = truncate(fx_model, 4) print(fx_model) # print(fx_model.code) nnc_model = nnc_compile(fx_model, example_inputs=inputs) import time iters = 1 with torch.no_grad(): nnc_model(*inputs) begin = time.time() # while True:
# Sample module class M(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(x) + 1.0 # Symbolically trace an instance of `M`. After tracing, `self.relu` is # represented as a `call_module` Node. The full operation in the # generated `forward` function's code will appear as `self.relu(x)` m = symbolic_trace(M()) # Insert nodes from the ReLU graph in place of the original call to # `self.relu` for node in m.graph.nodes: # Find `call_module` Node in `m` that corresponds to `self.relu`. # This is the Node we want to swap out for an inlined version of the # same call if (node.op, node.target) == ("call_module", "relu"): with m.graph.inserting_before(node): # Create a Proxy from each Node in the current Node's # args/kwargs proxy_args = map_arg(node.args, Proxy) proxy_kwargs = map_arg(node.kwargs, Proxy) # Call `m.relu` with the newly-created Proxy arguments. # `m.relu` is the generic version of the function; by
def checkGraphModeFxOp(self, model, inputs, quant_type, expected_node=None, expected_node_occurrence=None, expected_node_list=None, debug=False, print_debug_info=False): """ Quantizes model with graph mode quantization on fx and check if the quantized model contains the quantized_node Args: model: floating point torch.nn.Module inputs: one positional sample input arguments for model expected_node: NodeSpec e.g. NodeSpec.call_function(torch.quantize_per_tensor) expected_node_occurrence: a dict from NodeSpec to expected number of occurences (int) e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, NodeSpec.call_method('dequantize'): 1} expected_node_list: a list of NodeSpec, used to check the order of the occurrence of Node e.g. [NodeSpec.call_function(torch.quantize_per_tensor), NodeSpec.call_module(nnq.Conv2d), NodeSpec.call_function(F.hardtanh_), NodeSpec.call_method('dequantize')] """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: model.train() else: model.eval() original = symbolic_trace(model) fused = fuse_fx(original) qconfig_dict = { '': get_default_qconfig(torch.backends.quantized.engine) } if quant_type == QuantType.DYNAMIC: prepare = prepare_dynamic_fx convert = convert_dynamic_fx else: prepare = prepare_fx convert = convert_fx prepared = prepare(fused, qconfig_dict) prepared(*inputs) qgraph = convert(prepared) qgraph_debug = convert(prepared, debug=True) result = qgraph(*inputs) result_debug = qgraph_debug(*inputs) self.assertEqual((result - result_debug).abs().max(), 0), \ 'Expecting debug and non-debug option to produce identical result' if print_debug_info: print() print('quant type:', quant_type) print('origianl graph module:', type(model)) self.printGraphModule(original) print() print('quantized graph module:', type(qgraph)) self.printGraphModule(qgraph) print() qgraph_to_check = qgraph_debug if debug else qgraph self.checkGraphModuleNodes(qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
def test_update_args_kwargs_yells_at_you(self): symtraced = symbolic_trace(SimpleTest()) node = next(iter(symtraced.graph.nodes)) with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): node.__update_args_kwargs((), {})
def test_namedtuple_return_trace(self): class NamedTupReturn(torch.nn.Module): def forward(self, x): return Pair(x, x) traced = symbolic_trace(NamedTupReturn())
def test_torch_fx_len(self): class FXLenTest(torch.nn.Module): def forward(self, x): return torch.fx.len(x) traced = symbolic_trace(FXLenTest())
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant, is_standalone_module): """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. When we are preparing a standalone module: input of the module is observed in parent module, output of the module is observed in the standalone module. Returns: model(GraphModule): prepared standalone module with following attributes: _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that needs to be observed in parent module _output_is_observed(Bool): a boolean variable indicate whether the output of the custom module is observed or not """ if not inplace: model = copy.deepcopy(model) self.is_dynamic_quant = is_dynamic_quant if self.is_dynamic_quant: self.patterns = get_dynamic_quant_patterns() else: self.patterns = get_quant_patterns() flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict) if model.training: self._qat_swap_modules(model) self.modules = dict(model.named_modules()) convert_dict_to_ordered_dict(qconfig_dict) # map from node name to qconfig, used in _find_matches self._generate_qconfig_map(model, model.graph, qconfig_dict) # match the patterns that will get quantized standalone_module_names = qconfig_dict.get('standalone_module_name', None) matches = self._find_matches(model.graph, self.modules, self.patterns, standalone_module_names) # find _inputs_ to matched nodes that are not quantized, these # have to be quantized, which requires measuring stats, # initialize an DefaultQuant object for each quants = self._find_quants(model.graph, matches) self.activation_post_process_map = dict() env = {} observed_graph = Graph() observed_node_names_set = set() def load_arg(a): return map_arg(a, lambda node: env[node.name]) # indexes for the inputs that needs to be observed standalone_module_observed_input_idxs = [] graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix( 'activation_post_process_') for node in model.graph.nodes: if node.name in observed_node_names_set: continue prefix = node.name + '_activation_post_process_' root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: env[node.name] = observed_graph.node_copy(node, load_arg) if qconfig is None: continue def insert_observer(node, observer, device): get_new_observer_name = get_new_attr_name_with_prefix( prefix) observer_name = get_new_observer_name(model) setattr(model, observer_name, observer) self.activation_post_process_map[node.name] = observer env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) if device: getattr(model, observer_name).to(device) if isinstance(obj, CustomModuleQuantizeHandler): custom_module = self.modules[node.target] observed_custom_module_class = \ get_observed_custom_module_class(type(custom_module)) observed_custom_module = \ observed_custom_module_class.from_float(custom_module) mark_observed_custom_module(observed_custom_module, type(custom_module)) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) # index for input of custom module that needs to be observed in parent standalone_module_input_idxs = None if isinstance(obj, StandaloneModuleQuantizeHandler): # observe standalone module standalone_module = self.modules[node.target] traced_standalone_module = symbolic_trace( standalone_module) if self.is_dynamic_quant: prepare = torch.quantization.quantize_fx._prepare_dynamic_standalone_module_fx else: prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx observed_standalone_module = prepare( traced_standalone_module, {'': qconfig}) observed_standalone_module.qconfig = qconfig standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs observed_standalone_module = mark_observed_standalone_module( observed_standalone_module) parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_standalone_module) self.modules[node.target] = observed_standalone_module # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue # inserting observers for output of observed module, or mark the output # as observed if isinstance(obj, CopyNode): assert node.op in [ 'call_module', 'call_function', 'call_method'], \ 'CopyNode of type ' + node.op + ' is not handled' def is_observed(input_arg): if isinstance(input_arg, Node): return input_arg.name in observed_node_names_set elif isinstance(input_arg, list): return all(map(is_observed, input_arg)) # propagate observed property from input if is_observed(node.args[0]): observed_node_names_set.add(node.name) elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes: if node.args[0].name in observed_node_names_set: observed_node_names_set.add(node.name) elif isinstance(obj, StandaloneModuleQuantizeHandler): assert node.op == 'call_module' output_is_observed = self.modules[ node.target]._output_is_observed if output_is_observed: observed_node_names_set.add(node.name) elif qconfig is not None and obj.all_nodes: # observer for outputs new_observer = qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) insert_observer(node, new_observer, device) # insert observer for input of standalone module if standalone_module_input_idxs is not None: for idx in standalone_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() device = assert_and_get_unique_device(model) insert_observer(node.args[idx], new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) if node.name not in observed_node_names_set and node.name in quants: if is_standalone_module and node.name in graph_inputs: # we'll insert observer for input of standalone module # in parent graph standalone_module_observed_input_idxs.append( graph_inputs.index(node.name)) continue get_new_observer_name = get_new_attr_name_with_prefix(prefix) observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: # TODO: use insert_observer new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers device = assert_and_get_unique_device(model) if device: new_observer.to(device) self.activation_post_process_map[node.name] = new_observer setattr(model, observer_name, self.activation_post_process_map[node.name]) env[node.name] = observed_graph.create_node( 'call_module', observer_name, (load_arg(node), ), {}) observed_node_names_set.add(node.name) observed_graph.output(load_arg(model.graph.result)) model = GraphModule(model, observed_graph) self.save_state(model) if is_standalone_module: assert isinstance(model.graph.result, Node), \ 'standalone module returning dict is not yet supported' # indicator for whether output is observed or not. # This used for correctly quantize standalone modules output_is_observed = model.graph.result.name in observed_node_names_set model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs model._output_is_observed = output_is_observed return model