Beispiel #1
0
    def test_package_fx_with_imports(self):
        import package_a.subpackage

        # Manually construct a graph that invokes a leaf function
        graph = Graph()
        a = graph.placeholder("x")
        b = graph.placeholder("y")
        c = graph.call_function(package_a.subpackage.leaf_function, (a, b))
        d = graph.call_function(torch.sin, (c, ))
        graph.output(d)
        gm = GraphModule(torch.nn.Module(), graph)

        f = BytesIO()
        with PackageExporter(f) as pe:
            pe.intern("**")
            pe.save_pickle("model", "model.pkl", gm)
        f.seek(0)

        pi = PackageImporter(f)
        loaded_gm = pi.load_pickle("model", "model.pkl")
        input_x = torch.rand(2, 3)
        input_y = torch.rand(2, 3)

        self.assertTrue(
            torch.allclose(loaded_gm(input_x, input_y), gm(input_x, input_y)))

        # Check that the packaged version of the leaf_function dependency is
        # not the same as in the outer env.
        packaged_dependency = pi.import_module("package_a.subpackage")
        self.assertTrue(packaged_dependency is not package_a.subpackage)
def deepcopy_graph(gm: GraphModule) -> GraphModule:
    """
    Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
    traced with dynamic axes, and what were the values if that is the case.
    """

    # First, create a copy of the module without the graph.
    graph = gm.__dict__.pop("_graph")
    fake_mod = torch.nn.Module()
    fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
    gm.__dict__["_graph"] = graph

    # Then, copy the graph.
    val_map = {}
    graph_clone = Graph()
    output_val = graph_clone.graph_copy(graph, val_map=val_map)
    graph_clone.output(output_val)

    # Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
    # gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
    clone = gm.__class__(fake_mod, graph_clone)

    # Restore the dynamic axes related attributes to the clone.
    attributes = _cache_attributes(gm)
    attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
    attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
    _restore_attributes_(clone, attributes)

    return clone
Beispiel #3
0
 def test_graph_fns(self):
     g = Graph()
     a = g.placeholder('a')
     b = g.call_module('linear', (a, ))
     c = g.get_param('bias')
     d = g.call_method('add', (b, c))
     e = g.call_function(torch.sin, (d, ))
     g.output(e)
     mod = torch.nn.Module()
     mod.linear = torch.nn.Linear(3, 4)
     mod.bias = torch.rand(4)
     gm = GraphModule(mod, g)
     input = torch.rand(3)
     r = gm(input)
     ref = torch.sin(mod.linear(input) + mod.bias)
     self.assertEqual(r, ref)
# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes and graph's default tracer
y = Proxy(raw1, tracer)
z = Proxy(raw2, tracer)
# y = Proxy(raw1)
# z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)
# By using the graph's own appending tracer to create Proxies,
# notice we can now use n-ary operators on operations without
# multiple tracers being created at run-time (line 52) which leads
# to errors # To try this out for yourself, replace lines 42, 43
# with 44, 45
z = torch.add(b, c)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
Beispiel #5
0
class Quantizer:
    def __init__(self,
                 mod,
                 patterns=DEFAULT_QUANTIZATION_PATTERNS,
                 quant_ctor=DefaultQuant):
        self.root = mod
        self.graph = mod.graph
        self.quant_ctor = quant_ctor

        # cached information for observe
        self.state_dict = self.root.state_dict()
        self.modules = dict(self.root.named_modules())

        # match the patterns that will get quantized
        self.matches = self._find_matches(patterns)
        # find _inputs_ to matched nodes that are not quantized, these
        # have to be quantized, which requires measuring stats,
        # initialize an quant_ctor object for each
        self.quants = self._find_quants(quant_ctor)

    def observe(self, args):
        # most of this function is just an interpreter for the graph
        # it would be possible to put this in some abstraction, but
        # it is pretty nice to just be able to see exactly what is happening here
        # and hack on it.
        # maybe we should just provide an example interpreter that people copy/paste
        # then edit.
        args_iter = iter(args)
        env = {}

        def load_arg(a):
            return map_arg(a, lambda node: env[node.name])

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = self.state_dict[node.target]
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args),
                                     **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args),
                                                   **load_arg(node.kwargs))

            env[node.name] = result
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is node:
                obj.observe(node, env)
            if node.name in self.quants:
                self.quants[node.name].observe(node, env)

        return load_arg(self.graph.result)

    def quantize(self):
        self.quantized_graph = Graph()

        env = {}
        quant_env = {}

        def load_arg(n, quantized):
            if not quantized:
                if n.name not in env and n.name in quant_env:
                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
                return env[n.name]
            else:
                if n.name not in quant_env and n.name in env:
                    quant_env[n.name] = self.quants[n.name].quantize(
                        env[n.name])
                return quant_env[n.name]

        def copy_recursive(node):
            def load_or_emit(n):
                if n.name in env or e.name in quant_env:
                    return load_arg(n, quantized=False)
                else:
                    return copy_recusive(n)

            r = env[node.name] = self.quantized_graph.node_copy(
                node, lambda n: load_arg(n, quantized=False))
            return r

        for node in self.graph.nodes:
            root_node, obj = self.matches.get(node.name, (None, None))
            if root_node is None:
                # not quantized just copy it
                env[node.name] = self.quantized_graph.node_copy(
                    node, lambda n: load_arg(n, quantized=False))

            elif root_node is node:
                r = obj.quantize(
                    self, node, lambda a: map_arg(
                        a, lambda n: load_arg(n, quantized=True)))
                if r is NotImplemented:
                    # quantizer choose to to quantize the node take the entire match, and just copy it over
                    env[node.name] = copy_recursive(node)
                else:
                    quant_env[node.name] = r

        self.quantized_graph.output(
            load_arg(self.graph.result, quantized=False))
        return GraphModule(self.root, self.quantized_graph)

    def _find_matches(self, patterns):
        modules = dict(self.root.named_modules())
        match_map = {}  # node name -> (root_node, match_value?)

        def apply_match(pattern, node, match):
            if isinstance(pattern, tuple):
                s, *args = pattern
                apply_match(s, node, match)
                for subpattern, arg in zip(args, node.args):
                    apply_match(subpattern, arg, match)
            else:
                match_map[node.name] = match

        for node in reversed(self.graph.nodes):
            if node.name not in match_map:
                for pattern, value in patterns.items():
                    if matches(modules, node, pattern):
                        apply_match(pattern, node, (node, value(self, node)))

        return match_map

    def _find_quants(self, quant_ctor):
        quants = {}

        def visit_arg(n):
            # note: we have to measure quantization information
            # even for nodes where we might not use it because it is already
            # quantized. This is because each match has the option to
            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
            if n.name not in quants:
                quants[n.name] = quant_ctor(self, n)

        for node in self.graph.nodes:
            if node.name in self.matches:
                map_arg(node.args, visit_arg)
                map_arg(node.kwargs, visit_arg)
        return quants