示例#1
0
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
    """
    Prototype ATen executor.

    Just executes the context's graph.
    """

    if executor == "aten":
        gm = GraphModule({}, ctx.graph)
        return gm.forward(*args, **kwargs)
    elif executor == "nvfuser":
        if not torch.cuda.is_available():
            raise RuntimeError(
                "Attempting to use nvFuser trace executor but CUDA is not available!"
            )

        # PROTOTYPE nvfuser executor
        # Only accepts tensor inputs and single tensor outputs
        # Does not handle kwargs
        # Does not support reusing the same ctx to execute!
        assert len(kwargs) == 0
        # TODO: make this a proper trace -> trace transform that
        # doesn't mutate the context
        graph_fd = ctx.graph.placeholder("fd")
        ctx.graph._root.append(graph_fd)

        fusion = Fusion()
        with FusionDefinition(fusion) as fd:
            # Transforms graph to call nvfuser lowerings
            nv_args = [fd]
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    x = fd.define_tensor(arg.size(), arg.stride(),
                                         getnvFuserDtype(arg.dtype))
                    fd.add_input(x)
                    nv_args.append(x)
                else:
                    nv_args.append(x)

            for x in ctx.graph.nodes:
                if x.op == "call_function":
                    x.target = x.target.impl_nvfuser
                    x.args = (graph_fd, ) + x.args

            gm = GraphModule({}, ctx.graph)
            out = gm.forward(*nv_args)
            flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out)
            for o in flat_out:
                fd.add_output(o)

            return torch.utils._pytree.tree_unflatten(
                fusion.execute(
                    tuple(arg for arg in args
                          if isinstance(arg, torch.Tensor))),
                unflatten_spec,
            )

    msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
        executor)
    raise ValueError(msg)
示例#2
0
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
    """
    Prototype ATen executor.

    Just executes the context's graph.
    """

    if executor == "aten":
        gm = GraphModule({}, ctx.graph)
        return gm.forward(*args, **kwargs)
    elif executor == "nvfuser":
        # PROTOTYPE nvfuser executor
        # Only accepts tensor inputs and single tensor outputs
        # Does not handle kwargs
        # Does not support reusing the same ctx to execute!
        assert len(kwargs) == 0
        # TODO: make this a proper trace -> trace transform that
        # doesn't mutate the context
        graph_fd = ctx.graph.placeholder("fd")
        ctx.graph._root.append(graph_fd)

        fusion = Fusion()
        with FusionDefinition(fusion) as fd:
            # Transforms graph to call nvfuser lowerings
            nv_args = [fd]
            for arg in args:
                if isinstance(arg, torch.Tensor):
                    x = fd.define_tensor(
                        arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype])
                    fd.add_input(x)
                    nv_args.append(x)
                else:
                    nv_args.append(x)

            for x in ctx.graph.nodes:
                if x.op == "call_function":
                    x.target = x.target.impl_nvfuser
                    x.args = (graph_fd, ) + x.args

            gm = GraphModule({}, ctx.graph)
            out = gm.forward(*nv_args)
            fd.add_output(out)

            return fusion.execute(
                tuple(arg for arg in args if isinstance(arg, torch.Tensor)))[0]

    msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
        executor)
    raise ValueError(msg)
示例#3
0
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs):
    """
    Prototype ATen executor.

    Just executes the context's graph.
    """

    if executor == "aten":
        return gm.forward(*args, **kwargs)
    elif executor == "nvfuser":
        if not torch.cuda.is_available():
            raise RuntimeError(
                "Attempting to use nvFuser trace executor but CUDA is not available!"
            )

        # PROTOTYPE nvfuser executor
        # Everything in the graph must support nvfuser

        fusion = Fusion()
        with FusionDefinition(fusion) as fd:

            class FusionInterpreter(torch.fx.Interpreter):
                def call_function(self, target, args, kwargs):
                    target = target.impl_nvfuser
                    args = (fd, ) + args
                    return target(*args, **kwargs)

            def to_nv(arg):
                if isinstance(arg, torch.Tensor):
                    x = fd.define_tensor(arg.size(), arg.stride(),
                                         getnvFuserDtype(arg.dtype))
                    fd.add_input(x)
                    return x
                else:
                    return arg

            # Transforms graph to call nvfuser lowerings
            nv_args = tree_map(to_nv, args)
            nv_kwargs = tree_map(to_nv, kwargs)

            out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs)
            flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out)
            for o in flat_out:
                fd.add_output(o)

            return torch.utils._pytree.tree_unflatten(
                fusion.execute(
                    tuple(arg for arg in args
                          if isinstance(arg, torch.Tensor))),
                unflatten_spec,
            )

    msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
        executor)
    raise ValueError(msg)
示例#4
0
def execute(gm: GraphModule, *args, executor: str = "aten"):
    """
    Prototype ATen executor.

    Just executes the context's graph.
    """

    if executor == "aten":
        return gm.forward(*args)
    elif executor == "nvfuser":
        return nvfuser_execute(gm, *args)

    msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
        executor)
    raise ValueError(msg)