def to_nvfuser(arg): if isinstance(arg, torch.Tensor): return nvFuserTensorTemplate(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) elif isinstance(arg, Number): return nvFuserScalarTemplate(getnvFuserDtype(type(arg))) else: return arg
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Only accepts tensor inputs and single tensor outputs # Does not handle kwargs # Does not support reusing the same ctx to execute! assert len(kwargs) == 0 # TODO: make this a proper trace -> trace transform that # doesn't mutate the context graph_fd = ctx.graph.placeholder("fd") ctx.graph._root.append(graph_fd) fusion = Fusion() with FusionDefinition(fusion) as fd: # Transforms graph to call nvfuser lowerings nv_args = [fd] for arg in args: if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) nv_args.append(x) else: nv_args.append(x) for x in ctx.graph.nodes: if x.op == "call_function": x.target = x.target.impl_nvfuser x.args = (graph_fd, ) + x.args gm = GraphModule({}, ctx.graph) out = gm.forward(*nv_args) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def to_nv(arg): if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) return x else: return arg
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor: nvfuser_dtype = getnvFuserDtype(dtype) return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]