def _patch_arguments_(
    gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
    """
    Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
    and slices).
    """

    def _patch_slice(s, mapping):
        return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))

    graph = gm.graph
    supported_types = (Node, str, int, float)
    for node in graph.nodes:
        new_args = []
        for arg in node.args:
            if isinstance(arg, tuple):
                new_arg = []
                for a in arg:
                    if isinstance(a, slice):
                        new_arg.append(_patch_slice(a, mapping))
                    else:
                        new_arg.append(mapping.get(a, a))
                new_args.append(tuple(new_arg))
            elif isinstance(arg, slice):
                new_args.append(_patch_slice(arg, mapping))
            elif isinstance(arg, supported_types):
                new_args.append(mapping.get(arg, arg))
            else:
                new_args.append(arg)
        node.args = tuple(new_args)

    if lint_and_recompile:
        graph.lint()
        gm.recompile()
Ejemplo n.º 2
0
def wrap_in_activation_function(m: GraphModule,
                                fn: ActivationFunction) -> GraphModule:
    # Get output node
    output_node: Optional[Node] = None
    for n in reversed(m.graph.nodes):
        if n.op == "output":
            output_node = n
            break
    assert output_node

    # Get the actual output (the "input" of the output node). This is
    # the Node we want to wrap in a user-specified activation function
    assert len(output_node.all_input_nodes) == 1
    wrap_node = output_node.all_input_nodes[0]

    # Wrap the actual output in a Proxy
    wrap_proxy = Proxy(wrap_node)

    # Get the implementation of the specified activation function and
    # symbolically trace it
    fn_impl = activation_functions[fn]
    fn_impl_traced = symbolic_trace(fn_impl)

    # Call the specified activation function using the Proxy wrapper for
    # `output_op`. The result of this call is another Proxy, which we
    # can hook into our existing Graph.
    with traced.graph.inserting_after(wrap_node):
        fn_impl_output_node = fn_impl_traced(wrap_proxy)
        new_args = (fn_impl_output_node.node, )
        output_node.args = new_args

    m.recompile()
def _patch_getitem_(
    gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
    """Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
    # TODO: combine this with the patch_argument function which seems to do almost the same thing.
    graph = gm.graph
    for node in graph.nodes:
        if node.op == "call_function" and node.target == operator.getitem:
            indices = node.args[1]
            if isinstance(indices, tuple):
                new_indices = []
                for idx in indices:
                    if isinstance(idx, slice):
                        new_indices.append(
                            slice(
                                mapping.get(idx.start, idx.start),
                                mapping.get(idx.stop, idx.stop),
                                mapping.get(idx.step, idx.step),
                            )
                        )
                    elif isinstance(idx, int):
                        new_indices.append(mapping.get(idx, idx))
                    else:
                        new_indices.append(idx)

                node.args = (node.args[0], tuple(new_indices))
            else:
                node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))

        if lint_and_recompile:
            graph.lint()
            gm.recompile()
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
    """Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
    graph = gm.graph
    input_names = set(gm.dummy_inputs.keys())
    encoder_sequence_length_node = None
    for node in graph.nodes:
        if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
            with graph.inserting_after(node):
                # There are two cases to handle:
                #   1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
                #      input shapes is [batch_size, sequence_length] => index 1
                #   2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
                #      shape is [batch_size, num_choices, sequence_length] => index 2
                encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))

    if encoder_sequence_length_node is None:
        raise ValueError("Could not insert the node that computes the encoder sequence length")

    if lint_and_recompile:
        graph.lint()
        gm.recompile()

    # Useful when retracing for quantization.
    if hasattr(gm, "_qconfig_map"):
        gm._qconfig_map[encoder_sequence_length_node.name] = None

    return encoder_sequence_length_node
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
    """Transformation that enables traced models to perform inference on dynamic input shapes."""
    graph = gm.graph
    static2dynamic = {}

    # Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
    if gm.use_dynamic_batch_size:
        batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
        static2dynamic[gm.static_batch_size] = batch_size_node
        if gm.num_choices > 0:
            with graph.inserting_after(batch_size_node):
                static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
                    operator.mul, args=(batch_size_node, gm.num_choices)
                )
            # Useful when retracing for quantization.
            if hasattr(gm, "_qconfig_map"):
                gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None

    if gm.use_dynamic_sequence_length:
        encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
        static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node

        # TODO: do the same for the decoder.
        pass

    _change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
    _patch_getitem_(gm, static2dynamic, lint_and_recompile=False)

    remove_unused_nodes_(gm, lint_and_recompile=False)

    graph.lint()
    gm.recompile()

    gm.static2dynamic = static2dynamic
    gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
Ejemplo n.º 6
0
def tensorexpr_compile(fx_module: fx.GraphModule, flat_args) -> Callable:
    """Compiles the given fx_module using TensorExpr Kernel"""
    inp_devices = {i.device for i in flat_args if isinstance(i, torch.Tensor)}
    assert len(inp_devices) == 1
    inp_device = list(inp_devices)[0]
    inputs = []
    output_refs = []
    for node in fx_module.graph.nodes:
        if node.op == "placeholder":
            inputs.append(node)
        elif node.op == "output":
            outputs = node.args[0]
            if not isinstance(outputs, Iterable):
                outputs = (outputs, )
            new_outputs = []
            for idx, output in enumerate(outputs):
                # Appends (bool, idx) pairs
                # if True, read from kernel outputs
                # if False, read from kernel inputs
                if output in inputs:
                    output_refs.append((False, inputs.index(output)))
                elif output in outputs[:idx]:
                    output_refs.append(
                        (True, output_refs[outputs.index(output)][1]))
                else:
                    output_refs.append((True, len(new_outputs)))
                    new_outputs.append(output)
            node.args = (tuple(new_outputs), )
    fx_module.graph.lint()
    fx_module.recompile()

    for i in range(0, 100):
        attr = f"_tensor_constant{i}"
        if hasattr(fx_module, attr):
            setattr(fx_module, attr, getattr(fx_module, attr).to(inp_device))
        else:
            break

    jit_module = torch.jit.trace(fx_module, flat_args)
    jit_module = torch.jit.freeze(jit_module.eval())
    torch._C._jit_trace_module(jit_module._c, tuple(flat_args))
    torch._C._te.remove_unused_self_argument(jit_module.graph)
    torch._C._te.annotate_input_shapes(jit_module.graph, tuple(flat_args))
    torch._C._jit_pass_lower_all_tuples(jit_module.graph)
    te_kernel = torch._C._te.TensorExprKernel(jit_module.graph)

    def f(*args):
        outs = te_kernel.run(args)
        if not isinstance(outs, tuple) and not isinstance(outs, list):
            outs = (outs, )
        real_outs = []
        for out in output_refs:
            if out[0]:
                real_outs.append(outs[out[1]])
            else:
                real_outs.append(args[out[1]])
        return real_outs

    return f
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
    """Removes all the unused nodes in a GraphModule."""
    graph = gm.graph
    for node in graph.nodes:
        if not node.users and node.op not in ["placeholder", "output"]:
            graph.erase_node(node)

    if lint_and_recompile:
        graph.lint()
        gm.recompile()
Ejemplo n.º 8
0
def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
    """
    Compiles the :attr:`fx_g` with Torchscript compiler.

    .. warning::
        This API is experimental and likely to change.

    Args:
        fx_g(fx.GraphModule): The input Fx graph module to be compiled.

    Returns:
        Torch scripted model.
    """
    for node in fx_g.graph.nodes:
        if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty):
            if node.args[1] == []:
                args = list(node.args)
                args[1] = [1]
                node.args = tuple(args)
        elif node.target is torch.ops.aten.masked_fill and node.args[
                2] == float("-inf"):
            # Fx graph to torchscript fails for -inf
            args = list(node.args)
            args[2] = -3.403 * 10**37
            node.args = tuple(args)

    for node in fx_g.graph.nodes:
        new_kwargs = {}
        for k, v in node.kwargs.items():
            if isinstance(v, torch.device):
                v = v.type
            new_kwargs[k] = v
        node.kwargs = new_kwargs

    fx_g.graph.lint()

    # print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
    # Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
    for i in range(1000):
        attr = f"_tensor_constant{i}"
        if hasattr(fx_g, attr):
            setattr(fx_g, attr, getattr(fx_g, attr).cuda())
        else:
            break

    fx_g.recompile()

    f = torch.jit.script(fx_g)

    torch._C._jit_pass_remove_mutation(f.graph)

    f = torch.jit.freeze(f.eval())
    f = torch.jit.optimize_for_inference(f)
    return f
def _change_view_methods_(
    gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
    """
    Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
    batch_size / sequence_length nodes.
    """
    graph = gm.graph
    for node in graph.nodes:
        if node.op == "call_method" and node.target == "view":
            if isinstance(node.args[1], tuple):
                node.args = (node.args[0], *node.args[1])
            node.args = tuple((mapping.get(arg, arg) for arg in node.args))

    if lint_and_recompile:
        graph.lint()
        gm.recompile()
Ejemplo n.º 10
0
def _remove_duplicate_output_args(gm: fx.GraphModule) -> RemoveDuplicateResult:
    output_nodes = [n for n in gm.graph.nodes if n.op == "output"]
    assert len(output_nodes) == 1, \
           f"Expecting exactly one `output` node, but got {len(output_nodes)}"

    changed = False
    # arg node name to its index in the new output args tuple
    name_to_idx: t.Dict[str, int] = {}
    output_node = output_nodes[0]

    # Output op only uses its `args[0]`, and it does not have `kwargs`.
    # https://pytorch.org/docs/stable/fx.html#torch.fx.Node
    args: t.Sequence[t.Any] = output_node.args[0]

    # Only concern outselves to the case where the args is an iterable of fx.Node.
    # Other return cases (e.g., a single value) is possible and we don't handle
    # that in this pass.
    if not (isinstance(args, t.Iterable) and all(isinstance(a, fx.Node) for a in args)):
        return RemoveDuplicateResult(replacement_map=None, module=gm)

    # Map old index of the arg node to the remaining node's idx,
    # initialized to `i => i`
    replacement_map: t.List[int] = list(range(len(args)))
    args_new = []
    for idx, a in enumerate(args):
        assert isinstance(a, fx.Node), \
               f"Expecting fx.Node instance, but got: {type(a)}"

        if a.name not in name_to_idx:
            args_new.append(a)
            name_to_idx[a.name] = len(args_new) - 1
        else:
            changed = True
            _LOGGER.warning(
                f"Replaced duplicate output arg '{a.name}': "
                f"{idx} -> {name_to_idx[a.name]}"
            )
        replacement_map[idx] = name_to_idx[a.name]

    output_node.args = (tuple(args_new),)
    if changed:
        gm.recompile()
    return RemoveDuplicateResult(replacement_map, module=gm)
Ejemplo n.º 11
0
def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
    """
    Compiles the :attr:`fx_g` with Torchscript compiler.

    .. warning::
        This API is experimental and likely to change.

    Args:
        fx_g(fx.GraphModule): The input Fx graph module to be compiled.

    Returns:
        Torch scripted model.
    """

    with _disable_jit_autocast():
        strip_overloads(fx_g)

        for node in fx_g.graph.nodes:
            if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
                    and len(node.kwargs) == 1 and "dtype" in node.kwargs):
                node.target = torch.ops.aten.to

        for node in fx_g.graph.nodes:
            new_kwargs = {}
            for k, v in node.kwargs.items():
                if isinstance(v, torch.device):
                    v = v.type
                new_kwargs[k] = v
            node.kwargs = new_kwargs

        fx_g.graph.lint()

        fx_g.recompile()

        f = torch.jit.script(fx_g)

        torch._C._jit_pass_remove_mutation(f.graph)

        f = torch.jit.freeze(f.eval())
        f = torch.jit.optimize_for_inference(f)
        f(*inps)
    return f
Ejemplo n.º 12
0
def remove_duplicate_output_args(
    top_level: fx.GraphModule,
    target_subnets: t.Collection[str]
) -> t.Mapping[str, "RemoveDuplicateResult"]:
    """Removes duplicate output args.

    This pass removes duplicate output args from the target subnets and fixes
    their uses in the top level module where the subnets are called. This pass
    must be called after acc split on the top-level net and subsequent calls to
    the acc trace on the subnets.

    This pass will change both the subnets and top level module.

    Returns:
        a mapping of the target subnet name to its dedupcate result
    """

    processed_subnets = {}
    for node in top_level.graph.nodes:  # type: fx.Node
        if node.op == "call_module" and node.name in target_subnets:
            assert isinstance(node.target, str)
            sub_gm = top_level.get_submodule(node.target)
            assert isinstance(sub_gm, fx.GraphModule)

            replace_res = _remove_duplicate_output_args(sub_gm)
            processed_subnets[node.name] = replace_res
            if replace_res.replacement_map is None:
                continue
            sub_gm.recompile()

            needs_recompile = False
            # iterate on the copy since we will be changing elements of node.users
            for user in list(node.users):
                idx = _ensure_proper_output_use(user, node)
                idx_new = replace_res.replacement_map[idx]
                if idx_new != idx:
                    user.args = (user.args[0], idx_new)
                    needs_recompile = True

            if needs_recompile:
                top_level.recompile()
    return processed_subnets
Ejemplo n.º 13
0
def change_fx_graph_return_to_tuple(fx_g: fx.GraphModule):
    for node in fx_g.graph.nodes:
        if node.op == "output":
            # output nodes always have one argument
            node_arg = node.args[0]
            out_nodes = []
            if isinstance(node_arg, list):
                # Don't return NoneType elements.
                for out_node in node_arg:
                    if not isinstance(out_node, type(None)):
                        out_nodes.append(out_node)
                # If there is a single tensor/element to be returned don't
                # create a tuple for it.
                if len(out_nodes) == 1:
                    node.args = out_nodes
                else:
                    node.args = (tuple(out_nodes), )
    fx_g.graph.lint()
    fx_g.recompile()
    return fx_g
Ejemplo n.º 14
0
def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
    """
    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
    `gm` gets run.

    Args:
        gm: graph module to insert breakpoint. It is then recompiled for it to
            take effect.

    Returns:
        the `gm` with breakpoint inserted.
    """
    def insert_pdb(body):
        return ["import pdb; pdb.set_trace()\n", *body]

    with gm.graph.on_generate_code(make_transformer=lambda cur_transform: (
            # new code transformer to register
            lambda body:
        (insert_pdb(cur_transform(body) if cur_transform else body)))):
        gm.recompile()

    return gm
Ejemplo n.º 15
0
def force_lazy_device(model: fx.GraphModule):
    """
    Factory methods in a Fx graph may create tensors for a specific eager devices.
    If we take no actions, those eager tensors will be mixed with lazy tensors and
    cause crash. This method overwrite those eager device to lazy device.
    """
    def tolazydevice(dev):
        if isinstance(dev, torch.device):
            return torch.device("lazy", index=dev.index)
        return dev

    def hasDeviceArg(args, kwargs):
        return any(
            isinstance(arg, torch.device)
            for arg in itertools.chain(args, kwargs.values()))

    for nd in model.graph.nodes:
        nd.args = tuple(tolazydevice(arg) for arg in nd.args)
        nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()}

        # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return
        # eager tensors on the default device
        # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove,
        # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart).
        # To force those tensors on the lazy device, we can not simply override
        # the device argument since there is no explicit device argument.
        # What we are doing here is, for the list of covered tensor factory methods
        # we add a lazy device argument explicity.
        #
        # TODO: This solution is no ideal since we may miss some factory methods. In future
        # when we support lazy mode, this method can be replaced by that.
        if nd.target in tensor_factory_functions and not hasDeviceArg(
                nd.args, nd.kwargs):
            kwargs = dict(
                nd.kwargs)  # nd.kwargs is immutable. make a mutable copy.
            kwargs["device"] = torch.device("lazy")
            nd.kwargs = kwargs

    model.recompile()
Ejemplo n.º 16
0
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
    """Inserts a node that retrieves the batch size dynamically from the input of the model."""
    graph = gm.graph
    input_names = set(gm.dummy_inputs.keys())
    batch_size_node = None
    for node in graph.nodes:
        if node.op == "placeholder" and node.name in input_names:
            with graph.inserting_after(node):
                batch_size_node = graph.call_method("size", args=(node, 0))

    if batch_size_node is None:
        raise ValueError("Could not insert the node that computes the batch size")

    if lint_and_recompile:
        graph.lint()
        gm.recompile()

    # Useful when retracing for quantization.
    if hasattr(gm, "_qconfig_map"):
        gm._qconfig_map[batch_size_node.name] = None

    return batch_size_node
Ejemplo n.º 17
0
def min_cut_rematerialization_partition(
    joint_module: fx.GraphModule, _joint_inputs
) -> Tuple[fx.GraphModule, fx.GraphModule]:
    """
    Partitions the joint graph such that the backward recomputes the forward.
    Recomputing helps in trading off memory bandwidth with computation.

    To create the fwd and bwd graph, we copy the joint graph, manually set the
    outputs to just original forward or backward outputs. And then we run the
    resulting graphs through dead code elimintation.

    .. warning::
        This API is experimental and likely to change.

    Args:
        joint_module(fx.GraphModule): The joint forward and backward graph. This
            is the result of AOT Autograd tracing.

    Returns:
        Returns the generated forward and backward Fx graph modules.
    """
    try:
        import networkx as nx
    except ImportError:
        raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")

    joint_module.graph.eliminate_dead_code()
    joint_module.recompile()
    fx_g = joint_module.graph

    #  add the CSE pass
    cse_graph = fx_graph_cse(fx_g)
    joint_module.graph = cse_graph
    full_bw_graph = joint_module.graph

    name_to_node = {}
    for node in joint_module.graph.nodes:
        name_to_node[node.name] = node

    def classify_nodes(joint_module):
        required_bw_nodes = set()
        for node in joint_module.graph.nodes:
            if node.op == 'placeholder' and "tangents" in node.target:
                required_bw_nodes.add(node)
            if node in required_bw_nodes:
                for user in node.users:
                    required_bw_nodes.add(user)

        primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
        fwd_outputs, _ = _extract_fwd_bwd_outputs(joint_module)
        forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
        required_fw_nodes = {name_to_node[node.name] for node in forward_only_graph.nodes
                             if node.op != 'output'}
        unclaimed_nodes = {node for node in joint_module.graph.nodes
                           if node not in required_fw_nodes and node not in required_bw_nodes}
        return required_fw_nodes, required_bw_nodes, unclaimed_nodes

    required_fw_nodes, required_bw_nodes, unclaimed_nodes = classify_nodes(joint_module)
    for node in reversed(joint_module.graph.nodes):
        if node not in required_fw_nodes:
            node.dist_from_bw = 0
        else:
            node.dist_from_bw = int(1e9)
            for user in node.users:
                node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)

    aten = torch.ops.aten

    pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward]  # noqa: E501
    misc_ops = [aten.to, aten.type_as, operator.getitem]

    reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax]  # noqa: E501

    # not recomputed by default since these are kinda expensive/hard to fuse into
    # norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward]  # noqa: E501

    # Not used by default since NVFuser can't fuse view ops
    # view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat]  # noqa: E501

    # These are the view ops that NVFuser can fuse
    view_ops = [aten.squeeze, aten.unsqueeze]
    random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
    compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d]  # noqa: E501
    unrecomputable_ops = random_ops + compute_intensive_ops

    recomputable_ops = set(
        pointwise_ops
        + misc_ops
        + reduction_ops
        + view_ops
    )
    fusible_ops = recomputable_ops | set(random_ops)

    AGGRESSIVE_RECOMPUTATION = False

    def ban_recomputation(node):
        if AGGRESSIVE_RECOMPUTATION:
            return (node.op == 'call_function' and get_aten_target(node) in unrecomputable_ops)
        else:
            if node.op != 'call_function':
                return False
            if get_aten_target(node) not in recomputable_ops:
                return True
            # If the output of the reduction is 4x smaller (arbitrary choice),
            # then we don't allow recomputation.
            if get_aten_target(node) in reduction_ops:
                input_tensors_size = sum(_size_of(i.meta['tensor_meta']) for i in node.args if isinstance(i, fx.Node))
                output_size = _size_of(node.meta['tensor_meta'])
                return (output_size * 4 < input_tensors_size)
            return False

    def is_fusible(a, b):
        return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops

    def is_materialized(node):
        if node.op == 'placeholder':
            return True

        return not all(is_fusible(node, user) for user in node.users)

    def get_node_weight(node):
        mem_sz = _size_of(node.meta['tensor_meta'])

        # Heuristic to bias towards nodes closer to the backwards pass
        mem_sz = int(mem_sz + node.dist_from_bw)

        if is_materialized(node):
            return mem_sz
        else:
            return mem_sz * 2

    nx_graph = nx.DiGraph()
    for node in full_bw_graph.nodes:
        if node.op == 'output':
            continue

        if node in required_bw_nodes:
            nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
            continue

        if node.op == 'placeholder' and "primals" in node.target:
            nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)

        # If a node can't be recomputed (too expensive or involves randomness),
        # we prevent it from being recomputed by adding an inf edge to the source
        # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
        if ban_recomputation(node) and node in required_fw_nodes:
            nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)

        if 'tensor_meta' not in node.meta:
            weight = math.inf
        else:
            weight = get_node_weight(node)

        # Creates the weights on the "node" edge
        nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
        for user in node.users:
            nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)

    cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
    reachable, non_reachable = partition
    cutset = set()
    for u, nbrs in ((n, nx_graph[n]) for n in reachable):
        cutset.update((u, v) for v in nbrs if v in non_reachable)

    cut_nodes = set()
    for node_in, node_out in cutset:
        assert node_in[:-3] == node_out[:-4]
        node_name = node_in[:-3]
        cut_nodes.add(node_name)

    # To make this stuff deterministic
    node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
    saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
    return _extract_fwd_bwd_modules(joint_module, saved_values)
Ejemplo n.º 18
0
def replace_pattern(gm: GraphModule, pattern: Callable,
                    replacement: Callable) -> None:
    """
    Matches all possible non-overlapping sets of operators and their
    data dependencies (``pattern``) in the Graph of a GraphModule
    (``gm``), then replaces each of these matched subgraphs with another
    subgraph (``replacement``).

    Args:
        ``gm``: The GraphModule that wraps the Graph to operate on
        ``pattern``: The subgraph to match in ``gm`` for replacement
        ``replacement``: The subgraph to replace ``pattern`` with

    Examples:

    .. code-block:: python

        import torch
        from torch.fx import symbolic_trace, subgraph_rewriter

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, w1, w2):
                m1 = torch.cat([w1, w2]).sum()
                m2 = torch.cat([w1, w2]).sum()
                return x + torch.max(m1) + torch.max(m2)

        def pattern(w1, w2):
            return torch.cat([w1, w2]).sum()

        def replacement(w1, w2):
            return torch.stack([w1, w2])

        traced_module = symbolic_trace(M())

        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

    The above code will first match ``pattern`` in the ``forward``
    method of ``traced_module``. Pattern-matching is done based on
    use-def relationships, not node names. For example, if you had
    ``p = torch.cat([a, b])`` in ``pattern``, you could match
    ``m = torch.cat([a, b])`` in the original ``forward`` function,
    despite the variable names being different (``p`` vs ``m``).

    The ``return`` statement in ``pattern`` is matched based on its
    value only; it may or may not match to the ``return`` statement in
    the larger graph. In other words, the pattern doesn't have to extend
    to the end of the larger graph.

    When the pattern is matched, it will be removed from the larger
    function and replaced by ``replacement``. If there are multiple
    matches for ``pattern`` in the larger function, each non-overlapping
    match will be replaced. In the case of a match overlap, the first
    found match in the set of overlapping matches will be replaced.
    ("First" here being defined as the first in a topological ordering
    of the Nodes' use-def relationships. In most cases, the first Node
    is the parameter that appears directly after ``self``, while the
    last Node is whatever the function returns.)

    One important thing to note is that the parameters of the
    ``pattern`` Callable must be used in the Callable itself,
    and the parameters of the ``replacement`` Callable must match
    the pattern. The first rule is why, in the above code block, the
    ``forward`` function has parameters ``x, w1, w2``, but the
    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
    As an example of the second rule, consider replacing

    .. code-block:: python

        def pattern(x, y):
            return torch.neg(x) + torch.relu(y)

    with

    .. code-block:: python

        def replacement(x, y):
            return torch.relu(x)

    In this case, ``replacement`` needs the same number of parameters
    as ``pattern`` (both ``x`` and ``y``), even though the parameter
    ``y`` isn't used in ``replacement``.

    After calling ``subgraph_rewriter.replace_pattern``, the generated
    Python code looks like this:

    .. code-block:: python

        def forward(self, x, w1, w2):
            stack_1 = torch.stack([w1, w2])
            sum_1 = stack_1.sum()
            stack_2 = torch.stack([w1, w2])
            sum_2 = stack_2.sum()
            max_1 = torch.max(sum_1)
            add_1 = x + max_1
            max_2 = torch.max(sum_2)
            add_2 = add_1 + max_2
            return add_2

    """
    # Get the graphs for `gm`, `pattern`, `replacement`
    original_graph = gm.graph
    pattern_graph = symbolic_trace(pattern).graph
    replacement_graph = symbolic_trace(replacement).graph

    # Find all possible pattern matches in original_graph. Note that
    # pattern matches may overlap with each other.
    matcher = SubgraphMatcher(pattern_graph)
    matches: List[Match] = []

    # Consider each node as an "anchor" (deepest matching graph node)
    for anchor in original_graph.nodes:

        if matcher.matches_subgraph_from_anchor(anchor):

            def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
                # `lookup` represents all the nodes in `original_graph`
                # that are part of `pattern`
                lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
                for n in lookup.keys():
                    if n.op == "placeholder" or lookup[n].op == "output":
                        continue
                    for user in n.users:
                        # If this node has users that were not in
                        # `lookup`, then it must leak out of the
                        # pattern subgraph
                        if user not in lookup:
                            return False
                return True

            # It's not a match if the pattern leaks out into the rest
            # of the graph
            if pattern_is_contained(matcher.nodes_map):
                for k, v in matcher.nodes_map.items():
                    # Shallow copy nodes_map
                    matches.append(
                        Match(anchor=anchor,
                              nodes_map=copy.copy(matcher.nodes_map)))

    # The set of all nodes in `original_graph` that we've seen thus far
    # as part of a pattern match
    replaced_nodes: Set[Node] = set()

    # Return TRUE if one of the nodes in the current match has already
    # been used as part of another match
    def overlaps_with_prev_match(match: Match) -> bool:
        for n in match.nodes_map.values():
            if n in replaced_nodes and n.op != "placeholder":
                return True
        return False

    for match in matches:

        # Skip overlapping matches
        if overlaps_with_prev_match(match):
            continue

        # Map replacement graph nodes to their copy in `original_graph`
        val_map: Dict[Node, Node] = {}

        pattern_placeholders = [
            n for n in pattern_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders)
        replacement_placeholders = [
            n for n in replacement_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) == len(replacement_placeholders)
        placeholder_map = {
            r: p
            for r, p in zip(replacement_placeholders, pattern_placeholders)
        }

        # node from `original_graph` that matched with the output node
        # in `pattern`
        subgraph_output: Node = match.anchor

        def mark_node_as_replaced(n: Node) -> None:
            if n not in match.nodes_map.values():
                return
            for n_ in n.all_input_nodes:
                mark_node_as_replaced(n_)
            replaced_nodes.add(n)

        mark_node_as_replaced(subgraph_output)

        # Intialize `val_map` with mappings from placeholder nodes in
        # `replacement` to their corresponding node in `original_graph`
        for replacement_node in replacement_placeholders:
            # Get the `original_graph` placeholder node
            # corresponding to the current `replacement_node`
            pattern_node = placeholder_map[replacement_node]
            original_graph_node = match.nodes_map[pattern_node]
            # Populate `val_map`
            val_map[replacement_node] = original_graph_node

        # Copy the replacement graph over
        with original_graph.inserting_before(subgraph_output):
            copied_output = original_graph.graph_copy(replacement_graph,
                                                      val_map)
        assert isinstance(copied_output, Node)

        # We only want to copy in the output node from `pattern` if we
        # have an output-output match. Otherwise, we leave out the
        # `pattern` output node so we don't have two outputs in the
        # resultant graph
        if subgraph_output.op != "output":
            subgraph_output = subgraph_output.args[0]  # type: ignore
        subgraph_output.replace_all_uses_with(copied_output)

        # Erase the `pattern` nodes
        for node in reversed(original_graph.nodes):
            if len(node.users) == 0 and node.op != "output":
                original_graph.erase_node(node)

    # Update the passed-in GraphModule to reflect the new state of
    # `original_graph`
    gm.recompile()
Ejemplo n.º 19
0
def convert(
        model: GraphModule, is_reference: bool = False,
        convert_custom_config_dict: Dict[str, Any] = None,
        is_standalone_module: bool = False,
        _remove_qconfig_flag: bool = True,
        convert_qconfig_dict: Dict[str, Any] = None,
        backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
    """
    We will convert an observed model (a module with observer calls) to a reference
    quantized model, the rule is simple:
    1. for each observer module call in the graph, we'll convert it to calls to
       quantize and dequantize functions based on the observer instance
    2. for weighted operations like linear/conv, we need to convert them to reference
       quantized module, this requires us to know whether the dtype configured for the
       weight is supported in the backend, this is done in prepare step and the result
       is stored in observed_node_names, we can decide whether we need to swap the
       module based on this set

    standalone_module means it a submodule that is not inlined in
    parent module, and will be quantized separately as one unit.

    Returns a quantized standalone module, whether input/output is quantized is
    specified by prepare_custom_config_dict, with
    input_quantized_idxs, output_quantized_idxs, please
    see docs for prepare_fx for details
    """
    if convert_custom_config_dict is None:
        convert_custom_config_dict = {}
    node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
    qconfig_map: Dict[str, QConfigAny] = model._qconfig_map  # type: ignore[assignment]

    # TODO this should be removed now that gpu support for quantization is being supported.
    # however in practice, as of 7/22/2021, certain functions that get called by convert expect
    # only cpu arguments.
    # As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
    # fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
    if not is_reference:
        model.cpu()

    # mapping from fully qualified module name to module instance
    # for example,
    # {
    #   '': Model(...),
    #   'linear': Linear(...),
    #   'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
    # }
    # We use remove_duplicate=False here because torch.cat uses
    # the same activation_post_process module instance but different names
    modules = dict(model.named_modules(remove_duplicate=False))

    # TODO refactor this code once we update the prepare logic to have additional information on
    # which graph nodes have been observed and share that with convert to decide which observers to ignore.
    if convert_qconfig_dict:
        prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict  # type: ignore[assignment]
        modules_copy = copy.deepcopy(modules)
        convert_dict_to_ordered_dict(convert_qconfig_dict)
        if model._is_qat:
            convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, {})
        convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict)

        compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict)  # type: ignore[arg-type]
        convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope)
        # check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
        # or are set to None in the convert_qconfig_map.
        for k, v in qconfig_map.items():
            assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
            if convert_qconfig_map[k] is not None:
                assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \
                and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k])
        qconfig_map = convert_qconfig_map

    custom_module_classes = get_custom_module_class_keys(
        convert_custom_config_dict,
        "observed_to_quantized_custom_module_class")
    custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})

    if model._equalization_qconfig_map is not None:
        # If we want to do equalization then do the following:
        # Calculate the equalization scale, update the observers with the scaled
        # inputs, and scale the weight
        weight_eq_obs_dict = update_obs_for_equalization(model, modules)
        convert_eq_obs(model, modules, weight_eq_obs_dict)

    # always run weight observers in the top level forward method
    # for dynamic quant ops or weight only quant ops
    run_weight_observers(model)

    graph_inputs: List[str] = []
    for node in model.graph.nodes:
        if node.op == 'placeholder':
            graph_inputs.append(node.name)

    # TODO: move this outside of this function
    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module,
            graph: Graph,
            node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map) for n in
            list(node.args) + list(node.users.keys())])
        if skip_replacement or maybe_quantize_node_info is None:
            # didn't find correponding quantize op and info for the observer_module
            # so we just remove the observer
            with graph.inserting_before(node):
                node.replace_all_uses_with(node.args[0])
                graph.erase_node(node)
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)

                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

    # this is a temporary hack for custom module, we may want to implement
    # this properly after the custom module class design is finalized
    def replace_observer_with_dequantize_node(node: Node, graph: Graph):
        call_custom_module_node = node.args[0]
        assert isinstance(call_custom_module_node, Node), \
            f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
        node.replace_all_uses_with(call_custom_module_node)
        graph.erase_node(node)
        insert_dequantize_node(call_custom_module_node, graph)

    # additional state to override inputs to be quantized, if specified
    # by the user
    placeholder_node_seen_cnt = 0
    input_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "input_quantized_idxs", [])
    output_quantized_idxs: List[int] = prepare_custom_config_dict.get(
        "output_quantized_idxs", [])

    if backend_config_dict is None:
        backend_config_dict = get_native_backend_config_dict()
    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
    # convert tuples so that it can work with isinstance(module, tuple_of_classes)
    root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
    qat_module_classes = get_qat_module_classes(backend_config_dict)
    fused_module_classes = get_fused_module_classes(backend_config_dict)
    statically_quantized_custom_module_nodes: Set[Node] = set()

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            cur_placeholder_node_idx = placeholder_node_seen_cnt
            placeholder_node_seen_cnt += 1
            if cur_placeholder_node_idx in input_quantized_idxs:
                # Inputs are assumed to be quantized if the user specifid the
                # input_quantized_idxs override.
                # we need to dequantize the inputs since all operators took
                # floating point inputs in reference quantized models
                insert_dequantize_node(node, model.graph)
        elif node.op == "output":
            # If the argument is empty we don't need to do anything
            if len(output_quantized_idxs) == 0:
                continue
            # Result are kept quantized if the user specified the
            # output_quantized_idxs override.
            # Remove the dequantize operator for the node in the end if any
            return_node = node
            output = node.args[0]
            # outputs can be Node, list, tuple, dict, other cases are not supported yet
            if isinstance(output, (list, tuple)):
                for idx in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
            elif isinstance(output, (Node, dict)):
                # we treat dict as a single argument currently, but it can be extended
                # to support {"key": dtype} after we change output_quantized_idxs to
                # dict
                if 0 in output_quantized_idxs:
                    maybe_recursive_remove_dequantize(output, return_node, model.graph)
            else:
                warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                observed_node = node.args[0]
                if observed_node in statically_quantized_custom_module_nodes:
                    replace_observer_with_dequantize_node(node, model.graph)
                else:
                    replace_observer_with_quantize_dequantize_node(
                        model, model.graph, node, modules, node_name_to_scope,
                        qconfig_map)
            elif is_observed_standalone_module(modules[node.target]):
                convert_standalone_module(
                    node, modules, model, is_reference, backend_config_dict)
            elif type(modules[node.target]) in set(
                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type(modules[node.target]) in fused_module_classes and \
                   type(modules[node.target][0]) not in root_module_classes:
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, qconfig_map, backend_config_dict)
            elif type(modules[node.target]) in custom_module_classes:
                convert_custom_module(
                    node, model.graph, modules, custom_module_class_mapping,
                    statically_quantized_custom_module_nodes)

    preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
    model = QuantizedGraphModule(model, copy.deepcopy(model.graph), preserved_attributes)

    # remove deadcode after converting observers to quant/dequant ops
    model.graph.eliminate_dead_code()
    model.recompile()

    # TODO: maybe move this to quantize_fx.py
    if not is_reference:
        model = duplicate_dequantize_node(model)
        model = duplicate_quantize_dynamic_node(model)
        model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
        model = remove_quant_dequant_pairs(model)
        model = remove_extra_dequantize(model)
    # TODO: this looks hacky, we want to check why we need this and see if we can
    # remove this
    # removes qconfig and activation_post_process modules
    if _remove_qconfig_flag:
        _remove_qconfig(model)
    return model