def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if cpp.name(f.func) == name: return f # some functions only have in-place variants assert name + "_" == cpp.name(functions[0].func) return functions[0]
def format_prerecord_trace(f: NativeFunction) -> str: if not should_trace(f): return "" # TODO: clean up old codegen behavior is_inplace = ( f.func.kind() in (SchemaKind.inplace, SchemaKind.out) and not f.func.name.name.dunder_method ) add_args = ( RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else "" ) additional_inputs = ( SELECT.substitute( cond="tracer_state->force_outplace", true=add_args, false="", ) if add_args else "" ) return PRE_RECORD_TRACE.substitute( set_op_name=format_trace_op_name(f), add_trace_inputs=format_trace_inputs(f) + additional_inputs, inplace_guard=INPLACE_GUARD.substitute( name=cpp.name(f.func), mutable_input=f.func.arguments.out[0].name if f.func.arguments.out else "self", ) if is_inplace else "", )
def is_factory_function(f: NativeFunction) -> bool: if Variant.function not in f.variants: return False name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) return has_tensor_options or name.endswith("_like")
def process_function(f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None sig = CppSignatureGroup.from_native_function(f, method=False).signature formals: List[str] = [] exprs: List[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f"{qualified_type} {arg.name} = {arg.default}") else: formals.append(f"{qualified_type} {arg.name}") if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append( f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)") # Manually set the requires_grad bit on the result tensor. requires_grad = f"{arg.name}.requires_grad()" else: exprs.append(arg.name) return f"""\
def gen_trace_type(out: str, native_functions: List[NativeFunction], template_path: str) -> None: # NOTE: see Note [Sharded File] at the top of the VariableType.cpp # template regarding sharding of the generated files. fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_sharded( "TraceType.cpp", [ fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER ], key_fn=lambda fn: fn.root_name, base_env={ "generated_comment": f"@generated from {template_path}/TraceType.cpp", }, env_callable=gen_trace_type_func, num_shards=5, sharded_keys={ "ops_headers", "trace_method_definitions", "trace_wrapper_registrations", }, )
def emit_namedtuple_call( overloads: Sequence[PythonSignatureNativeFunctionPair], ) -> Tuple[List[str], Dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name typedefs: List[str] = [] # typedef declarations and init code for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' typenames[tn_key] = typename typedefs.append(f"""\ static PyTypeObject* {typename} = get_namedtuple("{name}");""") return typedefs, typenames
def method_registration(f: NativeFunction) -> str: assert cpp.name(f.func) not in MANUAL_TRACER return WRAPPER_REGISTRATION.substitute( name=f.func.name, type_wrapper_name=type_wrapper_name(f), class_type="TraceType", )
def emit_view_lambda(f: NativeFunction, unpacked_bindings: List[Binding]) -> str: """Generate an additional lambda function to recover views in backward when as_strided is not supported. See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.""" input_base = "input_base" replay_view_func = "" updated_unpacked_args: List[str] = [] known_view_arg_simple_types: List[CType] = [ BaseCType(longT), OptionalCType(BaseCType(longT)), BaseCType(boolT), BaseCType(intArrayRefT), BaseCType(symIntArrayRefT), ] for unpacked_binding in unpacked_bindings: arg, arg_type = unpacked_binding.name, unpacked_binding.nctype.type if arg == "self_": updated_unpacked_args.append(input_base) continue if arg_type not in known_view_arg_simple_types: known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types]) raise TypeError( f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: " f"{known_types_str}. Please update the list or materialize it so that it can be closed " "over by value, also add a test in pytorch/xla/test/test_operations.py where this code " "is exercised." ) if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType( symIntArrayRefT ): # It's not safe to close over IntArrayRef by value, since this is a # reference type, so materialize a vector to close over by value arg_vec = arg + "_vec" replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec) updated_unpacked_args.append(arg_vec) elif arg_type == OptionalCType(BaseCType(longT)): # Materialize int64_t? to int64_t arg_value = arg + "_val" replay_view_func += OPTIONAL_TO_VAL.substitute( arg=arg, val=arg_value, default="0" ) updated_unpacked_args.append(arg_value) else: updated_unpacked_args.append(arg) replay_view_call = emit_view_call(f, input_base, updated_unpacked_args) replay_view_func += REPLAY_VIEW_LAMBDA_FUNC.substitute( input_base=input_base, replay_view_call=replay_view_call ) is_view_with_metadata_change = ( "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false" ) return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute( is_view_with_metadata_change=is_view_with_metadata_change, replay_view_func=replay_view_func, )
def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if (not f.func.is_functional_fn() and not f.func.is_out_fn() and name == str(f.func.name.name)): return f # some functions only have in-place variants assert name + "_" == cpp.name(functions[0].func) return functions[0]
def generate_call_to_view_ops(g: NativeFunctionsViewGroup, backend_index: BackendIndex) -> str: schema = g.view.func kernel_name = cpp.name(schema) kernel = backend_index.get_kernel(g.view) if kernel: kernel_name = kernel.kernel arg_names = (arg.name for arg in schema.schema_order_arguments()) namespace_name = "native" return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
def go(f: NativeFunction) -> str: # header comments if isinstance(ps, PythonSignatureDeprecated): schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" else: schema_comment = f"// aten::{f.func}" deprecated = "[deprecated] " if ps.deprecated else "" # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ", ".join( map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f) lambda_arg_exprs = dispatch_lambda_exprs(ps, f) inits = "\n".join(lambda_arg_exprs.inits) lambda_args = ", ".join(lambda_arg_exprs.exprs) # scatter fields # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky # solution for enabling the 'requires_grad' argument for tensor methods # new_full, new_empty, and new_zeros. A much better but more difficult to # implement solution involves refactoring according to Ed's description here: # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 need_set_requires_grad = ps.tensor_options_args and ( not has_tensor_options(f) or (ps.method and ("requires_grad" in parser_outputs))) set_requires_grad = ( f'.set_requires_grad({parser_outputs["requires_grad"].expr})' if need_set_requires_grad else "") if lambda_return == "void": return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ pybind11::gil_scoped_release no_gil; {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; Py_RETURN_NONE; """ else: typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) namedtuple_typeref = f"{typename}, " if typename is not None else "" return f"""\
def should_generate_py_binding(f: NativeFunction) -> bool: name = cpp.name(f.func) for skip_regex in SKIP_PYTHON_BINDINGS: if skip_regex.match(name): return False signature = str(f.func) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == signature: return False return True
def generate_out_variant_call(g: NativeFunctionsGroup) -> str: schema = g.out.func assert schema.is_out_fn() arg_names = [out_arg.name for out_arg in schema.arguments.out] for arg in schema.arguments.non_out: if isinstance(arg, SelfArgument): arg_names.append(arg.argument.name) else: assert isinstance(arg, Argument) arg_names.append(arg.name) cpp_func_name = cpp.name(schema) cpp_arg_names = ",".join(arg_names) return f"at::cpu::{cpp_func_name}({cpp_arg_names})"
def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str: if f.func.name.overload_name: name = f"{cpp.name(f.func)}_{f.func.name.overload_name}" else: name = cpp.name(f.func) # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key. # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated, # the key argument should not be passed. # We do not append key if it is Default so that generated functions from # before per-dispatch-key derivatives were added retain the same names. if key != "Default": name = name + f"_{key}" return name
def generate_return_type_definition_and_map_entry( overloads: Sequence[PythonSignatureNativeFunctionPair], ) -> Tuple[List[str], List[str]]: """ Generate block of function in `python_return_types.cpp` to initialize and return named tuple for a native function which returns named tuple and relevant entry for the map in same file. """ typenames: Dict[ str, str ] = {} # map from unique name + field name lists to typedef name definitions: List[str] = [] # function defintion to register the typedef map_entries: List[ str ] = [] # C++ map entry of <function_name, function creates it namedtuple> for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames) name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' typenames[tn_key] = typename definitions.append( f"""\ PyTypeObject* get_{name}_namedtuple() {{ static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; static PyTypeObject {typename}; static bool is_initialized = false; static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; if (!is_initialized) {{ PyStructSequence_InitType(&{typename}, &desc); {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; is_initialized = true; }} return &{typename}; }} """ ) map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ') return definitions, map_entries
def gen_autograd( native_functions_path: str, tags_path: str, out: str, autograd_dir: str, operator_selector: SelectiveBuilder, disable_autograd: bool = False, ) -> None: # Parse and load derivatives.yaml differentiability_infos, used_dispatch_keys = load_derivatives( os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path) template_path = os.path.join(autograd_dir, "templates") native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions fns = list( sorted( filter(operator_selector.is_native_function_selected_for_training, native_funcs), key=lambda f: cpp.name(f.func), )) fns_with_diff_infos: List[ NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info( fns, differentiability_infos) # Generate VariableType.h/cpp if not disable_autograd: gen_variable_type( out, native_functions_path, tags_path, fns_with_diff_infos, template_path, used_dispatch_keys, ) gen_inplace_or_view_type(out, native_functions_path, tags_path, fns_with_diff_infos, template_path) # operator filter not applied as tracing sources are excluded in selective build gen_trace_type(out, native_funcs, template_path) # Generate Functions.h/cpp gen_autograd_functions_lib(out, differentiability_infos, template_path) # Generate variable_factories.h gen_variable_factories(out, native_functions_path, tags_path, template_path)
def should_generate_py_binding(f: NativeFunction) -> bool: # So far, all NativeFunctions that are entirely code-generated do not get python bindings. if "generated" in f.tags: return False name = cpp.name(f.func) for skip_regex in SKIP_PYTHON_BINDINGS: if skip_regex.match(name): return False signature = str(f.func) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == signature: return False return True
def method_definition(f: NativeFunction) -> str: assert cpp.name(f.func) not in MANUAL_TRACER formals = ", ".join( # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance. # See Note [Plumbing Keys Through The Dispatcher] for details. ["c10::DispatchKeySet ks"] + [ f'{cpp.argument_type(a, binds="__placeholder__").cpp_type()} {a.name}' for a in f.func.schema_order_arguments() ]) return METHOD_DEFINITION.substitute( return_type=cpp.returns_type(f.func.returns).cpp_type(), type_wrapper_name=type_wrapper_name(f), formals=formals, type_definition_body=emit_trace_body(f), )
def process_function(f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) sigs = [cpp_sigs.signature] if cpp_sigs.symint_signature is not None: sigs.append(cpp_sigs.symint_signature) r = "" for sig in sigs: formals: List[str] = [] exprs: List[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f"{qualified_type} {arg.name} = {arg.default}") else: formals.append(f"{qualified_type} {arg.name}") if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append( f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)" ) # Manually set the requires_grad bit on the result tensor. requires_grad = f"{arg.name}.requires_grad()" else: exprs.append(arg.name) r += f"""\ inline at::Tensor {sig.name()}({', '.join(formals)}) {{ at::AutoDispatchBelowADInplaceOrView guard; return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); }} """ return r
def format_trace_op_name(f: NativeFunction) -> str: # TODO: byte-for-byte compatible with old codegen behavior - should clean up if (f.func.kind() in (SchemaKind.functional, SchemaKind.out) or f.func.name.name.dunder_method): # special case for *_out functions: the in-place and out-of-place ops # are overloaded with the same name in the JIT trace_name = str(f.func.name.name) trace_name = RENAME_TRACE.get(trace_name, trace_name) return OP_NAME.substitute(trace_name=trace_name) # otherwise, this is an in-place op and we need to emit both in- and # out-of-place versions outplace_trace_name = f.func.name.name.base inplace_trace_name = cpp.name(f.func) outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name) inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name) return SELECT.substitute( cond="tracer_state->force_outplace", true=OP_NAME.substitute(trace_name=outplace_trace_name), false=OP_NAME.substitute(trace_name=inplace_trace_name), )
def generate_out_variant_call(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: schema = g.out.func assert schema.is_out_fn() arg_names = [] kernel_name = get_out_kernel_name(g, backend_index) if g.structured: # structured op starts with the output tensor argument. arg_names = [out_arg.name for out_arg in schema.arguments.out] else: arg_names = [] for arg in schema.arguments.non_out: if isinstance(arg, SelfArgument): arg_names.append(arg.argument.name) else: assert isinstance(arg, Argument) arg_names.append(arg.name) if not g.structured: assert len(schema.arguments.out) == 1 arg_names.append(schema.arguments.out[0].name) cpp_func_name = cpp.name(schema) cpp_arg_names = ",".join(arg_names) namespace_name = "cpu" if g.structured else "native" return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
def name(func: FunctionSchema) -> str: return cpp.name(func)
def gen_namedtuple_typename_key(f: NativeFunction) -> str: name = cpp.name(f.func) fieldnames = namedtuple_fieldnames(f.func.returns) return "_".join([name] + fieldnames)
def generate_function( f: NativeFunction, k: SchemaKind ) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: from torchgen.api import cpp if k == SchemaKind.functional: assert f.func.kind() != SchemaKind.functional gets_composite_kernel = True # The new "functional" NativeFunction has: # - any mutable arguments have been converted into (immutable) returns. # (if a mutable argument was not also a return, it gets converted to one) # - a "functional" overload name. # The default grouping logic in signature() actually already does this, # so we can piggy-back off it (but we still want return names) func = f.func.signature(keep_return_names=True).with_name( f.func.name.remove_inplace().with_overload( "functional" if not f.func.name.overload_name else f"{f.func.name.overload_name}_functional")) elif k == SchemaKind.out: # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, # but at least today, there is no good reason to actually use them. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. gets_composite_kernel = False if f.func.kind() == SchemaKind.inplace: func = self_to_out_signature(f.func) elif f.func.kind() == SchemaKind.mutable: func = mutable_to_out_signature(f.func) else: raise AssertionError( "We only bother generating out= functions from either inplace or mutable variants" ) else: raise AssertionError( "We currently only generate either functional or out= NativeFunctions" ) if gets_composite_kernel: backend_metadata = { DispatchKey.CompositeExplicitAutograd: { func.name: BackendMetadata(cpp.name(func), structured=False) } } else: backend_metadata = {} return ( NativeFunction( func=func, use_const_ref_for_mutable_tensors=f. use_const_ref_for_mutable_tensors, # These generated fn's aren't meant to be user friendly- don't generate methods. variants=set([Variant.function]), structured=False, structured_delegate=None, structured_inherits=None, precomputed=None, autogen=[], ufunc_inner_loop={}, manual_kernel_registration=False, manual_cpp_binding=False, python_module=None, category_override=None, device_guard=False, device_check=DeviceCheckType.NoCheck, loc=f.loc, cpp_no_default_args=set(), is_abstract=f.is_abstract, has_composite_implicit_autograd_kernel=False, has_composite_explicit_autograd_kernel=gets_composite_kernel, # Every generated NativeFunction gets a "generated" tag, so it's easy to tell # which NativeFunction objects did not come directly from native_functions.yaml. tags=set(["generated"]), ), backend_metadata, )
def type_wrapper_name(f: NativeFunction) -> str: if f.func.name.overload_name: return f"{cpp.name(f.func)}_{f.func.name.overload_name}" else: return cpp.name(f.func)
def generate_function( f: NativeFunction, k: SchemaKind ) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: from torchgen.api import cpp if k == SchemaKind.functional: assert f.func.kind() != SchemaKind.functional # The new "functional" NativeFunction has: # - any mutable arguments have been converted into (immutable) returns. # (if a mutable argument was not also a return, it gets converted to one) # - "_functional" appended to the base name, ONLY IF this op has a mutable variant. # See Note [Overload Ambiguity With Functional Variants] # The default grouping logic in signature() actually already does this, # so we can piggy-back off it (but we still want return names) func = f.func.signature(keep_return_names=True).with_name( OperatorName( name=BaseOperatorName( base=f.func.name.name.base, inplace=False, dunder_method=f.func.name.name.dunder_method, # See Note [Overload Ambiguity With Functional Variants] functional_overload=f.func.kind() == SchemaKind.mutable, ), overload_name=f.func.name.overload_name, ) ) elif k == SchemaKind.out: # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily, # but at least today, there is no good reason to actually use them. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them. if f.func.kind() == SchemaKind.inplace: func = self_to_out_signature(f.func) elif f.func.kind() == SchemaKind.mutable: func = mutable_to_out_signature(f.func) elif f.func.kind() == SchemaKind.functional: func = functional_to_out_signature(f.func) else: raise AssertionError( "We only bother generating out= functions from either inplace or mutable or functional variants" ) else: raise AssertionError( "We currently only generate either functional or out= NativeFunctions" ) # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and # `randn.generator_with_names_out`. kernel_name = ( func.name.unambiguous_name() if func.kind() == SchemaKind.out else cpp.name(func) ) backend_metadata = { DispatchKey.CompositeExplicitAutograd: { func.name: BackendMetadata( kernel=kernel_name, structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE, ) } } return ( NativeFunction( func=func, use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, # These generated fn's aren't meant to be user friendly- don't generate methods. variants=set([Variant.function]), structured=False, structured_delegate=None, structured_inherits=None, precomputed=None, autogen=[], ufunc_inner_loop={}, manual_kernel_registration=False, manual_cpp_binding=False, python_module=None, category_override=None, device_guard=False, device_check=DeviceCheckType.NoCheck, loc=f.loc, cpp_no_default_args=set(), is_abstract=f.is_abstract, has_composite_implicit_autograd_kernel=False, has_composite_explicit_autograd_kernel=True, has_composite_explicit_autograd_non_functional_kernel=False, # Every generated NativeFunction gets a "generated" tag, so it's easy to tell # which NativeFunction objects did not come directly from native_functions.yaml. tags=set(["generated"]) | (f.tags & {"nondeterministic_seeded"}), namespace=f.namespace, ), backend_metadata, )
def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], op_counter: Counter[str], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if (not f.func.is_functional_fn() and not f.func.is_out_fn() and name == str(f.func.name.name)): return f # some functions only have in-place variants assert name + "_" == cpp.name(functions[0].func) return functions[0] def split_names(raw_names: str) -> Tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" return tuple(x.strip() for x in raw_names.split(",")) def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ uses_grad = False # true if any derivative uses "grad" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" uses_named_grads = False # true if any derivative uses "grad_{name}" used_grads_indices: List[int] = [] # which indices of grads are used for d in derivatives: formula = d.formula uses_grad = uses_grad or bool( re.findall(IDENT_REGEX.format("grad"), formula)) num_grads_uses += len( re.findall(IDENT_REGEX.format("grads"), formula)) uses_named_grads = uses_named_grads or bool(d.named_gradients) used_grads_indices.extend(used_gradient_indices(formula)) # This is a basic sanity check: the number of places we see # "grads" should be no fewer than the number of indices we see # inside "grads". They may not be equal because we may use # "grads" without an index. assert num_grads_uses >= len(used_grads_indices) # Thus if the number is equal, every use of grads is also # indexed. only_used_grads_indices = num_grads_uses == len(used_grads_indices) if uses_grad and num_grads_uses > 0: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " "occurrences of 'grad' with 'grads[0]'") if only_used_grads_indices and set(used_grads_indices) == {0}: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " "declaration.") if uses_named_grads and (uses_grad or num_grads_uses > 0): raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' "only one method for identifying gradients.") @with_native_function def set_up_derivatives( f: NativeFunction, ) -> Tuple[Sequence[Derivative], Sequence[ForwardDerivative], Sequence[Binding], Sequence[str], Sequence[str], ]: # Set up the derivative information derivatives: List[Derivative] = [] forward_derivatives: List[ForwardDerivative] = [] non_differentiable_arg_names: List[str] = [] args_with_derivatives_set: Set[str] = set() all_arg_names = [a.name for a in cpp_arguments(f)] all_ret_names = [r.name for r in f.func.returns ] # only used for the assert below # output_differentiability is captured from the enclosed # scope. Don't modify it. # # If it is not present, then no output is explicitly # undifferentiable. # # It may be present and shorter than the length of return # values. If that's the case, any return value that does not # have a corresponding entry is considered not differentiable. differentiability = output_differentiability or [True] * len( f.func.returns) # A return is available as a named gradient ... available_named_gradients = [ f"grad_{ret.name}" for ret, differentiable in zip(f.func.returns, differentiability) # if it has not been explicitly made undifferentiable if differentiable # and if it has a name and ret.name is not None # and if its type is differentiable and ret.type.is_tensor_like() ] for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) for name in names: assert not (name in all_arg_names and name in all_ret_names), ( f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " f"expected '{name}' to not be both an input arg and named return. " ) if is_forward_derivative_definition(all_arg_names, names): forward_derivatives.append( create_forward_derivative(f, formula, names)) else: if formula.lower().strip() == "non_differentiable": non_differentiable_arg_names += names else: derivative = create_derivative(f, formula, names, available_named_gradients) derivatives.append(derivative) args_with_derivatives_set |= set(names) overlap = args_with_derivatives_set.intersection( non_differentiable_arg_names) if overlap: raise RuntimeError( f"derivatives definition for {defn} have overlapped non_differentiable " f"and differentiable variables: {overlap}") # Next, let us determine the list of inputs in order. # TODO: do we need eagerly calculate and save it here? Can it be derived # from NativeFunction and `derivatives` on callsites instead? args_with_derivatives = [ a for a in cpp_arguments(f) if a.name in args_with_derivatives_set ] # Postprocess forward derivatives definitions now that we know the differentiable arguments forward_derivatives = postprocess_forward_derivatives( f, defn_name, all_arg_names, derivatives, forward_derivatives, args_with_derivatives, ) # Test to see if the use of 'grads' makes sense. check_grad_usage(defn_name, derivatives) return ( derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients, ) # NB: Removes 'name' from defn dictionary specification = defn.pop("name") defn_name, _ = split_name_params(specification) # NB: Removes 'output_differentiability' from defn dictionary # `None` means all differentiable. output_differentiability = defn.pop("output_differentiability", None) output_differentiability_conditions = None if output_differentiability and any( [isinstance(diff, str) for diff in output_differentiability]): if len(output_differentiability) != 1: raise RuntimeError( f"Not supported: for {specification}," f"output_differentiability must either be " f"List[bool] or a List[str] where each str is a " f"condition. In the case where it is a condition, " f"we only support single-output functions. " f"Please file us an issue. ") output_differentiability_conditions = output_differentiability output_differentiability = [True] schema_function = functions_by_schema.get(specification) if not schema_function: avail = "\n".join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name) raise RuntimeError( f"could not find ATen function for schema: {specification} " f". Available signatures:\n{avail}") # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here # to map in-place schemas to the out-of-place variants. # TODO: maybe the logic to handle the legacy schema is no longer necessary? signature = schema_function.func.signature() functions = functions_by_signature[signature] if len(functions) == 0: avail = "\n".join( str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name) raise RuntimeError( f"could not find ATen function for legacy signature: {signature} " f"corresponding to schema {specification}. Please report a bug to PyTorch. " f"Available signatures:\n{avail}") canonical = canonical_function(functions, defn_name) if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named grad_input_mask, " "but this name would be shadowed by our codegen. " "Please use a different name in native_functions.yaml.") if "result" in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named result, " "but this is only allowed for outputs." "Please use a different name in native_functions.yaml.") ( derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients, ) = set_up_derivatives(canonical) used_named_gradients: Set[str] = set() for d in derivatives: used_named_gradients |= d.named_gradients # only assign an op name if we are actually going to calculate a derivative op = None if args_with_derivatives: op_prefix = _create_op_prefix(defn_name) op = f"{op_prefix}{op_counter[op_prefix]}" op_counter[op_prefix] += 1 return DifferentiabilityInfo( name=defn_name, func=canonical, op=op, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars( [v for d in derivatives for v in d.saved_inputs]), all_saved_outputs=dedup_vars( [v for d in derivatives for v in d.saved_outputs]), available_named_gradients=available_named_gradients, used_named_gradients=used_named_gradients, args_with_derivatives=args_with_derivatives, non_differentiable_arg_names=non_differentiable_arg_names, output_differentiability=output_differentiability, output_differentiability_conditions=output_differentiability_conditions, )
def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool: f = fn.func name = cpp.name(f.func) return name not in MANUAL_AUTOGRAD and dispatch_strategy( fn) == "use_derived"
def wrapper_name(func: FunctionSchema) -> str: if func.name.overload_name: return f"{cpp.name(func)}_{func.name.overload_name}" else: return cpp.name(func)
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: kernel = backend_index.get_kernel(g.out) if g.structured or kernel is None: return cpp.name(g.out.func) return kernel.kernel