def main() -> None: parser = argparse.ArgumentParser(description='Generate ATen source files') parser.add_argument('-s', '--source-path', help='path to source directory for ATen', default='aten/src/ATen') parser.add_argument( '-o', '--output-dependencies', help='output a list of dependencies into the given file and exit') parser.add_argument('-d', '--install_dir', help='output directory', default='build/aten/src/ATen') parser.add_argument( '--rocm', action='store_true', help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly') # TODO: --op_registration_whitelist will be removed when all call-sites # for gen.py are moved over to using the operator YAML file for mobile # custom build. parser.add_argument( '--op_registration_whitelist', nargs='*', help='filter op registrations by the whitelist (if set); ' 'each item is `namespace`::`operator name` without overload name; ' 'e.g.: aten::empty aten::conv2d ...') parser.add_argument( '--op_selection_yaml_path', help='Provide a path to the operator selection (for custom build) YAML ' 'that contains the information about the set of selected operators ' 'and their categories (training, ...). Each operator is either a ' 'full operator name with overload or just a bare operator name. ' 'The operator names also contain the namespace prefix (e.g. aten::)') parser.add_argument( '--backend_whitelist', nargs='*', help='filter dispatch backend by the whitelist (if set), ' 'e.g.: CPU CUDA QuantizedCPU ...') parser.add_argument( '--static_dispatch_backend', help='generate static dispatch code for the specific backend (if set)') parser.add_argument( '--force_schema_registration', action='store_true', help= 'force it to generate schema-only registrations for all ops, including' 'those that are not listed on --op_registration_whitelist') options = parser.parse_args() selector = get_custom_build_selector( options.op_registration_whitelist, options.op_selection_yaml_path, ) native_functions = parse_native_yaml( os.path.join(options.source_path, 'native/native_functions.yaml')) pre_grouped_native_functions: Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]] pre_grouped_native_functions = defaultdict(dict) for f in native_functions: d = pre_grouped_native_functions[f.func.signature()] assert f.func.kind() not in d d[f.func.kind()] = f def flatten_pre_group( d: Dict[SchemaKind, NativeFunction] ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: r = NativeFunctionsGroup.from_dict(d) if r is None: return list(d.values()) else: return [r] # TODO: how come ValuesView isn't a Sequence lol grouped_native_functions = list( concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))) structured_native_functions = [ g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup) ] template_dir = os.path.join(options.source_path, "templates") # NB: It is mandatory to NOT use os.path.join here, as the install directory # will eventually be ingested by cmake, which does not respect Windows style # path slashes. If you switch this to use os.path.join, you'll get an error # like: # # Syntax error in cmake code when parsing string # # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h # # Invalid character escape '\c'. core_install_dir = f'{options.install_dir}/core' pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True) def make_file_manager(install_dir: str) -> FileManager: return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.output_dependencies) core_fm = make_file_manager(core_install_dir) cpu_fm = make_file_manager(options.install_dir) cuda_fm = make_file_manager(options.install_dir) extra_cuda_headers = '''\ #include <c10/cuda/CUDAGuard.h> #include <ATen/cuda/ATenCUDAGeneral.h> #include <ATen/cuda/CUDADevice.h> #include <ATen/cuda/CUDAContext.h>''' if options.rocm: extra_cuda_headers = '''\ #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h> #include <ATen/hip/ATenHIPGeneral.h> #include <ATen/hip/HIPDevice.h> #include <ATen/hip/HIPContext.h>''' dispatch_keys = [ DispatchKey.CPU, DispatchKey.SparseCPU, DispatchKey.SparseCsrCPU, DispatchKey.MkldnnCPU, DispatchKey.CUDA, DispatchKey.SparseCUDA, DispatchKey.SparseCsrCUDA, DispatchKey.QuantizedCPU, DispatchKey.QuantizedCUDA, DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeExplicitAutograd, # Meta is a magic key: it is automatically generated for structured # kernels DispatchKey.Meta, ] # Only a limited set of dispatch keys get CPUFunctions.h headers generated # for them; this is the set functions_keys = { DispatchKey.CPU, DispatchKey.CUDA, DispatchKey.CompositeImplicitAutograd, DispatchKey.CompositeExplicitAutograd, } if options.backend_whitelist: dispatch_keys = [ k for k in dispatch_keys if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist ] static_dispatch_backend: Optional[DispatchKey] = None if options.static_dispatch_backend: static_dispatch_backend = DispatchKey.parse( options.static_dispatch_backend) for dispatch_key in dispatch_keys: fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm fm.write_with_template( f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '', 'legacy_th_headers': '#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == DispatchKey.CPU else '#include <ATen/LegacyTHFunctionsCUDA.h>' if dispatch_key == DispatchKey.CUDA else '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), 'dispatch_namespaced_definitions': list( concatMap( dest.RegisterDispatchKey(dispatch_key, Target.NAMESPACED_DEFINITION, selector, rocm=options.rocm), grouped_native_functions)), 'dispatch_anonymous_definitions': list( concatMap( dest.RegisterDispatchKey(dispatch_key, Target.ANONYMOUS_DEFINITION, selector, rocm=options.rocm), grouped_native_functions)), 'dispatch_registrations': list( concatMap( dest.RegisterDispatchKey(dispatch_key, Target.REGISTRATION, selector, rocm=options.rocm), grouped_native_functions)), }) if dispatch_key in functions_keys: fm.write_with_template( f'{dispatch_key}Functions.h', 'DispatchKeyFunctions.h', lambda: { 'dispatch_namespace': dispatch_key.lower(), 'dispatch_namespaced_declarations': list( concatMap( dest.RegisterDispatchKey( dispatch_key, Target.NAMESPACED_DECLARATION, selector, rocm=options.rocm), grouped_native_functions)), }) del fm # BackendSelect is generated specially cpu_fm.write( 'RegisterBackendSelect.cpp', lambda: { 'backend_select_method_definitions': list( mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)), 'backend_select_function_registrations': list( mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)), }) cpu_fm.write( 'MetaFunctions.h', lambda: { 'declarations': list( mapMaybe(compute_meta_function_declaration, structured_native_functions)), }) schema_selector = selector if options.force_schema_registration: schema_selector = SelectiveBuilder.get_nop_selector() cpu_fm.write( 'RegisterSchema.cpp', lambda: { 'schema_registrations': list(mapMaybe(RegisterSchema(schema_selector), native_functions)), }) cpu_fm.write( 'Functions.h', lambda: { 'function_declarations': list( mapMaybe( ComputeFunction( Target.DECLARATION, static_dispatch_backend=static_dispatch_backend, is_redispatching_fn=False), native_functions)), }) cpu_fm.write( 'Functions.cpp', lambda: { 'static_dispatch_extra_headers': static_dispatch_extra_headers(static_dispatch_backend), 'function_definitions': list( mapMaybe( ComputeFunction( Target.DEFINITION, static_dispatch_backend=static_dispatch_backend, is_redispatching_fn=False), native_functions)), }) cpu_fm.write( 'RedispatchFunctions.h', lambda: { 'function_redispatch_declarations': list( mapMaybe( ComputeFunction( Target.DECLARATION, static_dispatch_backend=static_dispatch_backend, is_redispatching_fn=True), native_functions)), }) cpu_fm.write( 'RedispatchFunctions.cpp', lambda: { 'static_dispatch_extra_headers': static_dispatch_extra_headers(static_dispatch_backend), 'function_redispatch_definitions': list( mapMaybe( ComputeFunction( Target.DEFINITION, static_dispatch_backend=static_dispatch_backend, is_redispatching_fn=True), native_functions)), }) core_fm.write( 'TensorBody.h', lambda: { 'tensor_method_declarations': list( mapMaybe( ComputeTensorMethod(Target.DECLARATION, static_dispatch_backend= static_dispatch_backend), native_functions)), }) core_fm.write( 'TensorMethods.cpp', lambda: { 'static_dispatch_extra_headers': static_dispatch_extra_headers(static_dispatch_backend), 'tensor_method_definitions': list( mapMaybe( ComputeTensorMethod(Target.DEFINITION, static_dispatch_backend= static_dispatch_backend), native_functions)), }) core_fm.write( 'ATenOpList.cpp', lambda: { 'aten_ops': list(mapMaybe(compute_aten_op, native_functions)), }) cpu_fm.write( 'NativeFunctions.h', lambda: { 'native_function_declarations': list( concatMap(dest.compute_native_function_declaration, grouped_native_functions)), }) cpu_fm.write( 'Declarations.yaml', lambda: format_yaml( [compute_declaration_yaml(f) for f in native_functions])) cpu_fm.write( 'RegistrationDeclarations.h', lambda: { 'registration_declarations': [compute_registration_declarations(f) for f in native_functions], }) if options.output_dependencies: cpu_fm.write_outputs(options.output_dependencies) core_fm.write_outputs(f"{options.output_dependencies}-core") cuda_fm.write_outputs(f"{options.output_dependencies}-cuda")
def gen_unstructured( self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None) -> Optional[str]: with native_function_manager(f): inplace_meta = False gets_out_inplace_wrapper = False if not self.backend_index.has_kernel(f): if (self.backend_index.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation not f.has_composite_kernel and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True elif (not self.backend_index.use_out_as_primary and g is not None and gets_generated_out_inplace_wrapper( f, g, self.backend_index)): # We want to generate inplace/out wrappers, that don't have a kernel for the backend. gets_out_inplace_wrapper = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = self.wrapper_kernel_sig(f) name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ # short circuit for generated inplace/out wrappers if gets_out_inplace_wrapper: return self.gen_out_inplace_wrapper(f, g) metadata = self.backend_index.get_kernel(f) if metadata is None: return None if self.class_method_name is None: impl_name = f"{self.cpp_namespace}::{metadata.kernel}" else: impl_name = f"{self.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" args_exprs_str = ', '.join(a.name for a in args) device_check = ' // No device check\n' if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) device_check = RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), name) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key( self.backend_index.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next((f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_check} {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)
def gen_one(self, f: NativeFunction) -> Optional[str]: assert not f.manual_kernel_registration if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None # TODO: Now, there is something interesting going on here. In the code below, # we generate CompositeExplicitAutograd implementations of functional and inplace # based on the out implementation. But in fact, out is definable by # functional too (just not very efficiently), and this is honestly the # MORE likely situation for a backend implementor. How do we pick? # Well, taking a page from Haskell type classes and default methods, # we could conceivably register a circular definition (out in terms # of functional, and functional in terms of out) and just require # someone to implement one or the other. We'd have to do a little bit # of work to not register one of these "weak" definitions unless there # is a strong definition somewhere in the DAG! So it's not implemented yet. if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd and f.func.kind( ) is SchemaKind.out: # Never generate a default implementation for out, that's what you # have to define as a backend implementor return None # Note [Direct dispatch bindings] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Signature of the non-dispatched function we'll expose in a header # (e.g., at::cpu::add). We don't generate methods (TODO: do this # when CPUTensor class is a thing); nor do we generate fallback # bindings for manual_cpp_binding functions. cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) # Signature of the wrapper function we'll register to the dispatcher sig = NativeSignature(f.func, prefix="wrapper_") if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: k = f.func.kind() # Construct the body of the wrapper function with signature sig sig_body = [] # We'll use context to keep track of any variables we've brought # into scope while generating code context: List[Union[Binding, Expr]] = list(sig.arguments()) # Initialize the class corresponding to this structured # operator; feeding it the output argument(s) if it is known if self.backend_index.dispatch_key is DispatchKey.Meta: class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" elif self.backend_index.dispatch_key is DispatchKey.CompositeExplicitAutograd: # TODO: dedup this branch class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" parent_class = f"at::meta::structured_{meta.name(self.g)}" else: metadata = self.backend_index.get_kernel(self.g) assert metadata is not None class_name = f"structured_{metadata.kernel}_{k.name}" parent_class = f"{self.cpp_namespace}::structured_{metadata.kernel}" if is_cuda_dispatch_key(self.backend_index.dispatch_key): device_check_args = itertools.chain( f.func.arguments.out, f.func.arguments.flat_positional) sig_body.append( RegisterDispatchKey.gen_device_check( f.device_check, list(device_check_args), sig.name())) if k is SchemaKind.functional: sig_body.append(f"{class_name} op;") elif k is SchemaKind.inplace: sig_body.append(f"{class_name} op(self);") elif k is SchemaKind.out: out_args_str = ', '.join(a.name for a in f.func.arguments.out) sig_body.append(f"{class_name} op({out_args_str});") # Translate the input native arguments into structured # arguments for the meta call meta_exprs = ', '.join(e.expr for e in translate( context, structured.meta_arguments(self.g), method=False)) if self.g.out.precomputed: # If this function group has precomputed elements, the meta function # returns a struct containing them which must be saved so that it # can be unpacked when generating code to call the impl. sig_body.append(f"auto precompute = op.meta({meta_exprs});") # Put all of the contents of the precompute struct into the context # so that translate will be able to return the correct args for the # call to the impl. for precomputed_elems in self.g.out.precomputed.replace.values( ): for arg in precomputed_elems: context.append( Expr( expr=f"precompute.{arg.name}", type=structured.argument_type(arg, binds=arg.name), )) # Add a use of the precompute struct so FB internal compilers don't # complain that there is an unused variable. sig_body.append("(void)precompute;") else: sig_body.append(f"op.meta({meta_exprs});") # After running meta, op.outputs_ is guaranteed to be valid; # add it to the context out_args = structured.out_arguments(self.g) maybe_star = '*' if k is SchemaKind.functional else '' for i, out_arg in enumerate(out_args): assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type context.append( Expr( expr=f"{maybe_star}op.outputs_[{i}]", # TODO: Stop hardcoding that the output type is a Tensor. Note # that for the codegen here this is fine because outputs_ is # hardcoded to be tensor already type=NamedCType(out_arg.nctype.name, MutRefCType(BaseCType(tensorT))))) # With the expanded context, do the impl call (if not a meta # function) if self.backend_index.dispatch_key == DispatchKey.CompositeExplicitAutograd: # TODO: https://github.com/pytorch/pytorch/issues/53023 out_sig_group = CppSignatureGroup.from_native_function( self.g.out, method=False, fallback_binding=f.manual_cpp_binding) out_sig = out_sig_group.most_faithful_signature() api_name = out_sig.name() out_exprs = ', '.join(e.expr for e in translate( context, out_sig.arguments(), method=False)) # TODO: I think this means structured won't work with method # only functions (but maybe you're saved by faithful? iunno.) # NB: Originally I wrote this as an at::redispatch call, but # I got in trouble because that meant I needed a DispatchKeySet # in the wrapper function, which meant I needed a DispatchKeySet # in the DispatchKeyFunctions declarations, but the defined API # there does NOT permit a dispatch key set. I think you can # probably unwind this by calling some function to do the TLS # fetch and get the DispatchKeySet when you don't have it, but # I didn't do it for this version sig_body.append(f"at::{api_name}({out_exprs});") elif self.backend_index.dispatch_key != DispatchKey.Meta: impl_exprs = ', '.join(e.expr for e in translate( context, structured.impl_arguments(self.g), method=False)) sig_body.append(f"op.impl({impl_exprs});") # Destructively return the final tensors # TODO: Do this in translate instead if k is SchemaKind.functional: if len(f.func.returns) == 1: ret_expr = "std::move(op.outputs_[0]).take()" # small optimization else: moved = ', '.join(f"std::move(op.outputs_[{i}]).take()" for i in range(len(f.func.returns))) ret_expr = f"std::make_tuple({moved})" elif k is SchemaKind.inplace: ret_expr = "self" elif k is SchemaKind.out: if len(f.func.returns) == 1: ret_expr = f.func.arguments.out[0].name else: refs = ', '.join(a.name for a in f.func.arguments.out) ret_expr = f"std::forward_as_tuple({refs})" sig_body.append(f"return {ret_expr};") sig_body_str = "\n".join(sig_body) # For an overview of what this template code looks like, see # https://github.com/pytorch/rfcs/pull/9 return f"""\ {self.gen_class( f, k, class_name=class_name, parent_class=parent_class, generate_super=self.g.out.structured_inherits is not None )} {sig.defn()} {{ {sig_body_str} }} """ elif self.target is Target.REGISTRATION: return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' else: assert_never(self.target) # Silence mypy's "Missing return statement" error return None
def gen_unstructured(self, f: NativeFunction) -> Optional[str]: inplace_meta = False if self.dispatch_key not in f.dispatch: if (self.dispatch_key == DispatchKey.Meta and f.func.kind() is SchemaKind.inplace and # Defer to composites for meta implementation DispatchKey.CompositeImplicitAutograd not in f.dispatch and DispatchKey.CompositeExplicitAutograd not in f.dispatch and # Inplace list operations are not supported len(f.func.returns) == 1): inplace_meta = True else: return None if f.manual_kernel_registration: return None if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected( f): return None sig = NativeSignature(f.func, prefix='wrapper_') name = sig.name() returns_type = sig.returns_type().cpp_type() args = sig.arguments() args_str = ', '.join(a.defn() for a in args) # See Note [Direct dispatch bindings] cpp_sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=False) if self.target is Target.NAMESPACED_DECLARATION: result = f"TORCH_API {cpp_sig_group.signature.decl()};\n" if cpp_sig_group.faithful_signature is not None: result += f"TORCH_API {cpp_sig_group.faithful_signature.decl()};\n" return result elif self.target is Target.NAMESPACED_DEFINITION: def generate_defn(cpp_sig: CppSignature) -> str: return f""" {cpp_sig.defn()} {{ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); }} """ result = generate_defn(cpp_sig_group.signature) if cpp_sig_group.faithful_signature is not None: result += generate_defn(cpp_sig_group.faithful_signature) return result elif self.target is Target.ANONYMOUS_DEFINITION: # short circuit for inplace_meta if inplace_meta: assert f.func.arguments.self_arg is not None self_arg_name = f.func.arguments.self_arg.argument.name # TODO: handle in place on tensor list return f""" {returns_type} {name}({args_str}) {{ TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), "Cannot inplace into non-meta tensor with meta tensor argument"); return {self_arg_name}; }} """ impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" args_exprs_str = ', '.join(a.name for a in args) device_guard = "// DeviceGuard omitted" # default if f.device_guard and is_cuda_dispatch_key(self.dispatch_key): has_tensor_options = any( isinstance(a.argument, TensorOptionsArguments) for a in args) if has_tensor_options: # kernel is creating a tensor device_guard = """globalContext().lazyInitCUDA(); const DeviceGuard device_guard(device_or_default(device));""" else: # kernel is operating on existing tensors # There is precedence for which argument we use to do # device guard. This describes the precedence order. self_arg = [ f.func.arguments.self_arg.argument ] if f.func.arguments.self_arg is not None else [] candidate_args = itertools.chain( self_arg, f.func.arguments.out, f.func.arguments.flat_positional) # Only tensor like arguments are eligible device_of = next( (f'{a.name}' for a in candidate_args if a.type.is_tensor_like()), None) if device_of is not None: device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" return f"""\ namespace {{ {returns_type} {name}({args_str}) {{ {device_guard} return {impl_name}({args_exprs_str}); }} }} // anonymous namespace """ elif self.target is Target.REGISTRATION: if f.manual_kernel_registration: return None else: dispatcher_sig = DispatcherSignature.from_schema(f.func) payload = f"TORCH_FN({name})" return f'm.impl("{f.func.name}",\n{payload});\n' else: assert_never(self.target)