Esempio n. 1
0
    def test_nvfuser_call_module_backend(self, device, dtype):
        class Model(torch.nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.bn = torch.nn.BatchNorm2d(3)
                self.relu = torch.nn.ReLU()

            def forward(self, inp):
                o = self.bn(inp)
                o = self.relu(o)
                return o

        inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
        m = Model().to(dtype=dtype, device=device)

        # note that the traced module here contains only `call_module` node,
        # which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error
        traced = symbolic_trace(m)

        nvfuser = NvFuserBackend()
        compiled_module = nvfuser.compile(traced)

        eager_result = m(inp)
        nvfuser_result = compiled_module(inp)

        torch.testing.assert_close(eager_result,
                                   nvfuser_result,
                                   rtol=1e-5,
                                   atol=1e-5)
Esempio n. 2
0
    def test_partitioner_xfail(self, fn, expected_partition):
        traced = symbolic_trace(fn)

        supported_ops = MockOperatorSupport()
        partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
        partitions = partitioner.propose_partitions()

        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
        with self.assertRaises(Exception):
            assert len(partitions_name) == len(expected_partition)
Esempio n. 3
0
    def test_fuser_util_xfail(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name: node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        with self.assertRaises(Exception):
            fuse_by_partitions(gm, partitions)
Esempio n. 4
0
    def test_subgraph_matcher(self, test_model):
        traced = symbolic_trace(test_model.forward)
        pattern_traced = symbolic_trace(test_model.pattern)

        for test_case in test_model.test_cases:

            matcher = SubgraphMatcher(
                pattern_traced.graph,
                match_output=test_case.match_output,
                match_placeholder=test_case.match_placeholder,
                remove_overlapping_matches=test_case.remove_overlapping_matches
            )
            matches = matcher.match(traced.graph)

            assert len(matches) == test_case.num_matches

            for match in matches:
                for node in pattern_traced.graph.nodes:
                    if not test_case.match_placeholder and node.op == "placeholder":
                        continue
                    if not test_case.match_output and node.op == "output":
                        continue
                    assert node in match.nodes_map
Esempio n. 5
0
    def test_nvfuser_backend(self, device, dtype):
        m = HF_T5_Partial()
        m.to(device)

        traced = symbolic_trace(m)

        nvfuser = NvFuserBackend()
        compiled_module = nvfuser.compile(traced)

        inputs = self._generate_random_inputs(device, m.inputs_meta())

        eager_result = m(*inputs)
        nvfuser_result = compiled_module(*inputs)

        torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
Esempio n. 6
0
    def test_fuser_util(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name: node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        fused_graph = fuse_by_partitions(gm, partitions)

        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)

        expected = m(a, b, c)
        result = fused_graph(a, b, c)

        torch.testing.assert_close(expected, result)
Esempio n. 7
0
    def test_partitioner(self, fn, expected_partition):
        traced = symbolic_trace(fn)

        supported_ops = MockOperatorSupport()
        partitioner = CapabilityBasedPartitioner(traced, supported_ops, allows_single_node_partition=True)
        partitions = partitioner.propose_partitions()

        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
        assert len(partitions_name) == len(expected_partition)
        for i in range(len(partitions_name)):
            assert set(partitions_name[i]) == set(expected_partition[i])

        fused_graph = partitioner.fuse_partitions(partitions)

        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)

        expected = fn(a, b, c)
        result = fused_graph(a, b, c)
        torch.testing.assert_close(expected, result)
Esempio n. 8
0
def replace_pattern(
        gm: GraphModule,
        pattern: Callable,
        replacement: Callable,
        is_match_filters: Optional[List[Callable]] = None) -> List[Match]:
    """
    Matches all possible non-overlapping sets of operators and their
    data dependencies (``pattern``) in the Graph of a GraphModule
    (``gm``), then replaces each of these matched subgraphs with another
    subgraph (``replacement``).

    Args:
        ``gm``: The GraphModule that wraps the Graph to operate on
        ``pattern``: The subgraph to match in ``gm`` for replacement
        ``replacement``: The subgraph to replace ``pattern`` with

    Returns:
        List[Match]: A list of ``Match`` objects representing the places
        in the original graph that ``pattern`` was matched to. The list
        is empty if there are no matches. ``Match`` is defined as:

        .. code-block:: python

            class Match(NamedTuple):
                # Node from which the match was found
                anchor: Node
                # Maps nodes in the pattern subgraph to nodes in the larger graph
                nodes_map: Dict[Node, Node]

    Examples:

    .. code-block:: python

        import torch
        from torch.fx import symbolic_trace, subgraph_rewriter

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, w1, w2):
                m1 = torch.cat([w1, w2]).sum()
                m2 = torch.cat([w1, w2]).sum()
                return x + torch.max(m1) + torch.max(m2)

        def pattern(w1, w2):
            return torch.cat([w1, w2]).sum()

        def replacement(w1, w2):
            return torch.stack([w1, w2])

        traced_module = symbolic_trace(M())

        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

    The above code will first match ``pattern`` in the ``forward``
    method of ``traced_module``. Pattern-matching is done based on
    use-def relationships, not node names. For example, if you had
    ``p = torch.cat([a, b])`` in ``pattern``, you could match
    ``m = torch.cat([a, b])`` in the original ``forward`` function,
    despite the variable names being different (``p`` vs ``m``).

    The ``return`` statement in ``pattern`` is matched based on its
    value only; it may or may not match to the ``return`` statement in
    the larger graph. In other words, the pattern doesn't have to extend
    to the end of the larger graph.

    When the pattern is matched, it will be removed from the larger
    function and replaced by ``replacement``. If there are multiple
    matches for ``pattern`` in the larger function, each non-overlapping
    match will be replaced. In the case of a match overlap, the first
    found match in the set of overlapping matches will be replaced.
    ("First" here being defined as the first in a topological ordering
    of the Nodes' use-def relationships. In most cases, the first Node
    is the parameter that appears directly after ``self``, while the
    last Node is whatever the function returns.)

    One important thing to note is that the parameters of the
    ``pattern`` Callable must be used in the Callable itself,
    and the parameters of the ``replacement`` Callable must match
    the pattern. The first rule is why, in the above code block, the
    ``forward`` function has parameters ``x, w1, w2``, but the
    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
    As an example of the second rule, consider replacing

    .. code-block:: python

        def pattern(x, y):
            return torch.neg(x) + torch.relu(y)

    with

    .. code-block:: python

        def replacement(x, y):
            return torch.relu(x)

    In this case, ``replacement`` needs the same number of parameters
    as ``pattern`` (both ``x`` and ``y``), even though the parameter
    ``y`` isn't used in ``replacement``.

    After calling ``subgraph_rewriter.replace_pattern``, the generated
    Python code looks like this:

    .. code-block:: python

        def forward(self, x, w1, w2):
            stack_1 = torch.stack([w1, w2])
            sum_1 = stack_1.sum()
            stack_2 = torch.stack([w1, w2])
            sum_2 = stack_2.sum()
            max_1 = torch.max(sum_1)
            add_1 = x + max_1
            max_2 = torch.max(sum_2)
            add_2 = add_1 + max_2
            return add_2
    """
    # Get the module and graph for `gm`, `pattern`, `replacement`
    original_module = gm
    original_graph = original_module.graph
    pattern_module = symbolic_trace(pattern)
    pattern_graph = pattern_module.graph
    replacement_module = symbolic_trace(replacement)
    replacement_graph = replacement_module.graph

    # Find all possible pattern matches in original_graph. Note that
    # pattern matches may overlap with each other.
    matcher = _SubgraphMatcher(pattern_graph)
    matches: List[Match] = []

    # Consider each node as an "anchor" (deepest matching graph node)
    for anchor in original_graph.nodes:

        if matcher.matches_subgraph_from_anchor(anchor, original_module,
                                                pattern_module):

            def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
                # `lookup` represents all the nodes in `original_graph`
                # that are part of `pattern`
                lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
                for n in lookup.keys():

                    # Nodes that can "leak"...
                    if not isinstance(lookup[n], Node):
                        continue
                    # Placeholders (by definition)
                    if lookup[n].op == "placeholder":
                        continue
                    # Pattern output (acts as a container)
                    if lookup[n].op == "output":
                        continue
                    # Result contained by pattern output (what we'll
                    # hook in to the new Graph, thus what we'll
                    # potentially use in other areas of the Graph as
                    # an input Node)
                    if (len(lookup[n].users) == 1 and list(
                            lookup[n].users.keys())[0].op == "output"):
                        continue

                    if not isinstance(n, Node):
                        continue

                    for user in n.users:
                        # If this node has users that were not in
                        # `lookup`, then it must leak out of the
                        # pattern subgraph
                        if user not in lookup:
                            return False
                return True

            # It's not a match if the pattern leaks out into the rest
            # of the graph
            if pattern_is_contained(matcher.nodes_map):
                # Shallow copy nodes_map
                matches.append(
                    Match(anchor=anchor,
                          nodes_map=copy.copy({
                              key: value
                              for key, value in matcher.nodes_map.items()
                          })))

    # The set of all nodes in `original_graph` that we've seen thus far
    # as part of a pattern match
    replaced_nodes: Set[Node] = set()
    # As we progressively replace nodes, we'll need to keep track of how the match results should change
    match_changed_node: Dict[Node, Node] = dict()

    # Return True if one of the nodes in the current match has already
    # been used as part of another match
    def overlaps_with_prev_match(match: Match) -> bool:
        for pn, gn in match.nodes_map.items():
            if not isinstance(pn, Node):
                continue
            if pn.op in ["placeholder", "output"]:
                continue
            if not isinstance(gn, Node):
                continue
            if gn in replaced_nodes and gn.op != "placeholder":
                return True
        return False

    if is_match_filters is None:
        is_match_filters = []

    def is_match(match: Match):
        # for mypy
        assert is_match_filters is not None
        for filter in is_match_filters:
            if not filter(match, pattern_graph, replacement_graph):
                return False
        return True

    for match in matches:
        # Skip overlapping matches
        if overlaps_with_prev_match(match):
            continue

        if not is_match(match):
            continue

        # Map replacement graph nodes to their copy in `original_graph`
        val_map: Dict[Node, Node] = {}

        pattern_placeholders = [
            n for n in pattern_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) > 0
        replacement_placeholders = [
            n for n in replacement_graph.nodes if n.op == "placeholder"
        ]
        assert len(pattern_placeholders) == len(replacement_placeholders)
        placeholder_map = {
            r: p
            for r, p in zip(replacement_placeholders, pattern_placeholders)
        }

        # node from `original_graph` that matched with the output node
        # in `pattern`
        subgraph_output: Node = match.anchor

        def mark_node_as_replaced(n: Node) -> None:
            if n not in match.nodes_map.values():
                return
            for n_ in n.all_input_nodes:
                mark_node_as_replaced(n_)
            replaced_nodes.add(n)

        for input_node in subgraph_output.all_input_nodes:
            mark_node_as_replaced(input_node)

        # Initialize `val_map` with mappings from placeholder nodes in
        # `replacement` to their corresponding node in `original_graph`
        for replacement_node in replacement_placeholders:
            # Get the `original_graph` placeholder node
            # corresponding to the current `replacement_node`
            pattern_node = placeholder_map[replacement_node]
            original_graph_node = match_changed_node.get(
                match.nodes_map[pattern_node], match.nodes_map[pattern_node])

            # Populate `val_map`
            val_map[replacement_node] = original_graph_node

        # Copy the replacement graph over
        with original_graph.inserting_before(subgraph_output):
            copied_output = original_graph.graph_copy(replacement_graph,
                                                      val_map)

        # Hook the output Node of the replacement subgraph in to the
        # original Graph at the correct location

        # CASE 1: We need to hook the replacement subgraph in somewhere
        # in the middle of the graph. We replace the Node in the
        # original graph that corresponds to the end of the pattern
        # subgraph
        if subgraph_output.op != "output":
            pattern_outputs = [
                n for n in pattern_graph.nodes if n.op == "output"
            ]
            assert len(pattern_outputs) > 0
            replacement_outputs = [
                n for n in replacement_graph.nodes if n.op == "output"
            ]
            assert len(replacement_outputs) == len(pattern_outputs)
            outputs_map = {
                p: r
                for r, p in zip(replacement_outputs, pattern_outputs)
            }

            for pn, gn in match.nodes_map.items():
                if not isinstance(gn, Node):
                    continue
                if gn.op == "placeholder":
                    continue

                # Search for the node corresponding to the output of the pattern
                if pn.op != "output":
                    continue
                assert subgraph_output == gn

                # Update all anchor inputs to the new nodes
                rn = outputs_map[pn]
                for pn_input, rn_input in zip(pn.args, rn.args):
                    gn_input = match.nodes_map[pn_input]  # type: ignore[index]
                    rn_input_in_original_graph = val_map[rn_input]
                    gn_input.replace_all_uses_with(rn_input_in_original_graph)
                    # We store the updated node point in case other nodes want to use it
                    match_changed_node[gn_input] = rn_input_in_original_graph

            assert subgraph_output.op != "output"
        # CASE 2: The pattern subgraph match extends to the end of the
        # original graph, so we need to change the current graph's
        # output Node to reflect the insertion of the replacement graph.
        # We'll keep the current output Node, but update its args and
        # `_input_nodes` as necessary
        else:
            subgraph_output.args = ((copied_output, ))
            if isinstance(copied_output, Node):
                subgraph_output._input_nodes = {copied_output: None}

        assert isinstance(copied_output, Node)
        # Erase the `pattern` nodes
        for node in reversed(original_graph.nodes):
            if len(node.users
                   ) == 0 and node.op != "output" and node.op != "placeholder":
                original_graph.erase_node(node)

    # Update the passed-in GraphModule to reflect the new state of
    # `original_graph`
    gm.recompile()

    # If `replacement` was an nn.Module, we'll need to make sure that
    # all the submodules have been copied over correctly
    if isinstance(replacement, torch.nn.Module):
        _replace_submodules(gm, replacement)

    return matches
Esempio n. 9
0
def merge_matmul(in_mod: torch.nn.Module):
    """
    A graph transformation that merges matrix multiplication operations that share the same right-hand
    side operand into one large matrix multiplication.
               ____      _________        _________
      ----    |    |    |         |     M|  A * C  |
    M| A  |  T| B  | * K|    C    | =    |---------|
      ---- ,  |    |    |         |     T|  B * C  |
       K       ----      ---------        ---------
                K            R                R
    """
    gm = symbolic_trace(in_mod)

    rhs_users: Dict[Node, List[Node]] = {}
    lhs_users: Dict[Node, List[Node]] = {}

    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
    # the matmul of which they are the LHS/RHS.
    for node in gm.graph.nodes:
        if node.op != "call_function" or node.target is not torch.matmul:
            continue

        lhs, rhs = node.args

        # TODO: Properly handle aliasing caused by get_attr. For now,
        # use the attribute name as the operand if the node is a
        # get_attr.
        lhs = lhs.target if lhs.op == "get_attr" else lhs
        rhs = rhs.target if rhs.op == "get_attr" else rhs

        lhs_users.setdefault(lhs, []).append(node)
        rhs_users.setdefault(rhs, []).append(node)

    for rhs, mms in rhs_users.items():
        # There must be at least matmuls for a merge to make sense.
        if len(mms) < 2:
            continue

        # All matmuls must not depend on each other directly or indirectly
        # in order for the merge to be possible.
        if not are_nodes_independent(mms):
            continue

        lhs_vals = [mm.args[0] for mm in mms]

        # Merge the matmul.
        # Collect a list of LHS operands and the single RHS operand.
        lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs

        # Concatenate all the LHS operands.
        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})

        # Multiply the concatenated LHS operands with the one RHS. This will produce
        # the same results as all the individual matmuls involving rhs in the original graph,
        # but they will all be concatenated together.
        merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})

        # Split the result of the merged matmul using the shapes of the LHS operands
        # to ascertain how large each chunk should be.
        merge_mm_sizes = [
            gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs
        ]
        merge_mm_split = gm.graph.call_function(
            torch.split, (merge_mm, merge_mm_sizes), {}
        )
        merge_mm_res = [
            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
            for out in range(len(lhs))
        ]

        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
        for old, new in zip(mms, merge_mm_res):
            old.replace_all_uses_with(new)
            gm.graph.erase_node(old)

        # All of the new nodes created above were inserted at the end, so we need to sort
        # the nodes topologically to make sure all definitions precede uses.
        legalize_graph(gm)

    gm.recompile()
    gm.graph.lint()
    return gm