def gen_structured(g: NativeFunctionsGroup) -> List[str]: # only out has dispatch meta_name = meta.name(g) rs = [] seen: Set[Any] = set() out_args = structured.impl_arguments(g) for k, n in g.out.dispatch.items(): if n in seen: continue if not is_structured_dispatch_key(k): continue seen.add(n) rs.append(f"""\ struct TORCH_API structured_{n} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """) seen = set() for f in g.functions(): returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) for k, n in f.dispatch.items(): if n in seen: continue if is_structured_dispatch_key(k): continue seen.add(n) args_str = ', '.join(a.decl() for a in args) rs.append(f"TORCH_API {returns_type} {n}({args_str});") return rs
def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str: with native_function_manager(g.out): sig = g.signature() name = meta.name(sig) returns_type = meta.returns_type(sig.returns) args = meta.arguments(sig) return f"CAFFE2_API {returns_type} {name}({', '.join(map(str, args))});"
def compute_meta_function_declaration(g: StructuredNativeFunctions) -> str: with native_function_manager(g.out): name = meta.name(g) args = structured.meta_arguments(g) args_str = ', '.join(a.decl() for a in args) parent_class = g.out.structured_inherits if parent_class is None: parent_class = "at::impl::MetaBase" return f"""\
def compute_native_function_declaration( g: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: if isinstance(g, StructuredNativeFunctions): # only out has dispatch meta_name = meta.name(g) rs = [] seen: Set[Any] = set() out_args = structured.impl_arguments(g) for k, n in g.out.dispatch.items(): if n in seen: continue if not is_structured_dispatch_key(k): continue seen.add(n) rs.append(f"""\ struct TORCH_API structured_{n} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """) seen = set() for f in g.functions(): returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) for k, n in f.dispatch.items(): if n in seen: continue if is_structured_dispatch_key(k): continue seen.add(n) args_str = ', '.join(a.decl() for a in args) rs.append(f"TORCH_API {returns_type} {n}({args_str});") return rs else: f = g ns = list(f.dispatch.values()) rs = [] # Sometimes a function name shows up multiple times; only generate # it once! seen = set() for n in ns: if n in seen: continue if "legacy::" in n: continue seen.add(n) returns_type = native.returns_type(f.func.returns) args = native.arguments(f.func) rs.append( f"TORCH_API {returns_type} {n}({', '.join(a.decl() for a in args)});" ) return rs
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: meta_name = meta.name(g) out_args = structured.impl_arguments(g) metadata = backend_index.get_kernel(g) if metadata is None: return [] prefix = '' if backend_index.external else 'TORCH_API ' return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::{meta_name} {{ void impl({', '.join(a.decl() for a in out_args)}); }}; """ ]
def gen_one(f: NativeFunction) -> Optional[str]: assert self.target is not Target.DECLARATION # TODO: put this into StructuredNativeFunctions itself functional_func = g.out.func.signature() functional_sig = DispatcherSignature.from_schema(functional_func) meta_name = meta.name(functional_func) # This is a little abusive; this assumes that the functionalization # transformation ALWAYS refers to valid arguments in the original # signature functional_exprs = ', '.join(e.expr for e in functional_sig.exprs()) op_name = f"aten::{f.func.name}" if self.target is Target.REGISTRATION and not self.selector.is_operator_selected( op_name): return None k = f.func.kind() sig = NativeSignature.from_schema(f.func) if self.target is Target.DEFINITION: out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}" # TODO: work a little harder to generate fresh names for 'result' # TODO: less praying that I picked the right argument name for 'self' if k is SchemaKind.functional: out_expr = "result" prologue = "auto result = tensor_from_meta(meta_result);" elif k is SchemaKind.inplace: out_expr = "self" prologue = "// TODO: consistency check assert" elif k is SchemaKind.out: # TODO: generalize this for multi-out assert len(f.func.out_arguments ) == 1, "multi-out structured not supported yet" # TODO: properly get the expression as it was brought into # scope by sig out_expr = f.func.out_arguments[0].name prologue = f""" // TODO: add a consistency check for meta_result {out_expr}.resize_(meta_result.sizes); """ device_guard = "" if is_generic_dispatch_key( self.dispatch_key) or is_cuda_dispatch_key( self.dispatch_key): # TODO: avoid copypasting the computation of self_args, # candidate_args and device_of self_args = (a for a in f.func.arguments if a.name == "self") candidate_args = itertools.chain(self_args, f.func.out_arguments, f.func.arguments) device_of = next( (f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) device_guard = '' if f.device_guard and device_of is not None: # TODO: Use OptionalCUDAGuard when possible device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" # TODO: figure out what to do about structured kernels and # factory functions # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {sig.defn()} {{ {device_guard} auto meta_result = meta::{meta_name}({functional_exprs}); {prologue} {out_impl_name}({out_expr}, {functional_exprs}); return {out_expr}; }} """ elif self.target is Target.REGISTRATION: if local.use_c10_dispatcher() is UseC10Dispatcher.full: payload = f'TORCH_FN({sig.name()})' else: payload = f'torch::CppFunction::makeUnboxedOnly({sig.name()})' return f'm.impl("{f.func.name}", {payload});' else: assert_never(self.target)