def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], op_counter: Counter[str], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if (not f.func.is_functional_fn() and not f.func.is_out_fn() and name == str(f.func.name.name)): return f # some functions only have in-place variants assert name + "_" == cpp.name(functions[0].func) return functions[0] def split_names(raw_names: str) -> Tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" return tuple(x.strip() for x in raw_names.split(",")) def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ uses_grad = False # true if any derivative uses "grad" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" uses_named_grads = False # true if any derivative uses "grad_{name}" used_grads_indices: List[int] = [] # which indices of grads are used for d in derivatives: formula = d.formula uses_grad = uses_grad or bool( re.findall(IDENT_REGEX.format("grad"), formula)) num_grads_uses += len( re.findall(IDENT_REGEX.format("grads"), formula)) uses_named_grads = uses_named_grads or bool(d.named_gradients) used_grads_indices.extend(used_gradient_indices(formula)) # This is a basic sanity check: the number of places we see # "grads" should be no fewer than the number of indices we see # inside "grads". They may not be equal because we may use # "grads" without an index. assert num_grads_uses >= len(used_grads_indices) # Thus if the number is equal, every use of grads is also # indexed. only_used_grads_indices = num_grads_uses == len(used_grads_indices) if uses_grad and num_grads_uses > 0: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " "occurrences of 'grad' with 'grads[0]'") if only_used_grads_indices and set(used_grads_indices) == {0}: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " "declaration.") if uses_named_grads and (uses_grad or num_grads_uses > 0): raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' "only one method for identifying gradients.") @with_native_function def set_up_derivatives( f: NativeFunction, ) -> Tuple[Sequence[Derivative], Sequence[ForwardDerivative], Sequence[Binding], Sequence[str], Sequence[str], ]: # Set up the derivative information derivatives: List[Derivative] = [] forward_derivatives: List[ForwardDerivative] = [] non_differentiable_arg_names: List[str] = [] args_with_derivatives_set: Set[str] = set() all_arg_names = [a.name for a in cpp_arguments(f)] all_ret_names = [r.name for r in f.func.returns ] # only used for the assert below # output_differentiability is captured from the enclosed # scope. Don't modify it. # # If it is not present, then no output is explicitly # undifferentiable. # # It may be present and shorter than the length of return # values. If that's the case, any return value that does not # have a corresponding entry is considered not differentiable. differentiability = output_differentiability or [True] * len( f.func.returns) # A return is available as a named gradient ... available_named_gradients = [ f"grad_{ret.name}" for ret, differentiable in zip(f.func.returns, differentiability) # if it has not been explicitly made undifferentiable if differentiable # and if it has a name and ret.name is not None # and if its type is differentiable and ret.type.is_tensor_like() ] for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) for name in names: assert not (name in all_arg_names and name in all_ret_names), ( f"While processing the derivative formula for '{f.func.name}' wrt '{name}', " f"expected '{name}' to not be both an input arg and named return. " ) if is_forward_derivative_definition(all_arg_names, names): forward_derivatives.append( create_forward_derivative(f, formula, names)) else: if formula.lower().strip() == "non_differentiable": non_differentiable_arg_names += names else: derivative = create_derivative(f, formula, names, available_named_gradients) derivatives.append(derivative) args_with_derivatives_set |= set(names) overlap = args_with_derivatives_set.intersection( non_differentiable_arg_names) if overlap: raise RuntimeError( f"derivatives definition for {defn} have overlapped non_differentiable " f"and differentiable variables: {overlap}") # Next, let us determine the list of inputs in order. # TODO: do we need eagerly calculate and save it here? Can it be derived # from NativeFunction and `derivatives` on callsites instead? args_with_derivatives = [ a for a in cpp_arguments(f) if a.name in args_with_derivatives_set ] # Postprocess forward derivatives definitions now that we know the differentiable arguments forward_derivatives = postprocess_forward_derivatives( f, defn_name, all_arg_names, derivatives, forward_derivatives, args_with_derivatives, ) # Test to see if the use of 'grads' makes sense. check_grad_usage(defn_name, derivatives) return ( derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients, ) # NB: Removes 'name' from defn dictionary specification = defn.pop("name") defn_name, _ = split_name_params(specification) # NB: Removes 'output_differentiability' from defn dictionary # `None` means all differentiable. output_differentiability = defn.pop("output_differentiability", None) output_differentiability_conditions = None if output_differentiability and any( [isinstance(diff, str) for diff in output_differentiability]): if len(output_differentiability) != 1: raise RuntimeError( f"Not supported: for {specification}," f"output_differentiability must either be " f"List[bool] or a List[str] where each str is a " f"condition. In the case where it is a condition, " f"we only support single-output functions. " f"Please file us an issue. ") output_differentiability_conditions = output_differentiability output_differentiability = [True] schema_function = functions_by_schema.get(specification) if not schema_function: avail = "\n".join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name) raise RuntimeError( f"could not find ATen function for schema: {specification} " f". Available signatures:\n{avail}") # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here # to map in-place schemas to the out-of-place variants. # TODO: maybe the logic to handle the legacy schema is no longer necessary? signature = schema_function.func.signature() functions = functions_by_signature[signature] if len(functions) == 0: avail = "\n".join( str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name) raise RuntimeError( f"could not find ATen function for legacy signature: {signature} " f"corresponding to schema {specification}. Please report a bug to PyTorch. " f"Available signatures:\n{avail}") canonical = canonical_function(functions, defn_name) if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named grad_input_mask, " "but this name would be shadowed by our codegen. " "Please use a different name in native_functions.yaml.") if "result" in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named result, " "but this is only allowed for outputs." "Please use a different name in native_functions.yaml.") ( derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients, ) = set_up_derivatives(canonical) used_named_gradients: Set[str] = set() for d in derivatives: used_named_gradients |= d.named_gradients # only assign an op name if we are actually going to calculate a derivative op = None if args_with_derivatives: op_prefix = _create_op_prefix(defn_name) op = f"{op_prefix}{op_counter[op_prefix]}" op_counter[op_prefix] += 1 return DifferentiabilityInfo( name=defn_name, func=canonical, op=op, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars( [v for d in derivatives for v in d.saved_inputs]), all_saved_outputs=dedup_vars( [v for d in derivatives for v in d.saved_outputs]), available_named_gradients=available_named_gradients, used_named_gradients=used_named_gradients, args_with_derivatives=args_with_derivatives, non_differentiable_arg_names=non_differentiable_arg_names, output_differentiability=output_differentiability, output_differentiability_conditions=output_differentiability_conditions, )
def load_deprecated_signatures( pairs: Sequence[PythonSignatureNativeFunctionPair], deprecated_yaml_path: str, *, method: bool, pyi: bool, ) -> List[PythonSignatureNativeFunctionPair]: # The deprecated.yaml doesn't have complete type information, we need # find and leverage the original ATen signature (to which it delegates # the call) to generate the full python signature. # We join the deprecated and the original signatures using type-only form. # native function -> type-only signature @with_native_function def signature_original(f: NativeFunction) -> str: # remove inplace suffix but keep outplace suffix opname = str(f.func.name.name.base) if f.func.is_out_fn(): opname += "_out" if f.func.name.name.inplace and pyi: opname += "_" args = CppSignatureGroup.from_native_function( f, method=False).signature.arguments() # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml. types = ", ".join( argument_type_str(a.argument.type) for a in args if isinstance(a.argument, Argument)) return f"{opname}({types})" # deprecated -> type-only native signature (according to the call order) def signature_deprecated(opname: str, params: List[str], call_args: List[str]) -> str: # create a mapping of parameter name to parameter type types: Dict[str, str] = {} for param in params: if param == "*": continue type, name = param.split(" ") types[name] = type # if the name in the call is not in the parameter list, assume it's # a literal Scalar rearranged_types = ", ".join( types.get(arg, "Scalar") for arg in call_args) return f"{opname}({rearranged_types})" # group the original ATen signatures by type-only signature grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) for pair in pairs: grouped[signature_original(pair.function)].append(pair) # find matching original signatures for each deprecated signature results: List[PythonSignatureNativeFunctionPair] = [] with open(deprecated_yaml_path, "r") as f: deprecated_defs = yaml.load(f, Loader=YamlLoader) for deprecated in deprecated_defs: _, params = split_name_params(deprecated["name"]) aten_name, call_args = split_name_params(deprecated["aten"]) for pair in grouped[signature_deprecated(aten_name, params, call_args)]: # It uses the types from the original ATen declaration, but the # ordering and parameter names from the deprecated overload. Any # default parameter values from the original ATen declaration are # ignored. # Deprecated signature might reorder input_args and input_kwargs, # but never changes output_args nor TensorOptions (if any?), # so here we only look into these two types of args. python_sig = pair.signature src_args: Dict[str, PythonArgument] = { a.name: PythonArgument( name=a.name, type=a.type, default=None, default_init=None, ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs) } args: List[str] = [] input_args: List[PythonArgument] = [] input_kwargs: List[PythonArgument] = [] kwarg_only = False for param in params: if param == "*": kwarg_only = True continue _, param_name = param.split(" ") args.append(param_name) if param_name not in src_args: # output argument continue if not kwarg_only: if not method or param_name != "self": input_args.append(src_args[param_name]) else: input_kwargs.append(src_args[param_name]) results.append( PythonSignatureNativeFunctionPair( signature=PythonSignatureDeprecated( name=python_sig.name, input_args=tuple(input_args), input_kwargs=tuple(input_kwargs), output_args=python_sig.output_args, tensor_options_args=python_sig.tensor_options_args, method=python_sig.method, deprecated_args_names=tuple(args), deprecated_args_exprs=tuple(call_args), returns=python_sig.returns, ), function=pair.function, )) return results
def load_deprecated_signatures( pairs: Sequence[PythonSignatureNativeFunctionPair], deprecated_yaml_path: str, *, method: bool, pyi: bool, ) -> List[PythonSignatureNativeFunctionPair]: # The deprecated.yaml doesn't have complete type information, we need # find and leverage the original ATen signature (to which it delegates # the call) to generate the full python signature. # We join the deprecated and the original signatures using type-only form. # group the original ATen signatures by name grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list) for pair in pairs: grouped[pair.signature.name].append(pair) # find matching original signatures for each deprecated signature results: List[PythonSignatureNativeFunctionPair] = [] with open(deprecated_yaml_path, "r") as f: deprecated_defs = yaml.load(f, Loader=YamlLoader) for deprecated in deprecated_defs: schema = FunctionSchema.parse(deprecated["name"]) aten_name, call_args = split_name_params(deprecated["aten"]) is_out = aten_name.endswith("_out") if is_out: aten_name = aten_name.replace("_out", "") # HACK: these are fixed constants used to pass the the aten function. # The type must be known ahead of time known_constants = { "1": Type.parse("Scalar"), } schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} for name in call_args: assert (name in schema_args_by_name or name in known_constants ), f"deprecation definiton: Unrecognized value {name}" # Map deprecated signature arguments to their aten signature and test # if the types and alias annotation match. def is_schema_compatible(aten_schema: FunctionSchema, ) -> bool: arguments: Iterable[Argument] if is_out: arguments = itertools.chain(aten_schema.arguments.out, aten_schema.arguments.flat_non_out) else: arguments = aten_schema.arguments.flat_all for i, arg in enumerate(arguments): if i < len(call_args): arg_name = call_args[i] if arg_name in known_constants: schema_type = known_constants[arg_name] schema_annotation = None else: schema_arg = schema_args_by_name[arg_name] schema_type = schema_arg.type schema_annotation = schema_arg.annotation if schema_type != arg.type or schema_annotation != arg.annotation: return False else: if arg.default is None: return False return len(schema.returns) == len(aten_schema.returns) and all( a == b for a, b in zip(schema.returns, aten_schema.returns)) any_schema_found = False for pair in grouped[aten_name]: if not is_schema_compatible(pair.function.func): continue any_schema_found = True python_sig = signature_from_schema( schema, category_override=pair.function.category_override, method=method, pyi=pyi, ) results.append( PythonSignatureNativeFunctionPair( signature=PythonSignatureDeprecated( name=python_sig.name, input_args=python_sig.input_args, input_kwargs=python_sig.input_kwargs, output_args=python_sig.output_args, tensor_options_args=python_sig.tensor_options_args, method=python_sig.method, deprecated_schema=schema, deprecated_args_exprs=tuple(call_args), returns=python_sig.returns, ), function=pair.function, )) assert ( any_schema_found ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" return results