Пример #1
0
def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):
    fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module)
    primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
    tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
    # Construct the forward module
    fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph,
                                                   primal_inputs,
                                                   fwd_outputs + saved_values)
    bwd_graph = _extract_graph_with_inputs_outputs(
        joint_module.graph, saved_values + tangent_inputs, bwd_outputs)

    # This is to filter out saved values that don't actually end up being used by the backwards pass
    for node in bwd_graph.nodes:
        if node.op == 'placeholder' and not node.users:
            for saved_value in saved_values:
                if saved_value.name == node.name:
                    saved_values.remove(saved_value)
                    break

    # Now, we re-generate the fwd/bwd graphs.
    # NB: This might increase compilation time, but I doubt it matters
    fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph,
                                                   primal_inputs,
                                                   fwd_outputs + saved_values)
    bwd_graph = _extract_graph_with_inputs_outputs(
        joint_module.graph, saved_values + tangent_inputs, bwd_outputs)

    fwd_module = fx.GraphModule(joint_module, fwd_graph)
    bwd_module = fx.GraphModule(joint_module, bwd_graph)
    return fwd_module, bwd_module
Пример #2
0
 def graph_fails(graph, inps):
     nonlocal num_queries
     graph = copy.deepcopy(graph)
     num_queries += 1
     mod = fx.GraphModule(fail_f, graph)
     mod.graph.lint()
     return module_fails(mod, inps)
Пример #3
0
def decompose(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    decompose(model, example_inputs) takes in a model, decomposes any of the functions in `decomposition_rules` to its constituent operations, and returns a `nn.Module` without any of the operations with decomposition rules.
    """
    # Run it multiple times so we converge to a fixed point.
    for _ in range(5):
        model = fx.symbolic_trace(model)
        ShapeProp(model).propagate(*example_inputs)
        new_graph = fx.Graph()
        env = {}
        for node in model.graph.nodes:
            if node.op == 'call_function' and node.target in decomposition_rules:
                # If the current function is in `decomposition_rules`, we use
                # `Proxy` objects to decompose the operations using the
                # decomposition rule. See
                # https://pytorch.org/docs/master/fx.html#proxy-retracing for
                # more details.
                proxy_args = map_arg(node.args,
                                     lambda n: fx.Proxy(env[n.name]))
                proxy_kwargs = map_arg(node.kwargs,
                                       lambda n: fx.Proxy(env[n.name]))
                new_node = decomposition_rules[node.target](
                    *proxy_args, **proxy_kwargs).node
                env[node.name] = new_node
            else:
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
                env[node.name] = new_node
        model = fx.GraphModule(model, new_graph)
    return model
Пример #4
0
def fuse(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d),
                (nn.Conv3d, nn.BatchNorm3d)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users
                       ) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                bn = modules[node.target]
                if not bn.track_running_stats:
                    continue
                fused_conv = fuse_conv_bn_eval(conv, bn)
                replace_node_module(node.args[0], modules, fused_conv)
                node.replace_all_uses_with(node.args[0])
                new_graph.erase_node(node)
    return fx.GraphModule(fx_model, new_graph)
Пример #5
0
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...],
         example_args: Tuple[Any, ...]) -> torch.nn.Module:
    """vmap
    Given a model with inputs, vmap will return a function that works on
    batched versions of those inputs. Which inputs will be batched is
    determined by in_axes. In addition, as vmap requires shape (actually
    rank) information, we will pass in example_args (example inputs for the
    original module).
    """
    in_axes = iter(in_axes)
    fx_model = fx.symbolic_trace(model)
    # Here we run a shape propagation pass in order to annotate the graph with shape information.
    ShapeProp(fx_model).propagate(*example_args)
    # As vmap rewrites the whole graph, it's easiest to create an entirely new
    # graph and append to that.
    new_graph: fx.Graph = fx.Graph()

    # We will create an environment to map the new nodes created to the
    # corresponding old nodes.
    def lookup_env(l):
        return fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    env = {}
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # If the node is an input placeholder, we simply copy it over and
            # annotate it with the batch dimension from `in_axes`.
            new_node = new_graph.placeholder(node.name)
            new_node.bdim = next(in_axes)
            new_node.meta = node.meta
            env[node.name] = new_node
        elif node.op == 'output':
            new_graph.output(env[node.args[0].name])
        elif node.op == 'call_function':
            new_args = lookup_env(node.args)
            # If any of the inputs to the function has a new batch dimension,
            # we will need to use our batching rules. Otherwise, we will simply
            # copy the node over.
            if any([
                    x.bdim is not None for x in new_args
                    if isinstance(x, fx.Node)
            ]):
                new_node = gen_batching_rule_function(node.target, *new_args)
            else:
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
                new_node.bdim = None
            new_node.meta = node.meta
            env[node.name] = new_node
        else:
            raise RuntimeError("Not yet implemented")

    res = fx.GraphModule(fx_model, new_graph)
    print(res.code)
    res.graph.lint()
    return res
Пример #6
0
def grad(model: torch.nn.Module,
         example_inps: Tuple[Any, ...],
         get_value=True) -> torch.nn.Module:
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inps)
    # graph and append to that.
    val_map = {}
    new_graph: fx.Graph = fx.Graph()
    orig_output = new_graph.graph_copy(fx_model.graph, val_map)

    def shape_proxy(node):
        proxy = fx.Proxy(val_map[node])
        proxy.shape = node.meta['shape']
        proxy.dim = lambda: len(proxy.shape)
        return proxy

    inputs = []
    ones = new_graph.create_node('call_function', torch.ones, ([], ))

    for node in reversed(fx_model.graph.nodes):
        if node.op == 'output':
            assert (len(node.args) == 1)
            val_map[node.args[0]].grad = [fx.Proxy(ones)]
        elif node.op == 'placeholder':
            inputs.append(sum(val_map[node].grad).node)
        elif node.op == 'call_function':
            g = sum(val_map[node].grad)
            new_args = [
                shape_proxy(i) if isinstance(i, fx.Node) else i
                for i in node.args
            ]
            if node.target not in vjp_map:
                raise RuntimeError("vjp not yet implemented")
            new_grads = vjp_map[node.target](g, *new_args)
            if not isinstance(new_grads, tuple):
                new_grads = (new_grads, )
            for new_g, arg in zip(new_grads, new_args):
                if isinstance(arg, fx.Proxy):
                    if not hasattr(arg.node, 'grad'):
                        arg.node.grad = []
                    arg.node.grad.append(new_g)
        elif node.op == 'call_method':
            raise RuntimeError("doesn't support methods since i'm lazy")

    if len(inputs) == 1:
        inputs = inputs[0]
    else:
        inputs = inputs[::-1]
    if get_value:
        new_graph.output((orig_output, inputs))
    else:
        new_graph.output(inputs)
    res = fx.GraphModule(fx_model, new_graph)
    res.graph.lint()
    return res
Пример #7
0
    def test_remove_duplicate_output_args(self):
        class Sub(nn.Module):
            def forward(self, x):
                return (x, x)

        class Top(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = Sub()

            def forward(self, x):
                a_res = self.a(x)
                return a_res[0] + a_res[1]

        class Tracer(fx.Tracer):
            def is_leaf_module(self, m, qn):
                if isinstance(m, Sub):  # don't trace into
                    return True
                return False

        top = Top()
        ttop = fx.GraphModule(top, Tracer().trace(top), "top")
        ttop.a = fx.symbolic_trace(ttop.a)

        name_to_processed_subnet = dedup.remove_duplicate_output_args(
            ttop, ["a"])

        ttop(1)  # run inference should work

        processed_a = name_to_processed_subnet["a"]
        *_, a_output = processed_a.module.graph.nodes
        a_output: fx.Node

        ttop_graph_actual = str(ttop.graph).strip()
        ttop_graph_expected = """
graph():
    %x : [#users=1] = placeholder[target=x]
    %a : [#users=2] = call_module[target=a](args = (%x,), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {})
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%a, 0), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {})
    return add
""".strip()
        assert (ttop_graph_expected == ttop_graph_actual
                ), f"Unexpected ttop graph: {ttop_graph_actual}"

        ttop_a_graph_actual = str(ttop.a.graph).strip()
        ttop_a_graph_expected = """
graph():
    %x : [#users=1] = placeholder[target=x]
    return (x,)
""".strip()
        assert (ttop_a_graph_expected == ttop_a_graph_actual
                ), f"Unexpected ttop.a graph: {ttop_a_graph_actual}"
Пример #8
0
def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True):
    if clear_meta:
        new_graph = copy.deepcopy(traced.graph)
        traced = fx.GraphModule(traced, new_graph)
        for node in traced.graph.nodes:
            node.meta = {}
    base, ext = os.path.splitext(fname)
    if not ext:
        ext = ".svg"
    print(f"Writing FX graph to file: {base}{ext}")
    g = graph_drawer.FxGraphDrawer(traced, figname)
    x = g.get_main_dot_graph()
    getattr(x, "write_" + ext.lstrip("."))(f"{base}{ext}")
Пример #9
0
def truncate(model, k):
    model = fx.symbolic_trace(model)
    new_graph = fx.Graph()
    env = {}

    cnt = 0
    for node in list(model.graph.nodes):
        new_node = new_graph.node_copy(node, lambda x: env[x.name])
        env[node.name] = new_node
        cnt += 1
        if cnt == k:
            new_graph.output(env[node.name])
            break

    return fx.GraphModule(model, new_graph)
Пример #10
0
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]):
    """
    Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
    """
    new_graph = fx.Graph()
    env: Dict[fx.Node, fx.Node] = {}
    for input in inputs:
        new_node = new_graph.placeholder(input.name)
        env[input] = new_node
    for node in nodes:
        new_node = new_graph.node_copy(node, lambda x: env[x])
        env[node] = new_node
    new_graph.output([env[output] for output in outputs])
    new_graph.lint()
    return fx.GraphModule(orig_module, new_graph)
Пример #11
0
def profile_function(name, f, inp):
    fx_g = make_fx(f)(inp)

    new_g = fx_graph_cse(fx_g.graph)
    new_g = fx.GraphModule(fx_g, new_g)
    # do not benchmark against the scripted version because script already does some CSE
    # script_f = torch.jit.script(fx_g)
    # script_g = torch.jit.script(new_g)
    # avg_cuda_time_f = profile_it(script_f, inp)
    # avg_cuda_time_g = profile_it(script_g, inp)
    avg_cuda_time_f = profile_it(fx_g, inp)
    avg_cuda_time_g = profile_it(new_g, inp)
    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)

    print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
Пример #12
0
def check(f, t, delta, check_val=True, graph_input=False):
    if graph_input:
        fx_g = f
    else:
        fx_g = make_fx(f)(t)
    new_graph = fx_graph_cse(fx_g.graph)
    new_g = fx.GraphModule(fx_g, new_graph)

    # the number of nodes decrease/ or stay the same
    old_num_nodes = len(fx_g.graph.nodes)
    new_num_nodes = len(new_graph.nodes)
    if delta == -1:
        assert old_num_nodes >= new_num_nodes, (
            f"number of nodes increased {old_num_nodes}, {new_num_nodes}")
    else:
        assert old_num_nodes == new_num_nodes + delta, (
            f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
        )

    # a second pass should not reduce more nodes
    pass_2_graph = fx_graph_cse(new_graph)
    pass_2_num_nodes = len(pass_2_graph.nodes)
    assert pass_2_num_nodes == new_num_nodes, (
        f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
    )

    # check correctness
    if check_val:
        true_result = fx_g(t)
        our_result = new_g(t)
        if true_result is None:  # both return None
            assert our_result is None, f"true result is None, CSE result is {our_result}"
        else:  # results returned are the same
            assert torch.all(true_result == our_result), (
                f"results are different {true_result}, {our_result}"
            )  # check results are the same
Пример #13
0
def invert(model: torch.nn.Module) -> torch.nn.Module:
    fx_model = fx.symbolic_trace(model)
    new_graph = fx.Graph()  # As we're building up a new graph
    env = {}
    for node in reversed(fx_model.graph.nodes):
        if node.op == "call_function":
            # This creates a node in the new graph with the inverse function,
            # and passes `env[node.name]` (i.e. the previous output node) as
            # input.
            new_node = new_graph.call_function(invert_mapping[node.target],
                                               (env[node.name], ))
            env[node.args[0].name] = new_node
        elif node.op == "output":
            # We turn the output into an input placeholder
            new_node = new_graph.placeholder(node.name)
            env[node.args[0].name] = new_node
        elif node.op == "placeholder":
            # We turn the input placeholder into an output
            new_graph.output(env[node.name])
        else:
            raise RuntimeError("Not implemented")

    new_graph.lint()
    return fx.GraphModule(fx_model, new_graph)
Пример #14
0
def codegen_tensor_product(
    irreps_in1: o3.Irreps,
    in1_var: List[float],
    irreps_in2: o3.Irreps,
    in2_var: List[float],
    irreps_out: o3.Irreps,
    out_var: List[float],
    instructions: List[Instruction],
    normalization: str = 'component',
    shared_weights: bool = False,
    specialized_code: bool = True,
    optimize_einsums: bool = True,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
    graph_out = fx.Graph()
    graph_right = fx.Graph()

    # = Function definitions =
    x1s_out = fx.Proxy(graph_out.placeholder('x1', torch.Tensor))
    x2s_out = fx.Proxy(graph_out.placeholder('x2', torch.Tensor))
    ws_out = fx.Proxy(graph_out.placeholder('w', torch.Tensor))

    x2s_right = fx.Proxy(graph_right.placeholder('x2', torch.Tensor))
    ws_right = fx.Proxy(graph_right.placeholder('w', torch.Tensor))

    empty_out = fx.Proxy(
        graph_out.call_function(torch.empty, ((), ), dict(device='cpu')))
    empty_right = fx.Proxy(
        graph_right.call_function(torch.empty, ((), ), dict(device='cpu')))
    if shared_weights:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]))[0].shape
        size_right = x2s_right.shape[:-1]
    else:
        size_out = torch.broadcast_tensors(
            empty_out.expand(x1s_out.shape[:-1]),
            empty_out.expand(x2s_out.shape[:-1]),
            empty_out.expand(ws_out.shape[:-1]))[0].shape
        size_right = torch.broadcast_tensors(
            empty_right.expand(x2s_right.shape[:-1]),
            empty_right.expand(ws_right.shape[:-1]))[0].shape

    # = Short-circut for zero dimensional =
    # We produce no code for empty instructions
    instructions = [ins for ins in instructions if 0 not in ins.path_shape]

    if len(instructions) == 0:
        out_out = x1s_out.new_zeros(size_out + (irreps_out.dim, ))
        out_right = x2s_right.new_zeros(size_right + (
            irreps_in1.dim,
            irreps_out.dim,
        ))

        graph_out.output(out_out.node, torch.Tensor)
        graph_right.output(out_right.node, torch.Tensor)
        # Short circut
        return (fx.GraphModule({}, graph_out, "tp_forward"),
                fx.GraphModule({}, graph_right, "tp_right"))

    # = Broadcast inputs =
    if shared_weights:
        x1s_out, x2s_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(size_out + (-1, ))
    else:
        x1s_out, x2s_out, ws_out = x1s_out.broadcast_to(
            size_out + (-1, )), x2s_out.broadcast_to(
                size_out + (-1, )), ws_out.broadcast_to(size_out + (-1, ))
        x2s_right, ws_right = x2s_right.broadcast_to(
            size_right + (-1, )), ws_right.broadcast_to(size_right + (-1, ))

    outsize_out = size_out + (irreps_out.dim, )
    outsize_right = size_right + (
        irreps_in1.dim,
        irreps_out.dim,
    )

    x1s_out = x1s_out.reshape(-1, irreps_in1.dim)
    x2s_out = x2s_out.reshape(-1, irreps_in2.dim)
    x2s_right = x2s_right.reshape(-1, irreps_in2.dim)

    batch_out = x1s_out.shape[0]
    batch_right = x2s_right.shape[0]

    # = Determine number of weights and reshape weights ==
    weight_numel = sum(
        prod(ins.path_shape) for ins in instructions if ins.has_weight)
    if weight_numel > 0:
        ws_out = ws_out.reshape(-1, weight_numel)
        ws_right = ws_right.reshape(-1, weight_numel)
    del weight_numel

    # = book-keeping for wigners =
    w3j = []
    w3j_dict_out = dict()
    w3j_dict_right = dict()

    # = extract individual input irreps =
    # If only one input irrep, can avoid creating a view
    if len(irreps_in1) == 1:
        x1_list_out = [
            x1s_out.reshape(batch_out, irreps_in1[0].mul, irreps_in1[0].ir.dim)
        ]
    else:
        x1_list_out = [
            x1s_out[:, i].reshape(batch_out, mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(irreps_in1.slices(), irreps_in1)
        ]

    x2_list_out = []
    x2_list_right = []
    # If only one input irrep, can avoid creating a view
    if len(irreps_in2) == 1:
        x2_list_out.append(
            x2s_out.reshape(batch_out, irreps_in2[0].mul,
                            irreps_in2[0].ir.dim))
        x2_list_right.append(
            x2s_right.reshape(batch_right, irreps_in2[0].mul,
                              irreps_in2[0].ir.dim))
    else:
        for i, mul_ir in zip(irreps_in2.slices(), irreps_in2):
            x2_list_out.append(x2s_out[:, i].reshape(batch_out, mul_ir.mul,
                                                     mul_ir.ir.dim))
            x2_list_right.append(x2s_right[:,
                                           i].reshape(batch_right, mul_ir.mul,
                                                      mul_ir.ir.dim))

    # The einsum string index to prepend to the weights if the weights are not shared and have a batch dimension
    z = '' if shared_weights else 'z'

    # Cache of input irrep pairs whose outer products (xx) have already been computed
    xx_dict = dict()

    # Current index in the flat weight tensor
    flat_weight_index = 0

    out_list_out = []
    out_list_right = []

    for ins in instructions:
        mul_ir_in1 = irreps_in1[ins.i_in1]
        mul_ir_in2 = irreps_in2[ins.i_in2]
        mul_ir_out = irreps_out[ins.i_out]

        assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p
        assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l
                   ) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l

        if mul_ir_in1.dim == 0 or mul_ir_in2.dim == 0 or mul_ir_out.dim == 0:
            continue

        alpha = ins.path_weight * out_var[ins.i_out] / sum(
            in1_var[i.i_in1] * in2_var[i.i_in2]
            for i in instructions if i.i_out == ins.i_out)

        # Open the profiler block
        name = f"{mul_ir_in1} x {mul_ir_in2} = {mul_ir_out} {ins.connection_mode} {ins.has_weight}"
        handle_out = graph_out.call_function(
            torch.ops.profiler._record_function_enter, (name, ))
        handle_right = graph_right.call_function(
            torch.ops.profiler._record_function_enter, (name, ))

        x1_out = x1_list_out[ins.i_in1]
        x2_out = x2_list_out[ins.i_in2]
        x2_right = x2_list_right[ins.i_in2]

        e1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        e2_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in2.mul, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))
        i1_right = fx.Proxy(
            graph_right.call_function(
                torch.eye, (mul_ir_in1.ir.dim, ),
                dict(dtype=x2s_right.dtype.node,
                     device=x2s_right.device.node)))

        assert ins.connection_mode in [
            'uvw', 'uvu', 'uvv', 'uuw', 'uuu', 'uvuv'
        ]

        alpha = sqrt(
            alpha / {
                'uvw': (mul_ir_in1.mul * mul_ir_in2.mul),
                'uvu': mul_ir_in2.mul,
                'uvv': mul_ir_in1.mul,
                'uuw': mul_ir_in1.mul,
                'uuu': 1,
                'uvuv': 1,
            }[ins.connection_mode])

        if ins.has_weight:
            # Extract the weight from the flattened weight tensor
            w_out = ws_out[:, flat_weight_index:flat_weight_index +
                           prod(ins.path_shape)].reshape((
                               () if shared_weights else (-1, )) +
                                                         tuple(ins.path_shape))
            w_right = ws_right[:, flat_weight_index:flat_weight_index +
                               prod(ins.path_shape)].reshape(
                                   (() if shared_weights else (-1, )) +
                                   tuple(ins.path_shape))
            flat_weight_index += prod(ins.path_shape)

        # Construct the general xx in case this instruction isn't specialized
        # If this isn't used, the dead code will get removed
        key = (ins.i_in1, ins.i_in2, ins.connection_mode[:2])
        if key not in xx_dict:
            if ins.connection_mode[:2] == 'uv':
                xx_dict[key] = torch.einsum('zui,zvj->zuvij', x1_out, x2_out)
            if ins.connection_mode[:2] == 'uu':
                xx_dict[key] = torch.einsum('zui,zuj->zuij', x1_out, x2_out)
        xx = xx_dict[key]

        # Create a proxy & request for the relevant wigner w3j
        # If not used (because of specialized code), will get removed later.
        key = (mul_ir_in1.ir.l, mul_ir_in2.ir.l, mul_ir_out.ir.l)
        if key not in w3j:
            w3j_dict_out[key] = fx.Proxy(
                graph_out.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j_dict_right[key] = fx.Proxy(
                graph_right.get_attr(f"_w3j_{key[0]}_{key[1]}_{key[2]}"))
            w3j.append(key)
        w3j_out = w3j_dict_out[key]
        w3j_right = w3j_dict_right[key]

        exp = {'component': 1, 'norm': -1}[normalization]

        if ins.connection_mode == 'uvw':
            assert ins.has_weight
            if specialized_code and key == (0, 0, 0):
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zv->zw", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim),
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,zv->zuw", w_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_in1.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zu,zvj->zwj", w_out,
                    x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                ein_right = torch.einsum(f"{z}uvw,zvi->zuwi", w_right,
                                         x2_right)
            elif specialized_code and mul_ir_in2.ir.l == 0:
                ein_out = torch.einsum(
                    f"{z}uvw,zui,zv->zwi", w_out, x1_out,
                    x2_out.reshape(batch_out, mul_ir_in2.dim))
                ein_right = torch.einsum(
                    f"{z}uvw,ij,zv->zuiwj", w_right, i1_right,
                    x2_right.reshape(batch_right, mul_ir_in2.dim))
            elif specialized_code and mul_ir_out.ir.l == 0:
                ein_out = torch.einsum(f"{z}uvw,zui,zvi->zw", w_out, x1_out,
                                       x2_out) / sqrt(mul_ir_in1.ir.dim)**exp
                ein_right = torch.einsum(f"{z}uvw,zvi->zuiw", w_right,
                                         x2_right) / sqrt(
                                             mul_ir_in1.ir.dim)**exp
            else:
                ein_out = torch.einsum(f"{z}uvw,ijk,zuvij->zwk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uvw,ijk,zvj->zuiwk", w_right,
                                         w3j_right, x2_right)
        if ins.connection_mode == 'uvu':
            assert mul_ir_in1.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,uw,zv->zuw", w_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuwi", w_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,uw,zv->zuiwj", w_right, i1_right, e1_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,uw,zvi->zuiw", w_right,
                                             e1_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                # not so useful operation because v is summed
                ein_out = torch.einsum("ijk,zuvij->zuk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwk", w3j_right,
                                         e1_right, x2_right)
        if ins.connection_mode == 'uvv':
            assert mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zv->zv", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,vw,zv->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zu,zvj->zvj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uv,zui,zv->zvi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}uv,ij,vw,zv->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uv,zui,zvi->zv", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}uv,vw,zvi->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zvk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}uv,ijk,zvj->zuivk", w_right,
                                             w3j_right, x2_right)
            else:
                # not so useful operation because u is summed
                # only specialize out for this path
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zv->zv", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zvj->zvj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zv->zvi", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zvi->zv", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuvij->zvk", w3j_out, xx)
                s2ones = fx.Proxy(
                    graph_right.call_function(
                        torch.ones, (mul_ir_in1.mul, ),
                        dict(device=x2_right.device.node,
                             dtype=x2_right.dtype.node)))
                ein_right = torch.einsum("u,ijk,zvj->zuivk", s2ones, w3j_right,
                                         x2_right)
        if ins.connection_mode == 'uuw':
            assert mul_ir_in1.mul == mul_ir_in2.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zu->zw", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zu,zuj->zwj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}uw,zui,zu->zwi", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}uw,zui,zui->zw", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}uw,ijk,zuij->zwk", w_out,
                                           w3j_out, xx)
                # TODO: specialize right()
                ein_right = torch.einsum(f"{z}uw,ijk,zuj->zuiwk", w_right,
                                         w3j_right, x2_right)
            else:
                # equivalent to tp(x, y, 'uuu').sum('u')
                assert mul_ir_out.mul == 1
                ein_out = torch.einsum("ijk,zuij->zk", w3j_out, xx)
                ein_right = torch.einsum("ijk,zuj->zuik", w3j_right, x2_right)
        if ins.connection_mode == 'uuu':
            assert mul_ir_in1.mul == mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        f"{z}u,zu,zu->zu", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,uw,zu->zuw", w_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.einsum(f"{z}u,zui->zui", w_out,
                                           torch.cross(x1_out, x2_out,
                                                       dim=2)) / sqrt(2)
                    # For cross product, use the general case right()
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zu,zuj->zuj", w_out,
                        x1_out.reshape(batch_out, mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuwi", w_right,
                                             e2_right, x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        f"{z}u,zui,zu->zui", w_out, x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        f"{z}u,ij,uw,zu->zuiwj", w_right, i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum(f"{z}u,zui,zui->zu", w_out, x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum(f"{z}u,uw,zui->zuiw", w_right,
                                             e2_right, x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum(f"{z}u,ijk,zuij->zuk", w_out,
                                           w3j_out, xx)
                    ein_right = torch.einsum(f"{z}u,ijk,uw,zuj->zuiwk",
                                             w_right, w3j_right, e1_right,
                                             x2_right)
            else:
                if specialized_code and key == (0, 0, 0):
                    ein_out = torch.einsum(
                        "zu,zu->zu", x1_out.reshape(batch_out, mul_ir_in1.dim),
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "uw,zu->zuw", e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and key == (
                        1, 1, 1) and normalization == "component":
                    ein_out = torch.cross(x1_out, x2_out,
                                          dim=2) * (1.0 / sqrt(2))
                    # For cross product, use the general case right()
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
                elif specialized_code and mul_ir_in1.ir.l == 0:
                    ein_out = torch.einsum(
                        "zu,zuj->zuj", x1_out.reshape(batch_out,
                                                      mul_ir_in1.dim), x2_out)
                    ein_right = torch.einsum("uw,zui->zuwi", e2_right,
                                             x2_right)
                elif specialized_code and mul_ir_in2.ir.l == 0:
                    ein_out = torch.einsum(
                        "zui,zu->zui", x1_out,
                        x2_out.reshape(batch_out, mul_ir_in2.dim))
                    ein_right = torch.einsum(
                        "ij,uw,zu->zuiwj", i1_right, e2_right,
                        x2_right.reshape(batch_right, mul_ir_in2.dim))
                elif specialized_code and mul_ir_out.ir.l == 0:
                    ein_out = torch.einsum("zui,zui->zu", x1_out,
                                           x2_out) / sqrt(
                                               mul_ir_in1.ir.dim)**exp
                    ein_right = torch.einsum("uw,zui->zuiw", e2_right,
                                             x2_right) / sqrt(
                                                 mul_ir_in1.ir.dim)**exp
                else:
                    ein_out = torch.einsum("ijk,zuij->zuk", w3j_out, xx)
                    ein_right = torch.einsum("ijk,uw,zuj->zuiwk", w3j_right,
                                             e1_right, x2_right)
        if ins.connection_mode == 'uvuv':
            assert mul_ir_in1.mul * mul_ir_in2.mul == mul_ir_out.mul
            if ins.has_weight:
                # TODO implement specialized code
                ein_out = torch.einsum(f"{z}uv,ijk,zuvij->zuvk", w_out,
                                       w3j_out, xx)
                ein_right = torch.einsum(f"{z}uv,ijk,uw,zvj->zuiwvk", w_right,
                                         w3j_right, e1_right, x2_right)
            else:
                # TODO implement specialized code
                ein_out = torch.einsum("ijk,zuvij->zuvk", w3j_out, xx)
                ein_right = torch.einsum("ijk,uw,zvj->zuiwvk", w3j_right,
                                         e1_right, x2_right)

        ein_out = alpha * ein_out
        ein_right = alpha * ein_right

        out_list_out += [ein_out.reshape(batch_out, mul_ir_out.dim)]
        out_list_right += [
            ein_right.reshape(batch_right, mul_ir_in1.dim, mul_ir_out.dim)
        ]

        # Close the profiler block
        graph_out.call_function(torch.ops.profiler._record_function_exit,
                                (handle_out, ))
        graph_right.call_function(torch.ops.profiler._record_function_exit,
                                  (handle_right, ))

        # Remove unused w3js:
        if len(w3j_out.node.users) == 0 and len(w3j_right.node.users) == 0:
            del w3j[-1]
            # The w3j nodes are reshapes, so we have to remove them from the graph
            # Although they are dead code, they try to reshape to dimensions that don't exist
            # (since the corresponding w3js are not in w3j)
            # so they screw up the shape propagation, even though they would be removed later as dead code by TorchScript.
            graph_out.erase_node(w3j_dict_out.pop(key).node)
            graph_right.erase_node(w3j_dict_right.pop(key).node)

    # = Return the result =
    out_out = [
        _sum_tensors([
            out for ins, out in zip(instructions, out_list_out)
            if ins.i_out == i_out
        ],
                     shape=(batch_out, mul_ir_out.dim),
                     like=x1s_out)
        for i_out, mul_ir_out in enumerate(irreps_out) if mul_ir_out.mul > 0
    ]
    if len(out_out) > 1:
        out_out = torch.cat(out_out, dim=1)
    else:
        # Avoid an unnecessary copy in a size one torch.cat
        out_out = out_out[0]

    out_right = [
        torch.cat([
            _sum_tensors([
                out for ins, out in zip(instructions, out_list_right)
                if (ins.i_in1, ins.i_out) == (i_in1, i_out)
            ],
                         shape=(batch_right, mul_ir_in1.dim, mul_ir_out.dim),
                         like=x2s_right)
            for i_out, mul_ir_out in enumerate(irreps_out)
            if mul_ir_out.mul > 0
        ],
                  dim=2) for i_in1, mul_ir_in1 in enumerate(irreps_in1)
        if mul_ir_in1.mul > 0
    ]
    if len(out_right) > 1:
        out_right = torch.cat(out_right, dim=1)
    else:
        out_right = out_right[0]

    out_out = out_out.reshape(outsize_out)
    out_right = out_right.reshape(outsize_right)

    graph_out.output(out_out.node, torch.Tensor)
    graph_right.output(out_right.node, torch.Tensor)

    # check graphs
    graph_out.lint()
    graph_right.lint()

    # Make GraphModules
    wigner_mats = {}
    for l_1, l_2, l_out in w3j:
        wig = o3.wigner_3j(l_1, l_2, l_out)

        if normalization == 'component':
            wig *= (2 * l_out + 1)**0.5
        if normalization == 'norm':
            wig *= (2 * l_1 + 1)**0.5 * (2 * l_2 + 1)**0.5

        wigner_mats[f"_w3j_{l_1}_{l_2}_{l_out}"] = wig

    # By putting the constants in a Module rather than a dict,
    # we force FX to copy them as buffers instead of as attributes.
    #
    # FX seems to have resolved this issue for dicts in 1.9, but we support all the way back to 1.8.0.
    constants_root = torch.nn.Module()
    for wkey, wmat in wigner_mats.items():
        constants_root.register_buffer(wkey, wmat)
    graphmod_out = fx.GraphModule(constants_root,
                                  graph_out,
                                  class_name="tp_forward")
    graphmod_right = fx.GraphModule(constants_root,
                                    graph_right,
                                    class_name="tp_right")

    # == Optimize ==
    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    if optimize_einsums:
        # Note that for our einsums, we can optimize _once_ for _any_ batch dimension
        # and still get the right path for _all_ batch dimensions.
        # This is because our einsums are essentially of the form:
        #    zuvw,ijk,zuvij->zwk    OR     uvw,ijk,zuvij->zwk
        # In the first case, all but one operands have the batch dimension
        #    => The first contraction gains the batch dimension
        #    => All following contractions have batch dimension
        #    => All possible contraction paths have cost that scales linearly in batch size
        #    => The optimal path is the same for all batch sizes
        # For the second case, this logic follows as long as the first contraction is not between the first two operands. Since those two operands do not share any indexes, contracting them first is a rare pathological case. See
        # https://github.com/dgasmith/opt_einsum/issues/158
        # for more details.
        #
        # TODO: consider the impact maximum intermediate result size on this logic
        #         \- this is the `memory_limit` option in opt_einsum
        # TODO: allow user to choose opt_einsum parameters?
        #
        # We use float32 and zeros to save memory and time, since opt_einsum_fx looks only at traced shapes, not values or dtypes.
        batchdim = 4
        example_inputs = (
            torch.zeros((batchdim, irreps_in1.dim)),
            torch.zeros((batchdim, irreps_in2.dim)),
            torch.zeros(
                1 if shared_weights else batchdim,
                flat_weight_index,
            ),
        )
        graphmod_out = jitable(
            optimize_einsums_full(graphmod_out, example_inputs))
        graphmod_right = jitable(
            optimize_einsums_full(graphmod_right, example_inputs[1:]))

    return graphmod_out, graphmod_right
Пример #15
0
def _codegen_linear(
    irreps_in: o3.Irreps,
    irreps_out: o3.Irreps,
    instructions: List[Instruction],
    biases: List[bool],
    f_in: Optional[int] = None,
    f_out: Optional[int] = None,
    shared_weights: bool = False,
    optimize_einsums: bool = True,
) -> Tuple[fx.GraphModule, int, int]:
    graph_out = fx.Graph()

    # = Function definitions =
    x = fx.Proxy(graph_out.placeholder('x', torch.Tensor))
    ws = fx.Proxy(graph_out.placeholder('w', torch.Tensor))
    bs = fx.Proxy(graph_out.placeholder('b', torch.Tensor))

    if f_in is None:
        size = x.shape[:-1]
        outsize = size + (irreps_out.dim, )
    else:
        size = x.shape[:-2]
        outsize = size + (
            f_out,
            irreps_out.dim,
        )

    bias_numel = sum(mul_ir.dim for bias, mul_ir in zip(biases, irreps_out)
                     if bias)
    if bias_numel > 0:
        if f_out is None:
            bs = bs.reshape(-1, bias_numel)
        else:
            bs = bs.reshape(-1, f_out, bias_numel)

    # = Short-circut for nothing to do =
    # We produce no code for empty instructions
    instructions = [ins for ins in instructions if 0 not in ins.path_shape]

    if len(instructions) == 0 and bias_numel == 0:
        out = x.new_zeros(outsize)

        graph_out.output(out.node, torch.Tensor)
        # Short circut
        # 0 is weight_numel
        return fx.GraphModule({}, graph_out, "linear_forward"), 0, 0

    if f_in is None:
        x = x.reshape(-1, irreps_in.dim)
    else:
        x = x.reshape(-1, f_in, irreps_in.dim)
    batch_out = x.shape[0]

    out_bias_list = []
    bias_index = 0
    for bias, mul_ir_out in zip(biases, irreps_out):
        if bias:
            if sum(biases) == 1:
                b = bs
            else:
                b = bs.narrow(-1, bias_index, mul_ir_out.dim)
                bias_index += mul_ir_out.dim
            out_bias_list += [[
                b.expand(batch_out, -1) if f_out is None else b.expand(
                    batch_out, f_out, -1)
            ]]
        else:
            out_bias_list += [[]]

    weight_numel = sum(prod(ins.path_shape) for ins in instructions)
    if weight_numel > 0:
        ws = ws.reshape(-1, weight_numel) if f_in is None else ws.reshape(
            -1, f_in, f_out, weight_numel)

    # = extract individual input irreps =
    if len(irreps_in) == 1:
        x_list = [
            x.reshape(batch_out, *(() if f_in is None else (f_in, )),
                      irreps_in[0].mul, irreps_in[0].ir.dim)
        ]
    else:
        x_list = [
            x.narrow(-1, i.start,
                     mul_ir.dim).reshape(batch_out,
                                         *(() if f_in is None else (f_in, )),
                                         mul_ir.mul, mul_ir.ir.dim)
            for i, mul_ir in zip(irreps_in.slices(), irreps_in)
        ]

    z = '' if shared_weights else 'z'

    flat_weight_index = 0

    out_list = []

    for ins in instructions:
        mul_ir_in = irreps_in[ins.i_in]
        mul_ir_out = irreps_out[ins.i_out]

        # Short-circut for empty irreps
        if mul_ir_in.dim == 0 or mul_ir_out.dim == 0:
            continue

        # Extract the weight from the flattened weight tensor
        path_nweight = prod(ins.path_shape)
        if len(instructions) == 1:
            # Avoid unnecessary view when there is only one weight
            w = ws
        else:
            w = ws.narrow(-1, flat_weight_index, path_nweight)
        w = w.reshape((() if shared_weights else (-1, )) +
                      (() if f_in is None else (f_in, f_out)) + ins.path_shape)
        flat_weight_index += path_nweight

        if f_in is None:
            ein_out = torch.einsum(f"{z}uw,zui->zwi", w, x_list[ins.i_in])
        else:
            ein_out = torch.einsum(f"{z}xyuw,zxui->zywi", w, x_list[ins.i_in])
        alpha = 1.0 / math.sqrt((f_in or 1) * mul_ir_in.mul *
                                sum(1 if other_ins.i_out == ins.i_out else 0
                                    for other_ins in instructions))
        ein_out = alpha * ein_out

        out_list += [
            ein_out.reshape(batch_out, *(() if f_out is None else (f_out, )),
                            mul_ir_out.dim)
        ]

    # = Return the result =
    out = [
        _sum_tensors([
            out
            for ins, out in zip(instructions, out_list) if ins.i_out == i_out
        ] + out_bias_list[i_out],
                     shape=(batch_out, *(() if f_out is None else
                                         (f_out, )), mul_ir_out.dim),
                     like=x) for i_out, mul_ir_out in enumerate(irreps_out)
        if mul_ir_out.mul > 0
    ]
    if len(out) > 1:
        out = torch.cat(out, dim=-1)
    else:
        out = out[0]

    out = out.reshape(outsize)

    graph_out.output(out.node, torch.Tensor)

    # check graphs
    graph_out.lint()

    graphmod_out = fx.GraphModule({}, graph_out, "linear_forward")

    # TODO: when eliminate_dead_code() is in PyTorch stable, use that
    if optimize_einsums:
        # See _tensor_product/_codegen.py for notes
        batchdim = 4
        example_inputs = (
            torch.zeros((batchdim, *(() if f_in is None else
                                     (f_in, )), irreps_in.dim)),
            torch.zeros(
                1 if shared_weights else batchdim,
                f_in or 1,
                f_out or 1,
                weight_numel,
            ),
            torch.zeros(
                1 if shared_weights else batchdim,
                f_out or 1,
                bias_numel,
            ),
        )
        graphmod_out = jitable(
            optimize_einsums_full(graphmod_out, example_inputs))

    return graphmod_out, weight_numel, bias_numel
Пример #16
0
def optimize_for_inference(
        model: torch.nn.Module,
        pass_config: Optional[Dict[str, Any]] = None,
        tracer: Type[fx.Tracer] = fx.Tracer) -> torch.nn.Module:
    """
    Performs a set of optimization passes to optimize a model for the
    purposes of inference. Specifically, the passes that are run are:
    1. Conv/BN fusion
    2. Dropout removal
    3. MKL layout optimizations

    The third optimization takes a function `use_mkl_heuristic` that's used
    to determine whether a subgraph should be explicity run in MKL layout.

    Note: As FX does not currently handle aliasing, this pass currently
    assumes nothing aliases. If that isn't true, use at your own risk.
    """
    default_pass_config = {
        "conv_bn_fuse": True,
        "remove_dropout": True,
        "mkldnn_layout_optimize": {
            'heuristic': use_mkl_length
        },
    }
    if pass_config is None:
        pass_config = {}
    default_pass_config.update(pass_config)

    if default_pass_config["conv_bn_fuse"]:
        model = fuse(model)
    if default_pass_config["remove_dropout"]:
        model = remove_dropout(model)
    if default_pass_config["mkldnn_layout_optimize"] is False:
        return model
    if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
        raise RuntimeError("mkldnn_layout_optimize config is not a dict")
    if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
        raise RuntimeError(
            "Heuristic not found in mkldnn_layout_optimize config")
    use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"][
        "heuristic"]

    cur_tracer = tracer()
    fx_graph = cur_tracer.trace(copy.deepcopy(model))
    fx_model = fx.GraphModule(cur_tracer.root, fx_graph)
    modules: Dict[str, nn.Module] = dict(model.named_modules())

    class MklSupport(Enum):
        NO = 1
        YES = 2
        UNKNOWN = 3

    # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
    # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
    # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
    # a MKLDNN node if its inputs are MKLDNN nodes.
    for node in list(fx_graph.nodes):
        supports_mkldnn = MklSupport.NO
        if node.op == 'call_module':
            cur_module = modules[node.target]
            if type(cur_module) in mkldnn_supported:
                supports_mkldnn = MklSupport.YES
                sample_parameter = next(cur_module.parameters(), None)
                if sample_parameter is not None:
                    assert (sample_parameter.dtype == torch.float
                            ), "this pass is only for torch.float modules"
                    assert (sample_parameter.device == torch.device('cpu')
                            ), "this pass is only for CPU modules"
        elif node.op == 'call_function':
            if node.target in mkldnn_supported:
                supports_mkldnn = MklSupport.YES
            elif node.target in mkldnn_supported_unknown:
                supports_mkldnn = MklSupport.UNKNOWN

        if supports_mkldnn != MklSupport.NO:
            if supports_mkldnn == MklSupport.UNKNOWN:
                if not any([arg.target == 'to_dense' for arg in node.args]):
                    continue
            with fx_graph.inserting_before(node):
                mkldnn_args = fx.map_arg(
                    node.args,
                    lambda n: fx_graph.call_method('to_mkldnn', (n, )))

            node.args = cast(Tuple[fx.node.Argument], mkldnn_args)

            with fx_graph.inserting_after(node):
                dense_x = fx_graph.create_node('call_method', 'to_dense',
                                               (node, ))
                node.replace_all_uses_with(dense_x)
                dense_x.args = (node, )

    # Does pre-conversion of all modules into MKLDNN (when possible)
    old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
    fx_graph.old_modules = old_modules  # type: ignore[attr-defined]

    # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
    for node in fx_graph.nodes:
        if node.op == 'call_method' and node.target == 'to_dense':
            prv_node = node.args[0]
            users = list(node.users)
            for user in users:
                if user.op == 'call_method' and user.target == 'to_mkldnn':
                    user.replace_all_uses_with(prv_node)
                    fx_graph.erase_node(user)
            if len(node.users) == 0:
                fx_graph.erase_node(node)

    num_nodes = len(fx_graph.nodes)
    uf = UnionFind(num_nodes)

    def get_color(n):
        if hasattr(n, 'color'):  # Current node is part of a MKL subgraph
            return uf.find(n.color)
        if hasattr(n, 'start_color'):  # Current node is input to MKL subgraph
            return uf.find(n.start_color)
        return None

    # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
    # of input nodes (which are only `to_mkldnn` calls), output nodes
    # (`to_dense` calls), and intermediate nodes, which are run entirely on
    # MKLDNN layout tensors.
    #
    # Specifically, this code does a flood fill on a directed acyclic graph
    # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
    # If every node only had one input, this would be sufficient. However, in
    # the case that a node has multiple inputs coming from different start
    # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
    # using a Disjoint Set Union.
    for cur_idx, node in enumerate(fx_graph.nodes):
        if node.op == 'call_method' and node.target == 'to_mkldnn':
            node.start_color = cur_idx
            uf.make_set(cur_idx)
        elif node.op == 'call_method' and node.target == 'to_dense':
            assert (get_color(node.args[0]) is not None)
            node.end_color = get_color(node.args[0])
        else:
            cur_colors = [
                get_color(i) for i in node.all_input_nodes
                if isinstance(i, fx.Node) if get_color(i) is not None
            ]

            if len(cur_colors) == 0:
                continue
            assert (not any(i is None for i in cur_colors))
            cur_colors = sorted(cur_colors)
            node.color = cur_colors[0]
            for other_color in cur_colors[1:]:
                uf.join(cur_colors[0], other_color)

    mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(
        lambda: MklSubgraph(fx_graph))
    for node in fx_graph.nodes:
        if hasattr(node, 'color'):
            mkldnn_graphs[uf.find(node.color)].nodes.append(node)
        if hasattr(node, 'start_color'):
            mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
        if hasattr(node, 'end_color'):
            mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)

    # Now that we have all the subgraphs, we need to decide which MKLDNN
    # subgraphs we actually want to keep in MKLDNN.
    for graph in mkldnn_graphs.values():
        if not use_mkl_heuristic(graph):
            for node in graph.start_nodes + graph.end_nodes:
                prv = node.args[0]
                node.replace_all_uses_with(prv)
                fx_graph.erase_node(node)
            reset_modules(graph.nodes, modules, old_modules)

    mkldnn_conversions = 0
    for node in fx_graph.nodes:
        if node.target == 'to_mkldnn' or node.target == 'to_dense':
            mkldnn_conversions += 1

    logging.info(f"mkldnn conversions: {mkldnn_conversions}")
    fx_graph.lint()
    result = fx.GraphModule(model, fx_graph)
    return result
Пример #17
0
def create_feature_extractor(
    model: nn.Module,
    return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
    train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
    eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
    tracer_kwargs: Optional[Dict[str, Any]] = None,
    suppress_diff_warning: bool = False,
) -> fx.GraphModule:
    """
    Creates a new graph module that returns intermediate nodes from a given
    model as dictionary with user specified keys as strings, and the requested
    outputs as values. This is achieved by re-writing the computation graph of
    the model via FX to return the desired nodes as outputs. All unused nodes
    are removed, together with their corresponding parameters.

    Desired output nodes must be specified as a ``.`` separated
    path walking the module hierarchy from top level module down to leaf
    operation or leaf module. For more details on the node naming conventions
    used here, please see the :ref:`relevant subheading <about-node-names>`
    in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.

    Not all models will be FX traceable, although with some massaging they can
    be made to cooperate. Here's a (not exhaustive) list of tips:

        - If you don't need to trace through a particular, problematic
          sub-module, turn it into a "leaf module" by passing a list of
          ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
          It will not be traced through, but rather, the resulting graph will
          hold a reference to that module's forward method.
        - Likewise, you may turn functions into leaf functions by passing a
          list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
          example below).
        - Some inbuilt Python functions can be problematic. For instance,
          ``int`` will raise an error during tracing. You may wrap them in your
          own function and then pass that in ``autowrap_functions`` as one of
          the ``tracer_kwargs``.

    For further information on FX see the
    `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.

    Args:
        model (nn.Module): model on which we will extract the features
        return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
            containing the names (or partial names - see note above)
            of the nodes for which the activations will be returned. If it is
            a ``Dict``, the keys are the node names, and the values
            are the user-specified keys for the graph module's returned
            dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
            node specification strings directly to output names. In the case
            that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
            this should not be specified.
        train_return_nodes (list or dict, optional): similar to
            ``return_nodes``. This can be used if the return nodes
            for train mode are different than those from eval mode.
            If this is specified, ``eval_return_nodes`` must also be specified,
            and ``return_nodes`` should not be specified.
        eval_return_nodes (list or dict, optional): similar to
            ``return_nodes``. This can be used if the return nodes
            for train mode are different than those from eval mode.
            If this is specified, ``train_return_nodes`` must also be specified,
            and `return_nodes` should not be specified.
        tracer_kwargs (dict, optional): a dictionary of keyword arguments for
            ``NodePathTracer`` (which passes them onto it's parent class
            `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
            By default it will be set to wrap and make leaf nodes all torchvision ops:
            {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
            WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
            provided dictionary.
        suppress_diff_warning (bool, optional): whether to suppress a warning
            when there are discrepancies between the train and eval version of
            the graph. Defaults to False.

    Examples::

        >>> # Feature extraction with resnet
        >>> model = torchvision.models.resnet18()
        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
        >>> model = create_feature_extractor(
        >>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
        >>> out = model(torch.rand(1, 3, 224, 224))
        >>> print([(k, v.shape) for k, v in out.items()])
        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]

        >>> # Specifying leaf modules and leaf functions
        >>> def leaf_function(x):
        >>>     # This would raise a TypeError if traced through
        >>>     return int(x)
        >>>
        >>> class LeafModule(torch.nn.Module):
        >>>     def forward(self, x):
        >>>         # This would raise a TypeError if traced through
        >>>         int(x.shape[0])
        >>>         return torch.nn.functional.relu(x + 4)
        >>>
        >>> class MyModule(torch.nn.Module):
        >>>     def __init__(self):
        >>>         super().__init__()
        >>>         self.conv = torch.nn.Conv2d(3, 1, 3)
        >>>         self.leaf_module = LeafModule()
        >>>
        >>>     def forward(self, x):
        >>>         leaf_function(x.shape[0])
        >>>         x = self.conv(x)
        >>>         return self.leaf_module(x)
        >>>
        >>> model = create_feature_extractor(
        >>>     MyModule(), return_nodes=['leaf_module'],
        >>>     tracer_kwargs={'leaf_modules': [LeafModule],
        >>>                    'autowrap_functions': [leaf_function]})

    """
    tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
    is_training = model.training

    if all(arg is None
           for arg in [return_nodes, train_return_nodes, eval_return_nodes]):

        raise ValueError(
            "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified"
        )

    if (train_return_nodes is None) ^ (eval_return_nodes is None):
        raise ValueError(
            "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified"
        )

    if not ((return_nodes is None) ^ (train_return_nodes is None)):
        raise ValueError(
            "If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified"
        )

    # Put *_return_nodes into Dict[str, str] format
    def to_strdict(n) -> Dict[str, str]:
        if isinstance(n, list):
            return {str(i): str(i) for i in n}
        return {str(k): str(v) for k, v in n.items()}

    if train_return_nodes is None:
        return_nodes = to_strdict(return_nodes)
        train_return_nodes = deepcopy(return_nodes)
        eval_return_nodes = deepcopy(return_nodes)
    else:
        train_return_nodes = to_strdict(train_return_nodes)
        eval_return_nodes = to_strdict(eval_return_nodes)

    # Repeat the tracing and graph rewriting for train and eval mode
    tracers = {}
    graphs = {}
    mode_return_nodes: Dict[str, Dict[str, str]] = {
        "train": train_return_nodes,
        "eval": eval_return_nodes
    }
    for mode in ["train", "eval"]:
        if mode == "train":
            model.train()
        elif mode == "eval":
            model.eval()

        # Instantiate our NodePathTracer and use that to trace the model
        tracer = NodePathTracer(**tracer_kwargs)
        graph = tracer.trace(model)

        name = model.__class__.__name__ if isinstance(
            model, nn.Module) else model.__name__
        graph_module = fx.GraphModule(tracer.root, graph, name)

        available_nodes = list(tracer.node_to_qualname.values())
        # FIXME We don't know if we should expect this to happen
        if len(set(available_nodes)) != len(available_nodes):
            raise ValueError(
                "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
            )
        # Check that all outputs in return_nodes are present in the model
        for query in mode_return_nodes[mode].keys():
            # To check if a query is available we need to check that at least
            # one of the available names starts with it up to a .
            if not any([
                    re.match(rf"^{query}(\.|$)", n) is not None
                    for n in available_nodes
            ]):
                raise ValueError(
                    f"node: '{query}' is not present in model. Hint: use "
                    "`get_graph_node_names` to make sure the "
                    "`return_nodes` you specified are present. It may even "
                    "be that you need to specify `train_return_nodes` and "
                    "`eval_return_nodes` separately.")

        # Remove existing output nodes (train mode)
        orig_output_nodes = []
        for n in reversed(graph_module.graph.nodes):
            if n.op == "output":
                orig_output_nodes.append(n)
        if not orig_output_nodes:
            raise ValueError(
                "No output nodes found in graph_module.graph.nodes")

        for n in orig_output_nodes:
            graph_module.graph.erase_node(n)

        # Find nodes corresponding to return_nodes and make them into output_nodes
        nodes = [n for n in graph_module.graph.nodes]
        output_nodes = OrderedDict()
        for n in reversed(nodes):
            module_qualname = tracer.node_to_qualname.get(n)
            if module_qualname is None:
                # NOTE - Know cases where this happens:
                # - Node representing creation of a tensor constant - probably
                #   not interesting as a return node
                # - When packing outputs into a named tuple like in InceptionV3
                continue
            for query in mode_return_nodes[mode]:
                depth = query.count(".")
                if ".".join(module_qualname.split(".")[:depth + 1]) == query:
                    output_nodes[mode_return_nodes[mode][query]] = n
                    mode_return_nodes[mode].pop(query)
                    break
        output_nodes = OrderedDict(reversed(list(output_nodes.items())))

        # And add them in the end of the graph
        with graph_module.graph.inserting_after(nodes[-1]):
            graph_module.graph.output(output_nodes)

        # Remove unused modules / parameters
        graph_module.graph.eliminate_dead_code()
        graph_module.recompile()

        # Keep track of the tracer and graph so we can choose the main one
        tracers[mode] = tracer
        graphs[mode] = graph

    # Warn user if there are any discrepancies between the graphs of the
    # train and eval modes
    if not suppress_diff_warning:
        _warn_graph_differences(tracers["train"], tracers["eval"])

    # Build the final graph module
    graph_module = DualGraphModule(model,
                                   graphs["train"],
                                   graphs["eval"],
                                   class_name=name)

    # Restore original training mode
    model.train(is_training)
    graph_module.train(is_training)

    return graph_module
Пример #18
0
def minifier(fail_f: fx.GraphModule, inps, module_fails):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    """
    failing_graph = fail_f.graph
    cur_size = len(failing_graph.nodes)

    def graph_fails(graph, inps):

        mod = fx.GraphModule(fail_f, graph)
        mod.graph.lint()
        return module_fails(mod, inps)

    ConcreteProp(fail_f).propagate(*inps)
    if not graph_fails(failing_graph, inps):
        raise RuntimeError("Input graph did not fail the tester")
    print(f"Started off with {cur_size} nodes")

    def remove_suffix(cur_graph, cur_inps):
        print("Strategy: Remove suffix")
        assert graph_fails(cur_graph, cur_inps)
        gap = 2**math.floor(math.log2(len(cur_graph.nodes)))
        tested = set()
        while gap >= 1:
            new_graph = fx.Graph()
            env = {}
            for idx, node in enumerate(cur_graph.nodes):
                new_node = new_graph.node_copy(node, lambda x: env[x])
                if node.op not in ['placeholder', 'output']:
                    if idx % gap == 0 and idx not in tested:
                        output_node = new_graph.output((new_node, ))
                        if graph_fails(new_graph, cur_inps) and len(
                                new_graph.nodes) < len(cur_graph.nodes):
                            print()
                            print(
                                f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})"
                            )
                            return (new_graph, cur_inps), True
                        else:
                            tested.add(idx)
                            new_graph.erase_node(output_node)
                env[node] = new_node
            gap //= 2
        print("FAIL: Could not remove suffix")
        return (cur_graph, cur_inps), False

    def remove_unused_inputs(cur_graph, cur_inps):
        assert graph_fails(cur_graph, cur_inps)
        ph_nodes = _get_placeholders(cur_graph)
        if len(ph_nodes) != len(cur_inps):
            print(cur_graph)
            print(len(cur_inps))
        assert len(ph_nodes) == len(cur_inps)

        new_inps = []
        for idx in range(len(ph_nodes)):
            if len(ph_nodes[idx].users) == 0:
                cur_graph.erase_node(ph_nodes[idx])
            else:
                new_inps.append(cur_inps[idx])

        if len(new_inps) < len(cur_inps) and graph_fails(cur_graph, new_inps):
            print("Strategy: Remove unused inputs")
            print(
                f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs"
            )
            return (cur_graph, new_inps), True
        else:
            return (cur_graph, new_inps), False

    def eliminate_dead_code(cur_graph, cur_inps):
        orig_size = len(cur_graph.nodes)
        if cur_graph.eliminate_dead_code() and graph_fails(
                cur_graph, cur_inps):
            print("Strategy: Eliminate dead code")
            print(
                f"SUCCESS: Went from {orig_size} nodes to {len(cur_graph.nodes)} nodes"
            )
            return (cur_graph, cur_inps), True
        else:
            return (cur_graph, cur_inps), False

    def consolidate_placeholders(cur_graph):
        new_graph = fx.Graph()
        env = {}
        for node in cur_graph.nodes:
            if node.op == 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node

        for node in cur_graph.nodes:
            if node.op != 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node
        return new_graph

    def delta_debugging(cur_graph: fx.Graph, cur_inps):
        print("Strategy: Delta Debugging")
        assert graph_fails(cur_graph, cur_inps)
        starting_placeholders = len(_get_placeholders(cur_graph))
        num_nodes = len(cur_graph.nodes)
        gap = int(2**math.floor(math.log2(num_nodes)))
        while gap >= 1:
            for start_range in range(0, num_nodes, gap):
                is_removing = False
                new_graph = copy.deepcopy(cur_graph)
                new_inps = cur_inps[:]
                end_range = min(num_nodes, start_range + gap)
                for idx in range(start_range, end_range):
                    new_node = list(new_graph.nodes)[idx]
                    if new_node.op not in ['placeholder', 'output']:
                        is_removing = True
                        _convert_node_to_placeholder(new_node, new_inps)
                if not is_removing:
                    continue
                new_graph = consolidate_placeholders(new_graph)
                if graph_fails(new_graph, new_inps):
                    print(
                        f"SUCCESS: Removed ({start_range}:{end_range}] - Went from {starting_placeholders} "
                        f"placeholders to {len(_get_placeholders(new_graph))}")
                    return (new_graph, new_inps), True
            gap //= 2

        print("FAIL: Could not remove prefix")
        return (cur_graph, inps), False

    print("###################")
    print(f"Current size: {len(failing_graph.nodes)}")
    print("###################")
    while True:
        any_succeeded = False
        strategies = [
            remove_suffix, eliminate_dead_code, remove_unused_inputs,
            delta_debugging, eliminate_dead_code, remove_unused_inputs
        ]
        for strategy in strategies:
            out = strategy(copy.deepcopy(failing_graph), inps[:])
            (cur_graph, cur_inps), succeeded = out
            if succeeded:
                print()
                print("###################")
                print(f"Current size: {len(cur_graph.nodes)}")
                print("###################")
                failing_graph = cur_graph
                inps = cur_inps
                any_succeeded = True

        if not any_succeeded:
            break
    failing_fx = fx.GraphModule(fail_f, failing_graph)
    print(f"""
inps = {[(i.shape, i.dtype) for i in inps]}
inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
{failing_fx.code}
f = torch.jit.script(forward)
with torch.jit.fuser("fuser2"):
  for _ in range(5):
    f(*inps)""")
    return failing_fx, inps
Пример #19
0
 def deepcopy_fx_graph(fx_graph):
     return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
Пример #20
0
def minifier(fail_f: fx.GraphModule,
             inps,
             module_fails,
             dump_state: Callable = dump_state):
    """
    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.

    Does 2 main strategies:
    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
        tries replacing quarter of the graph, etc.

    >>> failing_function = fx.symbolic_trace(f)
    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))

    note: module_fails returns True if it fails.
    """
    failing_graph = fail_f.graph
    cur_size = len(failing_graph.nodes)

    num_queries = 0

    def deepcopy_fx_graph(fx_graph):
        return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph

    def graph_fails(graph, inps):
        nonlocal num_queries
        graph = copy.deepcopy(graph)
        num_queries += 1
        mod = fx.GraphModule(fail_f, graph)
        mod.graph.lint()
        return module_fails(mod, inps)

    ConcreteProp(fail_f).propagate(*inps)
    if not graph_fails(failing_graph, inps):
        raise RuntimeError("Input graph did not fail the tester")
    print(f"Started off with {cur_size} nodes")

    def _register_strategy(strategy: Callable, name: str):
        @wraps(strategy)
        def new_func(old_state: ReproState, granularity=1):
            print()
            print(
                f"Strategy: {name} (G: {granularity}) ({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)"
            )
            new_state = strategy(deepcopy_fx_graph(old_state.graph),
                                 list(old_state.inps), granularity)
            if new_state is not None:
                new_nodes = len(new_state.graph.nodes)
                old_nodes = len(old_state.graph.nodes)
                new_inps = len(new_state.inps)
                old_inps = len(old_state.inps)
                new_outs = len(get_outputs(new_state.graph))
                old_outs = len(get_outputs(old_state.graph))
                progress_made = False
                if new_nodes < old_nodes:
                    progress_made = True
                    print(
                        f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes")
                if new_inps > old_inps:
                    progress_made = True
                    print(
                        f"SUCCESS: Went from {old_inps} to {new_inps} inputs")
                if new_outs < old_outs:
                    progress_made = True
                    print(
                        f"SUCCESS: Went from {old_outs} to {new_outs} outputs")

                if not progress_made:
                    raise RuntimeError("Success raised but no progress made?")

                if not graph_fails(new_state.graph, new_state.inps):
                    print(
                        "WARNING: Something went wrong, not applying this minification"
                    )
                    return None
                return new_state
            else:
                print(f"FAIL: {name}")
            return None

        return new_func

    def register_strategy(name: str):
        return partial(_register_strategy, name=name)

    @register_strategy("Truncate suffix")
    def remove_suffix(cur_graph, cur_inps, granularity):
        tested = set()
        new_graph = fx.Graph()
        env = {}
        for idx, node in enumerate(cur_graph.nodes):
            new_node = new_graph.node_copy(node, lambda x: env[x])
            if node.op not in ['placeholder', 'output']:
                # If idx is divisible by (granularity * 2), it would have been checked already.
                if idx % granularity == 0 and (idx % (granularity * 2) !=
                                               0) and idx not in tested:
                    output_node = new_graph.output((new_node, ))
                    if len(new_graph.nodes) < len(
                            cur_graph.nodes) and graph_fails(
                                new_graph, cur_inps):
                        return ReproState(new_graph, cur_inps)
                    else:
                        tested.add(idx)
                        new_graph.erase_node(output_node)
            env[node] = new_node
        return None

    @register_strategy("Remove outputs")
    def remove_outputs(cur_graph, cur_inps, granularity):
        granularity = max(1, granularity // 2)
        for idx, node in enumerate(cur_graph.nodes):
            node.idx = idx
            if node.op == 'output':
                output = node
                break

        output_args = sorted(output.args[0],
                             key=lambda x: x.idx
                             if isinstance(x, fx.Node) else int(1e9))
        if len(output_args) == 1:
            return None

        for idx in range(0, len(output_args), granularity):
            output.args = (output_args[:idx] +
                           output_args[idx + granularity:], )
            if graph_fails(cur_graph, cur_inps):
                return ReproState(cur_graph, cur_inps)
        return None

    def remove_unused_inputs_unchecked(cur_state: ReproState):
        cur_graph = cur_state.graph
        cur_inps = cur_state.inps
        ph_nodes = get_placeholders(cur_graph)
        assert len(ph_nodes) == len(cur_inps)

        new_inps = []
        for idx in range(len(ph_nodes)):
            if len(ph_nodes[idx].users) == 0:
                cur_graph.erase_node(ph_nodes[idx])
            else:
                new_inps.append(cur_inps[idx])
        if len(new_inps) < len(cur_inps):
            return ReproState(cur_graph, new_inps)
        return None

    def remove_unused_inputs_checked(cur_state: ReproState):
        new_state = remove_unused_inputs_unchecked(cur_state)
        if new_state is not None and graph_fails(new_state.graph,
                                                 new_state.inps):
            return new_state
        return None

    def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
        return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))

    remove_unused_inputs = register_strategy("Remove unused inputs")(
        _remove_unused_wrapper)

    @register_strategy("Eliminate dead code")
    def eliminate_dead_code(cur_graph, cur_inps, granularity):
        if cur_graph.eliminate_dead_code() and graph_fails(
                cur_graph, cur_inps):
            return ReproState(cur_graph, cur_inps)
        return None

    def _consolidate_placeholders(cur_graph):
        new_graph = fx.Graph()
        env = {}
        for node in cur_graph.nodes:
            if node.op == 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node

        for node in cur_graph.nodes:
            if node.op != 'placeholder':
                new_node = new_graph.node_copy(node, lambda x: env[x])
                env[node] = new_node
        return new_graph

    @register_strategy("Delta Debugging")
    def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
        num_nodes = len(cur_graph.nodes)
        for start_range in range(0, num_nodes, granularity):
            is_removing = False
            new_graph = deepcopy_fx_graph(cur_graph)
            new_inps = cur_inps[:]
            end_range = min(num_nodes, start_range + granularity)
            for idx in range(start_range, end_range):
                new_node = list(new_graph.nodes)[idx]
                if new_node.op not in ['placeholder', 'output']:
                    is_removing = True
                    _convert_node_to_placeholder(new_node, new_inps)
            if not is_removing:
                continue
            new_graph = _consolidate_placeholders(new_graph)
            new_state = remove_unused_inputs_unchecked(
                ReproState(new_graph, new_inps))
            if new_state is None:
                new_state = ReproState(new_graph, new_inps)
            if graph_fails(new_state.graph, new_state.inps):
                return ReproState(new_state.graph, new_state.inps)

        return None

    failing_state = ReproState(failing_graph, inps)

    def try_granularity(failing_state, granularity, use_non_granular):
        print(f"Trying granularity {granularity}")

        strategies = []
        num_nodes = len(failing_state.graph.nodes)
        num_outputs = len(get_outputs(failing_state.graph))
        if num_outputs > num_nodes // 2:
            strategies += [remove_outputs]

        if use_non_granular:
            strategies += [eliminate_dead_code, remove_unused_inputs]

        strategies += [remove_suffix, delta_debugging]

        for strategy in strategies:
            new_state = strategy(failing_state, granularity)
            if new_state is not None:
                return new_state
        return None

    while True:
        dump_state(fx.GraphModule(fail_f, failing_state.graph),
                   failing_state.inps)
        granularity = int(2**(math.floor(
            math.log2(len(failing_state.graph.nodes)))))
        new_state = try_granularity(failing_state,
                                    granularity,
                                    use_non_granular=True)
        if new_state is not None:
            failing_state = new_state
            continue

        granularity //= 2
        has_progress = False
        while granularity >= 1:
            new_state = try_granularity(failing_state,
                                        granularity,
                                        use_non_granular=False)
            if new_state is not None:
                failing_state = new_state
                has_progress = True
                break
            granularity //= 2
        if has_progress:
            continue

        new_state = remove_outputs(failing_state, 1)
        if new_state is not None:
            failing_state = new_state
            continue

        break

    if not graph_fails(failing_state.graph, failing_state.inps):
        raise RuntimeError(
            "Uh oh, something went wrong :( Final graph is not failing")

    print(f"Made {num_queries} queries")
    failing_fx = fx.GraphModule(fail_f, failing_state.graph)
    dump_state(failing_fx, failing_state.inps)
    print("Wrote minimal repro out to repro.py")
    return failing_fx, failing_state.inps
Пример #21
0
    def graph_fails(graph, inps):

        mod = fx.GraphModule(fail_f, graph)
        mod.graph.lint()
        return module_fails(mod, inps)