示例#1
0
def create_differentiability_info(
    defn: Dict[Any, Any],
    functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
    functions_by_schema: Dict[str, NativeFunction],
    op_counter: Counter[str],
) -> DifferentiabilityInfo:
    """Processes a single entry `defn` in derivatives.yaml"""
    def canonical_function(functions: Sequence[NativeFunction],
                           name: str) -> NativeFunction:
        for f in functions:
            if (not f.func.is_functional_fn() and not f.func.is_out_fn()
                    and name == str(f.func.name.name)):
                return f
        # some functions only have in-place variants
        assert name + "_" == cpp.name(functions[0].func)
        return functions[0]

    def split_names(raw_names: str) -> Tuple[str, ...]:
        """Given "foo, bar", return ["foo", "bar"]."""
        return tuple(x.strip() for x in raw_names.split(","))

    def check_grad_usage(defn_name: str,
                         derivatives: Sequence[Derivative]) -> None:
        """
        Check for some subtle mistakes one might make when writing derivatives.
        These mistakes will compile, but will be latent until a function is
        used with double backwards.
        """

        uses_grad = False  # true if any derivative uses "grad"
        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
        uses_named_grads = False  # true if any derivative uses "grad_{name}"
        used_grads_indices: List[int] = []  # which indices of grads are used
        for d in derivatives:
            formula = d.formula
            uses_grad = uses_grad or bool(
                re.findall(IDENT_REGEX.format("grad"), formula))
            num_grads_uses += len(
                re.findall(IDENT_REGEX.format("grads"), formula))
            uses_named_grads = uses_named_grads or bool(d.named_gradients)
            used_grads_indices.extend(used_gradient_indices(formula))
        # This is a basic sanity check: the number of places we see
        # "grads" should be no fewer than the number of indices we see
        # inside "grads". They may not be equal because we may use
        # "grads" without an index.
        assert num_grads_uses >= len(used_grads_indices)
        # Thus if the number is equal, every use of grads is also
        # indexed.
        only_used_grads_indices = num_grads_uses == len(used_grads_indices)

        if uses_grad and num_grads_uses > 0:
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                "mixes use of 'grad' and 'grads'. Consider replacing "
                "occurrences of 'grad' with 'grads[0]'")

        if only_used_grads_indices and set(used_grads_indices) == {0}:
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml solely "
                "refers to 'grads[0]'.  If the first output is indeed the "
                "only differentiable output, replace 'grads[0]' with 'grad'; "
                "otherwise, there is a likely error in your derivatives "
                "declaration.")

        if uses_named_grads and (uses_grad or num_grads_uses > 0):
            raise RuntimeError(
                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
                "only one method for identifying gradients.")

    @with_native_function
    def set_up_derivatives(
        f: NativeFunction,
    ) -> Tuple[Sequence[Derivative], Sequence[ForwardDerivative],
               Sequence[Binding], Sequence[str], Sequence[str], ]:
        # Set up the derivative information
        derivatives: List[Derivative] = []
        forward_derivatives: List[ForwardDerivative] = []
        non_differentiable_arg_names: List[str] = []
        args_with_derivatives_set: Set[str] = set()

        all_arg_names = [a.name for a in cpp_arguments(f)]
        all_ret_names = [r.name for r in f.func.returns
                         ]  # only used for the assert below
        # output_differentiability is captured from the enclosed
        # scope. Don't modify it.
        #
        # If it is not present, then no output is explicitly
        # undifferentiable.
        #
        # It may be present and shorter than the length of return
        # values. If that's the case, any return value that does not
        # have a corresponding entry is considered not differentiable.
        differentiability = output_differentiability or [True] * len(
            f.func.returns)
        # A return is available as a named gradient ...
        available_named_gradients = [
            f"grad_{ret.name}"
            for ret, differentiable in zip(f.func.returns, differentiability)
            # if it has not been explicitly made undifferentiable
            if differentiable
            # and if it has a name
            and ret.name is not None
            # and if its type is differentiable
            and ret.type.is_tensor_like()
        ]

        for raw_names in sorted(defn.keys()):
            formula = defn[raw_names]
            names = split_names(raw_names)

            for name in names:
                assert not (name in all_arg_names and name in all_ret_names), (
                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
                    f"expected '{name}' to not be both an input arg and named return. "
                )

            if is_forward_derivative_definition(all_arg_names, names):
                forward_derivatives.append(
                    create_forward_derivative(f, formula, names))
            else:
                if formula.lower().strip() == "non_differentiable":
                    non_differentiable_arg_names += names
                else:
                    derivative = create_derivative(f, formula, names,
                                                   available_named_gradients)
                    derivatives.append(derivative)
                    args_with_derivatives_set |= set(names)

        overlap = args_with_derivatives_set.intersection(
            non_differentiable_arg_names)
        if overlap:
            raise RuntimeError(
                f"derivatives definition for {defn} have overlapped non_differentiable "
                f"and differentiable variables: {overlap}")

        # Next, let us determine the list of inputs in order.
        # TODO: do we need eagerly calculate and save it here? Can it be derived
        # from NativeFunction and `derivatives` on callsites instead?
        args_with_derivatives = [
            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
        ]

        # Postprocess forward derivatives definitions now that we know the differentiable arguments
        forward_derivatives = postprocess_forward_derivatives(
            f,
            defn_name,
            all_arg_names,
            derivatives,
            forward_derivatives,
            args_with_derivatives,
        )

        # Test to see if the use of 'grads' makes sense.
        check_grad_usage(defn_name, derivatives)

        return (
            derivatives,
            forward_derivatives,
            args_with_derivatives,
            non_differentiable_arg_names,
            available_named_gradients,
        )

    # NB: Removes 'name' from defn dictionary
    specification = defn.pop("name")
    defn_name, _ = split_name_params(specification)
    # NB: Removes 'output_differentiability' from defn dictionary
    #     `None` means all differentiable.
    output_differentiability = defn.pop("output_differentiability", None)
    output_differentiability_conditions = None
    if output_differentiability and any(
        [isinstance(diff, str) for diff in output_differentiability]):
        if len(output_differentiability) != 1:
            raise RuntimeError(
                f"Not supported: for {specification},"
                f"output_differentiability must either be "
                f"List[bool] or a List[str] where each str is a "
                f"condition. In the case where it is a condition, "
                f"we only support single-output functions. "
                f"Please file us an issue. ")
        output_differentiability_conditions = output_differentiability
        output_differentiability = [True]

    schema_function = functions_by_schema.get(specification)
    if not schema_function:
        avail = "\n".join(k for k, v in functions_by_schema.items()
                          if cpp.name(v.func) == defn_name)
        raise RuntimeError(
            f"could not find ATen function for schema: {specification} "
            f".  Available signatures:\n{avail}")

    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
    # to map in-place schemas to the out-of-place variants.
    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
    signature = schema_function.func.signature()
    functions = functions_by_signature[signature]
    if len(functions) == 0:
        avail = "\n".join(
            str(k) for k, v in functions_by_signature.items()
            if cpp.name(k) == defn_name)
        raise RuntimeError(
            f"could not find ATen function for legacy signature: {signature} "
            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
            f"Available signatures:\n{avail}")

    canonical = canonical_function(functions, defn_name)
    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
        raise RuntimeError(
            f"Schema for {defn_name} has an argument named grad_input_mask, "
            "but this name would be shadowed by our codegen. "
            "Please use a different name in native_functions.yaml.")

    if "result" in (a.name for a in cpp_arguments(canonical)):
        raise RuntimeError(
            f"Schema for {defn_name} has an argument named result, "
            "but this is only allowed for outputs."
            "Please use a different name in native_functions.yaml.")

    (
        derivatives,
        forward_derivatives,
        args_with_derivatives,
        non_differentiable_arg_names,
        available_named_gradients,
    ) = set_up_derivatives(canonical)

    used_named_gradients: Set[str] = set()
    for d in derivatives:
        used_named_gradients |= d.named_gradients

    # only assign an op name if we are actually going to calculate a derivative
    op = None
    if args_with_derivatives:
        op_prefix = _create_op_prefix(defn_name)
        op = f"{op_prefix}{op_counter[op_prefix]}"
        op_counter[op_prefix] += 1

    return DifferentiabilityInfo(
        name=defn_name,
        func=canonical,
        op=op,
        derivatives=derivatives,
        forward_derivatives=forward_derivatives,
        all_saved_inputs=dedup_vars(
            [v for d in derivatives for v in d.saved_inputs]),
        all_saved_outputs=dedup_vars(
            [v for d in derivatives for v in d.saved_outputs]),
        available_named_gradients=available_named_gradients,
        used_named_gradients=used_named_gradients,
        args_with_derivatives=args_with_derivatives,
        non_differentiable_arg_names=non_differentiable_arg_names,
        output_differentiability=output_differentiability,
        output_differentiability_conditions=output_differentiability_conditions,
    )
示例#2
0
def load_deprecated_signatures(
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    deprecated_yaml_path: str,
    *,
    method: bool,
    pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
    # The deprecated.yaml doesn't have complete type information, we need
    # find and leverage the original ATen signature (to which it delegates
    # the call) to generate the full python signature.
    # We join the deprecated and the original signatures using type-only form.

    # native function -> type-only signature
    @with_native_function
    def signature_original(f: NativeFunction) -> str:
        # remove inplace suffix but keep outplace suffix
        opname = str(f.func.name.name.base)
        if f.func.is_out_fn():
            opname += "_out"
        if f.func.name.name.inplace and pyi:
            opname += "_"
        args = CppSignatureGroup.from_native_function(
            f, method=False).signature.arguments()
        # Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
        types = ", ".join(
            argument_type_str(a.argument.type) for a in args
            if isinstance(a.argument, Argument))
        return f"{opname}({types})"

    # deprecated -> type-only native signature (according to the call order)
    def signature_deprecated(opname: str, params: List[str],
                             call_args: List[str]) -> str:
        # create a mapping of parameter name to parameter type
        types: Dict[str, str] = {}
        for param in params:
            if param == "*":
                continue
            type, name = param.split(" ")
            types[name] = type
        # if the name in the call is not in the parameter list, assume it's
        # a literal Scalar
        rearranged_types = ", ".join(
            types.get(arg, "Scalar") for arg in call_args)
        return f"{opname}({rearranged_types})"

    # group the original ATen signatures by type-only signature
    grouped: Dict[str,
                  List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
    for pair in pairs:
        grouped[signature_original(pair.function)].append(pair)

    # find matching original signatures for each deprecated signature
    results: List[PythonSignatureNativeFunctionPair] = []

    with open(deprecated_yaml_path, "r") as f:
        deprecated_defs = yaml.load(f, Loader=YamlLoader)

    for deprecated in deprecated_defs:
        _, params = split_name_params(deprecated["name"])
        aten_name, call_args = split_name_params(deprecated["aten"])

        for pair in grouped[signature_deprecated(aten_name, params,
                                                 call_args)]:
            # It uses the types from the original ATen declaration, but the
            # ordering and parameter names from the deprecated overload. Any
            # default parameter values from the original ATen declaration are
            # ignored.
            # Deprecated signature might reorder input_args and input_kwargs,
            # but never changes output_args nor TensorOptions (if any?),
            # so here we only look into these two types of args.
            python_sig = pair.signature
            src_args: Dict[str, PythonArgument] = {
                a.name: PythonArgument(
                    name=a.name,
                    type=a.type,
                    default=None,
                    default_init=None,
                )
                for a in itertools.chain(python_sig.input_args,
                                         python_sig.input_kwargs)
            }

            args: List[str] = []
            input_args: List[PythonArgument] = []
            input_kwargs: List[PythonArgument] = []

            kwarg_only = False
            for param in params:
                if param == "*":
                    kwarg_only = True
                    continue
                _, param_name = param.split(" ")
                args.append(param_name)

                if param_name not in src_args:
                    # output argument
                    continue

                if not kwarg_only:
                    if not method or param_name != "self":
                        input_args.append(src_args[param_name])
                else:
                    input_kwargs.append(src_args[param_name])

            results.append(
                PythonSignatureNativeFunctionPair(
                    signature=PythonSignatureDeprecated(
                        name=python_sig.name,
                        input_args=tuple(input_args),
                        input_kwargs=tuple(input_kwargs),
                        output_args=python_sig.output_args,
                        tensor_options_args=python_sig.tensor_options_args,
                        method=python_sig.method,
                        deprecated_args_names=tuple(args),
                        deprecated_args_exprs=tuple(call_args),
                        returns=python_sig.returns,
                    ),
                    function=pair.function,
                ))

    return results
示例#3
0
def load_deprecated_signatures(
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    deprecated_yaml_path: str,
    *,
    method: bool,
    pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
    # The deprecated.yaml doesn't have complete type information, we need
    # find and leverage the original ATen signature (to which it delegates
    # the call) to generate the full python signature.
    # We join the deprecated and the original signatures using type-only form.

    # group the original ATen signatures by name
    grouped: Dict[str,
                  List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
    for pair in pairs:
        grouped[pair.signature.name].append(pair)

    # find matching original signatures for each deprecated signature
    results: List[PythonSignatureNativeFunctionPair] = []

    with open(deprecated_yaml_path, "r") as f:
        deprecated_defs = yaml.load(f, Loader=YamlLoader)

    for deprecated in deprecated_defs:
        schema = FunctionSchema.parse(deprecated["name"])
        aten_name, call_args = split_name_params(deprecated["aten"])
        is_out = aten_name.endswith("_out")
        if is_out:
            aten_name = aten_name.replace("_out", "")

        # HACK: these are fixed constants used to pass the the aten function.
        # The type must be known ahead of time
        known_constants = {
            "1": Type.parse("Scalar"),
        }
        schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
        for name in call_args:
            assert (name in schema_args_by_name or name in known_constants
                    ), f"deprecation definiton: Unrecognized value {name}"

        # Map deprecated signature arguments to their aten signature and test
        # if the types and alias annotation match.
        def is_schema_compatible(aten_schema: FunctionSchema, ) -> bool:
            arguments: Iterable[Argument]
            if is_out:
                arguments = itertools.chain(aten_schema.arguments.out,
                                            aten_schema.arguments.flat_non_out)
            else:
                arguments = aten_schema.arguments.flat_all

            for i, arg in enumerate(arguments):
                if i < len(call_args):
                    arg_name = call_args[i]
                    if arg_name in known_constants:
                        schema_type = known_constants[arg_name]
                        schema_annotation = None
                    else:
                        schema_arg = schema_args_by_name[arg_name]
                        schema_type = schema_arg.type
                        schema_annotation = schema_arg.annotation

                    if schema_type != arg.type or schema_annotation != arg.annotation:
                        return False
                else:
                    if arg.default is None:
                        return False

            return len(schema.returns) == len(aten_schema.returns) and all(
                a == b for a, b in zip(schema.returns, aten_schema.returns))

        any_schema_found = False
        for pair in grouped[aten_name]:
            if not is_schema_compatible(pair.function.func):
                continue
            any_schema_found = True

            python_sig = signature_from_schema(
                schema,
                category_override=pair.function.category_override,
                method=method,
                pyi=pyi,
            )

            results.append(
                PythonSignatureNativeFunctionPair(
                    signature=PythonSignatureDeprecated(
                        name=python_sig.name,
                        input_args=python_sig.input_args,
                        input_kwargs=python_sig.input_kwargs,
                        output_args=python_sig.output_args,
                        tensor_options_args=python_sig.tensor_options_args,
                        method=python_sig.method,
                        deprecated_schema=schema,
                        deprecated_args_exprs=tuple(call_args),
                        returns=python_sig.returns,
                    ),
                    function=pair.function,
                ))
        assert (
            any_schema_found
        ), f"No native function with name {aten_name} matched signature:\n  {str(schema)}"

    return results