예제 #1
0
def gen_dispatchkey_nativefunc_headers(
    fm: FileManager,
    class_name: str,
    cpp_namespace: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    autograd_dispatch_key: Optional[DispatchKey],
    backend_name: str = "",
) -> None:
    assert class_name is not None
    generated_comment = (
        "Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
    )

    # Convert to a set first to remove duplicate kernel names.
    # Backends are allowed to repeat kernel names; only generate the declaration once!
    # Sort for deterministic output.
    backend_declarations = list(
        sorted(
            set(
                concatMap(
                    lambda f: dest.compute_native_function_declaration(
                        f, backend_indices[backend_dispatch_key]
                    ),
                    grouped_native_functions,
                )
            )
        )
    )
    autograd_declarations = list(
        sorted(
            set(
                concatMap(
                    lambda f: []
                    if autograd_dispatch_key is None
                    else dest.compute_native_function_declaration(
                        f, backend_indices[autograd_dispatch_key]
                    ),
                    grouped_native_functions,
                )
            )
        )
    )

    ns_helper = NamespaceHelper(cpp_namespace)
    fm.write_with_template(
        f"{backend_dispatch_key}NativeFunctions.h",
        "DispatchKeyNativeFunctions.h",
        lambda: {
            "generated_comment": generated_comment,
            "namespace_prologue": ns_helper.prologue,
            "class_name": class_name,
            "namespace_epilogue": ns_helper.epilogue,
            "dispatch_declarations": backend_declarations + autograd_declarations,
            "BackendName": backend_name,
            "DispatchKey": backend_dispatch_key,
        },
    )
예제 #2
0
def parse_full_codegen_ops(
    backend_yaml_path: str,
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
) -> List[OperatorName]:

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(
            lambda f: [f]
            if isinstance(f, NativeFunction) else list(f.functions()),
            grouped_native_functions,
        )
    }

    with open(backend_yaml_path, "r") as f:
        yaml_values = yaml.load(f, Loader=YamlLoader)
    assert isinstance(yaml_values, dict)

    full_codegen = yaml_values.pop("full_codegen", [])
    assert isinstance(
        full_codegen,
        list), f'expected "full_codegen" to be a list, but got: {full_codegen}'
    full_codegen = [OperatorName.parse(name) for name in full_codegen]

    return full_codegen
예제 #3
0
def parse_native_functions_keys(
    backend_yaml_path: str,
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(
            lambda f: [f]
            if isinstance(f, NativeFunction) else list(f.functions()),
            grouped_native_functions,
        )
    }

    with open(backend_yaml_path, "r") as f:
        yaml_values = yaml.load(f, Loader=YamlLoader)
    assert isinstance(yaml_values, dict)

    full_codegen = yaml_values.pop("full_codegen", [])
    non_native = yaml_values.pop("non_native", [])
    ir_gen = yaml_values.pop("ir_gen", [])
    assert isinstance(full_codegen, list)
    assert isinstance(non_native, list)
    assert isinstance(ir_gen, list)
    full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
    ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
    return full_codegen_opnames, non_native, ir_gen_opnames
예제 #4
0
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str,
                     tags_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
    # Do some caching as this is a deterministic function
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
    key = (derivatives_yaml_path, native_yaml_path)
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:

        with open(derivatives_yaml_path, "r") as f:
            definitions = yaml.load(f, Loader=YamlLoader)

        funcs = parse_native_yaml(native_yaml_path,
                                  tags_yaml_path).native_functions
        # From the parsed native functions, separate out the (generated) view_copy functions,
        # so we can generate derivatives for them separately.
        native_functions_with_view_groups = get_grouped_by_view_native_functions(
            funcs)
        native_functions_without_view_copies = concatMap(
            # We need to pull out the view_inplace ops too, since they might have their own derivative entries.
            lambda g: [g] if isinstance(g, NativeFunction) else list(
                g.functions(include_copy=False)),
            native_functions_with_view_groups,
        )
        view_groups = [
            g for g in native_functions_with_view_groups
            if isinstance(g, NativeFunctionsViewGroup)
        ]

        # What's the difference between function schema v.s. signature?
        # function schema is the complete declaration including mutability annotation / default value and etc.
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
        # that are semantically related.
        functions_by_signature: Dict[FunctionSchema,
                                     List[NativeFunction]] = defaultdict(list)
        functions_by_schema: Dict[str, NativeFunction] = dict()
        for function in native_functions_without_view_copies:
            functions_by_signature[function.func.signature()].append(function)
            assert str(function.func) not in functions_by_schema
            functions_by_schema[str(function.func)] = function

        # Keep track of how many of which ops we've seen so we can
        # disambiguate them with a numeric suffix.
        op_counter = Counter[str]()

        infos = [
            create_differentiability_info(defn, functions_by_signature,
                                          functions_by_schema, op_counter)
            for defn in definitions
        ]
        infos += add_view_copy_derivatives(infos, view_groups)

        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos

    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
def generate_out_args_from_schema(
    func: FunctionSchema,
) -> Tuple[List[Return], List[Argument]]:
    # More of a sanity check - our existing restrictions on schemas should enforce that
    # mutable schema kinds never return their mutable arguments.
    assert not any(
        r.annotation is not None and r.annotation.is_write for r in func.returns
    )

    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
    assert len(tensorlike_rets) > 0

    used_annotations = concatMap(
        lambda a: [] if a.annotation is None else a.annotation.alias_set,
        func.arguments.flat_all,
    )
    valid_annotations = [
        x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
    ]

    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)

    new_out_args: List[Argument] = []
    # The end result of new_returns is that:
    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
    new_returns: List[Return] = []
    for (i, r) in enumerate(func.returns):
        if r.type.is_tensor_like():
            new_out = Argument(
                name="out" if len(func.returns) == 1 else f"out{i}",
                type=r.type,
                default=None,
                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
            )
            new_out_args.append(new_out)
            if all_rets_are_tensors:
                # The convention for out= schemas is that they only return their out arguments
                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
                new_ret = Return(
                    name=None, type=new_out.type, annotation=new_out.annotation
                )
                new_returns.append(new_ret)
        else:
            new_returns.append(r)
    return new_returns, new_out_args
예제 #6
0
def jit_arguments(func: FunctionSchema) -> List[Argument]:
    def to_argument(
        a: Union[Argument, TensorOptionsArguments, SelfArgument]
    ) -> List[Argument]:
        if isinstance(a, Argument):
            return [a]
        elif isinstance(a, SelfArgument):
            return [a.argument]
        elif isinstance(a, TensorOptionsArguments):
            return [a.dtype, a.layout, a.device, a.pin_memory]
        else:
            assert_never(a)

    return list(
        concatMap(
            to_argument,
            itertools.chain(func.arguments.positional,
                            func.arguments.kwarg_only, func.arguments.out),
        ))
예제 #7
0
def parse_backend_yaml(
    backend_yaml_path: str,
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_indices: Dict[DispatchKey, BackendIndex],
) -> ParsedExternalYaml:

    native_functions_map: Dict[OperatorName, NativeFunction] = {
        f.func.name: f
        for f in concatMap(
            lambda f: [f]
            if isinstance(f, NativeFunction) else list(f.functions()),
            grouped_native_functions,
        )
    }

    with open(backend_yaml_path, "r") as f:
        yaml_values = yaml.load(f, Loader=YamlLoader)
    assert isinstance(yaml_values, dict)

    valid_keys = [
        "backend",
        "class_name",
        "cpp_namespace",
        "extra_headers",
        "supported",
        "autograd",
        "full_codegen",
        "non_native",
        "ir_gen",
    ]

    backend = yaml_values.pop("backend", None)
    assert backend is not None, 'You must provide a value for "backend"'

    class_name = yaml_values.pop("class_name", None)

    cpp_namespace = yaml_values.pop("cpp_namespace", None)
    assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'

    # Mostly just defaulting to false to stick with LazyTensor convention.
    use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
    assert isinstance(
        use_out_as_primary, bool
    ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"

    use_device_guard = yaml_values.pop("device_guard", False)
    assert isinstance(
        use_device_guard, bool
    ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}"

    supported = yaml_values.pop("supported", [])
    if supported is None:
        supported = []  # Allow an empty list of supported ops
    assert isinstance(
        supported, list
    ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'

    supported_autograd = yaml_values.pop("autograd", [])
    assert isinstance(
        supported_autograd, list
    ), f'expected "autograd" to be a list, but got: {supported_autograd}'

    # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    full_codegen = yaml_values.pop("full_codegen", [])
    supported.extend(full_codegen)

    # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    non_native = yaml_values.pop("non_native", {})

    # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
    _ = yaml_values.pop("ir_gen", {})

    assert (
        len(yaml_values.keys()) == 0
    ), f'{backend_yaml_path} contains unexpected keys: {", ".join(yaml_values.keys())}. \
Only the following keys are supported: {", ".join(valid_keys)}'

    def create_backend_index(
        backend_ops: List[str],
        dispatch_key: DispatchKey,
        *,
        use_out_as_primary: bool,
        use_device_guard: bool,
    ) -> BackendIndex:
        metadata: Dict[OperatorName, BackendMetadata] = {}
        for op in backend_ops:
            op_name = OperatorName.parse(op)
            assert (op_name in native_functions_map
                    ), f"Found an invalid operator name: {op_name}"
            # See Note [External Backends Follow Dispatcher API]
            kernel_name = dispatcher.name(native_functions_map[op_name].func)
            # TODO: allow structured external backends later.
            m = BackendMetadata(kernel=kernel_name,
                                structured=False,
                                cpp_namespace=cpp_namespace)
            metadata[op_name] = m
        return BackendIndex(
            dispatch_key=dispatch_key,
            use_out_as_primary=use_out_as_primary,
            external=True,
            symint=True,  # TODO: make this configurable
            device_guard=use_device_guard,
            index=metadata,
        )

    backend_key: Optional[DispatchKey] = None
    if len(supported) > 0:
        with context(
                lambda:
                f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
        ):
            backend_key = DispatchKey.parse(backend)

        backend_idx = create_backend_index(
            supported,
            backend_key,
            use_out_as_primary=use_out_as_primary,
            use_device_guard=use_device_guard,
        )
        assert backend_key not in backend_indices
        backend_indices[backend_key] = backend_idx

    autograd_key: Optional[DispatchKey] = None
    if len(supported_autograd) > 0:
        with context(
                lambda:
                f'The "autograd" key was specified, which indicates that you would like to override \
the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
        ):
            autograd_key = DispatchKey.parse(f"Autograd{backend}")

        autograd_idx = create_backend_index(
            supported_autograd,
            autograd_key,
            use_out_as_primary=use_out_as_primary,
            use_device_guard=use_device_guard,
        )
        assert autograd_key not in backend_indices
        backend_indices[autograd_key] = autograd_idx

    for g in grouped_native_functions:
        if isinstance(g, NativeFunction):
            forward_kernels = ([] if backend_key is None else [
                m for m in [backend_indices[backend_key].get_kernel(g)]
                if m is not None
            ])
            backward_kernels = ([] if autograd_key is None else [
                m for m in [backend_indices[autograd_key].get_kernel(g)]
                if m is not None
            ])
        else:
            forward_kernels = ([] if backend_key is None else [
                m for m in [
                    backend_indices[backend_key].get_kernel(f)
                    for f in g.functions()
                ] if m is not None
            ])
            backward_kernels = ([] if autograd_key is None else [
                m for m in [
                    backend_indices[autograd_key].get_kernel(f)
                    for f in g.functions()
                ] if m is not None
            ])

        forward_kernels = [f for f in forward_kernels if f is not None]
        backward_kernels = [f for f in backward_kernels if f is not None]
        assert (
            len(forward_kernels) == 0 or len(backward_kernels) == 0
        ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'

    return ParsedExternalYaml(backend_key, autograd_key, class_name,
                              cpp_namespace, backend_indices)
예제 #8
0
def gen_dispatcher_registrations(
    fm: FileManager,
    output_dir: str,
    class_name: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    dispatch_key: DispatchKey,
    selector: "SelectiveBuilder",
    # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
    build_in_tree: bool = False,
    per_operator_headers: bool = False,
    backend_name: str = "",
    eager_registration: bool = True,
) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f"#include <{h}>"
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]

    dispatch_registrations_body = list(
        concatMap(
            dest.RegisterDispatchKey(
                backend_index,
                Target.REGISTRATION,
                selector,
                rocm=False,
                class_method_name=f"{class_name}",
                skip_dispatcher_op_registration=False,
            ),
            grouped_native_functions,
        ))
    newline = "\n"
    ns_helper = NamespaceHelper(namespace_str="at")
    deferred_dispatch_registrations = ""
    static_init_dispatch_registrations = ""
    if eager_registration:
        static_template = CodeTemplate("""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
    $dispatch_registrations_body
};""")
        static_init_dispatch_registrations = static_template.substitute(
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )
    else:
        deferred_template = CodeTemplate("""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
    $dispatch_registrations_body
}""")
        deferred_dispatch_registrations = deferred_template.substitute(
            backend_name=backend_name,
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )

    fm.write_with_template(
        f"Register{dispatch_key}.cpp",
        "RegisterDispatchKey.cpp",
        lambda: {
            "extra_cuda_headers":
            "",
            "external_backend_headers":
            external_backend_headers_str,
            "ops_headers":
            "#include <ATen/Functions.h>" if not per_operator_headers else "",
            "DispatchKey":
            dispatch_key,
            "dispatch_namespace":
            dispatch_key.lower(),
            "dispatch_headers":
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            "dispatch_definitions":
            fm.substitute_with_template(
                "RegisterDispatchDefinitions.ini",
                lambda: {
                    "ns_prologue":
                    ns_helper.prologue,
                    "ns_epilogue":
                    ns_helper.epilogue,
                    "static_init_dispatch_registrations":
                    static_init_dispatch_registrations,
                    "deferred_dispatch_registrations":
                    deferred_dispatch_registrations,
                    "dispatch_helpers":
                    dest.gen_registration_helpers(backend_index),
                    "dispatch_namespace":
                    dispatch_key.lower(),
                    "dispatch_namespaced_definitions":
                    "",
                    "dispatch_anonymous_definitions":
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_index,
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                class_method_name=f"{class_name}",
                                skip_dispatcher_op_registration=False,
                            ),
                            grouped_native_functions,
                        )),
                },
            ).split(newline),
        },
    )
예제 #9
0
def load_derivatives(
    derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
) -> Tuple[Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], Set[str]]:
    # Do some caching as this is a deterministic function
    global _GLOBAL_LOAD_DERIVATIVE_CACHE
    key = (derivatives_yaml_path, native_yaml_path)
    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:

        with open(derivatives_yaml_path, "r") as f:
            definitions = yaml.load(f, Loader=YamlLoader)

        funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
        # From the parsed native functions, separate out the (generated) view_copy functions,
        # so we can generate derivatives for them separately.
        native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
        native_functions_without_view_copies = concatMap(
            # We need to pull out the view_inplace ops too, since they might have their own derivative entries.
            lambda g: [g]
            if isinstance(g, NativeFunction)
            else list(g.functions(include_copy=False)),
            native_functions_with_view_groups,
        )
        view_groups = [
            g
            for g in native_functions_with_view_groups
            if isinstance(g, NativeFunctionsViewGroup)
        ]

        # What's the difference between function schema v.s. signature?
        # function schema is the complete declaration including mutability annotation / default value and etc.
        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
        # that are semantically related.
        functions_by_signature: Dict[
            FunctionSchema, List[NativeFunction]
        ] = defaultdict(list)
        functions_by_schema: Dict[str, NativeFunction] = dict()
        for function in native_functions_without_view_copies:
            functions_by_signature[function.func.signature()].append(function)
            assert str(function.func) not in functions_by_schema
            functions_by_schema[str(function.func)] = function

        # Keep track of how many of which ops we've seen so we can
        # disambiguate them with a numeric suffix.
        op_counter = Counter[str]()

        # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
        # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
        # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
        infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]] = dict()
        used_dispatch_keys: Set[str] = set()
        for defn_dict in definitions:
            # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
            if "dispatch" not in defn_dict:
                specification = defn_dict.pop("name")
                output_differentiability = defn_dict.pop(
                    "output_differentiability", None
                )
                defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
                if output_differentiability:
                    defn_dict["output_differentiability"] = output_differentiability
            name, per_dispatch_diffinfos = create_differentiability_info(
                defn_dict,
                functions_by_signature,
                functions_by_schema,
                op_counter,
                used_dispatch_keys,
            )
            infos[name] = per_dispatch_diffinfos

        add_view_copy_derivatives(infos, view_groups)

        # cache both loaded infos as well a a set of all the dispatch_keys/aliases
        # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
        # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys

    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
예제 #10
0
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
    # Generating an out= schema from a mutable schema.
    assert func.kind() == SchemaKind.mutable
    # The new out= schema has:
    # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
    #   (if the argument is a tensor then we also return it for method chaining,
    #   otherwise we return nothing)
    # - an "out" overload name
    #
    # Note that:
    # (1) This also means that we can *only* generate an out= variant from a mutable schema
    #     if the mutable schema has at least one tensor-like non-aliasing return.
    # (2) The generated out= variant still has mutable positional arguments,
    #     but if necessary we could probably add another out= variant that also
    #     functionalizes the mutable arguments (a functional_out variant)

    # More of a sanity check - our existing restrictions on schemas should enforce that
    # mutable schema kinds never return their mutable arguments.
    assert not any(r.annotation is not None and r.annotation.is_write
                   for r in func.returns)

    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
    assert len(tensorlike_rets) > 0

    used_annotations = concatMap(
        lambda a: [] if a.annotation is None else a.annotation.alias_set,
        func.arguments.flat_all,
    )
    valid_annotations = [
        x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
    ]

    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor)
                               for r in func.returns)

    new_out_args: List[Argument] = []
    # The end result of new_returns is that:
    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
    new_returns: List[Return] = []
    for (i, r) in enumerate(func.returns):
        if r.type.is_tensor_like():
            new_out = Argument(
                name=f"out{i}",
                type=r.type,
                default=None,
                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
            )
            new_out_args.append(new_out)
            if all_rets_are_tensors:
                # The convention for out= schemas is that they only return their out arguments
                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
                new_ret = Return(name=None,
                                 type=new_out.type,
                                 annotation=new_out.annotation)
                new_returns.append(new_ret)
        else:
            new_returns.append(r)

    return FunctionSchema(
        name=func.name.remove_inplace().with_overload(
            "out" if not func.name.overload_name else
            f"{func.name.overload_name}_out"),
        arguments=func.arguments.with_out_args(new_out_args),
        returns=tuple(new_returns),
    )