def create_forward_derivative(f: NativeFunction, formula: str, names: Tuple[str, ...]) -> ForwardDerivative: assert len(names) == 1, "Forward derivatives can define gradients for only one output at a time" var_name = names[0] var_type: Optional[Type] = None for r in f.func.returns: if r.name == var_name: var_type = r.type break # Handle default return names if var_type is None: if var_name == "result": assert len(f.func.returns) == 1 var_type = f.func.returns[0].type else: res = re.findall(r"^result(\d+)$", var_name) if len(res) == 1: arg_idx = int(res[0]) var_type = f.func.returns[arg_idx].type assert var_type is not None, "No matching output for forward derivative definition" return ForwardDerivative( formula=formula, var_name=var_name, var_type=var_type, required_inputs_fw_grad=None, required_inputs_primal=None)
def create_forward_derivative(f: NativeFunction, formula: str, names: Tuple[str, ...]) -> ForwardDerivative: var_names = names var_types: Optional[Tuple[Type, ...]] = None for r in f.func.returns: if r.name in var_names: if var_types is None: var_types = tuple() var_types = var_types + (r.type, ) # Handle default return names if var_types is None: if var_names == ("result", ): assert len(f.func.returns) == 1 var_types = (f.func.returns[0].type, ) else: for var_name in var_names: res = re.findall(r"^result(\d+)$", var_name) if len(res) == 1: if var_types is None: var_types = tuple() arg_idx = int(res[0]) var_types = var_types + (f.func.returns[arg_idx].type, ) assert var_types is not None, "No matching output for forward derivative definition" return ForwardDerivative(formula=formula, var_names=var_names, var_types=var_types, required_inputs_fw_grad=None, required_inputs_primal=None, required_original_self_value=False, is_reusing_outplace_formula=False)
def postprocess_forward_derivatives( f: NativeFunction, defn_name: str, all_arg_names: List[str], derivatives: List[Derivative], forward_derivatives: List[ForwardDerivative], args_with_derivatives: Sequence[Binding] ) -> List[ForwardDerivative]: def find_required_inputs(formula: str, postfix: str) -> Tuple[str, ...]: required_inputs = set() for arg in args_with_derivatives: if arg.type == 'TensorList': # The functions taking TensorList handle everything internally continue arg_name = arg.name found = re.search(IDENT_REGEX.format(arg_name), formula) if found: raise RuntimeError(f"The forward formula for {defn_name} is using the base name of the {arg_name} " f"argument which is ambiguous. You should use {arg_name}_p to access the primal " f"value and {arg_name}_t to access the tangent.") found = re.search(IDENT_REGEX.format(arg_name + postfix), formula) if found: required_inputs.add(arg_name) return tuple(required_inputs) updated_derivatives: List[ForwardDerivative] = [] for defn in forward_derivatives: formula = defn.formula required_inputs_tangent = find_required_inputs(formula, "_t") if formula == "auto_element_wise": if (not len(args_with_derivatives) == 1) or len(forward_derivatives) > 1: raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as element_wise but this only " "works for functions with a single differentiable input and a " "single differentiable output.") if not len(derivatives) == 1: raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as element_wise but it does not " "defines the gradient formula for its argument which is required.") # This transformation is based on the observation that for element-wise functions, the Jacobian # matrix is diagonal and thus doing J * v or v * J gives the same result. # So here we are going to re-use the backward formula and replace two things: # 1) all occurrences of "grad" with "foo_t", where foo is the name of the unique differentiable input. # 2) all usage of an original input "foo" with its primal value "foo_p". # For example, for abs, the backward formula is: # grad * self.sgn() # And this function generates a forward formula that is: # self_t * self_p.sgn() backward_formula = derivatives[0].original_formula input_name = args_with_derivatives[0].name # Do replacement 1) of the grad def repl(m: Any) -> str: return f"{m.group(1)}{input_name}_t{m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula) # Do replacement 2) of the input variables for arg in args_with_derivatives: arg_name = arg.name def repl(m: Any) -> str: return f"{m.group(1)}{arg_name}_p{m.group(2)}" fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula) # Since there is a single differentiable inputs and we necessarily need its tangent we can # simply require all differentiable input's tangent. required_inputs_tangent = tuple(all_arg_names) formula = fw_formula elif formula == "auto_linear": if len(forward_derivatives) > 1: raise RuntimeError(f"Derivative definition of {defn_name} in derivatives.yaml defines the " "forward definition of gradient as linear but this only works " "for functions with a single differentiable output.") # This transformation is based on the observation that linear functions can be written as: # y = f(x) = A * x # For some matrix A and the Jacobian of the function f is also A. # So doing J * v = A * v = f(v). # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x. # We do this by calling the forward again by replacing any occurrence of the differentiable # input "foo" by it's tangent "foo_t". # Note that multiple inputs are not a problem as long as the function is truly linear wrt to # the vector where all the differentiable inputs are stacked. diff_arg_names = [arg.name for arg in args_with_derivatives] assert len(diff_arg_names) > 0 # Do replacement of input variables new_args = [] for arg_name in all_arg_names: if arg_name in diff_arg_names: arg_name = arg_name + "_t" new_args.append(arg_name) # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions. if Variant.function in f.variants: fw_formula = "at::{}({})".format(defn_name, ", ".join(new_args)) else: assert f.func.kind() is not SchemaKind.inplace assert Variant.method in f.variants fw_formula = "{}.{}({})".format(new_args[0], defn_name, ", ".join(new_args[1:])) # All of the input tangents are always used so all of them are required here. required_inputs_tangent = tuple(diff_arg_names) formula = fw_formula # At this point, the formula is final and is not modified anymore. # During forward formula, we use the primal instead of the input Tensors. # This call inspects the formula to find for which input's primal are used. required_inputs_primal = find_required_inputs(formula, "_p") updated_derivatives.append(ForwardDerivative( formula=formula, var_name=defn.var_name, var_type=defn.var_type, required_inputs_fw_grad=required_inputs_tangent, required_inputs_primal=required_inputs_primal)) return updated_derivatives