Example #1
0
def decompose(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    decompose(model, example_inputs) takes in a model, decomposes any of the functions in `decomposition_rules` to its constituent operations, and returns a `nn.Module` without any of the operations with decomposition rules.
    """
    # Run it multiple times so we converge to a fixed point.
    for _ in range(5):
        model = fx.symbolic_trace(model)
        ShapeProp(model).propagate(*example_inputs)
        new_graph = fx.Graph()
        env = {}
        for node in model.graph.nodes:
            if node.op == 'call_function' and node.target in decomposition_rules:
                # If the current function is in `decomposition_rules`, we use
                # `Proxy` objects to decompose the operations using the
                # decomposition rule. See
                # https://pytorch.org/docs/master/fx.html#proxy-retracing for
                # more details.
                proxy_args = map_arg(node.args,
                                     lambda n: fx.Proxy(env[n.name]))
                proxy_kwargs = map_arg(node.kwargs,
                                       lambda n: fx.Proxy(env[n.name]))
                new_node = decomposition_rules[node.target](
                    *proxy_args, **proxy_kwargs).node
                env[node.name] = new_node
            else:
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
                env[node.name] = new_node
        model = fx.GraphModule(model, new_graph)
    return model
Example #2
0
    def use_mkl_heuristic(graph: MklSubgraph) -> bool:
        nonlocal fx_model, old_modules
        input_nodes = graph.start_nodes
        if fx_model is None:
            fx_model = graph.fx_graph.owning_module
            old_modules = graph.fx_graph.old_modules  # type: ignore[attr-defined]
            ShapeProp(fx_model).propagate(example_inputs)
        sample_inputs = [torch.randn(node.shape)
                         for node in input_nodes]  # type: ignore[attr-defined]
        output_args = cast(List[fx.Node],
                           [node.args[0] for node in graph.end_nodes])
        submodule = extract_subgraph(fx_model, graph.nodes, input_nodes,
                                     output_args)

        def benchmark(f):
            for _ in range(warmup):
                f()
            begin = time.time()
            for _ in range(iters):
                out = f()
            return time.time() - begin

        mkl_time = benchmark(lambda: [
            i.to_dense()
            for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
        ])

        reset_modules(submodule.graph.nodes, dict(submodule.named_modules()),
                      old_modules)
        no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
        return mkl_time < no_mkl_time
Example #3
0
    def test_resnet50(self):
        gm_run = symbolic_trace(resnet50())
        sample_input = torch.randn(1, 3, 224, 224)

        # run our nodes
        ShapeProp(gm_run).propagate(sample_input)

        gm_static = symbolic_trace(resnet50())

        for n in gm_static.graph.nodes:
            n.type = None

        g = GraphTypeChecker({}, gm_static)
        g.type_check()
        # here we are checking for consistency with fully dynamic nodes
        for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
            assert is_consistent(n1.type,
                                 TensorType(n2.meta['tensor_meta'].shape))

        # here we give the same input as to runtume
        gm_static_with_types = symbolic_trace(resnet50())

        # we initialize our placeholder
        for n in gm_static_with_types.graph.nodes:
            if n.op == 'placeholder':
                n.type = TensorType((1, 3, 224, 224))

        g = GraphTypeChecker({}, gm_static_with_types)
        g.type_check()
        for n1, n2 in zip(gm_static_with_types.graph.nodes,
                          gm_run.graph.nodes):
            assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
Example #4
0
    def _test_const_fold_tensor_meta(self, requires_grad):
        """
        Verify tensor_meta is handled correctly.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)

            def forward(self, x, y):
                a = self.attr_1 + self.attr_1
                x = x - a
                return x * y + self.attr_2

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)
        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
        ShapeProp(gm).propagate(in_x, in_y)
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
        self._verify_const_fold_mod(mod_folded)

        mod_folded.run_folding()

        for n in mod_folded.graph.nodes:
            if n.op == "get_attr":
                attr = self._get_attr(n)
                self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"])

        # Now run both folded and non-folded to check results equal.
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))
Example #5
0
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...],
         example_args: Tuple[Any, ...]) -> torch.nn.Module:
    """vmap
    Given a model with inputs, vmap will return a function that works on
    batched versions of those inputs. Which inputs will be batched is
    determined by in_axes. In addition, as vmap requires shape (actually
    rank) information, we will pass in example_args (example inputs for the
    original module).
    """
    in_axes = iter(in_axes)
    fx_model = fx.symbolic_trace(model)
    # Here we run a shape propagation pass in order to annotate the graph with shape information.
    ShapeProp(fx_model).propagate(*example_args)
    # As vmap rewrites the whole graph, it's easiest to create an entirely new
    # graph and append to that.
    new_graph: fx.Graph = fx.Graph()

    # We will create an environment to map the new nodes created to the
    # corresponding old nodes.
    def lookup_env(l):
        return fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    env = {}
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # If the node is an input placeholder, we simply copy it over and
            # annotate it with the batch dimension from `in_axes`.
            new_node = new_graph.placeholder(node.name)
            new_node.bdim = next(in_axes)
            new_node.meta = node.meta
            env[node.name] = new_node
        elif node.op == 'output':
            new_graph.output(env[node.args[0].name])
        elif node.op == 'call_function':
            new_args = lookup_env(node.args)
            # If any of the inputs to the function has a new batch dimension,
            # we will need to use our batching rules. Otherwise, we will simply
            # copy the node over.
            if any([
                    x.bdim is not None for x in new_args
                    if isinstance(x, fx.Node)
            ]):
                new_node = gen_batching_rule_function(node.target, *new_args)
            else:
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
                new_node.bdim = None
            new_node.meta = node.meta
            env[node.name] = new_node
        else:
            raise RuntimeError("Not yet implemented")

    res = fx.GraphModule(fx_model, new_graph)
    print(res.code)
    res.graph.lint()
    return res
Example #6
0
def grad(model: torch.nn.Module,
         example_inps: Tuple[Any, ...],
         get_value=True) -> torch.nn.Module:
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inps)
    # graph and append to that.
    val_map = {}
    new_graph: fx.Graph = fx.Graph()
    orig_output = new_graph.graph_copy(fx_model.graph, val_map)

    def shape_proxy(node):
        proxy = fx.Proxy(val_map[node])
        proxy.shape = node.meta['shape']
        proxy.dim = lambda: len(proxy.shape)
        return proxy

    inputs = []
    ones = new_graph.create_node('call_function', torch.ones, ([], ))

    for node in reversed(fx_model.graph.nodes):
        if node.op == 'output':
            assert (len(node.args) == 1)
            val_map[node.args[0]].grad = [fx.Proxy(ones)]
        elif node.op == 'placeholder':
            inputs.append(sum(val_map[node].grad).node)
        elif node.op == 'call_function':
            g = sum(val_map[node].grad)
            new_args = [
                shape_proxy(i) if isinstance(i, fx.Node) else i
                for i in node.args
            ]
            if node.target not in vjp_map:
                raise RuntimeError("vjp not yet implemented")
            new_grads = vjp_map[node.target](g, *new_args)
            if not isinstance(new_grads, tuple):
                new_grads = (new_grads, )
            for new_g, arg in zip(new_grads, new_args):
                if isinstance(arg, fx.Proxy):
                    if not hasattr(arg.node, 'grad'):
                        arg.node.grad = []
                    arg.node.grad.append(new_g)
        elif node.op == 'call_method':
            raise RuntimeError("doesn't support methods since i'm lazy")

    if len(inputs) == 1:
        inputs = inputs[0]
    else:
        inputs = inputs[::-1]
    if get_value:
        new_graph.output((orig_output, inputs))
    else:
        new_graph.output(inputs)
    res = fx.GraphModule(fx_model, new_graph)
    res.graph.lint()
    return res
def get_size_of_all_nodes(fx_module: GraphModule, args: List[torch.Tensor]) -> None:
    """Given a fx graph module, update each node with its total size (weights + bias + output)
    and its output_size(output). For a non-module node, the total size is the output size.
    return total size"""
    # Mark shape and dtype for each node (node.shape and node.dtype)
    ShapeProp(fx_module).propagate(*args)
    # Calculate the total size of the whole fx graph
    total_size_of_graph = 0.0
    for node in fx_module.graph.nodes:
        if node.op == "output":
            break
        node.size_bytes = get_size_of_node(fx_module, node)
    return
Example #8
0
    def test_resnet50(self):
        gm_run = symbolic_trace(resnet50())
        sample_input = torch.randn(1, 3, 224, 224)

        # run our nodes
        ShapeProp(gm_run).propagate(sample_input)

        gm_static = symbolic_trace(resnet50())

        for n in gm_static.graph.nodes:
            n.type = None

        g = GraphTypeChecker({}, gm_static)
        g.type_check()
        gm_static.graph.eliminate_dead_code()
        gm_run.graph.eliminate_dead_code()
        # here we are checking for consistency with fully dynamic nodes
        for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
            assert is_consistent(n1.type,
                                 TensorType(n2.meta['tensor_meta'].shape))

        # here we give the same input as to runtume
        gm_static_with_types = symbolic_trace(resnet50())

        # we initialize our placeholder
        for n in gm_static_with_types.graph.nodes:
            if n.op == 'placeholder':
                n.type = TensorType((1, 3, 224, 224))

        g = GraphTypeChecker({}, gm_static_with_types)
        g.type_check()
        for n1, n2 in zip(gm_static_with_types.graph.nodes,
                          gm_run.graph.nodes):
            assert n1.type == TensorType(n2.meta['tensor_meta'].shape)

        # apply shape inference to graph and check
        # that the batch size is equal across all layers
        infer_symbolic_types(gm_static)

        batch_sizes = set()
        gm_static.graph.eliminate_dead_code()
        for n in gm_static.graph.nodes:
            assert isinstance(n.type, TensorType)
            batch_sizes.add(n.type.__args__[0])
        assert (len(batch_sizes) == 1)
Example #9
0
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inputs)

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_shapes(node):
        return [te.ExprHandle.int(i) for i in node.shape]

    def get_nnc_type(dtype):
        if dtype == torch.float:
            return te.Dtype.Float
        elif dtype == torch.long:
            return te.Dtype.Long
        else:
            raise RuntimeError("nyi")

    def get_te_type(node):
        return get_nnc_type(node.dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        return fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            shapes = get_te_shapes(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
            inputs.append(env[node.name])
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            result = lower_function(node, node.target, lookup_env(node.args),
                                    node.args)
            env[node.name] = result
        elif node.op == 'output':
            outs = list(lookup_env(node.args))
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
        else:
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(outs)
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs
    ])

    def f(inps):
        module_stuff = [fetch_attr(i.target) for i in module_attrs]
        cg.call(module_stuff + list(inps))

    return f
Example #10
0
def nnc_compile(fx_model: fx.GraphModule,
                example_inputs,
                get_loopnest=False) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    # fx_model = fx.symbolic_trace(model)
    t = fx_model.graph.flatten_inps(*example_inputs)
    ShapeProp(fx_model).propagate(*fx_model.graph.flatten_inps(
        *example_inputs))

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_type(node):
        return get_nnc_type(node.meta['tensor_meta'].dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        res = fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
        return res

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    compute_stmts = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            if 'tensor_meta' not in node.meta:
                continue
            shapes = get_te_shapes(node.meta['tensor_meta'].shape)
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
            inputs.append(placeholder)
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            if 'tensor_meta' in node.meta:
                # todo: fix kwargs handling
                if node.kwargs:
                    raise RuntimeError("kwargs nyi")
                buf, stmt = lower_function(node, node.target,
                                           lookup_env(node.args), node.args)
                # if isinstance(stmt, list)
                compute_stmts.extend(stmt)
                env[node.name] = buf
            elif node.target == getattr or node.target == operator.getitem:
                # todo: handle non-tensor computations correctly
                continue
        elif node.op == 'output':
            args = node.args
            args = pytree.tree_map(
                lambda x: list(x) if isinstance(
                    x, fx.immutable_collections.immutable_list) else x, args)
            flat_args, _ = pytree.tree_flatten(list(args))
            te_args = lookup_env(flat_args)
            outs = (list(te_args), [(i.meta['tensor_meta'].shape,
                                     i.meta['tensor_meta'].dtype)
                                    for i in flat_args])
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            shapes = get_te_shapes(
                process_shape(node.meta['tensor_meta'].shape))
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
        else:
            print(node.op, node.target)
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0])
    if get_loopnest:
        return loopnest
    # loopnest.inline_intermediate_bufs(True)
    loopnest.simplify()
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs[0]
    ])
    if module_attrs:
        module_stuff = [
            fetch_attr(i.target).contiguous().data for i in module_attrs
        ]
    else:
        module_stuff = []

    def f(*inps, out_tensors=None):
        inps = fx_model.graph.flatten_inps(*inps)
        if out_tensors is None:
            results = [
                torch.empty(shape, dtype=dtype) for shape, dtype in outs[1]
            ]
            # results = alloc_results
        else:
            results = out_tensors
        full_inps = module_stuff + list(inps) + results
        cg.call(full_inps)
        results = fx_model.graph.unflatten_outs(results)
        return results
        if len(results) == 1:
            return results[0]
        return results

    return f
Example #11
0
        compiled_f = nnc_compile(fx_model, args, get_loopnest=True)
        return compiled_f

    return wrapped


################################
# Example usage and Benchmarking
################################


def bench(f, warmup=3, iters=1000):
    for _ in range(warmup):
        f()
    begin = time.time()
    for _ in range(iters):
        f()
    print(time.time() - begin)


if __name__ == '__main__':

    def f(a, b):
        return (torch.cos(a) * torch.sin(b))[:2000]

    mod = fx.symbolic_trace(f)
    inps = (torch.randn(5000), torch.randn(5000))
    ShapeProp(mod).propagate(*inps)
    cg = nnc_compile(mod, inps)
    bench(lambda: cg(*inps))
    bench(lambda: f(*inps))