Ejemplo n.º 1
0
 def _convert_param(self, param, name):
     if param.position_index == 0:
         if function_is_classmethod(self._function_value.tree_node):
             return InstanceExecutedParamName(self._instance.py__class__(),
                                              self._function_value, name)
         elif not function_is_staticmethod(self._function_value.tree_node):
             return InstanceExecutedParamName(self._instance,
                                              self._function_value, name)
     return super()._convert_param(param, name)
Ejemplo n.º 2
0
def extract_function(inference_state, path, module_context, name, pos,
                     until_pos):
    nodes = _find_nodes(module_context.tree_node, pos, until_pos)
    assert len(nodes)

    is_expression, _ = _is_expression_with_error(nodes)
    context = module_context.create_context(nodes[0])
    is_bound_method = context.is_bound_method()
    params, return_variables = list(
        _find_inputs_and_outputs(module_context, context, nodes))

    # Find variables
    # Is a class method / method
    if context.is_module():
        insert_before_leaf = None  # Leaf will be determined later
    else:
        node = _get_code_insertion_node(context.tree_node, is_bound_method)
        insert_before_leaf = node.get_first_leaf()
    if is_expression:
        code_block = 'return ' + _expression_nodes_to_string(nodes) + '\n'
        remaining_prefix = None
        has_ending_return_stmt = False
    else:
        has_ending_return_stmt = _is_node_ending_return_stmt(nodes[-1])
        if not has_ending_return_stmt:
            # Find the actually used variables (of the defined ones). If none are
            # used (e.g. if the range covers the whole function), return the last
            # defined variable.
            return_variables = list(
                _find_needed_output_variables(
                    context, nodes[0].parent, nodes[-1].end_pos,
                    return_variables)) or [return_variables[-1]
                                           ] if return_variables else []

        remaining_prefix, code_block = _suite_nodes_to_string(nodes, pos)
        after_leaf = nodes[-1].get_next_leaf()
        first, second = _split_prefix_at(after_leaf, until_pos[0])
        code_block += first

        code_block = dedent(code_block)
        if not has_ending_return_stmt:
            output_var_str = ', '.join(return_variables)
            code_block += 'return ' + output_var_str + '\n'

    # Check if we have to raise RefactoringError
    _check_for_non_extractables(
        nodes[:-1] if has_ending_return_stmt else nodes)

    decorator = ''
    self_param = None
    if is_bound_method:
        if not function_is_staticmethod(context.tree_node):
            function_param_names = context.get_value().get_param_names()
            if len(function_param_names):
                self_param = function_param_names[0].string_name
                params = [p for p in params if p != self_param]

        if function_is_classmethod(context.tree_node):
            decorator = '@classmethod\n'
    else:
        code_block += '\n'

    function_code = '%sdef %s(%s):\n%s' % (decorator, name, ', '.join(
        params if self_param is None else [self_param] + params),
                                           indent_block(code_block))

    function_call = '%s(%s)' % (
        ('' if self_param is None else self_param + '.') + name,
        ', '.join(params))
    if is_expression:
        replacement = function_call
    else:
        if has_ending_return_stmt:
            replacement = 'return ' + function_call + '\n'
        else:
            replacement = output_var_str + ' = ' + function_call + '\n'

    replacement_dct = _replace(nodes, replacement, function_code, pos,
                               insert_before_leaf, remaining_prefix)
    if not is_expression:
        replacement_dct[after_leaf] = second + after_leaf.value
    file_to_node_changes = {path: replacement_dct}
    return Refactoring(inference_state, file_to_node_changes)