Exemplo n.º 1
0
 def test_pretty_print(self):
     st = SimpleTest()
     traced = symbolic_trace(st)
     printed = str(traced)
     assert 'GraphModuleImpl()' in printed
     assert 'torch.relu' in printed
by creating a custom Tracer and overriding `is_leaf_module`. In this
case, we'll keep the default behavior for all `torch.nn` Modules except
for `ReLU`.
"""


class M1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x)


default_traced: GraphModule = symbolic_trace(M1())
"""
Tracing with the default tracer and calling `print_tabular` produces:

    opcode       name    target    args       kwargs
    -----------  ------  --------  ---------  --------
    placeholder  x       x         ()         {}
    call_module  relu_1  relu      (x,)       {}
    output       output  output    (relu_1,)  {}

"""
default_traced.graph.print_tabular()


class LowerReluTracer(Tracer):
    def is_leaf_module(self, m: torch.nn.Module, qualname: str):
Exemplo n.º 3
0
        compiled_f = nnc_compile(fx_model, args, get_loopnest=True)
        return compiled_f

    return wrapped


################################
# Example usage and Benchmarking
################################


def bench(f, warmup=3, iters=1000):
    for _ in range(warmup):
        f()
    begin = time.time()
    for _ in range(iters):
        f()
    print(time.time() - begin)


if __name__ == '__main__':

    def f(a, b):
        return (torch.cos(a) * torch.sin(b))[:2000]

    mod = fx.symbolic_trace(f)
    inps = (torch.randn(5000), torch.randn(5000))
    ShapeProp(mod).propagate(*inps)
    cg = nnc_compile(mod, inps)
    bench(lambda: cg(*inps))
    bench(lambda: f(*inps))
Exemplo n.º 4
0
        def lower_to_elementwise_interpreter(
                orig_mod: torch.nn.Module) -> torch.nn.Module:
            # ===== Stage 1: Symbolic trace the module =====
            mod = symbolic_trace(orig_mod)

            # ===== Stage 2: Lower GraphModule representation to the C++
            #       interpreter's instruction format ======
            instructions = []
            constant_idx = 0
            constants = {}
            fn_input_names = []

            target_to_name = {operator.add: "add", operator.mul: "mul"}

            output_node: Optional[Node] = None
            # For each instruction, create a triple
            # (instruction_name : str, inputs : List[str], output : str)
            # to feed into the C++ interpreter
            for n in mod.graph.nodes:
                target, args, out_name = n.target, n.args, n.name
                assert len(n.kwargs) == 0, "kwargs currently not supported"

                if n.op == 'placeholder':
                    # Placeholders specify function argument names. Save these
                    # for later when we generate the wrapper GraphModule
                    fn_input_names.append(target)
                elif n.op == 'call_function':
                    assert target in target_to_name, "Unsupported call target " + target
                    arg_names = []
                    for arg in args:
                        if not isinstance(arg, Node):
                            # Pull out constants. These constants will later be
                            # fed to the interpreter C++ object via add_constant()
                            arg_name = f'constant_{constant_idx}'
                            constants[arg_name] = torch.Tensor(
                                [arg] if isinstance(arg, numbers.Number
                                                    ) else arg)
                            arg_names.append(arg_name)
                            constant_idx += 1
                        else:
                            arg_names.append(arg.name)
                    instructions.append(
                        (target_to_name[target], arg_names, out_name))
                elif n.op == 'output':
                    if output_node is not None:
                        raise RuntimeError('Multiple output nodes!')
                    output_node = n
                else:
                    raise RuntimeError('Unsupported opcode ' + n.op)

            interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter(
            )
            # Load constants
            for k, v in constants.items():
                interpreter.add_constant(k, v)
            # Specify names for positional input arguments
            interpreter.set_input_names(fn_input_names)
            # Load instructions
            interpreter.set_instructions(instructions)
            # Specify name for single output
            assert isinstance(output_node.args[0], torch.fx.Node)
            interpreter.set_output_name(output_node.args[0].name)

            # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
            class WrapperModule(torch.nn.Module):
                def __init__(self, interpreter):
                    super().__init__()
                    self.interpreter = interpreter

            wrapper = WrapperModule(interpreter)

            # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
            # 3) Returns the speficied return value

            # FIXME: The following code could be greatly simplified by symbolic_trace'ing
            # the wrapper with a Tracer that considers the Wrapper instance a root
            # module, however, I can't get `__call__` exposed on TorchBind classes
            # without it messing up Python `hasattr` for some reason. More digging
            # into CPython's implementation of hasattr is probably in order...

            graph = torch.fx.Graph()
            # Add placeholders for fn inputs
            placeholder_nodes = []
            for name in fn_input_names:
                placeholder_nodes.append(graph.create_node(
                    'placeholder', name))

            # Get the interpreter object
            interpreter_node = graph.create_node('get_attr', 'interpreter')

            # Add a node to call the interpreter instance
            output_node = graph.create_node(op='call_method',
                                            target='__call__',
                                            args=(interpreter_node,
                                                  placeholder_nodes))

            # Register output
            graph.output(output_node)

            graph.lint(wrapper)

            # Return final GraphModule!!!
            return GraphModule(wrapper, graph)
Exemplo n.º 5
0
 def test_nonetype_annotation(self):
     eb = torch.nn.EmbeddingBag(3, 4)
     symbolic_trace(eb)
Exemplo n.º 6
0
    def test_fn_type_annotation_empty(self):
        def forward(a: List[torch.Tensor]):
            return a[0]

        torch.jit.script(symbolic_trace(forward))
Exemplo n.º 7
0
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inputs)

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_shapes(node):
        return [te.ExprHandle.int(i) for i in node.shape]

    def get_nnc_type(dtype):
        if dtype == torch.float:
            return te.Dtype.Float
        elif dtype == torch.long:
            return te.Dtype.Long
        else:
            raise RuntimeError("nyi")

    def get_te_type(node):
        return get_nnc_type(node.dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        return fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            shapes = get_te_shapes(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
            inputs.append(env[node.name])
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            result = lower_function(node, node.target, lookup_env(node.args),
                                    node.args)
            env[node.name] = result
        elif node.op == 'output':
            outs = list(lookup_env(node.args))
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
        else:
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(outs)
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs
    ])

    def f(inps):
        module_stuff = [fetch_attr(i.target) for i in module_attrs]
        cg.call(module_stuff + list(inps))

    return f
Exemplo n.º 8
0
    def test_general_shape_ops(self):
        """ A test that checks dequantize will be swapped for
        all supported general shape ops like aten::flatten
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
                self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
                self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
                self.dropout = torch.nn.Dropout()
                self.conv1 = torch.nn.Conv2d(3, 3, 3)
                self.conv2 = torch.nn.Conv2d(3, 3, 3)
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                x = self.conv1(x)
                # add_scalar
                x = x + 3
                # mul_scalar
                x = x * 3
                # add_scalar_out
                x += 3
                # mul_scalar_out
                x *= 3
                # add_scalar_relu
                x = x + 3
                x = F.relu(x)
                # add_scalar_relu_out
                x += 3
                x = F.relu(x)
                # mul_scalar_relu
                x = x * 3
                x = F.relu(x)
                # mul_scalar_relu_out
                x *= 3
                x = F.relu(x)
                x = self.maxpool1d(x)
                x = self.maxpool2d(x)
                x = self.maxpool3d(x)
                x = torch.flatten(x)
                x = torch.max(x)
                x = torch.min(x)
                x = x.reshape([-1])
                x = x.resize_(1, 1, x.numel())
                x = x.view(-1)
                # prim::ListConstruct
                xs = [x, x]
                # prim::ListUnpack
                x, y = xs
                # prim::TupleConstruct
                xs = (x, x)
                # prim::TupleUnpack
                x, y = xs
                x = x.transpose(1, 2)
                x = x.contiguous()
                x, y = torch.chunk(x, 2)
                x = F.dropout(x)
                x = self.dropout(x)
                x, _ = torch.sort(x)
                x = x.permute(0, 2, 3, 1)
                x = x.repeat_interleave(3, 1)
                x = torch.repeat_interleave(x, 3, 1)
                x = self.relu(x)
                x = F.relu(x)
                x = F.relu(x, inplace=True)
                x = x.relu()
                x.relu_()
                x = x.squeeze(0)
                x.squeeze_(0)
                x = torch.squeeze(x, 0)
                x = x.unsqueeze(0)
                x.unsqueeze_(0)
                x = torch.unsqueeze(x, 0)
                x = x.detach()
                x.detach_()
                x = x.repeat(4, 2)
                y = []
                y.append(x)
                z = torch.stack(y, 0)
                z = [z, z]
                x, _ = z
                x = self.conv2(x)
                return x

        data = torch.rand(1, 3, 10, 10)
        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        quantizer = Quantizer()
        qconfig_dict = {'': default_qconfig}
        prepared = quantizer.prepare(original, qconfig_dict)
        # not runnable
        quantized = quantizer.convert(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers and also successfully fused two quantized::conv2d
        # patterns
        # one quantize_per_tensor for input
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
Exemplo n.º 9
0
    def test_general_value_ops(self):
        """ A test that checks correct patterns are produced for
        all supported general value ops like aten::avg_pool2d \
        without actually checking for execution of these ops
        """
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)
                self.avg_pool1d = torch.nn.AvgPool1d(3)
                self.avg_pool2d = torch.nn.AvgPool2d(3)
                self.avg_pool3d = torch.nn.AvgPool3d(3)
                self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
                self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
                self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d(
                    (1, 1, 1))
                self.leaky_relu = torch.nn.LeakyReLU()
                self.hardsigmoid = torch.nn.Hardsigmoid()
                self.sigmoid = torch.nn.Sigmoid()
                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.conv(x)
                x = self.avg_pool1d(x)
                x = self.avg_pool2d(x)
                x = self.avg_pool3d(x)
                x = self.adaptive_avg_pool1d(x)
                x = self.adaptive_avg_pool2d(x)
                x = self.adaptive_avg_pool3d(x)
                x = F.avg_pool1d(x, 3)
                x = F.avg_pool2d(x, 3)
                x = F.avg_pool3d(x, 3)
                x = F.adaptive_avg_pool1d(x, (1))
                x = F.adaptive_avg_pool2d(x, (1, 1))
                x = F.adaptive_avg_pool3d(x, (1, 1, 1))
                x = torch.mean(x)
                x = torch.mean(x, [2, 3], False)
                x = x.mean()
                x = x.mean([2, 3], True)
                x = F.interpolate(x, 4, mode='nearest')
                x = F.interpolate(x, 4, mode='linear')
                x = self.leaky_relu(x)
                x = F.leaky_relu(x)
                x = F.leaky_relu(x, inplace=True)
                x = x.leaky_relu()
                x.leaky_relu_()
                x = self.hardsigmoid(x)
                x = F.hardsigmoid(x)
                x = F.hardsigmoid(x, inplace=True)
                x = x.hardsigmoid()
                x.hardsigmoid_()
                x = self.sigmoid(x)
                x = torch.sigmoid(x)
                # F.sigmoid is deprecated
                x = x.sigmoid()
                x.sigmoid_()
                x = self.tanh(x)
                # F.tanh is deprecated
                x = torch.tanh(x)
                x = x.tanh()
                x.tanh_()
                x = self.conv(x)
                return x

        # This model is not executable since we just put all ops
        # in the same forward
        m = M()
        original = symbolic_trace(m)
        # nothing to fuse so skipping the fuse step
        quantizer = Quantizer()
        qconfig_dict = {'': default_qconfig}
        prepared = quantizer.prepare(original, qconfig_dict)
        # not runnable
        quantized = quantizer.convert(prepared)

        # This checks that the dequantize from the output of first conv
        # is being propagated to the end, so that we don't insert extra
        # observers
        # check exact counts of quantize and dequantize
        count_check = {
            ns.call_function(torch.quantize_per_tensor): 1,
            ns.call_method('dequantize'): 1
        }
        order_check = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize'),
        ]
        self.checkGraphModuleNodes(quantized,
                                   expected_node_occurrence=count_check,
                                   expected_node_list=order_check)
Exemplo n.º 10
0
def replace_pattern(gm: GraphModule, pattern: Callable,
                    replacement: Callable) -> None:
    """
    Matches all possible non-overlapping sets of operators and their
    data dependencies (``pattern``) in the Graph of a GraphModule
    (``gm``), then replaces each of these matched subgraphs with another
    subgraph (``replacement``).

    Args:
        ``gm``: The GraphModule that wraps the Graph to operate on
        ``pattern``: The subgraph to match in ``gm`` for replacement
        ``replacement``: The subgraph to replace ``pattern`` with

    Examples:

    .. code-block:: python

        import torch
        from torch.fx import symbolic_trace, subgraph_rewriter

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, w1, w2):
                m1 = torch.cat([w1, w2]).sum()
                m2 = torch.cat([w1, w2]).sum()
                return x + torch.max(m1) + torch.max(m2)

        def pattern(w1, w2):
            return torch.cat([w1, w2]).sum()

        def replacement(w1, w2):
            return torch.stack([w1, w2])

        traced_module = symbolic_trace(M())

        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

    The above code will first match ``pattern`` in the ``forward``
    method of ``traced_module``. Pattern-matching is done based on
    use-def relationships, not node names. For example, if you had
    ``p = torch.cat([a, b])`` in ``pattern``, you could match
    ``m = torch.cat([a, b])`` in the original ``forward`` function,
    despite the variable names being different (``p`` vs ``m``).

    The ``return`` statement in ``pattern`` is matched based on its
    value only; it may or may not match to the ``return`` statement in
    the larger graph. In other words, the pattern doesn't have to extend
    to the end of the larger graph.

    When the pattern is matched, it will be removed from the larger
    function and replaced by ``replacement``. If there are multiple
    matches for ``pattern`` in the larger function, each non-overlapping
    match will be replaced. In the case of a match overlap, the first
    found match in the set of overlapping matches will be replaced.
    ("First" here being defined as the first in a topological ordering
    of the Nodes' use-def relationships. In most cases, the first Node
    is the parameter that appears directly after ``self``, while the
    last Node is whatever the function returns.)

    One important thing to note is that the parameters of the
    ``pattern`` Callable must be used in the Callable itself,
    and the parameters of the ``replacement`` Callable must match
    the pattern. The first rule is why, in the above code block, the
    ``forward`` function has parameters ``x, w1, w2``, but the
    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
    As an example of the second rule, consider replacing

    .. code-block:: python

        def pattern(x, y):
            return torch.neg(x) + torch.relu(y)

    with

    .. code-block:: python

        def replacement(x, y):
            return torch.relu(x)

    In this case, ``replacement`` needs the same number of parameters
    as ``pattern`` (both ``x`` and ``y``), even though the parameter
    ``y`` isn't used in ``replacement``.

    After calling ``subgraph_rewriter.replace_pattern``, the generated
    Python code looks like this:

    .. code-block:: python

        def forward(self, x, w1, w2):
            stack_1 = torch.stack([w1, w2])
            sum_1 = stack_1.sum()
            stack_2 = torch.stack([w1, w2])
            sum_2 = stack_2.sum()
            max_1 = torch.max(sum_1)
            add_1 = x + max_1
            max_2 = torch.max(sum_2)
            add_2 = add_1 + max_2
            return add_2

    """
    # Get the graphs for `gm`, `pattern`, `replacement`
    original_graph = gm.graph
    pattern_graph = symbolic_trace(pattern).graph
    replacement_graph = symbolic_trace(replacement).graph

    # Find all possible pattern matches in original_graph. Note that
    # pattern matches may overlap with each other.
    matcher = SubgraphMatcher(pattern_graph)
    matches: List[Match] = []

    # Consider each node as an "anchor" (deepest matching graph node)
    for anchor in original_graph.nodes:

        if matcher.matches_subgraph_from_anchor(anchor):

            def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
                # `lookup` represents all the nodes in `original_graph`
                # that are part of `pattern`
                lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
                for n in lookup.keys():
                    if n.op == "placeholder" or lookup[n].op == "output":
                        continue
                    for user in n.users:
                        # If this node has users that were not in
                        # `lookup`, then it must leak out of the
                        # pattern subgraph
                        if user not in lookup:
                            return False
                return True

            # It's not a match if the pattern leaks out into the rest
            # of the graph
            if pattern_is_contained(matcher.nodes_map):
                for k, v in matcher.nodes_map.items():
                    # Shallow copy nodes_map
                    matches.append(
                        Match(anchor=anchor,
                              nodes_map=copy.copy(matcher.nodes_map)))

    # The set of all nodes in `original_graph` that we've seen thus far
    # as part of a pattern match
    replaced_nodes: Set[Node] = set()

    # Return TRUE if one of the nodes in the current match has already
    # been used as part of another match
    def overlaps_with_prev_match(match: Match) -> bool:
        for n in match.nodes_map.values():
            if n in replaced_nodes and n.op != "placeholder":
                return True
        return False

    for match in matches:

        # Skip overlapping matches
        if overlaps_with_prev_match(match):
            continue

        # Map replacement graph nodes to their copy in `original_graph`
        val_map: Dict[Node, Node] = {}

        pattern_placeholders = [
            n for n in pattern_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders)
        replacement_placeholders = [
            n for n in replacement_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) == len(replacement_placeholders)
        placeholder_map = {
            r: p
            for r, p in zip(replacement_placeholders, pattern_placeholders)
        }

        # node from `original_graph` that matched with the output node
        # in `pattern`
        subgraph_output: Node = match.anchor

        def mark_node_as_replaced(n: Node) -> None:
            if n not in match.nodes_map.values():
                return
            for n_ in n.all_input_nodes:
                mark_node_as_replaced(n_)
            replaced_nodes.add(n)

        mark_node_as_replaced(subgraph_output)

        # Intialize `val_map` with mappings from placeholder nodes in
        # `replacement` to their corresponding node in `original_graph`
        for replacement_node in replacement_placeholders:
            # Get the `original_graph` placeholder node
            # corresponding to the current `replacement_node`
            pattern_node = placeholder_map[replacement_node]
            original_graph_node = match.nodes_map[pattern_node]
            # Populate `val_map`
            val_map[replacement_node] = original_graph_node

        # Copy the replacement graph over
        with original_graph.inserting_before(subgraph_output):
            copied_output = original_graph.graph_copy(replacement_graph,
                                                      val_map)
        assert isinstance(copied_output, Node)

        # We only want to copy in the output node from `pattern` if we
        # have an output-output match. Otherwise, we leave out the
        # `pattern` output node so we don't have two outputs in the
        # resultant graph
        if subgraph_output.op != "output":
            subgraph_output = subgraph_output.args[0]  # type: ignore
        subgraph_output.replace_all_uses_with(copied_output)

        # Erase the `pattern` nodes
        for node in reversed(original_graph.nodes):
            if len(node.users) == 0 and node.op != "output":
                original_graph.erase_node(node)

    # Update the passed-in GraphModule to reflect the new state of
    # `original_graph`
    gm.recompile()
Exemplo n.º 11
0
    def checkGraphModeFxOp(self,
                           model,
                           inputs,
                           check_spec,
                           quant_type=QuantType.STATIC,
                           debug=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            check_spec:
              either:
                quntized_node: a tuple of 2 elements, first element is
                   the op type for GraphModule node, second element
                   is the target function for call_function and
                   type of the module for call_module
                   e.g. ('call_function', torch.ops.quantized.conv2d)
                node map: a dict from node(tuple of 2 elements) to
                   number of occurences (int)
                   e.g. {('call_function', torch.quantize_per_tensor) : 1)}
                ordered node list: a list of node(tuple of 2 elements)
                   e.g. [('call_function', torch.quantize_per_tensor),
                         ('call_function', torch.dequantize)]
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse(original)

        quantizer = Quantizer()
        # TODO: uncommon after we make per channel observer work in the flow
        # qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
        qconfig_dict = {'': default_qconfig}
        if quant_type == QuantType.DYNAMIC:
            prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
        else:
            prepared = quantizer.prepare(fused, qconfig_dict)

        prepared(*inputs)
        qgraph = quantizer.convert(prepared)
        qgraph_debug = quantizer.convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if debug:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        self.checkGraphModule(qgraph, check_spec)
Exemplo n.º 12
0
# Non-torch annotation with no internal forward references
class M3(torch.nn.Module):
    def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
        return a(x[0])

# Non-torch annotation with internal forward references
class M4(torch.nn.Module):
    def forward(self, x: typing.List['torch.Tensor'], a: A) -> 'torch.Tensor':
        return a(x[0])

x = torch.rand(2, 3)

ref = torch.add(x, x)

traced1 = symbolic_trace(M1())
res1 = traced1(x, A())
assert torch.all(torch.eq(ref, res1))

traced2 = symbolic_trace(M2())
res2 = traced2(x, A())
assert torch.all(torch.eq(ref, res2))

traced3 = symbolic_trace(M3())
res3 = traced3([x], A())
assert torch.all(torch.eq(ref, res3))

traced4 = symbolic_trace(M4())
res4 = traced4([x], A())
assert torch.all(torch.eq(ref, res4))
Exemplo n.º 13
0
    def test_functional(self):
        """ Test quantizing functional conv and linear
        """
        class Conv(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight):
                return F.conv2d(x, weight, None, self.stride, self.padding,
                                self.dilation, self.groups)

        conv_input = torch.rand(1, 3, 224, 224)
        conv_weight = torch.rand(3, 3, 3, 3)

        class Linear(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, weight):
                return F.linear(x, weight)

        linear_input = torch.rand(8, 5)
        linear_weight = torch.rand(10, 5)

        class LinearModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(5, 10)

            def forward(self, x):
                return self.linear(x)

        linear_module_input = torch.rand(8, 5)

        tests = [
            (False, Conv, (conv_input, conv_weight),
             ('call_function', torch.ops.quantized.conv2d)),
            (True, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear_dynamic)),
            (False, Linear, (linear_input, linear_weight),
             ('call_function', torch.ops.quantized.linear)),
            (True, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.dynamic.Linear)),
            (False, LinearModule, (linear_module_input, ),
             ('call_module', torch.nn.quantized.Linear)),
        ]

        for is_dynamic, M, inputs, quantized_node in tests:
            m = M().eval()
            qconfig = default_qconfig

            graph = symbolic_trace(m)
            script = torch.jit.script(graph)

            a = m(*inputs)
            b = graph(*inputs)
            c = script(*inputs)
            assert (a - b).abs().max() == 0
            assert (a - c).abs().max() == 0
            assert torch.allclose(a, b)
            assert torch.allclose(a, c)

            graph = fuse(graph)

            quantizer = Quantizer()
            qconfig_dict = {'': qconfig}
            if is_dynamic:
                prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
            else:
                prepared = quantizer.prepare(graph, qconfig_dict)

            prepared(*inputs)

            qgraph = quantizer.convert(prepared)
            qgraph_debug = quantizer.convert(prepared, debug=True)
            qgraph.eval()
            qgraph_debug.eval()
            qgraph_script = torch.jit.script(qgraph)

            d = qgraph(*inputs)
            d_debug = qgraph_debug(*inputs)
            e = qgraph_script(*inputs)
            e_debug = qgraph_debug(*inputs)

            found = False
            modules = dict(qgraph.root.named_modules())
            for node in qgraph.graph.nodes:
                if node.op == 'call_function':
                    found = found or node.op == quantized_node[
                        0] and node.target == quantized_node[1]
                elif node.op == 'call_module':
                    found = found or node.op == quantized_node[0] and type(
                        modules[node.target]) == quantized_node[1]
            assert found, 'Expected to find quantized node:' + str(
                quantized_op)
            # assert (a-d).abs().max() < 2
            assert torch.allclose(d, e)
            assert (d - d_debug).abs().max() == 0
            assert (e - e_debug).abs().max() == 0
Exemplo n.º 14
0
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inputs)

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_type(node):
        return get_nnc_type(node.meta['tensor_meta'].dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        res = fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
        return res

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    compute_stmts = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            shapes = get_te_shapes(node.meta['tensor_meta'].shape)
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
            inputs.append(placeholder)
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            if 'tensor_meta' in node.meta:
                # todo: fix kwargs handling
                if node.kwargs:
                    raise RuntimeError("kwargs nyi")
                buf, stmt = lower_function(node, node.target,
                                           lookup_env(node.args), node.args)
                # if isinstance(stmt, list)
                compute_stmts.extend(stmt)
                env[node.name] = buf
            elif node.target == getattr or node.target == operator.getitem:
                # todo: handle non-tensor computations correctly
                continue
        elif node.op == 'output':
            args = node.args
            if not isinstance(args, tuple):
                args = (args, )
            if isinstance(args[0], tuple):
                args = args[0]
            te_args = lookup_env(args)
            outs = (list(te_args), [(i.meta['tensor_meta'].shape,
                                     i.meta['tensor_meta'].dtype)
                                    for i in args])
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            shapes = get_te_shapes(
                process_shape(node.meta['tensor_meta'].shape))
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
        else:
            print(node.op, node.target)
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0])
    # loopnest.inline_intermediate_bufs(True)
    loopnest.simplify()
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs[0]
    ])
    alloc_results = [
        torch.empty(shape, dtype=dtype) for shape, dtype in outs[1]
    ]
    if module_attrs:
        module_stuff = [
            fetch_attr(i.target).contiguous().data for i in module_attrs
        ]
    else:
        module_stuff = []

    def f(*inps, out_tensors=None):
        # begin = time.time()
        if out_tensors is None:
            results = alloc_results
        else:
            results = out_tensors
        cg.call(module_stuff + list(inps) + results)
        if out_tensors is None:
            if len(results) == 1:
                return results[0]
            return results

    return f
Exemplo n.º 15
0
    def test_subgraph_rewriter_placeholder_matching(self):
        """
        This tests that a placeholder Node can be matched to a Node with
        a different number of input Nodes. In the example below, the
        original traced Module looks like this:
            opcode         target                                                      args                      kwargs
            -------------  ----------------------------------------------------------  ------------------------  --------
            placeholder    x                                                           ()                        {}
            call_function  <built-in function add>                                     (x, 3)                    {}
            call_method    dequantize                                                  (add,)                    {}
            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
            call_method    to                                                          (sigmoid, torch.float16)  {}
            output         output                                                      (to,)                     {}
        while the pattern we want to match looks like this:
            opcode         target                                                      args                      kwargs
            -------------  ----------------------------------------------------------  ------------------------  --------
            placeholder    x                                                           ()                        {}
            call_method    dequantize                                                  (x,)                      {}
            call_function  <built-in method sigmoid of type object at 0x7f7c1f440fe0>  (dequantize,)             {}
            call_method    to                                                          (sigmoid, torch.float16)  {}
            output         output                                                      (to,)                     {}
        Here, we want to be able to match the original graph's
        `call_function.add` Node with the pattern graph's
        `plaeholder.x` Node.
        Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
        """
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.dtype = torch.float16

            def forward(self, x):
                x += 3
                x = x.dequantize()
                x = torch.sigmoid(x)
                dtype = self.dtype
                x = x.to(dtype)
                return x

        def pattern(x):
            x = x.dequantize()
            x = torch.sigmoid(x)
            x = x.to(torch.float16)
            return x

        def replacement(x):
            return x

        def comparison(x):
            return x + 3

        traced = symbolic_trace(M())
        comparison_fn = symbolic_trace(comparison)

        x = torch.randn(3, 4)

        subgraph_rewriter.replace_pattern(traced, pattern, replacement)

        traced.graph.lint()

        ref_outs = comparison_fn(x)
        test_outs = traced.forward(x)
        self.assertEqual(ref_outs, test_outs)
Exemplo n.º 16
0
    def _test_model_impl(self,
                         mode,
                         name,
                         model,
                         eager_quantizable_model,
                         check_with_eager=True,
                         diff_of_quant=None,
                         diff_from_eager=None):
        if diff_of_quant is None or diff_from_eager is None:
            diff_of_quant = {}
            diff_from_eager = {}

        if mode not in diff_of_quant or mode not in diff_from_eager:
            diff_of_quant[mode] = {}
            diff_from_eager[mode] = {}

        input_tensor = torch.rand(1, 3, 224, 224)
        input_tensor_inception = torch.rand(1, 3, 299, 299)
        output_value = torch.randint(0, 1, (1, ))

        # print('quantizing:', name, ' mode:', mode)
        if name == 'inception_v3':
            input_value = input_tensor_inception
        else:
            input_value = input_tensor

        qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
        qconfig_dict = {'': qconfig}
        graph_module = symbolic_trace(model)
        # print('graph module:', graph_module.src)
        script = torch.jit.script(graph_module)

        # make sure graph module and script module are both runanble
        original_out = graph_module(input_value)
        is_not_tuple_out = not isinstance(original_out, tuple)
        script_out = script(input_value)
        self.assertEqual(
            (original_out - script_out).abs().max(), 0,
            'Reslut of original graph module and script module does not match')

        # set to train just before quantization
        if mode != 'static':
            model.train()

        graph_module = fuse(graph_module)
        quantizer = Quantizer()
        prepared = quantizer.prepare(graph_module, qconfig_dict)

        if mode == 'ddp':
            mp.spawn(run_ddp,
                     args=(world_size, prepared),
                     nprocs=world_size,
                     join=True)
        elif mode == 'qat':
            assert prepared.training, 'prepared must be in training mode for qat'
            optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
            criterion = nn.CrossEntropyLoss()
            train_one_epoch(prepared, criterion,
                            optimizer, [(input_value, output_value)],
                            torch.device('cpu'), 1)
        else:
            for i in range(10):
                prepared(input_value)

        # print('after observation root:', prepared.root)

        qgraph = quantizer.convert(prepared)
        # print('after quantization root:', qgraph.root)
        # print('after quantization code:', qgraph.src)
        qgraph.eval()
        qgraph_script = torch.jit.script(qgraph)
        # print('quantized and scripted:', qgraph_script.graph)

        qgraph_out = qgraph(input_value)
        qgraph_script = qgraph_script(input_value)

        if is_not_tuple_out:
            diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
            assert torch.allclose(qgraph_out,
                                  qgraph_script), 'graph, scripted graph'
        else:
            print('tuple output')

        if eager_quantizable_model is not None:
            # comparing to eager mode quantization
            qeager = eager_quantizable_model
            ref_out = qeager(input_value)
            qeager.qconfig = qconfig
            if mode == 'static':
                qeager.fuse_model()
                prepare(qeager, inplace=True)
            else:
                qeager.train()
                qeager.fuse_model()
                prepare_qat(qeager, inplace=True)

            # calibration
            if mode == 'ddp':
                mp.spawn(run_ddp,
                         args=(world_size, qeager),
                         nprocs=world_size,
                         join=True)
            elif mode == 'qat':
                assert qeager.training, 'qeager should be in training mode for qat'
                optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
                train_one_epoch(qeager, criterion, optimizer,
                                [(input_value, output_value)],
                                torch.device('cpu'), 1)
            else:
                for i in range(10):
                    qeager(input_value)

            # print('ref after observation:', qeager)

            convert(qeager, inplace=True)
            qeager.eval()

            # print('ref after quantization:', qeager)
            qeager_out = qeager(input_value)
            qeager_script = torch.jit.script(qeager)
            qscript_out = qeager_script(input_value)
            if is_not_tuple_out:
                diff_from_eager[mode][name] = (qeager_out -
                                               qgraph_out).abs().max()
                if check_with_eager:
                    self.assertEqual(
                        diff_from_eager[mode][name], 0,
                        'Result of graph mode quantization and ' +
                        'eager mode quantization on model: ' + name +
                        ' should match. Mode: ' + mode + ' diff:' +
                        str(diff_from_eager[mode][name]))
Exemplo n.º 17
0
    return (q, p)


q_i = torch.tensor(onp.array(q_i['z']))
p_i = torch.tensor(onp.array(p_i(seed)['z']))
inv_mass_matrix = torch.eye(D)
# inv_mass_matrix = torch.ones(D)
step_size = 0.001
num_steps = 10000

import torch.fx as fx
# is_u_turning = nnc_compile(wrap_key(is_u_turning, (q_i,p_i,q_i)), (q_i, p_i, q_i))
# is_u_turning = nnc_compile(is_u_turning, example_inputs=(q_i, p_i, q_i))

inps = (q_i, p_i, potential_fn, inv_mass_matrix, step_size)
leapfrog = fx.symbolic_trace(wrap_key(leapfrog, inps))
leapfrog = remove_args(leapfrog)
# leapfrog = truncate(leapfrog, 7)
# leapfrog = fx.symbolic_trace(leapfrog, concrete_args={'potential_fn': potential_fn, 'step_size': step_size, 'inverse_mass_matrix': inv_mass_matrix})

# if VERSION == 'nnc':
leapfrog = nnc_compile(leapfrog, example_inputs=(q_i, p_i, inv_mass_matrix))
# elif VERSION == 'ts':
#     leapfrog = torch.jit.script(leapfrog)
# elif VERSION == 'pt':
#     pass

out = get_final_state(potential_fn, inv_mass_matrix, step_size, q_i, p_i)
begin = time.time()
out = get_final_state(potential_fn, inv_mass_matrix, step_size, q_i, p_i)
print(out[0][0])
Exemplo n.º 18
0
    for node in list(model.graph.nodes):
        new_node = new_graph.node_copy(node, lambda x: env[x.name])
        env[node.name] = new_node
        cnt += 1
        if cnt == k:
            new_graph.output(env[node.name])
            break

    return fx.GraphModule(model, new_graph)


torch.manual_seed(0)
inputs = (torch.randn(1, 3, 256, 256),)
model = mobilenet_v2()
model.eval()
fx_model = fx.symbolic_trace(model)
fx_model = fuse(fx_model)
fx_model = decompose(fx_model, example_inputs=inputs)
# f = torch.jit.freeze(model)
# torch._C._fancy_compile(f.graph, list(inputs[0].shape))
# f(*inputs)
# fx_model = truncate(fx_model, 4)
print(fx_model)
# print(fx_model.code)
nnc_model = nnc_compile(fx_model, example_inputs=inputs)
import time
iters = 1
with torch.no_grad():
    nnc_model(*inputs)
    begin = time.time()
    # while True:

# Sample module
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x) + 1.0


# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
for node in m.graph.nodes:
    # Find `call_module` Node in `m` that corresponds to `self.relu`.
    # This is the Node we want to swap out for an inlined version of the
    # same call
    if (node.op, node.target) == ("call_module", "relu"):
        with m.graph.inserting_before(node):
            # Create a Proxy from each Node in the current Node's
            # args/kwargs
            proxy_args = map_arg(node.args, Proxy)
            proxy_kwargs = map_arg(node.kwargs, Proxy)
            # Call `m.relu` with the newly-created Proxy arguments.
            # `m.relu` is the generic version of the function; by
Exemplo n.º 20
0
    def checkGraphModeFxOp(self,
                           model,
                           inputs,
                           quant_type,
                           expected_node=None,
                           expected_node_occurrence=None,
                           expected_node_list=None,
                           debug=False,
                           print_debug_info=False):
        """ Quantizes model with graph mode quantization on fx and check if the
        quantized model contains the quantized_node

        Args:
            model: floating point torch.nn.Module
            inputs: one positional sample input arguments for model
            expected_node: NodeSpec
                  e.g. NodeSpec.call_function(torch.quantize_per_tensor)
            expected_node_occurrence: a dict from NodeSpec to
                  expected number of occurences (int)
                  e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1,
                        NodeSpec.call_method('dequantize'): 1}
            expected_node_list: a list of NodeSpec, used to check the order
                  of the occurrence of Node
                  e.g. [NodeSpec.call_function(torch.quantize_per_tensor),
                        NodeSpec.call_module(nnq.Conv2d),
                        NodeSpec.call_function(F.hardtanh_),
                        NodeSpec.call_method('dequantize')]
        """
        # TODO: make img_data a single example instead of a list
        if type(inputs) == list:
            inputs = inputs[0]
        if quant_type == QuantType.QAT:
            model.train()
        else:
            model.eval()
        original = symbolic_trace(model)
        fused = fuse_fx(original)

        qconfig_dict = {
            '': get_default_qconfig(torch.backends.quantized.engine)
        }
        if quant_type == QuantType.DYNAMIC:
            prepare = prepare_dynamic_fx
            convert = convert_dynamic_fx
        else:
            prepare = prepare_fx
            convert = convert_fx

        prepared = prepare(fused, qconfig_dict)
        prepared(*inputs)
        qgraph = convert(prepared)
        qgraph_debug = convert(prepared, debug=True)

        result = qgraph(*inputs)
        result_debug = qgraph_debug(*inputs)

        self.assertEqual((result - result_debug).abs().max(), 0), \
            'Expecting debug and non-debug option to produce identical result'

        if print_debug_info:
            print()
            print('quant type:', quant_type)
            print('origianl graph module:', type(model))
            self.printGraphModule(original)
            print()
            print('quantized graph module:', type(qgraph))
            self.printGraphModule(qgraph)
            print()
        qgraph_to_check = qgraph_debug if debug else qgraph
        self.checkGraphModuleNodes(qgraph_to_check, expected_node,
                                   expected_node_occurrence,
                                   expected_node_list)
Exemplo n.º 21
0
 def test_update_args_kwargs_yells_at_you(self):
     symtraced = symbolic_trace(SimpleTest())
     node = next(iter(symtraced.graph.nodes))
     with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
         node.__update_args_kwargs((), {})
Exemplo n.º 22
0
    def test_namedtuple_return_trace(self):
        class NamedTupReturn(torch.nn.Module):
            def forward(self, x):
                return Pair(x, x)

        traced = symbolic_trace(NamedTupReturn())
Exemplo n.º 23
0
    def test_torch_fx_len(self):
        class FXLenTest(torch.nn.Module):
            def forward(self, x):
                return torch.fx.len(x)

        traced = symbolic_trace(FXLenTest())
Exemplo n.º 24
0
    def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant,
                 is_standalone_module):
        """ standalone_module means it a submodule that is not inlined in parent module,
        and will be quantized separately as one unit.

        When we are preparing a standalone module:
        input of the module is observed in parent module, output of the module
        is observed in the standalone module.
        Returns:
            model(GraphModule): prepared standalone module with following attributes:
                _standalone_module_observed_input_idxs(List[Int]): a list of indexs for the graph inputs that
                                         needs to be observed in parent module
                _output_is_observed(Bool): a boolean variable indicate whether the output of the
                                   custom module is observed or not
        """
        if not inplace:
            model = copy.deepcopy(model)
        self.is_dynamic_quant = is_dynamic_quant
        if self.is_dynamic_quant:
            self.patterns = get_dynamic_quant_patterns()
        else:
            self.patterns = get_quant_patterns()

        flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
        # TODO: support regex as well
        propagate_qconfig_(model, flattened_qconfig_dict)
        if model.training:
            self._qat_swap_modules(model)

        self.modules = dict(model.named_modules())

        convert_dict_to_ordered_dict(qconfig_dict)
        # map from node name to qconfig, used in _find_matches
        self._generate_qconfig_map(model, model.graph, qconfig_dict)

        # match the patterns that will get quantized
        standalone_module_names = qconfig_dict.get('standalone_module_name',
                                                   None)
        matches = self._find_matches(model.graph, self.modules, self.patterns,
                                     standalone_module_names)

        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an DefaultQuant object for each
        quants = self._find_quants(model.graph, matches)

        self.activation_post_process_map = dict()
        env = {}
        observed_graph = Graph()
        observed_node_names_set = set()

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        # indexes for the inputs that needs to be observed
        standalone_module_observed_input_idxs = []
        graph_inputs = []
        for node in model.graph.nodes:
            if node.op == 'placeholder':
                graph_inputs.append(node.name)

        get_new_observer_name = get_new_attr_name_with_prefix(
            'activation_post_process_')

        for node in model.graph.nodes:
            if node.name in observed_node_names_set:
                continue

            prefix = node.name + '_activation_post_process_'
            root_node, _, obj, qconfig = matches.get(node.name,
                                                     (None, None, None, None))
            if root_node is None:
                env[node.name] = observed_graph.node_copy(node, load_arg)
            elif root_node is node:
                env[node.name] = observed_graph.node_copy(node, load_arg)
                if qconfig is None:
                    continue

                def insert_observer(node, observer, device):
                    get_new_observer_name = get_new_attr_name_with_prefix(
                        prefix)
                    observer_name = get_new_observer_name(model)
                    setattr(model, observer_name, observer)
                    self.activation_post_process_map[node.name] = observer
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)
                    if device:
                        getattr(model, observer_name).to(device)

                if isinstance(obj, CustomModuleQuantizeHandler):
                    custom_module = self.modules[node.target]
                    observed_custom_module_class = \
                        get_observed_custom_module_class(type(custom_module))
                    observed_custom_module = \
                        observed_custom_module_class.from_float(custom_module)
                    mark_observed_custom_module(observed_custom_module,
                                                type(custom_module))
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_custom_module)

                # index for input of custom module that needs to be observed in parent
                standalone_module_input_idxs = None
                if isinstance(obj, StandaloneModuleQuantizeHandler):
                    # observe standalone module
                    standalone_module = self.modules[node.target]
                    traced_standalone_module = symbolic_trace(
                        standalone_module)
                    if self.is_dynamic_quant:
                        prepare = torch.quantization.quantize_fx._prepare_dynamic_standalone_module_fx
                    else:
                        prepare = torch.quantization.quantize_fx._prepare_standalone_module_fx
                    observed_standalone_module = prepare(
                        traced_standalone_module, {'': qconfig})
                    observed_standalone_module.qconfig = qconfig
                    standalone_module_input_idxs = observed_standalone_module._standalone_module_observed_input_idxs
                    observed_standalone_module = mark_observed_standalone_module(
                        observed_standalone_module)
                    parent_name, name = _parent_name(node.target)
                    setattr(self.modules[parent_name], name,
                            observed_standalone_module)
                    self.modules[node.target] = observed_standalone_module

                # don't need to insert observer for output in dynamic quantization
                if self.is_dynamic_quant:
                    continue

                # inserting observers for output of observed module, or mark the output
                # as observed
                if isinstance(obj, CopyNode):
                    assert node.op in [
                        'call_module',
                        'call_function',
                        'call_method'], \
                        'CopyNode of type ' + node.op + ' is not handled'

                    def is_observed(input_arg):
                        if isinstance(input_arg, Node):
                            return input_arg.name in observed_node_names_set
                        elif isinstance(input_arg, list):
                            return all(map(is_observed, input_arg))

                    # propagate observed property from input
                    if is_observed(node.args[0]):
                        observed_node_names_set.add(node.name)
                elif (isinstance(obj, Add)
                      or isinstance(obj, Mul)) and not obj.all_nodes:
                    if node.args[0].name in observed_node_names_set:
                        observed_node_names_set.add(node.name)
                elif isinstance(obj, StandaloneModuleQuantizeHandler):
                    assert node.op == 'call_module'
                    output_is_observed = self.modules[
                        node.target]._output_is_observed
                    if output_is_observed:
                        observed_node_names_set.add(node.name)
                elif qconfig is not None and obj.all_nodes:
                    # observer for outputs
                    new_observer = qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    insert_observer(node, new_observer, device)

                # insert observer for input of standalone module
                if standalone_module_input_idxs is not None:
                    for idx in standalone_module_input_idxs:
                        if node.args[idx].name not in observed_node_names_set:
                            new_observer = qconfig.activation()
                            device = assert_and_get_unique_device(model)
                            insert_observer(node.args[idx], new_observer,
                                            device)
            else:
                env[node.name] = observed_graph.node_copy(node, load_arg)

            if node.name not in observed_node_names_set and node.name in quants:
                if is_standalone_module and node.name in graph_inputs:
                    # we'll insert observer for input of standalone module
                    # in parent graph
                    standalone_module_observed_input_idxs.append(
                        graph_inputs.index(node.name))
                    continue
                get_new_observer_name = get_new_attr_name_with_prefix(prefix)
                observer_name = get_new_observer_name(model)
                _, qconfig, is_weight = quants[node.name]
                if qconfig is not None:
                    # TODO: use insert_observer
                    new_observer = \
                        qconfig.weight() if is_weight else qconfig.activation()
                    # respect device affinity when adding observers
                    device = assert_and_get_unique_device(model)
                    if device:
                        new_observer.to(device)
                    self.activation_post_process_map[node.name] = new_observer
                    setattr(model, observer_name,
                            self.activation_post_process_map[node.name])
                    env[node.name] = observed_graph.create_node(
                        'call_module', observer_name, (load_arg(node), ), {})
                    observed_node_names_set.add(node.name)

        observed_graph.output(load_arg(model.graph.result))
        model = GraphModule(model, observed_graph)
        self.save_state(model)
        if is_standalone_module:
            assert isinstance(model.graph.result, Node), \
                'standalone module returning dict is not yet supported'
            # indicator for whether output is observed or not.
            # This used for correctly quantize standalone modules
            output_is_observed = model.graph.result.name in observed_node_names_set
            model._standalone_module_observed_input_idxs = standalone_module_observed_input_idxs
            model._output_is_observed = output_is_observed
        return model