def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Only accepts tensor inputs and single tensor outputs # Does not handle kwargs # Does not support reusing the same ctx to execute! assert len(kwargs) == 0 # TODO: make this a proper trace -> trace transform that # doesn't mutate the context graph_fd = ctx.graph.placeholder("fd") ctx.graph._root.append(graph_fd) fusion = Fusion() with FusionDefinition(fusion) as fd: # Transforms graph to call nvfuser lowerings nv_args = [fd] for arg in args: if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) nv_args.append(x) else: nv_args.append(x) for x in ctx.graph.nodes: if x.op == "call_function": x.target = x.target.impl_nvfuser x.args = (graph_fd, ) + x.args gm = GraphModule({}, ctx.graph) out = gm.forward(*nv_args) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": # PROTOTYPE nvfuser executor # Only accepts tensor inputs and single tensor outputs # Does not handle kwargs # Does not support reusing the same ctx to execute! assert len(kwargs) == 0 # TODO: make this a proper trace -> trace transform that # doesn't mutate the context graph_fd = ctx.graph.placeholder("fd") ctx.graph._root.append(graph_fd) fusion = Fusion() with FusionDefinition(fusion) as fd: # Transforms graph to call nvfuser lowerings nv_args = [fd] for arg in args: if isinstance(arg, torch.Tensor): x = fd.define_tensor( arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]) fd.add_input(x) nv_args.append(x) else: nv_args.append(x) for x in ctx.graph.nodes: if x.op == "call_function": x.target = x.target.impl_nvfuser x.args = (graph_fd, ) + x.args gm = GraphModule({}, ctx.graph) out = gm.forward(*nv_args) fd.add_output(out) return fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor)))[0] msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Everything in the graph must support nvfuser fusion = Fusion() with FusionDefinition(fusion) as fd: class FusionInterpreter(torch.fx.Interpreter): def call_function(self, target, args, kwargs): target = target.impl_nvfuser args = (fd, ) + args return target(*args, **kwargs) def to_nv(arg): if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) return x else: return arg # Transforms graph to call nvfuser lowerings nv_args = tree_map(to_nv, args) nv_kwargs = tree_map(to_nv, kwargs) out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def execute(gm: GraphModule, *args, executor: str = "aten"): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": return gm.forward(*args) elif executor == "nvfuser": return nvfuser_execute(gm, *args) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)