def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]: # Do some caching as this is a deterministic function global _GLOBAL_LOAD_DERIVATIVE_CACHE key = (derivatives_yaml_path, native_yaml_path) if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE: with open(derivatives_yaml_path, 'r') as f: definitions = yaml.load(f, Loader=YamlLoader) functions = parse_native_yaml(native_yaml_path).native_functions # What's the difference between function schema v.s. signature? # function schema is the complete declaration including mutability annotation / default value and etc. # signature is the canonical schema for a group of functions (in-place/out/functional variants) # that are semantically related. functions_by_signature: Dict[FunctionSchema, List[NativeFunction]] = defaultdict(list) functions_by_schema: Dict[str, NativeFunction] = dict() for function in functions: functions_by_signature[function.func.signature()].append(function) assert str(function.func) not in functions_by_schema functions_by_schema[str(function.func)] = function infos = [ create_differentiability_info(defn, functions_by_signature, functions_by_schema) for defn in definitions ] # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate # step. We only assign op names to those with differentiable args, and only append suffix to # duplicated op names. This can be simplified if the first of the duplicates can be named # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons. op_names = create_op_names(infos) res = [ DifferentiabilityInfo( name=info.name, func=info.func, op=op_name, derivatives=info.derivatives, forward_derivatives=info.forward_derivatives, all_saved_inputs=info.all_saved_inputs, all_saved_outputs=info.all_saved_outputs, args_with_derivatives=info.args_with_derivatives, non_differentiable_arg_names=info.non_differentiable_arg_names, output_differentiability=info.output_differentiability, output_differentiability_conditions=info. output_differentiability_conditions, ) for info, op_name in zip(infos, op_names) ] _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if cpp.name(f.func) == name: return f # some functions only have in-place variants assert name + '_' == cpp.name(functions[0].func) return functions[0] def split_names(raw_names: str) -> Tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" return tuple(x.strip() for x in raw_names.split(',')) def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ used_grad = 0 used_grads = 0 fully_implemented = True used_grads_indices: List[int] = [] for d in derivatives: formula = d.formula used_grad += len(re.findall(IDENT_REGEX.format('grad'), formula)) used_grads += len(re.findall(IDENT_REGEX.format('grads'), formula)) fully_implemented = \ fully_implemented and \ not re.search(IDENT_REGEX.format('not_implemented'), formula) used_grads_indices.extend(used_gradient_indices(formula)) assert used_grads >= len(used_grads_indices) only_used_grads_indices = used_grads == len(used_grads_indices) if used_grad and used_grads: raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " "occurrences of 'grad' with 'grads[0]'") if only_used_grads_indices and set(used_grads_indices) == {0}: raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " "declaration.") @with_native_function def set_up_derivatives(f: NativeFunction) -> Tuple[ Sequence[Derivative], Sequence[ForwardDerivative], Sequence[Binding], Sequence[str], ]: # Set up the derivative information derivatives: List[Derivative] = [] forward_derivatives: List[ForwardDerivative] = [] non_differentiable_arg_names: List[str] = [] args_with_derivatives_set: Set[str] = set() all_arg_names = [a.name for a in cpp_arguments(f)] for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) if is_forward_derivative_definition(all_arg_names, names): forward_derivatives.append(create_forward_derivative(f, formula, names)) else: if formula.lower().strip() == 'non_differentiable': non_differentiable_arg_names += names else: derivative = create_derivative(f, formula, names) derivatives.append(derivative) args_with_derivatives_set |= set(names) overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names) if overlap: raise RuntimeError(f'derivatives definition for {defn} have overlapped non_differentiable ' f'and differentiable variables: {overlap}') # Next, let us determine the list of inputs in order. # TODO: do we need eagerly calculate and save it here? Can it be derived # from NativeFunction and `derivatives` on callsites instead? args_with_derivatives = [a for a in cpp_arguments(f) if a.name in args_with_derivatives_set] # Postprocess forward derivatives definitions now that we know the differentiable arguments forward_derivatives = postprocess_forward_derivatives(f, defn_name, all_arg_names, derivatives, forward_derivatives, args_with_derivatives) # Test to see if the use of 'grads' makes sense. check_grad_usage(defn_name, derivatives) return derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names # NB: Removes 'name' from defn dictionary specification = defn.pop('name') defn_name, _ = split_name_params(specification) # NB: Removes 'output_differentiability' from defn dictionary # `None` means all differentiable. output_differentiability = defn.pop('output_differentiability', None) schema_function = functions_by_schema.get(specification) if not schema_function: avail = '\n'.join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name) raise RuntimeError(f'could not find ATen function for schema: {specification} ' f'. Available signatures:\n{avail}') # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here # to map in-place schemas to the out-of-place variants. # TODO: maybe the logic to handle the legacy schema is no longer necessary? signature = schema_function.func.signature() functions = functions_by_signature[signature] if len(functions) == 0: avail = '\n'.join(str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name) raise RuntimeError(f'could not find ATen function for legacy signature: {signature} ' f'corresponding to schema {specification}. Please report a bug to PyTorch. ' f'Available signatures:\n{avail}') canonical = canonical_function(functions, defn_name) if 'grad_input_mask' in (a.name for a in cpp_arguments(canonical)): raise RuntimeError(f"Schema for {defn_name} has an argument named grad_input_mask, " "but this name would be shadowed by our codegen. " "Please use a different name in native_functions.yaml.") if 'result' in (a.name for a in cpp_arguments(canonical)): raise RuntimeError(f"Schema for {defn_name} has an argument named result, " "but this is only allowed for outputs." "Please use a different name in native_functions.yaml.") derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical) return DifferentiabilityInfo( name=defn_name, func=canonical, op=None, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]), all_saved_outputs=dedup_vars([v for d in derivatives for v in d.saved_outputs]), args_with_derivatives=args_with_derivatives, non_differentiable_arg_names=non_differentiable_arg_names, output_differentiability=output_differentiability, )
def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], op_counter: Counter[str], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" def canonical_function(functions: Sequence[NativeFunction], name: str) -> NativeFunction: for f in functions: if cpp.name(f.func) == name: return f # some functions only have in-place variants assert name + '_' == cpp.name(functions[0].func) return functions[0] def split_names(raw_names: str) -> Tuple[str, ...]: """Given "foo, bar", return ["foo", "bar"].""" return tuple(x.strip() for x in raw_names.split(',')) def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ uses_grad = False # true if any derivative uses "grad" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" uses_named_grads = False # true if any derivative uses "grad_{name}" used_grads_indices: List[int] = [] # which indices of grads are used for d in derivatives: formula = d.formula uses_grad = uses_grad or bool( re.findall(IDENT_REGEX.format('grad'), formula)) num_grads_uses += len( re.findall(IDENT_REGEX.format('grads'), formula)) uses_named_grads = uses_named_grads or bool(d.named_gradients) used_grads_indices.extend(used_gradient_indices(formula)) # This is a basic sanity check: the number of places we see # "grads" should be no fewer than the number of indices we see # inside "grads". They may not be equal because we may use # "grads" without an index. assert num_grads_uses >= len(used_grads_indices) # Thus if the number is equal, every use of grads is also # indexed. only_used_grads_indices = num_grads_uses == len(used_grads_indices) if uses_grad and num_grads_uses > 0: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " "occurrences of 'grad' with 'grads[0]'") if only_used_grads_indices and set(used_grads_indices) == {0}: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " "declaration.") if uses_named_grads and (uses_grad or num_grads_uses > 0): raise RuntimeError( f'Derivative definition of {defn_name} in derivatives.yaml illegally ' 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' 'only one method for identifying gradients.') @with_native_function def set_up_derivatives( f: NativeFunction ) -> Tuple[Sequence[Derivative], Sequence[ForwardDerivative], Sequence[Binding], Sequence[str], Sequence[str], ]: # Set up the derivative information derivatives: List[Derivative] = [] forward_derivatives: List[ForwardDerivative] = [] non_differentiable_arg_names: List[str] = [] args_with_derivatives_set: Set[str] = set() all_arg_names = [a.name for a in cpp_arguments(f)] # output_differentiability is captured from the enclosed # scope. Don't modify it. # # If it is not present, then no output is explicitly # undifferentiable. # # It may be present and shorter than the length of return # values. If that's the case, any return value that does not # have a corresponding entry is considered not differentiable. differentiability = output_differentiability or [True] * len( f.func.returns) # A return is available as a named gradient ... available_named_gradients = [ f'grad_{ret.name}' for ret, differentiable in zip(f.func.returns, differentiability) # if it has not been explicitly made undifferentiable if differentiable # and if it has a name and ret.name is not None # and if its type is differentiable and ret.type.is_tensor_like() ] for raw_names in sorted(defn.keys()): formula = defn[raw_names] names = split_names(raw_names) if is_forward_derivative_definition(all_arg_names, names): forward_derivatives.append( create_forward_derivative(f, formula, names)) else: if formula.lower().strip() == 'non_differentiable': non_differentiable_arg_names += names else: derivative = create_derivative(f, formula, names, available_named_gradients) derivatives.append(derivative) args_with_derivatives_set |= set(names) overlap = args_with_derivatives_set.intersection( non_differentiable_arg_names) if overlap: raise RuntimeError( f'derivatives definition for {defn} have overlapped non_differentiable ' f'and differentiable variables: {overlap}') # Next, let us determine the list of inputs in order. # TODO: do we need eagerly calculate and save it here? Can it be derived # from NativeFunction and `derivatives` on callsites instead? args_with_derivatives = [ a for a in cpp_arguments(f) if a.name in args_with_derivatives_set ] # Postprocess forward derivatives definitions now that we know the differentiable arguments forward_derivatives = postprocess_forward_derivatives( f, defn_name, all_arg_names, derivatives, forward_derivatives, args_with_derivatives) # Test to see if the use of 'grads' makes sense. check_grad_usage(defn_name, derivatives) return (derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients) # NB: Removes 'name' from defn dictionary specification = defn.pop('name') defn_name, _ = split_name_params(specification) # NB: Removes 'output_differentiability' from defn dictionary # `None` means all differentiable. output_differentiability = defn.pop('output_differentiability', None) output_differentiability_conditions = None if output_differentiability and any( [isinstance(diff, str) for diff in output_differentiability]): if len(output_differentiability) != 1: raise RuntimeError( f'Not supported: for {specification},' f'output_differentiability must either be ' f'List[bool] or a List[str] where each str is a ' f'condition. In the case where it is a condition, ' f'we only support single-output functions. ' f'Please file us an issue. ') output_differentiability_conditions = output_differentiability output_differentiability = [True] schema_function = functions_by_schema.get(specification) if not schema_function: avail = '\n'.join(k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name) raise RuntimeError( f'could not find ATen function for schema: {specification} ' f'. Available signatures:\n{avail}') # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here # to map in-place schemas to the out-of-place variants. # TODO: maybe the logic to handle the legacy schema is no longer necessary? signature = schema_function.func.signature() functions = functions_by_signature[signature] if len(functions) == 0: avail = '\n'.join( str(k) for k, v in functions_by_signature.items() if cpp.name(k) == defn_name) raise RuntimeError( f'could not find ATen function for legacy signature: {signature} ' f'corresponding to schema {specification}. Please report a bug to PyTorch. ' f'Available signatures:\n{avail}') canonical = canonical_function(functions, defn_name) if 'grad_input_mask' in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named grad_input_mask, " "but this name would be shadowed by our codegen. " "Please use a different name in native_functions.yaml.") if 'result' in (a.name for a in cpp_arguments(canonical)): raise RuntimeError( f"Schema for {defn_name} has an argument named result, " "but this is only allowed for outputs." "Please use a different name in native_functions.yaml.") (derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names, available_named_gradients) = set_up_derivatives(canonical) used_named_gradients: Set[str] = set() for d in derivatives: used_named_gradients |= d.named_gradients # only assign an op name if we are actually going to calculate a derivative op = None if args_with_derivatives: op_prefix = _create_op_prefix(defn_name) op = f'{op_prefix}{op_counter[op_prefix]}' op_counter[op_prefix] += 1 return DifferentiabilityInfo( name=defn_name, func=canonical, op=op, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars( [v for d in derivatives for v in d.saved_inputs]), all_saved_outputs=dedup_vars( [v for d in derivatives for v in d.saved_outputs]), available_named_gradients=available_named_gradients, used_named_gradients=used_named_gradients, args_with_derivatives=args_with_derivatives, non_differentiable_arg_names=non_differentiable_arg_names, output_differentiability=output_differentiability, output_differentiability_conditions=output_differentiability_conditions, )