예제 #1
0
    def do_refactor(self):
        program = utils2.get_program(self.source_filenames, print_status=True)
        # print(program.packages)
        if self.package_name not in program.packages \
                or self.class_name not in program.packages[self.package_name].classes \
                or self.field_name not in program.packages[self.package_name].classes[self.class_name].fields:
            return False

        _class: utils_listener_fast.Class = program.packages[
            self.package_name].classes[self.class_name]
        if _class.superclass_name is None:
            return False

        superclass_name = _class.superclass_name

        superclass: utils_listener_fast.Class = program.packages[
            self.package_name].classes[superclass_name]
        superclass_body_start = utils_listener_fast.TokensInfo(
            superclass.parser_context.classBody())
        superclass_body_start.stop = superclass_body_start.start  # Start and stop both point to the '{'

        if self.field_name in superclass.fields:
            return False

        datatype = _class.fields[self.field_name].datatype

        fields_to_remove = []
        for pn in program.packages:
            p: utils_listener_fast.Package = program.packages[pn]
            for cn in p.classes:
                c: utils_listener_fast.Class = p.classes[cn]
                if ((c.superclass_name == superclass_name and c.file_info.has_imported_class(self.package_name,
                                                                                             superclass_name))
                    or (
                            self.package_name is not None and c.superclass_name == self.package_name + '.' + superclass_name)) \
                        and self.field_name in c.fields \
                        and c.fields[self.field_name].datatype == datatype:
                    fields_to_remove.append(c.fields[self.field_name])

        if len(fields_to_remove) == 0:
            return False

        is_public = False
        is_protected = True
        for field in fields_to_remove:
            field: utils_listener_fast.Field = field
            is_public = is_public or "public" in field.modifiers
            is_protected = is_protected and ("protected" in field.modifiers
                                             or "private" in field.modifiers)

        rewriter = utils2.Rewriter(program, self.filename_mapping)

        rewriter.insert_after(
            superclass_body_start,
            "\n    " + ("public " if is_public else
                        ("protected " if is_protected else "")) + datatype +
            " " + self.field_name + ";")

        for field in fields_to_remove:
            if len(field.neighbor_names) == 0:
                rewriter.replace(field.get_tokens_info(), "")
                # Have to remove the modifiers too, because of the new grammar.
                for mod_ctx in field.modifiers_parser_contexts:
                    rewriter.replace(utils_listener_fast.TokensInfo(mod_ctx),
                                     "")
            else:
                i = field.index_in_variable_declarators
                var_ctxs = field.all_variable_declarator_contexts
                if i == 0:
                    to_remove = utils_listener_fast.TokensInfo(var_ctxs[i])
                    to_remove.stop = utils_listener_fast.TokensInfo(
                        var_ctxs[i + 1]).start - 1  # Include the ',' after it
                    rewriter.replace(to_remove, "")
                else:
                    to_remove = utils_listener_fast.TokensInfo(var_ctxs[i])
                    to_remove.start = utils_listener_fast.TokensInfo(
                        var_ctxs[i - 1]).stop + 1  # Include the ',' before it
                    rewriter.replace(to_remove, "")

            # Add initializer to class constructor if initializer exists in field declaration
            if field.initializer is not None:
                _class: utils_listener_fast.Class = program.packages[
                    field.package_name].classes[field.class_name]
                initializer_statement = (
                    field.name + " = " +
                    ("new " + field.datatype +
                     " " if field.initializer.startswith('{') else "") +
                    field.initializer + ";")
                has_contructor = False
                for class_body_decl in _class.parser_context.classBody(
                ).getChildren():
                    if class_body_decl.getText() in ['{', '}']:
                        continue
                    member_decl = class_body_decl.memberDeclaration()
                    if member_decl is not None:
                        constructor = member_decl.constructorDeclaration()
                        if constructor is not None:
                            body = constructor.constructorBody  # Start token = '{'
                            body_start = utils_listener_fast.TokensInfo(body)
                            body_start.stop = body_start.start  # Start and stop both point to the '{'
                            rewriter.insert_after(
                                body_start,
                                "\n        " + initializer_statement)
                            has_contructor = True
                if not has_contructor:
                    body = _class.parser_context.classBody()
                    body_start = utils_listener_fast.TokensInfo(body)
                    body_start.stop = body_start.start  # Start and stop both point to the '{'
                    rewriter.insert_after(
                        body_start, "\n    " + _class.modifiers[0] + " " +
                        _class.name + "() { " + initializer_statement + " }")

        rewriter.apply()

        # check for multilevel inheritance recursively.

        if _class.superclass_name is not None:
            PullUpFieldRefactoring(self.source_filenames, self.package_name,
                                   _class.superclass_name, "id").do_refactor()
        return True
예제 #2
0
    def do_refactor(self):
        program = utils2.get_program(self.source_filenames, print_status=True)
        if self.package_name not in program.packages \
                or any(
            class_name not in program.packages[self.package_name].classes
            for class_name in self.class_names
        ) \
                or any(
            method_key not in program.packages[self.package_name].classes[class_name].methods
            for class_name in self.class_names for method_key in self.method_keys
        ):
            return False

        method_returntypes = {}
        method_parameters = {}
        method_names = []
        for method_key in self.method_keys:
            method_names.append(method_key[:method_key.find('(')])

        rewriter = utils2.Rewriter(program, self.filename_mapping)

        for class_name in self.class_names:
            c: utils_listener_fast.Class = program.packages[
                self.package_name].classes[class_name]
            # Add implements to the class
            has_superinterface = False
            if c.parser_context.IMPLEMENTS(
            ) is not None:  # old: c.parser_context.superinterfaces()
                t = utils_listener_fast.TokensInfo(c.parser_context.typeList(
                ))  # old: c.parser_context.superinterfaces()
                has_superinterface = True
            elif c.parser_context.EXTENDS(
            ) is not None:  # old: c.parser_context.superclass()
                t = utils_listener_fast.TokensInfo(c.parser_context.typeType(
                ))  # old: c.parser_context.superclass()
            elif c.parser_context.typeParameters() is not None:
                t = utils_listener_fast.TokensInfo(
                    c.parser_context.typeParameters())
            else:
                # old: TokensInfo(c.parser_context.identifier())
                t = utils_listener_fast.TokensInfo(c.parser_context)
                t.stop = c.parser_context.IDENTIFIER().getSymbol().tokenIndex
            rewriter.insert_after(
                t, (", " if has_superinterface else " implements ") +
                self.interface_name)
            for method_key in self.method_keys:
                m: utils_listener_fast.Method = c.methods[method_key]
                # Check if the return types / parameter types are the same
                # Or add to dictionary
                if method_key in method_returntypes:
                    if method_returntypes[method_key] != m.returntype:
                        return False
                    if len(method_parameters[method_key]) != len(m.parameters):
                        return False
                    for i in range(len(m.parameters)):
                        if method_parameters[method_key][i][0] != m.parameters[
                                i][0]:
                            return False
                else:
                    method_returntypes[method_key] = m.returntype
                    method_parameters[method_key] = m.parameters
                # Manage method modifiers
                if len(m.modifiers_parser_contexts) > 0:
                    t = utils_listener_fast.TokensInfo(
                        m.modifiers_parser_contexts[0])
                else:
                    t = m.get_tokens_info()
                rewriter.insert_before_start(
                    t,  # old: m.get_tokens_info() # without requiring t
                    ("" if "@Override" in m.modifiers else "@Override\n    ") +
                    ("" if "public" in m.modifiers else "public "))
                for i in range(len(m.modifiers)):
                    mm = m.modifiers[i]
                    if mm == "private" or mm == "protected":
                        t = utils_listener_fast.TokensInfo(
                            m.modifiers_parser_contexts[i]
                        )  # old: m.parser_context.methodModifier(i)
                        rewriter.replace(t, "")

        # Change variable types to the interface if only interface methods are used.
        for package_name in program.packages:
            p: utils_listener_fast.Package = program.packages[package_name]
            for class_name in p.classes:
                c: utils_listener_fast.Class = p.classes[class_name]
                fields_of_interest = {}
                for fn in c.fields:
                    f: utils_listener_fast.Field = c.fields[fn]
                    d = False
                    for cn in self.class_names:
                        if (f.datatype == cn and f.file_info.has_imported_class(package_name, cn)) \
                                or (package_name is not None and f.datatype == package_name + '.' + cn):
                            d = True
                            break
                    if d and "private" in f.modifiers:
                        fields_of_interest[f.name] = f
                for method_key in c.methods:
                    m: utils_listener_fast.Method = c.methods[method_key]
                    vars_of_interest = {}
                    for item in m.body_local_vars_and_expr_names:
                        if isinstance(item, utils_listener_fast.LocalVariable):
                            for cn in self.class_names:
                                if (item.datatype == cn and c.file_info.has_imported_class(package_name, cn)) \
                                        or (package_name is not None and item.datatype == package_name + '.' + cn):
                                    vars_of_interest[item.identifier] = item
                                    break
                        if isinstance(item,
                                      utils_listener_fast.MethodInvocation):
                            if len(item.dot_separated_identifiers) == 2 or \
                                    (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[
                                        0] == "this"):
                                if item.dot_separated_identifiers[
                                        -2] in vars_of_interest:
                                    if item.dot_separated_identifiers[
                                            -1] not in method_names:
                                        vars_of_interest.pop(
                                            item.dot_separated_identifiers[-2])
                                elif item.dot_separated_identifiers[-2] in fields_of_interest \
                                        and item.dot_separated_identifiers[-1] not in method_names:
                                    fields_of_interest.pop(
                                        item.dot_separated_identifiers[-2])
                    for var_name in vars_of_interest:
                        var = vars_of_interest[var_name]
                        if m.file_info.has_imported_package(package_name):
                            # old: var.parser_context.unannType()
                            rewriter.replace(
                                utils_listener_fast.TokensInfo(
                                    var.parser_context.typeType()),
                                self.interface_name)
                        else:
                            if package_name is None:
                                break
                            # old: var.parser_context.unannType()
                            rewriter.replace(
                                utils_listener_fast.TokensInfo(
                                    var.parser_context.typeType()),
                                package_name + '.' + self.interface_name)
                for field_name in fields_of_interest:
                    f = fields_of_interest[field_name]
                    if c.file_info.has_imported_package(package_name):
                        typename = self.interface_name
                    else:
                        if package_name is None:
                            break
                        typename = package_name + '.' + self.interface_name
                    if len(f.neighbor_names) == 0:
                        rewriter.replace(
                            utils_listener_fast.TokensInfo(
                                f.parser_context.typeType()),
                            typename)  # old: f.parser_context.unannType()
                    else:
                        if not any(nn in fields_of_interest
                                   for nn in f.neighbor_names):
                            t = utils_listener_fast.TokensInfo(
                                f.all_variable_declarator_contexts[
                                    f.index_in_variable_declarators])
                            if f.index_in_variable_declarators == 0:
                                t.stop = utils_listener_fast.TokensInfo(
                                    f.all_variable_declarator_contexts[
                                        f.index_in_variable_declarators +
                                        1]).start - 1
                            else:
                                t.start = utils_listener_fast.TokensInfo(
                                    f.all_variable_declarator_contexts[
                                        f.index_in_variable_declarators -
                                        1]).start + 1
                            rewriter.replace(t, "")
                            rewriter.insert_after(
                                f.get_tokens_info(),
                                "\n    private " + typename + " " + f.name +
                                (" = " + f.initializer +
                                 ";" if f.initializer is not None else ";"))

        # Create the interface
        interface_file_content = ("package " + package_name + ";\n\n" +
                                  "public interface " + self.interface_name +
                                  "\n" + "{\n")
        for method_key in self.method_keys:
            method_name = method_key[:method_key.find('(')]
            interface_file_content += "    " + method_returntypes[
                method_key] + " " + method_name + "("
            if len(method_parameters[method_key]) > 0:
                interface_file_content += method_parameters[method_key][0][
                    0] + " " + method_parameters[method_key][0][1]
            for i in range(1, len(method_parameters[method_key])):
                param = method_parameters[method_key][i]
                interface_file_content += ", " + param[0] + " " + param[1]
            interface_file_content += ");\n"
        interface_file_content += "}\n"

        if not os.path.exists(
                self.interface_filename[:self.interface_filename.rfind('/')]):
            os.makedirs(
                self.interface_filename[:self.interface_filename.rfind('/')])
        file = open(self.interface_filename, "w+")
        file.write(interface_file_content)
        file.close()

        rewriter.apply()
        return True
예제 #3
0
    def do_refactor(self):
        program = utils2.get_program(self.source_filenames, print_status=True)
        superclass: utils_listener_fast.Class = program.packages[
            self.package_name].classes[self.superclass_name]

        if not self.pre_condition_check(program, superclass):
            print("Can't refactor")
            return False

        # all_derived_classes = [] # Not needed
        other_derived_classes = []
        classes_to_add_to = []
        for pn in program.packages:
            p: utils_listener_fast.Package = program.packages[pn]
            for cn in p.classes:
                c: utils_listener_fast.Class = p.classes[cn]
                if ((c.superclass_name == self.superclass_name and c.file_info.has_imported_class(self.package_name,
                                                                                                  self.superclass_name)) \
                        or (self.package_name is not None and c.superclass_name == self.package_name + '.' + self.superclass_name)):
                    # all_derived_classes.append(c)

                    if len(self.class_names) == 0 or cn in self.class_names:
                        if self.field_name in c.fields:
                            print("some classes have same variable")
                            return False
                        else:
                            classes_to_add_to.append(c)
                    else:
                        other_derived_classes.append(c)

        # Check if the field is used from the superclass or other derived classes
        for pn in program.packages:
            p: utils_listener_fast.Package = program.packages[pn]
            for cn in p.classes:
                c: utils_listener_fast.Class = p.classes[cn]
                has_imported_superclass = c.file_info.has_imported_class(
                    self.package_name, self.superclass_name)
                fields_of_superclass_type_or_others = []
                for fn in c.fields:
                    f: utils_listener_fast.Field = c.fields[fn]
                    if (f.name == self.field_name and has_imported_superclass) \
                            or (self.package_name is not None and f.name == (
                            self.package_name + '.' + self.superclass_name)):
                        fields_of_superclass_type_or_others.append(f.name)
                    if any((c.file_info.has_imported_class(
                            o.package_name, o.name) and f.datatype == o.name)
                           or f.datatype == (o.package_name + '.' + o.name)
                           for o in other_derived_classes):
                        fields_of_superclass_type_or_others.append(f.name)
                for mk in c.methods:
                    m: utils_listener_fast.Method = c.methods[mk]
                    local_vars_of_superclass_type_or_others = []
                    for item in m.body_local_vars_and_expr_names:
                        if isinstance(item, utils_listener_fast.LocalVariable):
                            if (item.datatype == self.superclass_name and has_imported_superclass) \
                                    or item.datatype == (self.package_name + '.' + self.superclass_name):
                                local_vars_of_superclass_type_or_others.append(
                                    item.identifier)
                            if any((c.file_info.has_imported_class(
                                    o.package_name, o.name) and item.datatype
                                    == o.name) or item.datatype == (
                                        o.package_name + '.' + o.name)
                                   for o in other_derived_classes):
                                local_vars_of_superclass_type_or_others.append(
                                    item.identifier)
                        elif isinstance(item,
                                        utils_listener_fast.ExpressionName):
                            if item.dot_separated_identifiers[-1] == self.field_name \
                                    and (
                                    (len(item.dot_separated_identifiers) == 2)
                                    or (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[
                                0] == "this")
                            ) and (
                                    (item.dot_separated_identifiers[
                                         -2] in local_vars_of_superclass_type_or_others and len(
                                        item.dot_separated_identifiers) == 2)
                                    or item.dot_separated_identifiers[-2] in fields_of_superclass_type_or_others
                            ):
                                return False

        rewriter = utils2.Rewriter(program, self.filename_mapping)

        field = superclass.fields[self.field_name]
        if len(field.neighbor_names) == 0:
            rewriter.replace(field.get_tokens_info(), "")
            # Have to remove the modifiers too, because of the new grammar.
            for mod_ctx in field.modifiers_parser_contexts:
                rewriter.replace(utils_listener_fast.TokensInfo(mod_ctx), "")
        else:
            i = field.index_in_variable_declarators
            var_ctxs = field.all_variable_declarator_contexts
            if i == 0:
                to_remove = utils_listener_fast.TokensInfo(var_ctxs[i])
                to_remove.stop = utils_listener_fast.TokensInfo(
                    var_ctxs[i + 1]).start - 1  # Include the ',' after it
                rewriter.replace(to_remove, "")
            else:
                to_remove = utils_listener_fast.TokensInfo(var_ctxs[i])
                to_remove.start = utils_listener_fast.TokensInfo(
                    var_ctxs[i - 1]).stop + 1  # Include the ',' before it
                rewriter.replace(to_remove, "")

        is_public = "public" in field.modifiers
        is_protected = "protected" in field.modifiers
        modifier = ("public " if is_public else
                    ("protected " if is_protected else ""))
        for c in classes_to_add_to:
            c_body_start = utils_listener_fast.TokensInfo(
                c.parser_context.classBody())
            c_body_start.stop = c_body_start.start  # Start and stop both point to the '{'
            rewriter.insert_after(c_body_start, "\n    " + modifier + field.datatype + " " + self.field_name \
                                  + ((" = " + field.initializer) if field.initializer is not None else "")
                                  + ";")

        rewriter.apply()
        return True
예제 #4
0
    def do_refactor(self):
        program = get_program(
            self.source_filenames,
            print_status=True)  # getting the program packages
        superclass: utils_listener_fast.Class = program.packages[
            self.package_name].classes[self.superclass_name]
        if not self.pre_propagation_check(program, superclass):
            return False

        other_derived_classes = []
        classes_to_add_to = []
        for p in program.packages:
            package: utils_listener_fast.Package = program.packages[p]
            for cn in package.classes:
                c: utils_listener_fast.Class = package.classes[cn]
                if ((c.superclass_name == self.superclass_name
                     and c.file_info.has_imported_class(
                         self.package_name, self.superclass_name)) or
                    (self.package_name is not None and c.superclass_name
                     == self.package_name + '.' + self.superclass_name)):

                    # enable functionality of class name
                    if len(self.class_names) == 0 or cn in self.class_names:
                        # check for methods with same name (post-condition 1)
                        if self.method_key in c.methods:
                            print("some classes have same method")
                            return False
                        else:
                            for m in c.methods:
                                method: utils_listener_fast.Method = c.methods[
                                    m]
                                is_used = False
                                for item in method.body_local_vars_and_expr_names:
                                    if isinstance(item, utils_listener_fast.MethodInvocation) and \
                                            item.dot_separated_identifiers[0] + '()' == self.method_key:
                                        is_used = True
                                if is_used:
                                    classes_to_add_to.append(c)
                    else:
                        other_derived_classes.append(c)

        rewriter = utils2.Rewriter(program, self.filename_mapping)

        method = superclass.methods[self.method_key]

        is_public = "public" in method.modifiers
        is_protected = "protected" in method.modifiers
        is_private = "private" in method.modifiers

        for c in classes_to_add_to:
            c_body_start = utils_listener_fast.TokensInfo(
                c.parser_context.classBody())
            c_body_start.stop = c_body_start.start
            rewriter.insert_after(
                c_body_start, "\n\n    %s %s %s() {\n       %s\n    }\n" %
                (" ".join(method.modifiers), method.returntype, method.name,
                 method.body_text[1:-1]))

        method_token_info = utils_listener_fast.TokensInfo(
            method.parser_context)
        rewriter.replace(method_token_info, "")

        rewriter.apply()
        return True