def deepcopy_graph(gm: GraphModule) -> GraphModule:
    """
    Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
    traced with dynamic axes, and what were the values if that is the case.
    """

    # First, create a copy of the module without the graph.
    graph = gm.__dict__.pop("_graph")
    fake_mod = torch.nn.Module()
    fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
    gm.__dict__["_graph"] = graph

    # Then, copy the graph.
    val_map = {}
    graph_clone = Graph()
    output_val = graph_clone.graph_copy(graph, val_map=val_map)
    graph_clone.output(output_val)

    # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
    # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
    clone = gm.__class__(fake_mod, graph_clone)

    # Restore the dynamic axes related attributes to the clone.
    attributes = _cache_attributes(gm)
    attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
    attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
    _restore_attributes_(clone, attributes)

    return clone
예제 #2
0
    def quantize(self):
        self.quantized_graph = Graph()
        self.delegate = DelegateBase(self.quantized_graph)

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)
예제 #3
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
예제 #4
0
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_param('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)
예제 #5
0
파일: test_fx.py 프로젝트: leonvol/pytorch
    def test_remove_uses(self):
        g: torch.fx.Graph = Graph()
        x: torch.fx.Node = g.placeholder('x')
        relu: torch.fx.Node = g.call_function(torch.relu, (x, ))
        neg: torch.fx.Node = g.call_function(torch.neg, (relu, ))
        g.output(neg)

        neg.replace_all_uses_with(relu)
        g.erase_node(neg)

        self.assertTrue(neg not in relu.users)
By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.

.. code-block:: python

    def forward(self, x, y):
        cat_1 = torch.cat([x, y]);  x = y = None
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

'''

# Create a graph independently of symbolic tracing
graph = Graph()
tracer = torch.fx.proxy.GraphAppendingTracer(graph)

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes and graph's default tracer
y = Proxy(raw1, tracer)
z = Proxy(raw2, tracer)
# y = Proxy(raw1)
# z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
예제 #7
0
        def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]],
                  concrete_args: Optional[Dict[str, Any]] = None) -> Graph:

            if isinstance(root, torch.nn.Module):
                self.root = root
                fn = type(root).forward
                self.submodule_paths = {
                    mod: name
                    for name, mod in root.named_modules()
                }
            else:
                self.root = torch.nn.Module()
                fn = root

            tracer_cls: Optional[st.Type['Tracer']] = getattr(
                self, '__class__', None)
            self.graph = Graph(tracer_cls=tracer_cls)

            self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject],
                                    str] = {}

            def collect_tensor_attrs(m: torch.nn.Module,
                                     prefix_atoms: st.List[str]):
                for k, v in m.__dict__.items():
                    if isinstance(v, (torch.Tensor, st.ScriptObject)):
                        self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
                for k, v in m.named_children():
                    collect_tensor_attrs(v, prefix_atoms + [k])

            collect_tensor_attrs(self.root, [])

            assert isinstance(fn, st.FunctionType)

            fn_globals = fn.__globals__  # run before it gets patched
            fn, args = self.create_args_for_root(
                fn, isinstance(root, torch.nn.Module), concrete_args)

            parameter_proxy_cache: Dict[str, st.Proxy] = {
            }  # Reduce number of get_attr calls

            @st.functools.wraps(st._orig_module_getattr)
            def module_getattr_wrapper(mod, attr):
                attr_val = st._orig_module_getattr(mod, attr)
                return self._module_getattr(attr, attr_val,
                                            parameter_proxy_cache)

            @st.functools.wraps(st._orig_module_call)
            def module_call_wrapper(mod, *args, **kwargs):
                def forward(*args, **kwargs):
                    return st._orig_module_call(mod, *args, **kwargs)

                st._autowrap_check(
                    patcher,
                    getattr(getattr(mod, "forward", mod), "__globals__", {}),
                    self._autowrap_function_ids)
                return self.call_module(mod, forward, args, kwargs)

            with st._Patcher() as patcher:
                # allow duplicate patches to support the case of nested calls
                patcher.patch_method(torch.nn.Module, "__getattr__",
                                     module_getattr_wrapper, deduplicate=False)
                patcher.patch_method(torch.nn.Module, "__call__",
                                     module_call_wrapper, deduplicate=False)
                patcher.patch_method(Aggregation, "__call__",
                                     module_call_wrapper, deduplicate=False)
                st._patch_wrapped_functions(patcher)
                st._autowrap_check(patcher, fn_globals,
                                   self._autowrap_function_ids)
                for module in self._autowrap_search:
                    st._autowrap_check(patcher, module.__dict__,
                                       self._autowrap_function_ids)
                self.create_node(
                    'output', 'output', (self.create_arg(fn(*args)), ), {},
                    type_expr=fn.__annotations__.get('return', None))

            self.submodule_paths = None

            return self.graph
예제 #8
0
class Quantizer:
    def __init__(self,
                 mod,
                 patterns=DEFAULT_QUANTIZATION_PATTERNS,
                 quant_ctor=DefaultQuant):
        self.root = mod
        self.graph = mod.graph
        self.quant_ctor = quant_ctor

        # cached information for observe
        self.state_dict = self.root.state_dict()
        self.modules = dict(self.root.named_modules())

        # match the patterns that will get quantized
        self.matches = self._find_matches(patterns)
        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an quant_ctor object for each
        self.quants = self._find_quants(quant_ctor)

    def observe(self, args):
        # most of this function is just an interpreter for the graph
        # it would be possible to put this in some abstraction, but
        # it is pretty nice to just be able to see exactly what is happening here
        # and hack on it.
        # maybe we should just provide an example interpreter that people copy/paste
        # then edit.
        args_iter = iter(args)
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = self.state_dict[node.target]
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args),
                                     **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args),
                                                   **load_arg(node.kwargs))

            env[node.name] = result
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is node:
                obj.observe(node, env)
            if node.name in self.quants:
                self.quants[node.name].observe(node, env)

        return load_arg(self.graph.result)

    def quantize(self):
        self.quantized_graph = Graph()

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)

    def _find_matches(self, patterns):
        modules = dict(self.root.named_modules())
        match_map = {}  # node name -> (root_node, match_value?)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
            else:
                match_map[node.name] = match

        for node in reversed(self.graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    if matches(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map

    def _find_quants(self, quant_ctor):
        quants = {}

        def visit_arg(n):
            # note: we have to measure quantization information
            # even for nodes where we might not use it because it is already
            # quantized. This is because each match has the option to
            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
            if n.name not in quants:
                quants[n.name] = quant_ctor(self, n)

        for node in self.graph.nodes:
            if node.name in self.matches:
                map_arg(node.args, visit_arg)
                map_arg(node.kwargs, visit_arg)
        return quants
예제 #9
0
    def call(self, graph_module: GraphModule) -> PassResult:
        """
        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph

        Example usage:

        from torch.fx.experimental.proxy_tensor import make_fx
        def f(a):
            b = a * a
            c = a * a
            return b+c

        p = CSEPass()
        traced_graph = make_fx(f)(torch.tensor(1))
        print(traced_graph)
        result = p(traced_graph)
        print(result.graph_module)
        """
        def get_aten_target(node):
            if hasattr(node.target, 'overloadpacket'):
                return node.target.overloadpacket
            return node.target

        modified = False
        new_graph = Graph()
        env: Dict[Node, Node] = {
        }  # map from node in the old graph to node in the new graph
        hash_env: Dict[Tuple[torch._ops.OpOverload, int],
                       Node] = {}  # map from hash to a node in the new graph
        token_map: Dict[Tuple[torch._ops.OpOverload, int],
                        Dict[str, Any]] = {}  # map from hash to token
        for n in graph_module.graph.nodes:
            # The placeholder, output, and get_attr nodes are copied to the new grpah without change
            # do not CSE away random operations
            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(
                    n) in self.banned_ops:
                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
                # substitute args and kwargs memebrs to their mapping in env if exists
                # specs can be used to reconstruct nested list/dictionaries
                def substitute(arg_list):
                    arg_list, spec = tree_flatten(arg_list)
                    for i in range(len(arg_list)):
                        v = arg_list[i]
                        if isinstance(v, Node) and v in env:
                            arg_list[i] = env[v]
                    return tuple(arg_list), spec

                args, args_spec = substitute(n.args)
                kwargs, kwargs_spec = substitute(n.kwargs)

                # each token corresponds to a unique node
                # nodes with the same token can be substituted
                token = {
                    "target": n.target,
                    "args": args,
                    "args_spec": args_spec,
                    "kwargs": kwargs,
                    "kwargs_spec": kwargs_spec
                }

                # hash substituted args to a number, do not hash specs because specs are not hashable
                hash_arg = hash((args, kwargs))
                hash_val = (n.target, hash_arg)

                # check if a node has a substitute and can be eliminated
                hash_val_in_hash_env = hash_val in hash_env
                if hash_val_in_hash_env and token_map[hash_val] == token:
                    modified = True  # substition happens and the graph is modified
                    env[n] = hash_env[hash_val]
                    continue

                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
                if not hash_val_in_hash_env:
                    hash_env[hash_val] = new_node
                    token_map[hash_val] = token

        csed_gm = GraphModule(graph_module, new_graph)
        return PassResult(csed_gm, modified)