예제 #1
0
    def test_external_calls(self):
        dtype = torch.float32

        ONE = te.ExprHandle.int(1)
        FOUR = te.ExprHandle.int(4)
        A = te.BufHandle('A', [ONE, FOUR], dtype)
        B = te.BufHandle('B', [FOUR, ONE], dtype)
        C = te.BufHandle('C', [ONE, ONE], dtype)

        s = te.ExternalCall(C, "nnc_aten_matmul", [A, B], [])

        loopnest = te.LoopNest(s, [C])
        loopnest.prepare_for_codegen()
        codegen = te.construct_codegen('ir_eval', s,
                                       [te.BufferArg(x) for x in [A, B, C]])

        tA = torch.ones(1, 4)
        tB = torch.ones(4, 1)
        tC = torch.empty(1, 1)
        codegen.call([tA, tB, tC])
        torch.testing.assert_close(torch.matmul(tA, tB), tC)
예제 #2
0
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inputs)

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_shapes(node):
        return [te.ExprHandle.int(i) for i in node.shape]

    def get_nnc_type(dtype):
        if dtype == torch.float:
            return te.Dtype.Float
        elif dtype == torch.long:
            return te.Dtype.Long
        else:
            raise RuntimeError("nyi")

    def get_te_type(node):
        return get_nnc_type(node.dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        return fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            shapes = get_te_shapes(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
            inputs.append(env[node.name])
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            result = lower_function(node, node.target, lookup_env(node.args),
                                    node.args)
            env[node.name] = result
        elif node.op == 'output':
            outs = list(lookup_env(node.args))
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node),
                                            shapes)
        else:
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(outs)
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs
    ])

    def f(inps):
        module_stuff = [fetch_attr(i.target) for i in module_attrs]
        cg.call(module_stuff + list(inps))

    return f
예제 #3
0
def nnc_compile(fx_model: fx.GraphModule,
                example_inputs,
                get_loopnest=False) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    # fx_model = fx.symbolic_trace(model)
    t = fx_model.graph.flatten_inps(*example_inputs)
    ShapeProp(fx_model).propagate(*fx_model.graph.flatten_inps(
        *example_inputs))

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}

    def get_te_type(node):
        return get_nnc_type(node.meta['tensor_meta'].dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        res = fx.node.map_aggregate(
            l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
        return res

    def fetch_attr(target: str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(
                    f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
                )
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    compute_stmts = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            if 'tensor_meta' not in node.meta:
                continue
            shapes = get_te_shapes(node.meta['tensor_meta'].shape)
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
            inputs.append(placeholder)
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            if 'tensor_meta' in node.meta:
                # todo: fix kwargs handling
                if node.kwargs:
                    raise RuntimeError("kwargs nyi")
                buf, stmt = lower_function(node, node.target,
                                           lookup_env(node.args), node.args)
                # if isinstance(stmt, list)
                compute_stmts.extend(stmt)
                env[node.name] = buf
            elif node.target == getattr or node.target == operator.getitem:
                # todo: handle non-tensor computations correctly
                continue
        elif node.op == 'output':
            args = node.args
            args = pytree.tree_map(
                lambda x: list(x) if isinstance(
                    x, fx.immutable_collections.immutable_list) else x, args)
            flat_args, _ = pytree.tree_flatten(list(args))
            te_args = lookup_env(flat_args)
            outs = (list(te_args), [(i.meta['tensor_meta'].shape,
                                     i.meta['tensor_meta'].dtype)
                                    for i in flat_args])
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            shapes = get_te_shapes(
                process_shape(node.meta['tensor_meta'].shape))
            placeholder = te.Placeholder(node.name, get_te_type(node), shapes)
            env[node.name] = placeholder.data()
        else:
            print(node.op, node.target)
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(te.Stmt(compute_stmts), outs[0])
    if get_loopnest:
        return loopnest
    # loopnest.inline_intermediate_bufs(True)
    loopnest.simplify()
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [
        te.BufferArg(x)
        for x in [env[i.name] for i in module_attrs] + inputs + outs[0]
    ])
    if module_attrs:
        module_stuff = [
            fetch_attr(i.target).contiguous().data for i in module_attrs
        ]
    else:
        module_stuff = []

    def f(*inps, out_tensors=None):
        inps = fx_model.graph.flatten_inps(*inps)
        if out_tensors is None:
            results = [
                torch.empty(shape, dtype=dtype) for shape, dtype in outs[1]
            ]
            # results = alloc_results
        else:
            results = out_tensors
        full_inps = module_stuff + list(inps) + results
        cg.call(full_inps)
        results = fx_model.graph.unflatten_outs(results)
        return results
        if len(results) == 1:
            return results[0]
        return results

    return f