def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": gm = GraphModule({}, ctx.graph) return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Only accepts tensor inputs and single tensor outputs # Does not handle kwargs # Does not support reusing the same ctx to execute! assert len(kwargs) == 0 # TODO: make this a proper trace -> trace transform that # doesn't mutate the context graph_fd = ctx.graph.placeholder("fd") ctx.graph._root.append(graph_fd) fusion = Fusion() with FusionDefinition(fusion) as fd: # Transforms graph to call nvfuser lowerings nv_args = [fd] for arg in args: if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) nv_args.append(x) else: nv_args.append(x) for x in ctx.graph.nodes: if x.op == "call_function": x.target = x.target.impl_nvfuser x.args = (graph_fd, ) + x.args gm = GraphModule({}, ctx.graph) out = gm.forward(*nv_args) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Everything in the graph must support nvfuser fusion = Fusion() with FusionDefinition(fusion) as fd: class FusionInterpreter(torch.fx.Interpreter): def call_function(self, target, args, kwargs): target = target.impl_nvfuser args = (fd, ) + args return target(*args, **kwargs) def to_nv(arg): if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) return x else: return arg # Transforms graph to call nvfuser lowerings nv_args = tree_map(to_nv, args) nv_kwargs = tree_map(to_nv, kwargs) out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
t1 = fd.define_tensor(1) s0 = fd.define_scalar() fd.add_input(t0) fd.add_input(t1) fd.add_input(s0) c0 = fd.define_constant(3.0) t1_b = fd.Ops.broadcast(t1, [True, True, False]) t2 = fd.Ops.add(t0, t1) t3 = fd.Ops.mul(t2, c0) t4 = fd.Ops.mul(t3, s0) t5 = fd.Ops.relu(t4) t6 = fd.Ops.sum(t5, [-1], False) fd.add_output(t6) fusion.print_ir() # Execute Fusion input1 = torch.ones(2, 4, 8, device='cuda') input2 = torch.ones(8, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5) : outputs = fusion.execute([input1, input2, 2.0]) print(outputs[0])
import torch from torch._C._nvfuser import Fusion, FusionDefinition, DataType # Construct and Define Fusion fusion = Fusion() with FusionDefinition(fusion) as fd: t0 = fd.define_tensor(2, DataType.Double) t1 = fd.define_tensor(2, DataType.Double) t0h = fd.ops.cast(t0, DataType.Half) t1h = fd.ops.cast(t1, DataType.Half) t2 = fd.ops.add(t0h, t1h) t3 = fd.ops.relu(t2) fd.add_output(t3) fusion.print_ir() # Execute Fusion input1 = torch.ones(2, 4, device='cuda', dtype=torch.float64) input2 = torch.ones(2, 4, device='cuda', dtype=torch.float64) # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5): outputs = fusion.execute([input1, input2]) print(outputs[0])
t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1]) t2 = fd.ops.add(t0_b, t1) fd.add_output(t2) fusion1.print_ir() # Execute Fusion input1 = torch.randn(3, device='cuda') input2 = torch.randn(2, 3, 4, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5) : o = fusion1.execute([input1, input2])[0] assert(o.shape == torch.Size([2, 3, 4])) # Reference in prim torch ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) assert(ref_o.allclose(o)) assert(ref_o.shape == o.shape) fusion2 = Fusion() input1 = torch.randn(1, 1, 4, device='cuda') input2 = torch.randn(2, 3, 4, device='cuda') with FusionDefinition(fusion2) as fd : t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
t0_b = fd.Ops.broadcast_in_dim(t0, [2, 3, 4], [1]) t2 = fd.Ops.add(t0_b, t1) fd.add_output(t2) fusion1.print_ir() # Execute Fusion input1 = torch.ones(3, device='cuda') input2 = torch.ones(2, 3, 4, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5): outputs = fusion1.execute([input1, input2]) print(outputs[0]) fusion2 = Fusion() input1 = torch.ones(1, 1, 4, device='cuda') input2 = torch.ones(2, 3, 4, device='cuda') with FusionDefinition(fusion2) as fd: t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride()) fd.add_input(t0) fd.add_input(t1)