예제 #1
0
def create_derivative(f: NativeFunction, formula: str,
                      var_names: Tuple[str, ...]) -> Derivative:
    arguments = cpp_arguments(f)
    argument_names = tuple(a.name for a in arguments)
    argument_types = tuple(a.type for a in arguments)

    return_names = tuple(n if n != 'self' else 'result'
                         for n in cpp.return_names(f))
    return_types = tuple(cpp.return_type(r) for r in f.func.returns)

    formula, saved_inputs = saved_variables(formula, argument_names,
                                            argument_types, var_names)
    formula, saved_outputs = saved_variables(formula, return_names,
                                             return_types, var_names)

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
            )

    return Derivative(
        formula=formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
    )
예제 #2
0
    def gen_differentiable_outputs(
            f: NativeFunction) -> List[DifferentiableOutput]:
        outputs: List[DifferentiableOutput] = [
            DifferentiableOutput(name=name,
                                 type=ret.type,
                                 cpp_type=cpp.return_type(ret))
            for name, ret in zip(cpp.return_names(f), f.func.returns)
        ]

        output_differentiability = info.output_differentiability if info else None
        if output_differentiability is not None:
            differentiable_outputs: List[DifferentiableOutput] = []
            if False in output_differentiability and f.func.kind(
            ) == SchemaKind.inplace:
                raise RuntimeError(
                    "output_differentiability=False for inplace operation (version_counter won't get updated)"
                )
            for differentiable, output in zip(output_differentiability,
                                              outputs):
                if differentiable:
                    differentiable_outputs.append(output)
            return differentiable_outputs

        candidate_differentiable_outputs = list(
            filter(lambda r: is_differentiable(r.name, r.type), outputs))

        if uses_single_grad(info):
            return candidate_differentiable_outputs[:1]
        else:
            return candidate_differentiable_outputs
예제 #3
0
def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative:
    original_formula = formula
    arguments: List[NamedCType] = [a.nctype.remove_const_ref() for a in cpp_arguments(f)]

    return_names = tuple(n if n != 'self' else 'result' for n in cpp.return_names(f))
    return_types = tuple(cpp.return_type(r).remove_const_ref() for r in f.func.returns)

    named_returns = [NamedCType(name, type) for name, type in zip(return_names, return_types)]

    formula, saved_inputs = saved_variables(formula, arguments, var_names)
    formula, saved_outputs = saved_variables(formula, named_returns, var_names)

    # Check that the referenced derivatives in the formula are in bounds
    for i in used_gradient_indices(formula):
        if i >= len(f.func.returns):
            raise RuntimeError(
                f'Out of bounds grads access: derivative formula for {cpp.name(f.func)} '
                f'used grads[{i}], but the forward only returns {len(f.func.returns)} outputs.'
            )

    return Derivative(
        formula=formula,
        original_formula=original_formula,
        var_names=var_names,
        saved_inputs=saved_inputs,
        saved_outputs=saved_outputs,
    )
예제 #4
0
def gen_differentiable_outputs(
        fn: NativeFunctionWithDifferentiabilityInfo
) -> List[DifferentiableOutput]:
    f = fn.func
    info = fn.info
    outputs: List[DifferentiableOutput] = [
        DifferentiableOutput(name=name,
                             type=ret.type,
                             cpp_type=cpp.return_type(ret).cpp_type())
        for name, ret in zip(cpp.return_names(f), f.func.returns)
    ]
    output_differentiability = info.output_differentiability if info else None
    if output_differentiability is not None:
        if len(output_differentiability) != len(outputs):
            raise RuntimeError(
                f"The length of output_differentiability ({len(output_differentiability)}), "
                f"does not match the number of outputs ({len(outputs)}).")
        differentiable_outputs: List[DifferentiableOutput] = []
        if False in output_differentiability and f.func.kind(
        ) == SchemaKind.inplace:
            raise RuntimeError(
                "output_differentiability=False for inplace operation (version_counter won't get updated)"
            )
        for differentiable, output in zip(output_differentiability, outputs):
            if differentiable:
                differentiable_outputs.append(output)
        return differentiable_outputs
    candidate_differentiable_outputs = list(
        filter(lambda r: is_differentiable(r.name, r.type, info), outputs))
    if uses_single_grad(info):
        return candidate_differentiable_outputs[:1]
    else:
        return candidate_differentiable_outputs
예제 #5
0
def compute_returns_yaml(
        f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
    # Note [name and field_name]
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
    # To understand name_to_field_name, we must first talk about this
    # schema:
    #
    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
    #
    # There is something very odd about this schema: it is an out
    # variant of the function (that is to say, it will convert into
    # at::lstsq_out() in the C++ API), but the names of the output
    # return arguments don't match the keyword argument names of
    # the inputs.  It TURNS OUT that in this situation, the historical
    # Declarations.yaml we want to output is this (abbreviated to
    # only show relevant fields):
    #
    #   arguments:
    #     ...
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #     ...
    #
    #   returns:
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #
    # The name of the return fields is stored in 'field_name', and the
    # name of the arguments is stored in 'name'.  So when we process
    # arguments, we need a way to get at the corresponding return.  At
    # the moment, this is most conveniently done by constructing a
    # mapping from name (the argument concept) to field_name (the
    # return concept) while processing return arguments, since we don't
    # directly maintain this correspondence in the modeling of function
    # schema itself.
    #
    # See also https://github.com/pytorch/pytorch/issues/43114
    name_to_field_name: Dict[str, str] = {}

    # Compute the returns field of the YAML entry
    returns = []
    for i, r in enumerate(f.func.returns):
        # If we have an inplace function, the return argument is
        # implicitly named self.
        # TODO: Consider incorporating this into the data model
        if f.func.name.name.inplace:
            assert i == 0, "illegal inplace function with multiple returns"
            name = 'self'
        # If we are out function, the name is the name of the
        # corresponding output function (r.name will get recorded
        # in field_name later.)
        elif f.func.is_out_fn():
            name = f.func.out_arguments[i].name
        # If the return argument is explicitly named...
        elif r.name:
            name_conflict = any(r.name == a.name
                                for a in f.func.schema_order_arguments())
            if name_conflict and not f.func.is_out_fn():
                name = f'{r.name}_return'
            else:
                name = r.name
        # If there is no explicit name, we just name the output result,
        # unless it's a multi-return, in which case it's result0,
        # result1, etc (zero-indexed)
        else:
            name = 'result' if len(f.func.returns) == 1 else f'result{i}'

        ret = {
            'dynamic_type': dynamic_type(r.type),
            'name': name,
            'type': cpp.return_type(r),
        }

        if r.name:
            # See Note [name and field_name]
            ret['field_name'] = r.name
            if f.func.is_out_fn():
                name_to_field_name[f.func.out_arguments[i].name] = r.name

        returns.append(ret)

    return returns, name_to_field_name
예제 #6
0
def compute_returns_yaml(
        f: NativeFunction) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
    # Note [name and field_name]
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
    # To understand name_to_field_name, we must first talk about this
    # schema:
    #
    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
    #
    # There is something very odd about this schema: it is an out
    # variant of the function (that is to say, it will convert into
    # at::lstsq_out() in the C++ API), but the names of the output
    # return arguments don't match the keyword argument names of
    # the inputs.  It TURNS OUT that in this situation, the historical
    # Declarations.yaml we want to output is this (abbreviated to
    # only show relevant fields):
    #
    #   arguments:
    #     ...
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #     ...
    #
    #   returns:
    #   - field_name: solution
    #     name: X
    #   - field_name: QR
    #     name: qr
    #
    # The name of the return fields is stored in 'field_name', and the
    # name of the arguments is stored in 'name'.  So when we process
    # arguments, we need a way to get at the corresponding return.  At
    # the moment, this is most conveniently done by constructing a
    # mapping from name (the argument concept) to field_name (the
    # return concept) while processing return arguments, since we don't
    # directly maintain this correspondence in the modeling of function
    # schema itself.
    #
    # See also https://github.com/pytorch/pytorch/issues/43114
    name_to_field_name: Dict[str, str] = {}

    # Compute the returns field of the YAML entry
    names = cpp.return_names(f)
    returns = []
    for i, (r, name) in enumerate(zip(f.func.returns, names)):
        ret = {
            'dynamic_type': dynamic_type(r.type),
            'name': name,
            'type': cpp.return_type(r).cpp_type(),
        }

        if r.name:
            # See Note [name and field_name]
            ret['field_name'] = r.name
            if f.func.is_out_fn():
                name_to_field_name[f.func.arguments.out[i].name] = r.name

        returns.append(ret)

    return returns, name_to_field_name
예제 #7
0
    def check_tensorimpl_and_storage(call: str,
                                     unpacked_bindings: List[Binding]) -> str:
        # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
        stmts_before_call: List[str] = []
        stmts_after_call: List[str] = []

        if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
            return call

        # Check properties of inputs (enforce (1))
        for unpacked_binding in unpacked_bindings:
            arg = unpacked_binding.name
            noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
            if noref_cpp_type == BaseCType(tensorListT):
                stmts_before_call += [
                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == ListCType(OptionalCType(
                    BaseCType(tensorT))):
                stmts_before_call += [
                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
                        tensorlist_name=arg),
                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
                        tensorlist_name=arg)
                ]
            elif noref_cpp_type == BaseCType(tensorT):
                stmts_before_call += [
                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg)
                ]
                stmts_after_call += [
                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
                        tensor_name=arg, out_tensor_name=arg),
                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg)
                ]

        assert (stmts_before_call
                and stmts_after_call) or (not stmts_before_call
                                          and not stmts_after_call)

        # Check properties of outputs (enforce (2), (3))
        if not f.func.kind() in (SchemaKind.inplace, SchemaKind.out):
            base_name = f.func.name.name.base  # TODO: should be str(f.func.name.name)?
            aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
            if aliased_arg_name is not None:
                aliased_arg_name = unpacked_name(aliased_arg_name)
            for i, (ret, ret_name) in enumerate(
                    zip(f.func.returns, cpp.return_names(f))):
                noref_cpp_type = cpp.return_type(ret).remove_const_ref()
                if noref_cpp_type == BaseCType(tensorT):
                    if aliased_arg_name is not None:
                        assert i == 0, "Expect non-CompositeImplicitAutograd view function {base} to return single output"
                        stmts_after_call += [
                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
                                tensor_name=aliased_arg_name,
                                out_tensor_name=ret_name)
                        ]
                    else:
                        if type_wrapper_name(
                                f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT:
                            stmts_after_call += [
                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.
                                substitute(tensor_name=ret_name,
                                           fn_name=type_wrapper_name(f))
                            ]

                    if type_wrapper_name(
                            f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
                        stmts_after_call += [
                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.
                            substitute(tensor_name=ret_name,
                                       fn_name=type_wrapper_name(f))
                        ]

                # Currently we don't have any functions that return the following types, but
                # we should update the checks once we do
                elif noref_cpp_type == ListCType(
                        OptionalCType(BaseCType(tensorT))):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")
                elif noref_cpp_type == BaseCType(tensorListT):
                    raise AssertionError(
                        f"Please add use_count checks for {noref_cpp_type}")

        if stmts_before_call and stmts_after_call:
            call = RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call) + \
                call + \
                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
        return call