Exemplo n.º 1
0
    def test_reassign_args_kwargs_uses(self):
        graph = torch.fx.Graph()
        x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
        z = x + y
        zed = z + z + z
        graph.output(zed.node)
        graph.lint()

        # zed = z + z + z -> zed = z + z + x
        zed.node.args = (zed.node.args[0], x.node)
        self.assertEqual(x.node.users.keys(), [z.node, zed.node])

        # z = x + y -> z = y + y
        z.node.args = (y.node, y.node)
        self.assertEqual(x.node.users.keys(), [zed.node])
Exemplo n.º 2
0
 def load_non_quantized(n):
     if n.name not in env:
         assert n.name in quant_env, \
             'trying to load float node but did not find node:' + n.name + \
             ' in quantized environment:' + str(quant_env)
         env[n.name] = Proxy(quant_env[n.name]).dequantize().node
     return env[n.name]
Exemplo n.º 3
0
def wrap_in_activation_function(m: GraphModule,
                                fn: ActivationFunction) -> GraphModule:
    # Get output node
    output_node: Optional[Node] = None
    for n in reversed(m.graph.nodes):
        if n.op == "output":
            output_node = n
            break
    assert output_node

    # Get the actual output (the "input" of the output node). This is
    # the Node we want to wrap in a user-specified activation function
    assert len(output_node.all_input_nodes) == 1
    wrap_node = output_node.all_input_nodes[0]

    # Wrap the actual output in a Proxy
    wrap_proxy = Proxy(wrap_node)

    # Get the implementation of the specified activation function and
    # symbolically trace it
    fn_impl = activation_functions[fn]
    fn_impl_traced = symbolic_trace(fn_impl)

    # Call the specified activation function using the Proxy wrapper for
    # `output_op`. The result of this call is another Proxy, which we
    # can hook into our existing Graph.
    with traced.graph.inserting_before(wrap_node):
        fn_impl_output_node = fn_impl_traced(wrap_proxy)
        new_args = (fn_impl_output_node.node, )
        output_node.args = new_args
Exemplo n.º 4
0
 def load_non_quantized(n: Node) -> Node:
     if n.name not in env:
         assert n.name in quant_env, \
             'trying to load float node but did not find ' + \
             'node:' + n.name + \
             ' in quantized or non quantized environment, env: ' + \
             str(env) + ' quant_env:' + str(quant_env)
         env[n.name] = Proxy(quant_env[n.name]).dequantize().node
     return env[n.name]
Exemplo n.º 5
0
 def load_non_quantized(n: Node) -> Node:
     assert n.name in env, \
         'trying to load float node but did not find ' + \
         'node:' + n.name + \
         ' in env: ' + \
         str(env)
     quantized_node, dtype = env[n.name]
     if dtype and dtype != torch.float:
         env[n.name] = Proxy(quantized_node).dequantize().node, torch.float
     return env[n.name][0]
Exemplo n.º 6
0
 def test_graph_edit_with_proxy(self):
     class M(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = M()
     g = symbolic_trace(m).graph
     t = Proxy(g.result)
     # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
     g.output((t + t).node)
     gm = GraphModule(m, g)
     self.assertEqual(gm(3, 4), 14)
Exemplo n.º 7
0
 def test_graph_edit_with_proxy(self):
     class M(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = M()
     g = symbolic_trace(m).graph
     new_g = torch.fx.Graph()
     new_g.graph_copy(g)
     t = Proxy(new_g.nodes[-1])
     # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
     new_g.output((t + t).node)
     gm = GraphModule(m, new_g)
     gm.graph.lint(gm)
     self.assertEqual(gm(3, 4), 14)
Exemplo n.º 8
0
 def test_graph_unique_names(self):
     class M(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = M()
     g = symbolic_trace(m).graph
     new_g = torch.fx.Graph()
     new_g.graph_copy(g)
     t = Proxy(new_g.nodes[-1])
     # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
     new_g.output((t + t).node)
     gm = GraphModule(m, new_g)
     seen_names : Set[str] = set()
     for node in gm.graph.nodes:
         assert node.name not in seen_names
         seen_names.add(node.name)
Exemplo n.º 9
0
 def load_non_quantized(n: Node) -> Node:
     assert n.name in env, \
         'trying to load float node but did not find ' + \
         'node:' + n.name + \
         ' in env: ' + \
         str(env)
     dtype_to_node = env[n.name]
     if torch.float in dtype_to_node:
         return dtype_to_node[torch.float]
     elif None in dtype_to_node:
         return dtype_to_node[None]
     else:
         quantized_node = None
         for dtype in [torch.quint8, torch.qint8, torch.float16]:
             if dtype in dtype_to_node:
                 quantized_node = dtype_to_node[dtype]
                 break
         assert quantized_node is not None, "Did not find a supported quantized dtype:{}".format(dtype_to_node)
         env[n.name][torch.float] = Proxy(quantized_node).dequantize().node
         return env[n.name][torch.float]
Exemplo n.º 10
0
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
# create a graph-appending tracer pointing to the original graph
tracer = torch.fx.proxy.GraphAppendingTracer(m.graph)
for node in m.graph.nodes:
    # Find `call_module` Node in `m` that corresponds to `self.relu`.
    # This is the Node we want to swap out for an inlined version of the
    # same call
    if (node.op, node.target) == ("call_module", "relu"):
        with m.graph.inserting_before(node):
            # Create a Proxy from each Node in the current Node's
            # args/kwargs
            proxy_args = map_arg(node.args, lambda n: Proxy(n, tracer))
            proxy_kwargs = map_arg(node.kwargs, lambda n: Proxy(n, tracer))
            # Call `m.relu` with the newly-created Proxy arguments.
            # `m.relu` is the generic version of the function; by
            # calling it with Proxies created from Nodes in `m`, we're
            # emitting Nodes that reference exiting values in the IR.
            # The result of this call is another Proxy, which we can
            # hook into our existing Graph to complete the function
            # inlining.
            proxy_output = m.relu(*proxy_args, **proxy_kwargs)
            # Replace the relu `call_module` node with the inlined
            # version of the function
            node.replace_all_uses_with(proxy_output.node)
            # Make sure that the old relu Node is erased
            m.graph.erase_node(node)
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

'''

# Create a graph independently of symbolic tracing
graph = Graph()
tracer = torch.fx.proxy.GraphAppendingTracer(graph)

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes and graph's default tracer
y = Proxy(raw1, tracer)
z = Proxy(raw2, tracer)
# y = Proxy(raw1)
# z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)
# By using the graph's own appending tracer to create Proxies,
# notice we can now use n-ary operators on operations without
# multiple tracers being created at run-time (line 52) which leads
# to errors # To try this out for yourself, replace lines 42, 43
# with 44, 45
z = torch.add(b, c)
Exemplo n.º 12
0
 def lift_shape(i):
     res = Proxy(i)
     res.shape = i.shape
     res.bdim = i.bdim
     return res
        cat_1 = torch.cat([x, y]);  x = y = None
        tanh_1 = torch.tanh(cat_1);  cat_1 = None
        neg_1 = torch.neg(tanh_1);  tanh_1 = None
        return neg_1

'''

# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)