コード例 #1
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
コード例 #2
0
def deepcopy_graph(gm: GraphModule) -> GraphModule:
    """
    Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
    traced with dynamic axes, and what were the values if that is the case.
    """

    # First, create a copy of the module without the graph.
    graph = gm.__dict__.pop("_graph")
    fake_mod = torch.nn.Module()
    fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
    gm.__dict__["_graph"] = graph

    # Then, copy the graph.
    val_map = {}
    graph_clone = Graph()
    output_val = graph_clone.graph_copy(graph, val_map=val_map)
    graph_clone.output(output_val)

    # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
    # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
    clone = gm.__class__(fake_mod, graph_clone)

    # Restore the dynamic axes related attributes to the clone.
    attributes = _cache_attributes(gm)
    attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
    attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
    _restore_attributes_(clone, attributes)

    return clone
コード例 #3
0
ファイル: test_fx.py プロジェクト: leonvol/pytorch
    def test_remove_uses(self):
        g: torch.fx.Graph = Graph()
        x: torch.fx.Node = g.placeholder('x')
        relu: torch.fx.Node = g.call_function(torch.relu, (x, ))
        neg: torch.fx.Node = g.call_function(torch.neg, (relu, ))
        g.output(neg)

        neg.replace_all_uses_with(relu)
        g.erase_node(neg)

        self.assertTrue(neg not in relu.users)
コード例 #4
0
ファイル: quantization.py プロジェクト: zxin1023/pytorch
    def quantize(self):
        self.quantized_graph = Graph()
        self.delegate = DelegateBase(self.quantized_graph)

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)
コード例 #5
0
ファイル: test_fx.py プロジェクト: qianjia1996/pytorch
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_param('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)
コード例 #6
0
By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.

.. code-block:: python

    def forward(self, x, y):
        cat_1 = torch.cat([x, y]);  x = y = None
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

'''

# Create a graph independently of symbolic tracing
graph = Graph()
tracer = torch.fx.proxy.GraphAppendingTracer(graph)

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes and graph's default tracer
y = Proxy(raw1, tracer)
z = Proxy(raw2, tracer)
# y = Proxy(raw1)
# z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
コード例 #7
0
ファイル: fx.py プロジェクト: rusty1s/pytorch_geometric
        def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]],
                  concrete_args: Optional[Dict[str, Any]] = None) -> Graph:

            if isinstance(root, torch.nn.Module):
                self.root = root
                fn = type(root).forward
                self.submodule_paths = {
                    mod: name
                    for name, mod in root.named_modules()
                }
            else:
                self.root = torch.nn.Module()
                fn = root

            tracer_cls: Optional[st.Type['Tracer']] = getattr(
                self, '__class__', None)
            self.graph = Graph(tracer_cls=tracer_cls)

            self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject],
                                    str] = {}

            def collect_tensor_attrs(m: torch.nn.Module,
                                     prefix_atoms: st.List[str]):
                for k, v in m.__dict__.items():
                    if isinstance(v, (torch.Tensor, st.ScriptObject)):
                        self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
                for k, v in m.named_children():
                    collect_tensor_attrs(v, prefix_atoms + [k])

            collect_tensor_attrs(self.root, [])

            assert isinstance(fn, st.FunctionType)

            fn_globals = fn.__globals__  # run before it gets patched
            fn, args = self.create_args_for_root(
                fn, isinstance(root, torch.nn.Module), concrete_args)

            parameter_proxy_cache: Dict[str, st.Proxy] = {
            }  # Reduce number of get_attr calls

            @st.functools.wraps(st._orig_module_getattr)
            def module_getattr_wrapper(mod, attr):
                attr_val = st._orig_module_getattr(mod, attr)
                return self._module_getattr(attr, attr_val,
                                            parameter_proxy_cache)

            @st.functools.wraps(st._orig_module_call)
            def module_call_wrapper(mod, *args, **kwargs):
                def forward(*args, **kwargs):
                    return st._orig_module_call(mod, *args, **kwargs)

                st._autowrap_check(
                    patcher,
                    getattr(getattr(mod, "forward", mod), "__globals__", {}),
                    self._autowrap_function_ids)
                return self.call_module(mod, forward, args, kwargs)

            with st._Patcher() as patcher:
                # allow duplicate patches to support the case of nested calls
                patcher.patch_method(torch.nn.Module, "__getattr__",
                                     module_getattr_wrapper, deduplicate=False)
                patcher.patch_method(torch.nn.Module, "__call__",
                                     module_call_wrapper, deduplicate=False)
                patcher.patch_method(Aggregation, "__call__",
                                     module_call_wrapper, deduplicate=False)
                st._patch_wrapped_functions(patcher)
                st._autowrap_check(patcher, fn_globals,
                                   self._autowrap_function_ids)
                for module in self._autowrap_search:
                    st._autowrap_check(patcher, module.__dict__,
                                       self._autowrap_function_ids)
                self.create_node(
                    'output', 'output', (self.create_arg(fn(*args)), ), {},
                    type_expr=fn.__annotations__.get('return', None))

            self.submodule_paths = None

            return self.graph
コード例 #8
0
    def call(self, graph_module: GraphModule) -> PassResult:
        """
        Return a new copy of torch.fx.GraphModule with CSE applied to the input graph

        Example usage:

        from torch.fx.experimental.proxy_tensor import make_fx
        def f(a):
            b = a * a
            c = a * a
            return b+c

        p = CSEPass()
        traced_graph = make_fx(f)(torch.tensor(1))
        print(traced_graph)
        result = p(traced_graph)
        print(result.graph_module)
        """
        def get_aten_target(node):
            if hasattr(node.target, 'overloadpacket'):
                return node.target.overloadpacket
            return node.target

        modified = False
        new_graph = Graph()
        env: Dict[Node, Node] = {
        }  # map from node in the old graph to node in the new graph
        hash_env: Dict[Tuple[torch._ops.OpOverload, int],
                       Node] = {}  # map from hash to a node in the new graph
        token_map: Dict[Tuple[torch._ops.OpOverload, int],
                        Dict[str, Any]] = {}  # map from hash to token
        for n in graph_module.graph.nodes:
            # The placeholder, output, and get_attr nodes are copied to the new grpah without change
            # do not CSE away random operations
            if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(
                    n) in self.banned_ops:
                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
            else:  # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
                # substitute args and kwargs memebrs to their mapping in env if exists
                # specs can be used to reconstruct nested list/dictionaries
                def substitute(arg_list):
                    arg_list, spec = tree_flatten(arg_list)
                    for i in range(len(arg_list)):
                        v = arg_list[i]
                        if isinstance(v, Node) and v in env:
                            arg_list[i] = env[v]
                    return tuple(arg_list), spec

                args, args_spec = substitute(n.args)
                kwargs, kwargs_spec = substitute(n.kwargs)

                # each token corresponds to a unique node
                # nodes with the same token can be substituted
                token = {
                    "target": n.target,
                    "args": args,
                    "args_spec": args_spec,
                    "kwargs": kwargs,
                    "kwargs_spec": kwargs_spec
                }

                # hash substituted args to a number, do not hash specs because specs are not hashable
                hash_arg = hash((args, kwargs))
                hash_val = (n.target, hash_arg)

                # check if a node has a substitute and can be eliminated
                hash_val_in_hash_env = hash_val in hash_env
                if hash_val_in_hash_env and token_map[hash_val] == token:
                    modified = True  # substition happens and the graph is modified
                    env[n] = hash_env[hash_val]
                    continue

                new_node = new_graph.node_copy(n, lambda x: env[x])
                env[n] = new_node
                if not hash_val_in_hash_env:
                    hash_env[hash_val] = new_node
                    token_map[hash_val] = token

        csed_gm = GraphModule(graph_module, new_graph)
        return PassResult(csed_gm, modified)