def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None: """ Check for some subtle mistakes one might make when writing derivatives. These mistakes will compile, but will be latent until a function is used with double backwards. """ uses_grad = False # true if any derivative uses "grad" num_grads_uses = 0 # count of uses of "grads" or "grads[INDEX]" uses_named_grads = False # true if any derivative uses "grad_{name}" used_grads_indices: List[int] = [] # which indices of grads are used for d in derivatives: formula = d.formula uses_grad = uses_grad or bool( re.findall(IDENT_REGEX.format("grad"), formula) ) num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula)) uses_named_grads = uses_named_grads or bool(d.named_gradients) used_grads_indices.extend(used_gradient_indices(formula)) # This is a basic sanity check: the number of places we see # "grads" should be no fewer than the number of indices we see # inside "grads". They may not be equal because we may use # "grads" without an index. assert num_grads_uses >= len(used_grads_indices) # Thus if the number is equal, every use of grads is also # indexed. only_used_grads_indices = num_grads_uses == len(used_grads_indices) if uses_grad and num_grads_uses > 0: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " "mixes use of 'grad' and 'grads'. Consider replacing " "occurrences of 'grad' with 'grads[0]'" ) if only_used_grads_indices and set(used_grads_indices) == {0}: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml solely " "refers to 'grads[0]'. If the first output is indeed the " "only differentiable output, replace 'grads[0]' with 'grad'; " "otherwise, there is a likely error in your derivatives " "declaration." ) if uses_named_grads and (uses_grad or num_grads_uses > 0): raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml illegally " 'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use ' "only one method for identifying gradients." )
def replace_self_with_original_self(formula: str, postfix: str) -> str: def repl(m: Match[str]) -> str: return f"{m.group(1)}original_self{postfix}{m.group(2)}" return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool: if info is None: return False for derivative in info.derivatives: formula = derivative.formula if re.search(IDENT_REGEX.format(ident), formula): return True return False
def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: required_inputs = set() for arg in args_with_derivatives: if arg.type == "at::TensorList": # The functions taking TensorList handle everything internally continue arg_name = arg.name found = re.search(IDENT_REGEX.format(arg_name), formula) if found: raise RuntimeError( f"The forward formula for {defn_name} is using the base name of the {arg_name} " f"argument which is ambiguous. You should use {arg_name}_p to access the primal " f"value and {arg_name}_t to access the tangent.") found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) if found: required_inputs.add(arg_name) return tuple(required_inputs)
def create_derivative( f: NativeFunction, formula: str, var_names: Tuple[str, ...], available_named_gradients: Sequence[str], ) -> Derivative: original_formula = formula arguments: List[NamedCType] = [ a.nctype.remove_const_ref() for a in cpp_arguments(f) ] return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f)) return_types = tuple( cpp.return_type(r).remove_const_ref() for r in f.func.returns) named_returns = [ NamedCType(name, type) for name, type in zip(return_names, return_types) ] formula, saved_inputs = saved_variables(formula, arguments, var_names) formula, saved_outputs = saved_variables(formula, named_returns, var_names) used_named_gradients = { name for name in available_named_gradients if re.search(IDENT_REGEX.format(name), formula) } # Check that the referenced derivatives in the formula are in bounds for i in used_gradient_indices(formula): if i >= len(f.func.returns): raise RuntimeError( f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} " f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs." ) return Derivative( formula=formula, original_formula=original_formula, var_names=var_names, saved_inputs=saved_inputs, saved_outputs=saved_outputs, named_gradients=used_named_gradients, )
def saved_variables( formula: str, nctypes: List[NamedCType], var_names: Tuple[str, ...], ) -> Tuple[str, Tuple[SavedAttribute, ...]]: def stride_expr(name: str) -> str: assert var_names == (name, ), ( 'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor ' 'that ".strides()" is being called on.') return f'strides_or_error({name}, "{name}")' REPLACEMENTS: List[Tuple[str, Dict[str, Any]]] = [ # replace self.sizes() with self_sizes ( r"{}.sizes\(\)", { "suffix": "_sizes", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), }, ), # replace self.sym_sizes() with self_sym_sizes ( r"{}.sym_sizes\(\)", { "suffix": "_sym_sizes", "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)), }, ), # replace self->sizes() with self_sizes_opt ( r"{}->sizes\(\)", { "suffix": "_sizes_opt", "nctype": lambda name: NamedCType( name, OptionalCType(BaseCType(intArrayRefT))), "expr": lambda name: f"{name}.has_value() ? c10::optional<IntArrayRef>({name}->sizes()) : c10::nullopt", }, ), # replace self.options() with self_options ( r"{}.options\(\)", { "suffix": "_options", "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)), }, ), # replace zeros_like(self) with self_info ( r"zeros_like\({}\)", { "suffix": "_info", "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)), "expr": lambda name: name, # at save-time "res": lambda name: name + "_info.zeros()", # at eval-time }, ), # replace self.size(2) with self_size_2 ( r"{}.size\((\w+)\)", { "suffix": lambda m: "_argsize_{}".format(*m.groups()), "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.numel() with self_numel ( r"{}.numel\(\)", { "suffix": "_numel", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace to_args_sizes(self) with self_args_sizes ( r"to_args_sizes\({}\)", { "suffix": "_args_sizes", "nctype": lambda name: NamedCType( name, VectorCType(VectorCType(BaseCType(longT)))), }, ), # replace to_args_scalartypes(self) with self_args_scalartypes ( r"to_args_scalartypes\({}\)", { "suffix": "_args_scalartypes", "nctype": lambda name: NamedCType(name, VectorCType(BaseCType(scalarTypeT))), }, ), # replace TensorGeometry(self) with self_geometry ( r"TensorGeometry\({}\)", { "suffix": "_geometry", "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)), }, ), ( r"{}.scalar_type\(\)", { "suffix": "_scalar_type", "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)), }, ), # replace self.dim() with self_dim ( r"{}.dim\(\)", { "suffix": "_dim", "nctype": lambda name: NamedCType(name, BaseCType(longT)), }, ), # replace self.strides() with self_strides ( r"{}.strides\(\)", { "suffix": "_strides", "nctype": lambda name: NamedCType(name, BaseCType(intArrayRefT)), "expr": stride_expr, }, ), # replace self.layout() with self_layout ( r"{}.layout\(\)", { "suffix": "_layout", "nctype": lambda name: NamedCType(name, BaseCType(layoutT)), }, ), # replace self.is_conj() with self_conjugate ( r"{}.is_conj\(\)", { "suffix": "_conjugate", "nctype": lambda name: NamedCType(name, BaseCType(boolT)), }, ), ] # find which arguments need to be saved saved: List[SavedAttribute] = [] for nctype in nctypes: name = (nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name) # First search the formula for expressions which can be evaluated # when the autograd Function is created to avoid saving variables for regex, info in REPLACEMENTS: def repl(m: Match[str]) -> str: suffix: str = (info["suffix"](m) if callable(info["suffix"]) else info["suffix"]) expr: str = info["expr"](name) if "expr" in info else m.group( 0) saved.append( SavedAttribute( nctype=info["nctype"](name + suffix), expr=expr, )) if "res" in info: replacement: str = info["res"](name) return replacement return name + suffix formula = re.sub(regex.format(name), repl, formula) # c10::optional<std::string> types stored in Backward nodes must be # converted to c10::optional<c10::string_view> before being passed into # the backward function if nctype.type == OptionalCType(BaseCType(stringT)): formula = re.sub( rf"\b{name}\b", f"{name}.has_value() ? c10::optional<c10::string_view>({name}.value()) : c10::nullopt", formula, ) # Find any variables which remain in the formula and save them if re.search(IDENT_REGEX.format(name), formula): saved.append(SavedAttribute( nctype=nctype, expr=name, )) return formula, tuple(saved)
def postprocess_forward_derivatives( f: NativeFunction, defn_name: str, all_arg_names: List[str], derivatives: List[Derivative], forward_derivatives: List[ForwardDerivative], args_with_derivatives: Sequence[Binding], ) -> List[ForwardDerivative]: def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: required_inputs = set() for arg in args_with_derivatives: if arg.type == "at::TensorList": # The functions taking TensorList handle everything internally continue arg_name = arg.name found = re.search(IDENT_REGEX.format(arg_name), formula) if found: raise RuntimeError( f"The forward formula for {defn_name} is using the base name of the {arg_name} " f"argument which is ambiguous. You should use {arg_name}_p to access the primal " f"value and {arg_name}_t to access the tangent.") found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) if found: required_inputs.add(arg_name) return tuple(required_inputs) updated_derivatives: List[ForwardDerivative] = [] for defn in forward_derivatives: formula = defn.formula required_inputs_tangent = find_required_inputs(formula, "_t") if formula == "auto_element_wise": if ((not len(args_with_derivatives) == 1) or len(forward_derivatives) > 1 or len(forward_derivatives[0].var_names) > 1): raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as element_wise but this only " "works for functions with a single differentiable input and a " "single differentiable output.") if not len(derivatives) == 1: raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as element_wise but it does not " "defines the gradient formula for its argument which is required." ) # This transformation is based on the observation that for element-wise functions, the Jacobian # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions) # For the complex case, we use hermitian transpose and get (v.conj() J).conj() # So here we are going to re-use the backward formula and replace two things: # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input. # 2) all usage of an original input "foo" with its primal value "foo_p". # 3) conjugate the final result # For example, for abs, the backward formula is: # grad * self.sgn() # And this function generates a forward formula that is: # (self_t.conj() * self_p.sgn()).conj() backward_formula = derivatives[0].original_formula input_name = args_with_derivatives[0].name # Do replacement 1) of the grad def repl(m: Any) -> str: return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) # Do replacement 2) of the input variables for arg in args_with_derivatives: arg_name = arg.name def repl(m: Any) -> str: return f"{m.group(1)}{arg_name}_p{m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) # Do the final conjugate 3) fw_formula = f"({fw_formula}).conj()" # Since there is a single differentiable inputs and we necessarily need its tangent we can # simply require all differentiable input's tangent. required_inputs_tangent = tuple(all_arg_names) formula = fw_formula elif formula == "auto_linear": if (len(forward_derivatives) > 1 or len(forward_derivatives[0].var_names) > 1): raise RuntimeError( f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as linear but this only works " "for functions with a single differentiable output.") # This transformation is based on the observation that linear functions can be written as: # y = f(x) = A * x # For some matrix A and the Jacobian of the function f is also A. # So doing J * v = A * v = f(v). # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. # We do this by calling the forward again by replacing any occurrence of the differentiable # input "foo" by it's tangent "foo_t". # Note that multiple inputs are not a problem as long as the function is truly linear wrt to # the vector where all the differentiable inputs are stacked. diff_arg_names = [arg.name for arg in args_with_derivatives] assert len(diff_arg_names) > 0 # Do replacement of input variables new_args = [] for arg_name in all_arg_names: if arg_name in diff_arg_names: arg_name = arg_name + "_t" new_args.append(arg_name) # TODO we are trolling if f.func.is_symint_fn(): defn_name += "_symint" # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. if Variant.function in f.variants: fw_formula = "at::{}({})".format(defn_name, ", ".join(new_args)) else: assert Variant.method in f.variants fw_formula = "{}.{}({})".format(new_args[0], defn_name, ", ".join(new_args[1:])) # All of the input tangents are always used so all of them are required here. required_inputs_tangent = tuple(diff_arg_names) formula = fw_formula # At this point, the formula is final and is not modified anymore. # During forward formula, we use the primal instead of the input Tensors. # This call inspects the formula to find for which input's primal are used. required_inputs_primal = find_required_inputs(formula, "_p") updated_derivatives.append( ForwardDerivative( formula=formula, var_names=defn.var_names, var_types=defn.var_types, required_inputs_fw_grad=required_inputs_tangent, required_inputs_primal=required_inputs_primal, required_original_self_value=False, is_reusing_outplace_formula=False, )) return updated_derivatives
def match_differentiability_info( native_functions: List[NativeFunction], differentiability_infos: Sequence[DifferentiabilityInfo], ) -> List[NativeFunctionWithDifferentiabilityInfo]: """Sets the "derivative" key on declarations to matching autograd function In-place functions will use the out-of-place derivative definition if there is no in-place specific derivative. """ info_by_schema = {info.func.func: info for info in differentiability_infos} functional_info_by_signature = { info.func.func.signature(strip_default=True): info for info in differentiability_infos if info.func.func.kind() == SchemaKind.functional } non_functional_info_by_signature = { info.func.func.signature(strip_default=True): info for info in differentiability_infos if info.func.func.kind() != SchemaKind.functional } def find_info( f: NativeFunction) -> Tuple[Optional[DifferentiabilityInfo], bool]: # (1) Check for an exact match if f.func in info_by_schema: return info_by_schema[f.func], True # (2) If no exact match, check if the out-of-place variant # of this operator has a match. # i.e mul() for mul_() or mul_out() f_sig = f.func.signature(strip_default=True) if f_sig in functional_info_by_signature: return functional_info_by_signature[f_sig], False # (3) Some operators have a derivative explicitly defined for the mutable # variant, but get a code-generated out-of-place variant which does *not* # come with a derivative formula. # For the generated out-of-place variant, use the mutable variant's formula # if it exists. if "generated" in f.tags and f_sig in non_functional_info_by_signature: info = non_functional_info_by_signature[f_sig] # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389 assert not any("self" in str(inpt.nctype.name) for inpt in info.all_saved_inputs), f"""\ Attempted to convert a derivative formula for a mutable operator to be used by automatically by its functional variant ("{str(f.func)}"). this is not currently supported (we'd need to fix up the formula in the codegen).""" return info, False return None, False result: List[NativeFunctionWithDifferentiabilityInfo] = [] for f in native_functions: info, is_exact_match = find_info(f) # Currently, the '.strides()' to 'strides_or_error' replacement does not support # 'self' derivatives of an inplace function, so we must check for this case. if f.func.kind() == SchemaKind.inplace and (info is not None): for derivative in info.derivatives: if "self" in derivative.var_names: for saved_input in derivative.saved_inputs: assert "strides_or_error" not in saved_input.expr, ( "Calling '.strides()' in the 'self' derivative formula of an " f"in-place function is not supported: {f.func}") # For functions that have a single def for out-of-place and inplace (like abs()) if info and info.forward_derivatives: forward_derivatives = info.forward_derivatives if f.func.kind() == SchemaKind.inplace: # For inplace functions there is a little bit of work to do: # 1) Validate the formula and make sure the input that is modified in not used: # - If there is a formula for the inplace variant of the function (is_exact_match == True) then # we make sure that the original value of the input that is being modified inplace (self_p) is # not used in the formula. Note that the formula can use "original_self_p" here and that would # trigger a clone of the original input. # - If we are re-using the out of place formula (is_exact_match == False) then we replace every # occurrence of self_p and self_t by original_self_p and original_self_t. These will be # populated by cloned version of the original input (either the clone done by the backward AD # logic if self is also used in a backward formula or a special clone that we add). # 2) At this point, there cannot be a self_p in the formula. # 3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is # simply called self (as it is modified inplace). # 4) Update the required primals data in case it used to contain "result" but should now contain # "self" # 5) If it is not an exact match, the user formula is not modifying the existing forward grad # inplace as it should. So add some code that makes sure that we do so if the forward grad # already exists. assert (len(info.forward_derivatives) == 1 ) # Only single output inplace should exist fw_info = info.forward_derivatives[0] formula = fw_info.formula def replace_self_with_original_self(formula: str, postfix: str) -> str: def repl(m: Match[str]) -> str: return f"{m.group(1)}original_self{postfix}{m.group(2)}" return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) if re.search(IDENT_REGEX.format("self_p"), formula): if is_exact_match: # For manually defined formulas, don't allow the original value to be used raise RuntimeError( f'The formula for "{f.func.name}" is using the original value of self ' "that is being modified inplace. This would lead to wrong forward gradients. " 'Please use "result" in the formula only.') else: # When the original formula is out of place, we save a clone of the primal # value to be able to access this value if needed # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t" formula = replace_self_with_original_self( formula, "_p") formula = replace_self_with_original_self( formula, "_t") # replace "result" from the formula by "self_p" def repl(m: Match[str]) -> str: return f"{m.group(1)}self_p{m.group(2)}" formula = re.sub(IDENT_REGEX.format("result"), repl, formula) required_primals = fw_info.required_inputs_primal if re.search(IDENT_REGEX.format("self_p"), formula): required_primals = (required_primals + ("self", ) if required_primals else ("self", )) if not is_exact_match: # NOTE [In-place forward AD formula Optimization] # # This optimization transforms the formula to directly do inplace, i.e. # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met: # # 1) the formula satisfies the pattern: "self_t.op(*args)" # 2) "op" in (1) needs to be the same as the op the derivative is for # # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2) # If there is a need, we can relax (2) to allow any op that has an in-place variant is_single_method_on_self_t = False match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) if match: op_name, between_parens = match.group(1), match.group( 2) # We want to... # Match: self_t.op1(other_p.op2(arg)) # Avoid: self_t.op1(args) + self_t.op2(args) # Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args) def check_parens_nest_level_gt_zero(s: str) -> bool: level = 1 for ch in s: if ch == ")": level -= 1 if level == 0: return False if ch == "(": level += 1 return True is_single_method_on_self_t = check_parens_nest_level_gt_zero( between_parens) directly_do_inplace = (is_single_method_on_self_t and op_name == info.name) if directly_do_inplace: formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}" else: # Make sure that the forward grad is modified inplace when the original formula # is out of place formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}" required_original_self_value = bool( re.search(IDENT_REGEX.format("original_self_p"), formula)) forward_derivatives = [ ForwardDerivative( formula=formula, var_names=("self", ), var_types=fw_info.var_types, required_inputs_fw_grad=fw_info. required_inputs_fw_grad, required_inputs_primal=required_primals, required_original_self_value= required_original_self_value, is_reusing_outplace_formula=not is_exact_match, ), ] else: forward_derivatives = [] result.append( NativeFunctionWithDifferentiabilityInfo( func=f, info=info, fw_derivatives=forward_derivatives)) return result