Example #1
0
def _functools_partial_inference(node, context=None):
    call = arguments.CallSite.from_call(node, context=context)
    number_of_positional = len(call.positional_arguments)
    if number_of_positional < 1:
        raise astroid.UseInferenceDefault(
            "functools.partial takes at least one argument"
        )
    if number_of_positional == 1 and not call.keyword_arguments:
        raise astroid.UseInferenceDefault(
            "functools.partial needs at least to have some filled arguments"
        )

    partial_function = call.positional_arguments[0]
    try:
        inferred_wrapped_function = next(partial_function.infer(context=context))
    except astroid.InferenceError as exc:
        raise astroid.UseInferenceDefault from exc
    if inferred_wrapped_function is astroid.Uninferable:
        raise astroid.UseInferenceDefault("Cannot infer the wrapped function")
    if not isinstance(inferred_wrapped_function, astroid.FunctionDef):
        raise astroid.UseInferenceDefault("The wrapped function is not a function")

    # Determine if the passed keywords into the callsite are supported
    # by the wrapped function.
    if not inferred_wrapped_function.args:
        function_parameters = []
    else:
        function_parameters = chain(
            inferred_wrapped_function.args.args or (),
            inferred_wrapped_function.args.posonlyargs or (),
            inferred_wrapped_function.args.kwonlyargs or (),
        )
    parameter_names = {
        param.name
        for param in function_parameters
        if isinstance(param, astroid.AssignName)
    }
    if set(call.keyword_arguments) - parameter_names:
        raise astroid.UseInferenceDefault(
            "wrapped function received unknown parameters"
        )

    partial_function = objects.PartialFunction(
        call,
        name=inferred_wrapped_function.name,
        doc=inferred_wrapped_function.doc,
        lineno=inferred_wrapped_function.lineno,
        col_offset=inferred_wrapped_function.col_offset,
        parent=inferred_wrapped_function.parent,
    )
    partial_function.postinit(
        args=inferred_wrapped_function.args,
        body=inferred_wrapped_function.body,
        decorators=inferred_wrapped_function.decorators,
        returns=inferred_wrapped_function.returns,
        type_comment_returns=inferred_wrapped_function.type_comment_returns,
        type_comment_args=inferred_wrapped_function.type_comment_args,
    )
    return iter((partial_function,))
Example #2
0
def _functools_partial_inference(node, context=None):
    call = arguments.CallSite.from_call(node)
    number_of_positional = len(call.positional_arguments)
    if number_of_positional < 1:
        raise astroid.UseInferenceDefault(
            "functools.partial takes at least one argument")
    if number_of_positional == 1 and not call.keyword_arguments:
        raise astroid.UseInferenceDefault(
            "functools.partial needs at least to have some filled arguments")

    partial_function = call.positional_arguments[0]
    try:
        inferred_wrapped_function = next(
            partial_function.infer(context=context))
    except astroid.InferenceError as exc:
        raise astroid.UseInferenceDefault from exc
    if inferred_wrapped_function is astroid.Uninferable:
        raise astroid.UseInferenceDefault("Cannot infer the wrapped function")
    if not isinstance(inferred_wrapped_function, astroid.FunctionDef):
        raise astroid.UseInferenceDefault(
            "The wrapped function is not a function")

    # Determine if the passed keywords into the callsite are supported
    # by the wrapped function.
    function_parameters = chain(
        inferred_wrapped_function.args.args or (),
        inferred_wrapped_function.args.kwonlyargs or (),
    )
    parameter_names = set(param.name for param in function_parameters
                          if isinstance(param, astroid.AssignName))
    if set(call.keyword_arguments) - parameter_names:
        raise astroid.UseInferenceDefault(
            "wrapped function received unknown parameters")

    # Return a wrapped() object that can be used further for inference
    class PartialFunction(astroid.FunctionDef):

        filled_positionals = len(call.positional_arguments[1:])
        filled_keywords = list(call.keyword_arguments)

        def infer_call_result(self, caller=None, context=None):
            nonlocal call
            filled_args = call.positional_arguments[1:]
            filled_keywords = call.keyword_arguments

            if context:
                current_passed_keywords = {
                    keyword
                    for (keyword, _) in context.callcontext.keywords
                }
                for keyword, value in filled_keywords.items():
                    if keyword not in current_passed_keywords:
                        context.callcontext.keywords.append((keyword, value))

                call_context_args = context.callcontext.args or []
                context.callcontext.args = filled_args + call_context_args

            return super().infer_call_result(caller=caller, context=context)

    partial_function = PartialFunction(
        name=inferred_wrapped_function.name,
        doc=inferred_wrapped_function.doc,
        lineno=inferred_wrapped_function.lineno,
        col_offset=inferred_wrapped_function.col_offset,
        parent=inferred_wrapped_function.parent,
    )
    partial_function.postinit(
        args=inferred_wrapped_function.args,
        body=inferred_wrapped_function.body,
        decorators=inferred_wrapped_function.decorators,
        returns=inferred_wrapped_function.returns,
        type_comment_returns=inferred_wrapped_function.type_comment_returns,
        type_comment_args=inferred_wrapped_function.type_comment_args,
    )
    return iter((partial_function, ))