Example #1
0
    def do_refactor(self):
        program = symbol_table.get_program(self.source_filenames, print_status=True)
        if self.package_name not in program.packages or any(
                class_name not in program.packages[self.package_name].classes for class_name in self.class_names) or \
                any(method_key not in program.packages[self.package_name].classes[class_name].methods
                    for class_name in self.class_names for method_key in self.method_keys):
            return False

        method_returntypes = {}
        method_parameters = {}
        method_names = []
        for method_key in self.method_keys:
            method_names.append(method_key[:method_key.find('(')])

        rewriter = symbol_table.Rewriter(program, self.filename_mapping)

        for class_name in self.class_names:
            c: symbol_table.Class = program.packages[self.package_name].classes[class_name]
            # Add implements to the class
            has_superinterface = False
            if c.parser_context.IMPLEMENTS() is not None:  # old: c.parser_context.superinterfaces()
                t = symbol_table.TokensInfo(
                    c.parser_context.typeList())  # old: c.parser_context.superinterfaces()
                has_superinterface = True
            elif c.parser_context.EXTENDS() is not None:  # old: c.parser_context.superclass()
                t = symbol_table.TokensInfo(c.parser_context.typeType())  # old: c.parser_context.superclass()
            elif c.parser_context.typeParameters() is not None:
                t = symbol_table.TokensInfo(c.parser_context.typeParameters())
            else:
                # old: TokensInfo(c.parser_context.identifier())
                t = symbol_table.TokensInfo(c.parser_context)
                t.stop = c.parser_context.IDENTIFIER().getSymbol().tokenIndex
            rewriter.insert_after(t, (", " if has_superinterface else " implements ") + self.interface_name)
            for method_key in self.method_keys:
                m: symbol_table.Method = c.methods[method_key]
                # Check if the return types / parameter types are the same
                # Or add to dictionary
                if method_key in method_returntypes:
                    if method_returntypes[method_key] != m.returntype:
                        return False
                    if len(method_parameters[method_key]) != len(m.parameters):
                        return False
                    for i in range(len(m.parameters)):
                        if method_parameters[method_key][i][0] != m.parameters[i][0]:
                            return False
                else:
                    method_returntypes[method_key] = m.returntype
                    method_parameters[method_key] = m.parameters
                # Manage method modifiers
                if len(m.modifiers_parser_contexts) > 0:
                    t = symbol_table.TokensInfo(m.modifiers_parser_contexts[0])
                else:
                    t = m.get_tokens_info()
                rewriter.insert_before_start(
                    t,  # old: m.get_tokens_info() # without requiring t
                    ("" if "@Override" in m.modifiers else "@Override\n    ")
                    + ("" if "public" in m.modifiers else "public ")
                )
                for i in range(len(m.modifiers)):
                    mm = m.modifiers[i]
                    if mm == "private" or mm == "protected":
                        t = symbol_table.TokensInfo(
                            m.modifiers_parser_contexts[i])  # old: m.parser_context.methodModifier(i)
                        rewriter.replace(t, "")

        # Change variable types to the interface if only interface methods are used.
        for package_name in program.packages:
            p: symbol_table.Package = program.packages[package_name]
            for class_name in p.classes:
                c: symbol_table.Class = p.classes[class_name]
                fields_of_interest = {}
                for fn in c.fields:
                    f: symbol_table.Field = c.fields[fn]
                    d = False
                    for cn in self.class_names:
                        if (f.datatype == cn and f.file_info.has_imported_class(package_name, cn)) \
                                or (package_name is not None and f.datatype == package_name + '.' + cn):
                            d = True
                            break
                    if d and "private" in f.modifiers:
                        fields_of_interest[f.name] = f
                for method_key in c.methods:
                    m: symbol_table.Method = c.methods[method_key]
                    vars_of_interest = {}
                    for item in m.body_local_vars_and_expr_names:
                        if isinstance(item, symbol_table.LocalVariable):
                            for cn in self.class_names:
                                if (item.datatype == cn and c.file_info.has_imported_class(package_name, cn)) \
                                        or (package_name is not None and item.datatype == package_name + '.' + cn):
                                    vars_of_interest[item.identifier] = item
                                    break
                        if isinstance(item, symbol_table.MethodInvocation):
                            if len(item.dot_separated_identifiers) == 2 or \
                                    (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[
                                        0] == "this"):
                                if item.dot_separated_identifiers[-2] in vars_of_interest:
                                    if item.dot_separated_identifiers[-1] not in method_names:
                                        vars_of_interest.pop(item.dot_separated_identifiers[-2])
                                elif item.dot_separated_identifiers[-2] in fields_of_interest \
                                        and item.dot_separated_identifiers[-1] not in method_names:
                                    fields_of_interest.pop(item.dot_separated_identifiers[-2])
                    for var_name in vars_of_interest:
                        var = vars_of_interest[var_name]
                        if m.file_info.has_imported_package(package_name):
                            # old: var.parser_context.unannType()
                            rewriter.replace(symbol_table.TokensInfo(var.parser_context.typeType()),
                                             self.interface_name)
                        else:
                            if package_name is None:
                                break
                            # old: var.parser_context.unannType()
                            rewriter.replace(symbol_table.TokensInfo(var.parser_context.typeType()),
                                             package_name + '.' + self.interface_name)
                for field_name in fields_of_interest:
                    f = fields_of_interest[field_name]
                    if c.file_info.has_imported_package(package_name):
                        typename = self.interface_name
                    else:
                        if package_name is None:
                            break
                        typename = package_name + '.' + self.interface_name
                    if len(f.neighbor_names) == 0:
                        rewriter.replace(symbol_table.TokensInfo(f.parser_context.typeType()),
                                         typename)  # old: f.parser_context.unannType()
                    else:
                        if not any(nn in fields_of_interest for nn in f.neighbor_names):
                            t = symbol_table.TokensInfo(
                                f.all_variable_declarator_contexts[f.index_in_variable_declarators])
                            if f.index_in_variable_declarators == 0:
                                t.stop = symbol_table.TokensInfo(
                                    f.all_variable_declarator_contexts[f.index_in_variable_declarators + 1]).start - 1
                            else:
                                t.start = symbol_table.TokensInfo(
                                    f.all_variable_declarator_contexts[f.index_in_variable_declarators - 1]).start + 1
                            rewriter.replace(t, "")
                            rewriter.insert_after(
                                f.get_tokens_info(),
                                "\n    private " + typename + " " + f.name + (
                                    " = " + f.initializer + ";" if f.initializer is not None else ";")
                            )

        # Create the interface
        interface_file_content = (
                "package " + package_name + ";\n\n"
                + "public interface " + self.interface_name + "\n"
                + "{\n"
        )
        for method_key in self.method_keys:
            method_name = method_key[:method_key.find('(')]
            interface_file_content += "    " + method_returntypes[method_key] + " " + method_name + "("
            if len(method_parameters[method_key]) > 0:
                interface_file_content += method_parameters[method_key][0][0] + " " + method_parameters[method_key][0][
                    1]
            for i in range(1, len(method_parameters[method_key])):
                param = method_parameters[method_key][i]
                interface_file_content += ", " + param[0] + " " + param[1]
            interface_file_content += ");\n"
        interface_file_content += "}\n"

        if not os.path.exists(self.interface_filename[:self.interface_filename.rfind('/')]):
            os.makedirs(self.interface_filename[:self.interface_filename.rfind('/')])
        file = open(self.interface_filename, "w+", encoding='utf8', errors='ignore')
        file.write(interface_file_content)
        file.close()

        rewriter.apply()
        return True
Example #2
0
    def do_refactor(self):
        program = symbol_table.get_program(self.source_filenames,
                                           print_status=False)
        superclass: symbol_table.Class = program.packages[
            self.package_name].classes[self.superclass_name]
        if not self.pre_condition_check(program, superclass):
            print(f"Cannot push-down field from {superclass.name}")
            return False

        other_derived_classes = []
        classes_to_add_to = []
        for pn in program.packages:
            p: symbol_table.Package = program.packages[pn]
            for cn in p.classes:
                c: symbol_table.Class = p.classes[cn]
                if ((c.superclass_name == self.superclass_name
                     and c.file_info.has_imported_class(
                         self.package_name, self.superclass_name)) or
                    (self.package_name is not None and c.superclass_name
                     == self.package_name + '.' + self.superclass_name)):
                    # all_derived_classes.append(c)
                    if len(self.class_names) == 0 or cn in self.class_names:
                        if self.field_name in c.fields:
                            print("some classes have same variable")
                            return False
                        else:
                            classes_to_add_to.append(c)
                    else:
                        other_derived_classes.append(c)

        # Check if the field is used from the superclass or other derived classes
        for pn in program.packages:
            p: symbol_table.Package = program.packages[pn]
            for cn in p.classes:
                c: symbol_table.Class = p.classes[cn]
                has_imported_superclass = c.file_info.has_imported_class(
                    self.package_name, self.superclass_name)
                fields_of_superclass_type_or_others = []
                for fn in c.fields:
                    f: symbol_table.Field = c.fields[fn]
                    if (f.name == self.field_name and has_imported_superclass) \
                            or (self.package_name is not None and f.name == (
                            self.package_name + '.' + self.superclass_name)):
                        fields_of_superclass_type_or_others.append(f.name)
                    if any((c.file_info.has_imported_class(
                            o.package_name, o.name) and f.datatype == o.name)
                           or f.datatype == (o.package_name + '.' + o.name)
                           for o in other_derived_classes):
                        fields_of_superclass_type_or_others.append(f.name)
                for mk in c.methods:
                    m: symbol_table.Method = c.methods[mk]
                    local_vars_of_superclass_type_or_others = []
                    for item in m.body_local_vars_and_expr_names:
                        if isinstance(item, symbol_table.LocalVariable):
                            if (item.datatype == self.superclass_name and has_imported_superclass) \
                                    or item.datatype == (self.package_name + '.' + self.superclass_name):
                                local_vars_of_superclass_type_or_others.append(
                                    item.identifier)
                            if any((c.file_info.has_imported_class(
                                    o.package_name, o.name) and item.datatype
                                    == o.name) or item.datatype == (
                                        o.package_name + '.' + o.name)
                                   for o in other_derived_classes):
                                local_vars_of_superclass_type_or_others.append(
                                    item.identifier)
                        elif isinstance(item, symbol_table.ExpressionName):
                            if item.dot_separated_identifiers[-1] == self.field_name \
                                    and (
                                    (len(item.dot_separated_identifiers) == 2)
                                    or (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[
                                0] == "this")
                            ) and (
                                    (item.dot_separated_identifiers[
                                         -2] in local_vars_of_superclass_type_or_others and len(
                                        item.dot_separated_identifiers) == 2)
                                    or item.dot_separated_identifiers[-2] in fields_of_superclass_type_or_others
                            ):
                                return False

        rewriter = symbol_table.Rewriter(program, self.filename_mapping)

        field = superclass.fields[self.field_name]
        if len(field.neighbor_names) == 0:
            rewriter.replace(field.get_tokens_info(), "")
            # Have to remove the modifiers too, because of the new grammar.
            for mod_ctx in field.modifiers_parser_contexts:
                rewriter.replace(symbol_table.TokensInfo(mod_ctx), "")
        else:
            i = field.index_in_variable_declarators
            var_ctxs = field.all_variable_declarator_contexts
            if i == 0:
                to_remove = symbol_table.TokensInfo(var_ctxs[i])
                to_remove.stop = symbol_table.TokensInfo(
                    var_ctxs[i + 1]).start - 1  # Include the ',' after it
                rewriter.replace(to_remove, "")
            else:
                to_remove = symbol_table.TokensInfo(var_ctxs[i])
                to_remove.start = symbol_table.TokensInfo(
                    var_ctxs[i - 1]).stop + 1  # Include the ',' before it
                rewriter.replace(to_remove, "")

        is_public = "public" in field.modifiers
        is_protected = "protected" in field.modifiers
        modifier = ("public " if is_public else
                    ("protected " if is_protected else ""))
        for c in classes_to_add_to:
            c_body_start = symbol_table.TokensInfo(
                c.parser_context.classBody())
            c_body_start.stop = c_body_start.start  # Start and stop both point to the '{'
            rewriter.insert_after(
                c_body_start,
                ("\n    " + modifier + field.datatype + " " + self.field_name +
                 ((" = " + field.initializer)
                  if field.initializer is not None else "") + ";"))

        rewriter.apply()
        return True
Example #3
0
    def do_refactor(self):
        program = symbol_table.get_program(self.source_filenames, print_status=False)
        # print(program.packages)
        if (
                self.package_name not in program.packages
                or self.class_name not in program.packages[self.package_name].classes
                or self.field_name not in program.packages[self.package_name].classes[self.class_name].fields
        ):
            logger.error("One or more inputs are not valid.")
            return False

        _class: symbol_table.Class = program.packages[self.package_name].classes[self.class_name]
        if _class.superclass_name is None:
            logger.error("Super class is none.")
            return False

        superclass_name = _class.superclass_name
        if not program.packages[self.package_name].classes.get(superclass_name):
            logger.error("Super class package is none!")
            return False

        superclass: symbol_table.Class = program.packages[self.package_name].classes[superclass_name]
        superclass_body_start = symbol_table.TokensInfo(superclass.parser_context.classBody())
        superclass_body_start.stop = superclass_body_start.start  # Start and stop both point to the '{'

        if self.field_name in superclass.fields:
            logger.error("Field is in superclass fields.")
            return False

        datatype = _class.fields[self.field_name].datatype

        fields_to_remove = []
        for pn in program.packages:
            p: symbol_table.Package = program.packages[pn]
            for cn in p.classes:
                c: symbol_table.Class = p.classes[cn]
                if (
                        (
                        (
                                c.superclass_name == superclass_name
                                and c.file_info.has_imported_class(self.package_name, superclass_name)
                        )
                        or (self.package_name is not None and c.superclass_name == superclass_name)
                        )
                        and
                        self.field_name in c.fields and c.fields[self.field_name].datatype == datatype
                ):
                    fields_to_remove.append(c.fields[self.field_name])

        if len(fields_to_remove) == 0:
            logger.error("No fields to remove.")
            return False

        is_public = False
        is_protected = True
        for field in fields_to_remove:
            field: symbol_table.Field = field
            is_public = is_public or "public" in field.modifiers
            is_protected = is_protected and ("protected" in field.modifiers or "private" in field.modifiers)

        rewriter = symbol_table.Rewriter(program, self.filename_mapping)

        rewriter.insert_after(superclass_body_start, "\n\t" + (
            "public " if is_public else (
                "protected " if is_protected else "")) + datatype + " " + self.field_name + ";")

        for field in fields_to_remove:
            if len(field.neighbor_names) == 0:
                rewriter.replace(field.get_tokens_info(), "")
                # Have to remove the modifiers too, because of the new grammar.
                for mod_ctx in field.modifiers_parser_contexts:
                    rewriter.replace(symbol_table.TokensInfo(mod_ctx), "")
            else:
                i = field.index_in_variable_declarators
                var_ctxs = field.all_variable_declarator_contexts
                if i == 0:
                    to_remove = symbol_table.TokensInfo(var_ctxs[i])
                    to_remove.stop = symbol_table.TokensInfo(
                        var_ctxs[i + 1]).start - 1  # Include the ',' after it
                    rewriter.replace(to_remove, "")
                else:
                    to_remove = symbol_table.TokensInfo(var_ctxs[i])
                    to_remove.start = symbol_table.TokensInfo(
                        var_ctxs[i - 1]).stop + 1  # Include the ',' before it
                    rewriter.replace(to_remove, "")

            # Add initializer to class constructor if initializer exists in field declaration
            if field.initializer is not None:
                _class: symbol_table.Class = program.packages[field.package_name].classes[field.class_name]
                initializer_statement = (field.name
                                         + " = "
                                         + ("new " + field.datatype + " " if field.initializer.startswith('{') else "")
                                         + field.initializer
                                         + ";")

                # Todo: Requires better handling
                if 'new' in initializer_statement and '()' in initializer_statement:
                    initializer_statement = initializer_statement.replace('new', 'new ')

                has_contructor = False
                for class_body_decl in _class.parser_context.classBody().getChildren():
                    if class_body_decl.getText() in ['{', '}']:
                        continue
                    member_decl = class_body_decl.memberDeclaration()
                    if member_decl is not None:
                        constructor = member_decl.constructorDeclaration()
                        if constructor is not None:
                            body = constructor.constructorBody  # Start token = '{'
                            body_start = symbol_table.TokensInfo(body)
                            body_start.stop = body_start.start  # Start and stop both point to the '{'
                            rewriter.insert_after(body_start, "\n\t" + initializer_statement)
                            has_contructor = True
                if not has_contructor:
                    body = _class.parser_context.classBody()
                    body_start = symbol_table.TokensInfo(body)
                    body_start.stop = body_start.start  # Start and stop both point to the '{'
                    rewriter.insert_after(body_start,
                                          "\n\t" + _class.modifiers[
                                              0] + " " + _class.name + "() { " + initializer_statement + " }"
                                          )

        rewriter.apply()

        # Todo: check for multilevel inheritance recursively.
        # if _class.superclass_name is not None:
        #     PullUpFieldRefactoring(self.source_filenames, self.package_name, _class.superclass_name, "id").do_refactor()
        return True