def method_registration(f: NativeFunction) -> Optional[str]: if cpp.name(f.func) in MANUAL_TRACER: return None return WRAPPER_REGISTRATION.substitute( name=f.func.name, type_wrapper_name=type_wrapper_name(f), class_type='TraceType', )
def gen_variable_type( out: str, native_yaml_path: str, differentiability_infos: Sequence[DifferentiabilityInfo], template_path: str, operator_selector: SelectiveBuilder, ) -> None: """VariableType.h and VariableType.cpp body This is the at::Type subclass for differentiable tensors. The implementation of each function dispatches to the base tensor type to compute the output. The grad_fn is attached to differentiable functions. """ fns = list( sorted(filter( operator_selector.is_native_function_selected_for_training, parse_native_yaml(native_yaml_path)), key=lambda f: cpp.name(f.func))) fns_with_infos = match_differentiability_info(fns, differentiability_infos) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h', 'VariableType.h') # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 5 shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [ [] for _ in range(num_shards) ] # functions are assigned arbitrarily but stably to a file based on hash for fn in fns_with_infos: x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards shards[x].append(fn) for i, shard in enumerate(shards): gen_variable_type_shard(fm, shard, 'VariableType.cpp', f'VariableType_{i}.cpp') gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp', 'VariableTypeEverything.cpp')
def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert f.func.arguments.self_arg is not None name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful: sig = sig_group.faithful_signature assert sig is not None else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments(), method=True) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) static_dispatch_block = static_dispatch(f, sig, method=True, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ {static_dispatch_block} }} """ result = generate_defn(faithful=False) if sig_group.faithful_signature is not None: result += generate_defn(faithful=True) return result
def callImpl(self, f: NativeFunction) -> str: name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: sig_str = sig_group.signature.decl(is_redispatching_fn=self.is_redispatching_fn) result = f"TORCH_API {sig_str};\n" if sig_group.faithful_signature is not None: sig_str = sig_group.faithful_signature.decl(is_redispatching_fn=self.is_redispatching_fn) result += f"TORCH_API {sig_str};\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) if self.is_redispatching_fn: dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.expr for a in dispatcher_exprs]) dispatcher_call = 'redispatch' else: dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) dispatcher_call = 'call' static_dispatch_block = static_dispatch(f, sig, method=False, backend_index=self.static_dispatch_backend_index) if static_dispatch_block is None: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.{dispatcher_call}({dispatcher_exprs_str}); }} """ else: return f""" // aten::{f.func} {sig.defn(is_redispatching_fn=self.is_redispatching_fn)} {{ {static_dispatch_block} }} """ result = generate_defn(sig_group.faithful_signature is None) if sig_group.faithful_signature is not None: result += generate_defn(True) return result
def gen_variable_type_shard( fm: FileManager, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_name: str, output_name: str, ) -> None: type_declarations: List[str] = [] type_definitions: List[str] = [] wrapper_registrations: List[str] = [] for fn in fns_with_infos: f = fn.func name = cpp.name(f.func) formals = gen_formals(f) type_declarations.append( METHOD_DECLARATION.substitute( return_type=cpp.returns_type(f.func.returns), type_wrapper_name=type_wrapper_name(f), formals=formals, )) if name not in MANUAL_AUTOGRAD and dispatch_strategy( fn) == 'use_derived': type_definitions.append( METHOD_DEFINITION.substitute( return_type=cpp.returns_type(f.func.returns), type_wrapper_name=type_wrapper_name(f), type_definition_body=emit_body(fn), formals=formals, )) wrapper_registrations.append(gen_wrapper_registration(f)) # See Note [Manual Backend kernels] assert (name in MANUAL_BACKEND) == f.manual_kernel_registration # If you want to register a kernel to Autograd, you must make the op abstract. # In other words, this op must have dispatch section in native_functions.yaml. if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info and fn.info.has_derivatives): msg = ( f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. ' f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA ' f'or DefaultBackend in native_functions.yaml. Please see ' f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword ' f'for instructions to choose the right dispatch keyword.') assert f.is_abstract, msg fm.write_with_template( output_name, template_name, lambda: { 'generated_comment': '@' + f'generated from {fm.template_dir}/{template_name}', 'type_derived_method_declarations': type_declarations, 'type_derived_method_definitions': type_definitions, 'wrapper_registrations': wrapper_registrations, })
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: """ Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" input_base = 'input_base' replay_view_func = '' updated_unpacked_args: List[str] = [] known_view_arg_simple_types: List[CType] = [ BaseCType(intT), OptionalCType(BaseCType(intT)), BaseCType(boolT), BaseCType(intArrayRefT) ] for unpacked_binding in unpacked_bindings: arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type if arg == 'self_': updated_unpacked_args.append(input_base) continue if arg_type not in known_view_arg_simple_types: known_types_str = ', '.join( [str(t) for t in known_view_arg_simple_types]) raise TypeError( f'You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: ' f'{known_types_str}. Please update the list or materialize it so that it can be closed ' 'over by value, also add a test in pytorch/xla/test/test_operations.py where this code ' 'is exercised.') if arg_type == BaseCType(intArrayRefT): # It's not safe to close over IntArrayRef by value, since this is a # reference type, so materialize a vector to close over by value arg_vec = arg + '_vec' replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) updated_unpacked_args.append(arg_vec) elif arg_type == OptionalCType(BaseCType(intT)): # Materialize int64_t? to int64_t arg_value = arg + '_val' replay_view_func += OPTIONAL_TO_VAL.substitute(arg=arg, val=arg_value, default='0') updated_unpacked_args.append(arg_value) else: updated_unpacked_args.append(arg) replay_view_call = emit_view_call(f, input_base, updated_unpacked_args) replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( input_base=input_base, replay_view_call=replay_view_call) is_view_with_metadata_change = 'true' if cpp.name( f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else 'false' return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( is_view_with_metadata_change=is_view_with_metadata_change, replay_view_func=replay_view_func)
def cpp_dispatch_target(f: NativeFunction) -> str: name = cpp.name(f.func) if Variant.method in f.variants: return f'self.{name}' if Variant.function in f.variants: if has_tensor_options(f) or f.func.name.name.base.endswith('_like'): namespace = 'torch' else: namespace = 'at' return f'{namespace}::{name}' raise RuntimeError(f'could not dispatch, neither function nor method: {f.func}')
def should_generate_py_binding(f: NativeFunction) -> bool: name = cpp.name(f.func) for skip_regex in SKIP_PYTHON_BINDINGS: if skip_regex.match(name): return False signature = str(f.func) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == signature: return False return True
def gen_trace_type(out: str, native_yaml_path: str, template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 5 shards: List[List[NativeFunction]] = [[] for _ in range(num_shards)] # functions are assigned arbitrarily but stably to a file based on hash native_functions = list( sorted(parse_native_yaml(native_yaml_path), key=lambda f: cpp.name(f.func))) for f in native_functions: x = sum(ord(c) for c in cpp.name(f.func)) % num_shards shards[x].append(f) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for i, shard in enumerate(shards): gen_trace_type_shard(fm, shard, '_%d' % i) gen_trace_type_shard(fm, native_functions, 'Everything')
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded('TraceType.cpp', [ fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER ], key_fn=lambda fn: cpp.name(fn.func), base_env={ 'generated_comment': f'@generated from {template_path}/TraceType.cpp', }, env_callable=gen_trace_type_func, num_shards=5, sharded_keys={ 'trace_method_definitions', 'trace_wrapper_registrations' })
def enforce_same_tensorimpl_and_storage( call: str, unpacked_bindings: List[Binding]) -> str: save_ptrs_stmts: List[str] = [] enforce_same_ptrs_stmts: List[str] = [] if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref( ) if noref_cpp_type == BaseCType(tensorListT): save_ptrs_stmts += [ SAVE_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == ListCType( OptionalCType(BaseCType(tensorT))): save_ptrs_stmts += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == BaseCType(tensorT): save_ptrs_stmts += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or ( not save_ptrs_stmts and not enforce_same_ptrs_stmts) if save_ptrs_stmts and enforce_same_ptrs_stmts: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts) return call
def generate_out_variant_call(g: NativeFunctionsGroup) -> str: schema = g.out.func assert schema.is_out_fn() arg_names = [out_arg.name for out_arg in schema.arguments.out] for arg in schema.arguments.non_out: if isinstance(arg, SelfArgument): arg_names.append(arg.argument.name) else: assert isinstance(arg, Argument) arg_names.append(arg.name) cpp_func_name = cpp.name(schema) cpp_arg_names = ",".join(arg_names) return f'at::cpu::{cpp_func_name}({cpp_arg_names})'
def enforce_same_tensorimpl_and_storage( call: str, unpacked_bindings: List[Binding]) -> str: save_ptrs_stmts: List[str] = [] enforce_same_ptrs_stmts: List[str] = [] if cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE: for unpacked_binding in unpacked_bindings: arg = unpacked_binding.name noref_cpp_type = unpacked_binding.ctype.cpp_type( strip_ref=True) if noref_cpp_type == 'TensorList': save_ptrs_stmts += [ SAVE_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_TENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_TENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == 'c10::List<c10::optional<Tensor>>': save_ptrs_stmts += [ SAVE_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), SAVE_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute( tensorlist_name=arg), ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute( tensorlist_name=arg) ] elif noref_cpp_type == 'Tensor': save_ptrs_stmts += [ SAVE_TENSOR_STORAGE.substitute(tensor_name=arg), SAVE_TENSOR_IMPL.substitute(tensor_name=arg) ] enforce_same_ptrs_stmts += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=arg), ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg) ] assert (save_ptrs_stmts and enforce_same_ptrs_stmts) or ( not save_ptrs_stmts and not enforce_same_ptrs_stmts) if save_ptrs_stmts and enforce_same_ptrs_stmts: call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=save_ptrs_stmts) + \ call + \ RUN_ONLY_IN_DEBUG_MODE.substitute(statements=enforce_same_ptrs_stmts) return call
def should_generate_py_binding(f: NativeFunction) -> bool: name = cpp.name(f.func) for pattern in SKIP_PYTHON_BINDINGS: if re.match('^' + pattern + '$', name): return False args = ', '.join( argument_type_str(arg.type) for arg in signature(f).arguments()) sig = f'{name}({args})' for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == sig: return False return True
def method_definition(f: NativeFunction) -> Optional[str]: if cpp.name(f.func) in MANUAL_TRACER: return None formals = ', '.join( f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' for a in f.func.schema_order_arguments()) return METHOD_DEFINITION.substitute( return_type=cpp.returns_type(f.func.returns), type_wrapper_name=type_wrapper_name(f), formals=formals, type_definition_body=emit_trace_body(f), )
def go(f: NativeFunction) -> str: # header comments deprecated = '[deprecated] ' if ps.deprecated else '' schema_comment = f'// {deprecated}aten::{f.func}' # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ', '.join( map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f) lambda_arg_exprs = dispatch_lambda_exprs(ps, f) inits = '\n'.join(lambda_arg_exprs.inits) lambda_args = ', '.join(lambda_arg_exprs.exprs) # scatter fields # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky # solution for enabling the 'requires_grad' argument for tensor methods # new_full, new_empty, and new_zeros. A much better but more difficult to # implement solution involves refactoring according to Ed's description here: # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 need_set_requires_grad = ps.tensor_options_args and ( not has_tensor_options(f) or (ps.method and ('requires_grad' in parser_outputs))) set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ if need_set_requires_grad else '' if lambda_return == 'void': return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ pybind11::gil_scoped_release no_gil; {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; Py_RETURN_NONE; """ else: typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) namedtuple_typeref = f'&{typename}, ' if typename is not None else '' return f"""\
def emit_dispatch_call(f: NativeFunction, input_base: str, unpacked_args: Sequence[str]) -> str: """ Dispatch call via function in a namespace or method on Tensor.""" dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher_sig.exprs() # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance. # Ops also always have a function variant of the redispatch API. # See Note [Plumbing Keys Through The Dispatcher] for details. dispatch_key_set = 'ks & c10::after_autograd_keyset' call = CALL_REDISPATCH.substitute( api_name=cpp.name( f.func, faithful_name_for_out_overloads=True, ), unpacked_args=[dispatch_key_set] + list(unpacked_args)) return call
def method_registration(f: NativeFunction) -> Optional[str]: if cpp.name(f.func) in MANUAL_TRACER: return None if f.use_c10_dispatcher.dispatcher_uses_new_style(): return WRAPPER_REGISTRATION.substitute( name=f.func.name, type_wrapper_name=type_wrapper_name(f), class_type='TraceType', ) else: return UNBOXEDONLY_WRAPPER_REGISTRATION.substitute( name=f.func.name, type_wrapper_name=type_wrapper_name(f), class_type='TraceType', )
def method_definition(f: NativeFunction) -> Optional[str]: if cpp.name(f.func) in MANUAL_TRACER: return None if f.use_c10_dispatcher.dispatcher_uses_new_style(): formals = ', '.join(f'{cpp.argument_type(a)} {a.name}' for a in f.func.schema_order_arguments()) else: sig_group = CppSignatureGroup.from_schema(f.func, method=False) formals = ', '.join(f'{a.type} {a.name}' for a in sig_group.signature.arguments()) return METHOD_DEFINITION.substitute( return_type=cpp.returns_type(f.func.returns), type_wrapper_name=type_wrapper_name(f), formals=formals, type_definition_body=emit_trace_body(f), )
def emit_namedtuple_typedefs( overloads: Sequence[PythonSignatureNativeFunctionPair] ) -> Tuple[List[str], Dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ flddefnames: Dict[str, str] = { } # map from unique field name lists to field def name flddefs: List[str] = [] # field def declarations typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name typedefs: List[str] = [] # typedef declarations and init code for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue fn_key = '_'.join(fieldnames) fieldsname = flddefnames.get(fn_key) if fieldsname is None: fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}' flddefnames[fn_key] = fieldsname fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames) flddefs.append(f"""\ static PyStructSequence_Field {fieldsname}[] = {{ {fields}, {{nullptr}} }}; """) name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' typenames[tn_key] = typename typedefs.append(f"""\ static PyTypeObject {typename}; static bool {typename}_initialized = false; if (!{typename}_initialized) {{ {typename}_initialized = true; static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }}; PyStructSequence_InitType(&{typename}, &desc); {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; }} """) return flddefs + typedefs, typenames
def method_definition(f: NativeFunction) -> str: assert cpp.name(f.func) not in MANUAL_TRACER formals = ', '.join( # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. ['c10::DispatchKeySet ks'] + [ f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' for a in f.func.schema_order_arguments() ]) return METHOD_DEFINITION.substitute( return_type=cpp.returns_type(f.func.returns).cpp_type(), type_wrapper_name=type_wrapper_name(f), formals=formals, type_definition_body=emit_trace_body(f), )
def generate_return_type_definition_and_map_entry( overloads: Sequence[PythonSignatureNativeFunctionPair], ) -> Tuple[List[str], List[str]]: """ Generate block of function in `python_return_types.cpp` to initialize and return named tuple for a native function which returns named tuple and relevant entry for the map in same file. """ typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name definitions: List[str] = [] # function defintion to register the typedef map_entries: List[str] = [ ] # C++ map entry of <function_name, function creates it namedtuple> for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames) name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' typenames[tn_key] = typename definitions.append(f"""\ PyTypeObject* get_{name}_namedtuple() {{ static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; static PyTypeObject {typename}; static bool is_initialized = false; static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; if (!is_initialized) {{ PyStructSequence_InitType(&{typename}, &desc); {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; is_initialized = true; }} return &{typename}; }} """) map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ') return definitions, map_entries
def __call__(self, f: NativeFunction) -> Optional[str]: if Variant.function not in f.variants: return None name = cpp.name(f.func) sig_group = CppSignatureGroup.from_native_function( f, method=False, fallback_binding=f.manual_cpp_binding) if self.target is Target.DECLARATION: result = f"TORCH_API {sig_group.signature.decl()};\n" if sig_group.faithful_signature is not None: result += f"TORCH_API {sig_group.faithful_signature.decl()};\n" return result if self.target is not Target.DEFINITION: assert_never(self.target) def generate_defn(faithful: bool) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) if faithful and sig_group.faithful_signature is not None: sig = sig_group.faithful_signature else: sig = sig_group.signature dispatcher_exprs = translate(sig.arguments(), dispatcher_sig.arguments()) dispatcher_exprs_str = ', '.join(a.expr for a in dispatcher_exprs) return f""" // aten::{f.func} {sig.defn()} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ result = generate_defn(sig_group.faithful_signature is None) if sig_group.faithful_signature is not None: result += generate_defn(True) return result
def go(f: NativeFunction) -> Optional[str]: if Variant.method not in f.variants: return None assert not f.func.is_out_fn() assert len(f.func.arguments) > 0 assert sum(a.name == 'self' for a in f.func.arguments) == 1 name = cpp.name(f.func) sig_group = CppSignatureGroup.from_schema(f.func, method=True) if target is Target.DECLARATION: result = f"{sig_group.signature.decl()} const;\n" if sig_group.faithful_signature is not None: result += f"{sig_group.faithful_signature.decl()} const;\n" return result assert target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher.cpparguments_exprs( sig.argument_packs()) dispatcher_exprs_str = ', '.join( map(lambda a: a.expr, dispatcher_exprs)) return f""" // aten::{f.func} {sig.defn(prefix="Tensor::")} const {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ result = generate_defn(sig_group.signature) if sig_group.faithful_signature is not None: result += generate_defn(sig_group.faithful_signature) return result
def go(f: NativeFunction) -> Optional[str]: if f.manual_kernel_registration: return None if Variant.function not in f.variants: return None name = cpp.name(f.func) sig_group = CppSignatureGroup.from_schema(f.func, method=False) if target is Target.DECLARATION: result = f"CAFFE2_API {sig_group.signature.decl()};\n" if sig_group.faithful_signature is not None: result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n" return result assert target is Target.DEFINITION def generate_defn(sig: CppSignature) -> str: dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_exprs = dispatcher.cpparguments_exprs( sig.argument_packs()) dispatcher_exprs_str = ', '.join( map(lambda a: a.expr, dispatcher_exprs)) return f""" // aten::{f.func} {sig.defn()} {{ static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::{f.func.name.name}", "{f.func.name.overload_name}") .typed<{dispatcher_sig.type()}>(); return op.call({dispatcher_exprs_str}); }} """ result = generate_defn(sig_group.signature) if sig_group.faithful_signature is not None: if local.use_c10_dispatcher().dispatcher_uses_new_style(): result += generate_defn(sig_group.faithful_signature) return result
def gen_inplace_or_view_type( out: str, native_yaml_path: str, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], template_path: str ) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. num_shards = 2 shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)] # functions are assigned arbitrarily but stably to a file based on hash for fn in fns_with_infos: x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards shards[x].append(fn) fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) for i, shard in enumerate(shards): gen_inplace_or_view_type_shard(fm, shard, f'_{i}') gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
def gen_autograd( aten_path: str, native_functions_path: str, out: str, autograd_dir: str, operator_selector: SelectiveBuilder, disable_autograd: bool = False, ) -> None: # Parse and load derivatives.yaml from .load_derivatives import load_derivatives differentiability_infos = load_derivatives( os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path) template_path = os.path.join(autograd_dir, 'templates') fns = list( sorted(filter( operator_selector.is_native_function_selected_for_training, parse_native_yaml(native_functions_path)), key=lambda f: cpp.name(f.func))) fns_with_diff_infos: List[ NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info( fns, differentiability_infos) # Generate VariableType.h/cpp from .gen_trace_type import gen_trace_type from .gen_variable_type import gen_variable_type if not disable_autograd: gen_variable_type(out, native_functions_path, fns_with_diff_infos, template_path) # operator filter not applied as tracing sources are excluded in selective build gen_trace_type(out, native_functions_path, template_path) # Generate Functions.h/cpp from .gen_autograd_functions import gen_autograd_functions_lib gen_autograd_functions_lib(out, differentiability_infos, template_path) # Generate variable_factories.h from .gen_variable_factories import gen_variable_factories gen_variable_factories(out, native_functions_path, template_path)
def format_trace_op_name(f: NativeFunction) -> str: # TODO: byte-for-byte compatible with old codegen behavior - should clean up if f.func.kind() in (SchemaKind.functional, SchemaKind.out) or f.func.name.name.dunder_method: # special case for *_out functions: the in-place and out-of-place ops # are overloaded with the same name in the JIT trace_name = str(f.func.name.name) trace_name = RENAME_TRACE.get(trace_name, trace_name) return OP_NAME.substitute(trace_name=trace_name) # otherwise, this is an in-place op and we need to emit both in- and # out-of-place versions outplace_trace_name = f.func.name.name.base inplace_trace_name = cpp.name(f.func) outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) return SELECT.substitute( cond='tracer_state->force_outplace', true=OP_NAME.substitute(trace_name=outplace_trace_name), false=OP_NAME.substitute(trace_name=inplace_trace_name), )
def compute_native_function_declaration(f: NativeFunction) -> List[str]: if f.dispatch is None: ns = [cpp.name(f.func)] else: ns = list(f.dispatch.values()) rs = [] # Sometimes a function name shows up multiple times; only generate # it once! seen = set() for n in ns: if n in seen: continue if "legacy::" in n: continue seen.add(n) returns_type = legacy_dispatcher.returns_type(f.func.returns) args = legacy_dispatcher.arguments(f.func) rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(map(lambda a: a.str_with_default(), args))});") return rs
def format_prerecord_trace(f: NativeFunction) -> str: if not should_trace(f): return '' # TODO: clean up old codegen behavior is_inplace = f.func.kind() in (SchemaKind.inplace, SchemaKind.out) and not f.func.name.name.dunder_method add_args = RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, '') if is_inplace else '' additional_inputs = SELECT.substitute( cond='tracer_state->force_outplace', true=add_args, false='', ) if add_args else '' return PRE_RECORD_TRACE.substitute( set_op_name=format_trace_op_name(f), add_trace_inputs=format_trace_inputs(f) + additional_inputs, inplace_guard=INPLACE_GUARD.substitute( name=cpp.name(f.func), mutable_input=f.func.arguments.out[0].name if f.func.arguments.out else 'self', ) if is_inplace else '', )