示例#1
0
class MethodUsageListener(UtilsListener):
    def __init__(self, filename: str, methods: str, target_class: str):
        super(MethodUsageListener, self).__init__(filename)
        self.methods = methods
        self.method_names = set(map(lambda m: m.name, methods))
        self.rewriter = None
        self.target_class = target_class

    def enterCompilationUnit(self, ctx: JavaParser.CompilationUnitContext):
        super().enterCompilationUnit(ctx)
        self.rewriter = TokenStreamRewriter(ctx.parser.getTokenStream())

    def enterClassCreatorRest(self, ctx: JavaParser.ClassCreatorRestContext):
        if type(ctx.parentCtx) is JavaParser.CreatorContext:
            if ctx.parentCtx.createdName().IDENTIFIER()[0].getText(
            ) not in self.method_names:
                return
        text = f"new {self.target_class}()" if ctx.arguments().expressionList(
        ) is None else f", new {self.target_class}()"
        index = ctx.arguments().RPAREN().symbol.tokenIndex
        self.rewriter.insertBeforeIndex(index, text)

    def exitMethodCall(self, ctx: JavaParser.MethodCallContext):
        super().exitMethodCall(ctx)
        if ctx.THIS() is not None:
            return
        if ctx.IDENTIFIER().getText() in self.method_names:
            text = f"new {self.target_class}()" if ctx.expressionList(
            ) is None else f", new {self.target_class}()"
            self.rewriter.insertBeforeIndex(ctx.RPAREN().symbol.tokenIndex,
                                            text)

    def exitClassBody(self, ctx: JavaParser.ClassBodyContext):
        super().exitClassBody(ctx)
        save(self.rewriter, self.filename)
示例#2
0
    def testInsertBeforeIndexZero(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)
        rewriter.insertBeforeIndex(0, '0')

        self.assertEqual(rewriter.getDefaultText(), '0abc')
    def testInsertBeforeIndexZero(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)
        rewriter.insertBeforeIndex(0, '0')

        self.assertEquals(rewriter.getDefaultText(), '0abc')
    def testDropPrevCoveredInsert(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'foo')
        rewriter.replaceRange(1, 2, 'foo')

        self.assertEqual('afoofoo', rewriter.getDefaultText())
示例#5
0
    def testCombineInserts(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, 'x')
        rewriter.insertBeforeIndex(0, 'y')

        self.assertEqual('yxabc', rewriter.getDefaultText())
    def testReplaceThenInsertBeforeLastIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceIndex(2, 'x')
        rewriter.insertBeforeIndex(2, 'y')

        self.assertEqual('abyx', rewriter.getDefaultText())
    def testCombineInsertOnLeftWithDelete(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.delete('default', 0, 2)
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEquals('z', rewriter.getDefaultText())
    def testLeaveAloneDisjointInsert2(self):
        input = InputStream('abcc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 3, 'foo')
        rewriter.insertBeforeIndex(1, 'x')

        self.assertEquals('axbfoo', rewriter.getDefaultText())
示例#9
0
    def testReplaceThenInsertBeforeLastIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceIndex(2, 'x')
        rewriter.insertBeforeIndex(2, 'y')

        self.assertEquals('abyx', rewriter.getDefaultText())
    def testInsertBeforeTokenThenDeleteThatToken(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'foo')
        rewriter.replaceRange(1, 2, 'foo')

        self.assertEquals('afoofoo', rewriter.getDefaultText())
示例#11
0
    def testReplaceRangeThenInsertAtLeftEdge(self):
        input = InputStream('abcccba')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 4, 'x')
        rewriter.insertBeforeIndex(2, 'y')

        self.assertEqual('abyxba', rewriter.getDefaultText())
    def testInsertThenReplaceSameIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, '0')
        rewriter.replaceIndex(0, 'x')

        self.assertEquals('0xbc', rewriter.getDefaultText())
示例#13
0
    def test2InsertBeforeAfterMiddleIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'x')
        rewriter.insertAfter(1, 'x')

        self.assertEqual(rewriter.getDefaultText(), 'axbxc')
示例#14
0
    def testCombineInsertOnLeftWithDelete(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.delete('default', 0, 2)
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEqual('z', rewriter.getDefaultText())
    def testReplaceRangeThenInsertAtLeftEdge(self):
        input = InputStream('abcccba')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 4, 'x')
        rewriter.insertBeforeIndex(2, 'y')

        self.assertEquals('abyxba', rewriter.getDefaultText())
示例#16
0
    def testInsertBeforeTokenThenDeleteThatToken(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'foo')
        rewriter.replaceRange(1, 2, 'foo')

        self.assertEqual('afoofoo', rewriter.getDefaultText())
    def testCombineInserts(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, 'x')
        rewriter.insertBeforeIndex(0, 'y')

        self.assertEquals('yxabc', rewriter.getDefaultText())
示例#18
0
    def testLeaveAloneDisjointInsert2(self):
        input = InputStream('abcc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 3, 'foo')
        rewriter.insertBeforeIndex(1, 'x')

        self.assertEqual('axbfoo', rewriter.getDefaultText())
示例#19
0
    def testDropPrevCoveredInsert(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'foo')
        rewriter.replaceRange(1, 2, 'foo')

        self.assertEquals('afoofoo', rewriter.getDefaultText())
示例#20
0
    def testCombineInsertOnLeftWithReplace(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(0, 2, 'foo')
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEqual('zfoo', rewriter.getDefaultText())
    def test2InsertBeforeAfterMiddleIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'x')
        rewriter.insertAfter(1, 'x')

        self.assertEquals(rewriter.getDefaultText(), 'axbxc')
示例#22
0
    def testInsertThenReplaceSameIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, '0')
        rewriter.replaceIndex(0, 'x')

        self.assertEqual('0xbc', rewriter.getDefaultText())
    def testCombineInsertOnLeftWithReplace(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(0, 2, 'foo')
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEquals('zfoo', rewriter.getDefaultText())
示例#24
0
    def testDisjointInserts(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'x')
        rewriter.insertBeforeIndex(2, 'y')
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEquals('zaxbyc', rewriter.getDefaultText())
    def test2ReplaceMiddleIndex1InsertBefore(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, "_")
        rewriter.replaceIndex(1, 'x')
        rewriter.replaceIndex(1, 'y')

        self.assertEquals('_ayc', rewriter.getDefaultText())
    def testDisjointInserts(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(1, 'x')
        rewriter.insertBeforeIndex(2, 'y')
        rewriter.insertBeforeIndex(0, 'z')

        self.assertEqual('zaxbyc', rewriter.getDefaultText())
示例#27
0
    def test2ReplaceMiddleIndex1InsertBefore(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, "_")
        rewriter.replaceIndex(1, 'x')
        rewriter.replaceIndex(1, 'y')

        self.assertEqual('_ayc', rewriter.getDefaultText())
    def testReplaceThenDeleteMiddleIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(0, 2, 'x')
        rewriter.insertBeforeIndex(1, '0')

        with self.assertRaises(ValueError) as ctx:
            rewriter.getDefaultText()
        self.assertEquals(
            'insert op <InsertBeforeOp@[@1,1:1=\'b\',<2>,1:1]:"0"> within boundaries of previous <ReplaceOp@[@0,0:0=\'a\',<1>,1:0]..[@2,2:2=\'c\',<3>,1:2]:"x">',
            ctx.exception.message)
示例#29
0
    def testPreservesOrderOfContiguousInserts(self):
        """
        Test for fix for: https://github.com/antlr/antlr4/issues/550
        """
        input = InputStream('aa')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, '<b>')
        rewriter.insertAfter(0, '</b>')
        rewriter.insertBeforeIndex(1, '<b>')
        rewriter.insertAfter(1, '</b>')

        self.assertEqual('<b>a</b><b>a</b>', rewriter.getDefaultText())
示例#30
0
    def testReplaceThenDeleteMiddleIndex(self):
        input = InputStream('abc')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(0, 2, 'x')
        rewriter.insertBeforeIndex(1, '0')

        with self.assertRaises(ValueError) as ctx:
            rewriter.getDefaultText()
        self.assertEqual(
            'insert op <InsertBeforeOp@[@1,1:1=\'b\',<2>,1:1]:"0"> within boundaries of previous <ReplaceOp@[@0,0:0=\'a\',<1>,1:0]..[@2,2:2=\'c\',<3>,1:2]:"x">',
            str(ctx.exception)
        )
    def testPreservesOrderOfContiguousInserts(self):
        """
        Test for fix for: https://github.com/antlr/antlr4/issues/550
        """
        input = InputStream('aa')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.insertBeforeIndex(0, '<b>')
        rewriter.insertAfter(0, '</b>')
        rewriter.insertBeforeIndex(1, '<b>')
        rewriter.insertAfter(1, '</b>')

        self.assertEquals('<b>a</b><b>a</b>', rewriter.getDefaultText())
    def testReplaceRangeThenInsertAtRightEdge(self):
        input = InputStream('abcccba')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 4, 'x')
        rewriter.insertBeforeIndex(4, 'y')

        with self.assertRaises(ValueError) as ctx:
            rewriter.getDefaultText()
        msg = ctx.exception.message
        self.assertEquals(
            "insert op <InsertBeforeOp@[@4,4:4='c',<3>,1:4]:\"y\"> within boundaries of previous <ReplaceOp@[@2,2:2='c',<3>,1:2]..[@4,4:4='c',<3>,1:4]:\"x\">",
            msg)
示例#33
0
    def testReplaceRangeThenInsertAtRightEdge(self):
        input = InputStream('abcccba')
        lexer = TestLexer(input)
        stream = CommonTokenStream(lexer=lexer)
        stream.fill()
        rewriter = TokenStreamRewriter(tokens=stream)

        rewriter.replaceRange(2, 4, 'x')
        rewriter.insertBeforeIndex(4, 'y')

        with self.assertRaises(ValueError) as ctx:
            rewriter.getDefaultText()
        msg = str(ctx.exception)
        self.assertEqual(
            "insert op <InsertBeforeOp@[@4,4:4='c',<3>,1:4]:\"y\"> within boundaries of previous <ReplaceOp@[@2,2:2='c',<3>,1:2]..[@4,4:4='c',<3>,1:4]:\"x\">",
            msg
        )
示例#34
0
class FieldUsageListener(UtilsListener):
    """
    FieldUsageListener finds all the usage of
    an specified field f, from a class c in
    package pkg.
    """
    def __init__(self, filename: str, source_class: str, source_package: str,
                 target_class: str, target_package: str, field_name: str,
                 field_candidates: set, field_tobe_moved: Field):
        super(FieldUsageListener, self).__init__(filename)
        self.source_class = source_class
        self.source_package = source_package
        self.target_class = target_class
        self.target_package = target_package
        self.field_name = field_name
        self.has_imported_source = False
        self.has_imported_target = False
        self.usages = []
        # current class name is the public class in each file.
        self.current_class_name = ""
        self.field_candidates = field_candidates
        self.rewriter = None
        # this represents the text to be added in target i.e. public int a;
        self.field_tobe_moved = field_tobe_moved
        self.methods_tobe_updated = []

    def enterCompilationUnit(self, ctx: JavaParser.CompilationUnitContext):
        super().enterCompilationUnit(ctx)
        self.rewriter = TokenStreamRewriter(ctx.parser.getTokenStream())

    def enterClassDeclaration(self, ctx: JavaParser.ClassDeclarationContext):
        super().enterClassDeclaration(ctx)

        if ctx.parentCtx.classOrInterfaceModifier()[0].getText() == "public":
            self.current_class_name = ctx.IDENTIFIER().getText()
        else:
            return

        self.has_imported_source = self.file_info.has_imported_package(self.package.name) or \
                                   self.file_info.has_imported_class(self.package.name, self.source_class)

        # import target if we're not in Target and have not imported before
        if self.current_class_name != self.target_class:
            self.rewriter.insertBeforeIndex(
                ctx.parentCtx.start.tokenIndex,
                f"import {self.target_package}.{self.target_class};\n")

    def enterClassBody(self, ctx: JavaParser.ClassBodyContext):
        super().exitClassBody(ctx)
        if self.current_class_name == self.target_class:
            replacement_text = ""
            if self.field_tobe_moved.name == self.field_name:
                for mod in self.field_tobe_moved.modifiers:
                    replacement_text += f"{mod} "
                replacement_text += f"{self.field_tobe_moved.datatype} {self.field_tobe_moved.name};"
            self.rewriter.insertAfter(ctx.start.tokenIndex,
                                      f"\n\t{replacement_text}\n")

            # add getter and setter
            name = self.field_tobe_moved.name
            method_name = self.field_tobe_moved.name.upper(
            ) + self.field_tobe_moved.name[1:-1]
            type = self.field_tobe_moved.datatype

            getter = f"\tpublic {type} get{method_name}() {{ return this.{name}; }}\n"
            setter = f"\tpublic void set{method_name}({type} {name}) {{ this.{name} = {name}; }}\n"
            self.rewriter.insertBeforeIndex(ctx.stop.tokenIndex, getter)
            self.rewriter.insertBeforeIndex(ctx.stop.tokenIndex, setter)

    def exitFieldDeclaration(self, ctx: JavaParser.FieldDeclarationContext):
        super().exitFieldDeclaration(ctx)
        if self.current_class_name != self.source_class:
            return

        if self.field_tobe_moved is None:
            field = self.package.classes[self.current_class_name].fields[
                ctx.variableDeclarators().children[0].children[0].IDENTIFIER(
                ).getText()]
            if field.name == self.field_name:
                self.field_tobe_moved = field

    def exitClassBody(self, ctx: JavaParser.ClassBodyContext):
        super().exitClassBody(ctx)
        save(self.rewriter, self.filename)

    def exitMethodDeclaration(self, ctx: JavaParser.MethodDeclarationContext):
        super().exitMethodDeclaration(ctx)
        # we will remove getter and setter from source
        # and add it to target so there is no need to
        # find usages there

        if self.current_class_name == self.source_class and \
                self.is_method_getter_or_setter(ctx.IDENTIFIER().getText()):
            self.rewriter.replaceRange(
                ctx.parentCtx.parentCtx.start.tokenIndex,
                ctx.parentCtx.parentCtx.stop.tokenIndex, "")

    def exitConstructorDeclaration(
            self, ctx: JavaParser.ConstructorDeclarationContext):
        self.current_method.name = ctx.IDENTIFIER().getText()
        self.current_method.returntype = self.current_method.class_name
        self.handleMethodUsage(ctx, True)
        super().exitConstructorDeclaration(ctx)

    def exitMethodBody(self, ctx: JavaParser.MethodBodyContext):
        super().exitMethodBody(ctx)
        self.handleMethodUsage(ctx, False)

    def handleMethodUsage(self, ctx, is_constructor: bool):
        method_identifier = ctx.IDENTIFIER().getText(
        ) if is_constructor else ctx.parentCtx.IDENTIFIER().getText()
        formal_params = ctx.formalParameters(
        ) if is_constructor else ctx.parentCtx.formalParameters()
        target_added = False
        target_param_name = "$$target"
        target_param = f"Target {target_param_name}" if \
            len(self.current_method.parameters) == 0 \
            else f", Target {target_param_name}"

        # if we have not imported source package or
        # Source class just ignore this
        if not self.has_imported_source:
            return

        local_candidates = set()
        if self.current_class_name == self.source_class:
            # we will remove getter and setter from source
            # and add it to target so there is no need to
            # find usages there
            if self.is_method_getter_or_setter(method_identifier):
                self.rewriter.replaceRange(ctx.start.tokenIndex,
                                           ctx.stop.tokenIndex, "")
                return
            local_candidates.add("this")

        # find parameters with type Source
        for t, identifier in self.current_method.parameters:
            if t == self.source_class:
                local_candidates.add(identifier)

        # find all local variables with type Source
        for var_or_exprs in self.current_method.body_local_vars_and_expr_names:
            if type(var_or_exprs) is LocalVariable:
                if var_or_exprs.datatype == self.source_class:
                    local_candidates.add(var_or_exprs.identifier)

        should_ignore = False

        for var_or_exprs in self.current_method.body_local_vars_and_expr_names:
            if type(var_or_exprs) is ExpressionName:
                # we're going to find source.field
                try:
                    local_ctx = var_or_exprs.parser_context.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx
                    creator = local_ctx.expression()[0].getText()
                    if creator.__contains__(
                            f"new{self.source_class}"
                    ) and local_ctx.IDENTIFIER().getText() == self.field_name:
                        self.propagate_field(local_ctx, target_param_name)

                except:
                    pass

                if len(var_or_exprs.dot_separated_identifiers) < 2:
                    continue
                if (var_or_exprs.dot_separated_identifiers[0] in local_candidates or
                    var_or_exprs.dot_separated_identifiers[0] in self.field_candidates) and \
                        var_or_exprs.dot_separated_identifiers[1] == self.field_name:
                    if not target_added:
                        # add target to param
                        self.rewriter.insertBeforeIndex(
                            formal_params.stop.tokenIndex, target_param)
                        self.methods_tobe_updated.append(self.current_method)
                        target_added = True

                    self.usages.append(var_or_exprs.parser_context)
                    self.propagate_field(var_or_exprs.parser_context,
                                         target_param_name)

            elif type(var_or_exprs) is MethodInvocation:
                # we are going to find getter or setters
                # if len(var_or_exprs.dot_separated_identifiers) < 2:
                #     continue
                if var_or_exprs.dot_separated_identifiers[
                        0] == f"new{self.source_class}":
                    if var_or_exprs.parser_context.methodCall() is not None and \
                            self.is_method_getter_or_setter(
                                var_or_exprs.parser_context.methodCall().IDENTIFIER().getText()):
                        self.propagate_getter_setter(
                            var_or_exprs.parser_context, target_param_name)
                elif self.is_method_getter_or_setter(
                        var_or_exprs.dot_separated_identifiers[0]):
                    if not target_added:
                        # add target to param
                        self.rewriter.insertBeforeIndex(
                            formal_params.stop.tokenIndex, target_param)
                        self.methods_tobe_updated.append(self.current_method)
                        target_added = True
                    if not should_ignore and var_or_exprs.parser_context is not None and type(
                            var_or_exprs.parser_context
                    ) is not JavaParser.ExpressionContext:
                        continue
                    self.usages.append(var_or_exprs.parser_context)
                    self.propagate_getter_setter_form2(
                        var_or_exprs.parser_context, target_param_name)
                elif len(var_or_exprs.dot_separated_identifiers
                         ) > 1 and self.is_getter_or_setter(
                             var_or_exprs.dot_separated_identifiers[0],
                             var_or_exprs.dot_separated_identifiers[1],
                             local_candidates):
                    if not target_added:
                        # add target to param
                        self.rewriter.insertBeforeIndex(
                            formal_params.stop.tokenIndex, target_param)
                        self.methods_tobe_updated.append(self.current_method)
                        target_added = True

                    self.usages.append(var_or_exprs.parser_context)
                    self.propagate_getter_setter(var_or_exprs.parser_context,
                                                 target_param_name)

    def is_getter_or_setter(self, first_id: str, second_id: str,
                            local_candidates: set):
        return (
            first_id in local_candidates or first_id in self.field_candidates
        ) and (second_id
               == f"set{self.field_name[0].upper() + self.field_name[1:-1]}"
               or second_id
               == f"get{self.field_name[0].upper() + self.field_name[1:-1]}"
               or second_id
               == f"has{self.field_name[0].upper() + self.field_name[1:-1]}"
               or second_id
               == f"is{self.field_name[0].upper() + self.field_name[1:-1]}")

    def is_method_getter_or_setter(self, method: str):
        return (method
                == f"set{self.field_name[0].upper() + self.field_name[1:-1]}"
                or method
                == f"get{self.field_name[0].upper() + self.field_name[1:-1]}"
                or method
                == f"has{self.field_name[0].upper() + self.field_name[1:-1]}"
                or method
                == f"is{self.field_name[0].upper() + self.field_name[1:-1]}")

    def propagate_getter_setter(self, ctx: JavaParser.ExpressionContext,
                                target_name: str):
        index = ctx.DOT().symbol.tokenIndex
        self.rewriter.replaceRange(ctx.start.tokenIndex, index - 1,
                                   target_name)

    def propagate_getter_setter_form2(self, ctx: JavaParser.ExpressionContext,
                                      target_name: str):
        """
        form 2 is getA() setA()...
        """
        self.rewriter.insertBeforeIndex(ctx.start.tokenIndex,
                                        f"{target_name}.")

    def propagate_field(self, ctx: JavaParser.ExpressionContext,
                        target_name: str):
        index = ctx.DOT().symbol.tokenIndex
        self.rewriter.replaceRange(ctx.start.tokenIndex, index - 1,
                                   target_name)
示例#35
0
class EncapsulateFiledRefactoringListener(JavaParserLabeledListener):
    """

    To implement encapsulate field refactoring.

    Makes a public field private and provide accessors and mutator methods.

    """
    def __init__(self,
                 common_token_stream: CommonTokenStream = None,
                 package_name: str = None,
                 source_class_name: str = None,
                 field_identifier: str = None):
        """

        Args:

            common_token_stream (CommonTokenStream): contains the program tokens

            package_name (str): The enclosing package of the field

            source_class_name (str): The enclosing class of the field

            field_identifier (str): The field name to be encapsulated

        Returns:

            object (DecreaseMethodVisibilityListener): An instance of EncapsulateFiledRefactoringListener

        """

        self.token_stream = common_token_stream
        if package_name is None:
            self.package_name = ''
        else:
            self.package_name = package_name
        self.source_class_name = source_class_name
        self.field_identifier = field_identifier
        self.getter_exist = False
        self.setter_exist = False
        self.in_source_class = False
        self.in_selected_package = True if self.package_name == '' else False

        # Move all the tokens in the source code in a buffer, token_stream_rewriter.
        if common_token_stream is not None:
            self.token_stream_rewriter = \
                TokenStreamRewriter(common_token_stream)
        else:
            raise TypeError('common_token_stream is None')

    def enterPackageDeclaration(
            self, ctx: JavaParserLabeled.PackageDeclarationContext):
        if self.package_name == ctx.qualifiedName().getText():
            self.in_selected_package = True
        else:
            self.in_selected_package = False

    def enterClassDeclaration(self,
                              ctx: JavaParserLabeled.ClassDeclarationContext):
        if ctx.IDENTIFIER().getText() == self.source_class_name:
            self.in_source_class = True

    def exitClassDeclaration(self,
                             ctx: JavaParserLabeled.ClassDeclarationContext):
        self.in_source_class = False

    def exitFieldDeclaration(self,
                             ctx: JavaParserLabeled.FieldDeclarationContext):
        if self.in_source_class and self.in_selected_package:
            if ctx.variableDeclarators().variableDeclarator(
                    0).variableDeclaratorId().getText(
                    ) == self.field_identifier:
                if not ctx.parentCtx.parentCtx.modifier(0):
                    self.token_stream_rewriter.insertBeforeIndex(
                        index=ctx.typeType().stop.tokenIndex, text='private ')

                elif ctx.parentCtx.parentCtx.modifier(0).getText() == 'public':
                    self.token_stream_rewriter.replaceRange(
                        from_idx=ctx.parentCtx.parentCtx.modifier(
                            0).start.tokenIndex,
                        to_idx=ctx.parentCtx.parentCtx.modifier(
                            0).stop.tokenIndex,
                        text='private')
                else:
                    return

                for c in ctx.parentCtx.parentCtx.parentCtx.classBodyDeclaration(
                ):
                    try:
                        print('method name: ' + c.memberDeclaration().
                              methodDeclaration().IDENTIFIER().getText())

                        if c.memberDeclaration().methodDeclaration().IDENTIFIER() \
                                .getText() == 'get' + str.capitalize(
                            self.field_identifier):
                            self.getter_exist = True

                        if c.memberDeclaration().methodDeclaration().IDENTIFIER() \
                                .getText() == 'set' + str.capitalize(
                            self.field_identifier):
                            self.setter_exist = True

                    except:
                        logger.error("not method !!!")

                logger.debug("setter find: " + str(self.setter_exist))
                logger.debug("getter find: " + str(self.getter_exist))

                # generate accessor and mutator methods
                # Accessor body
                new_code = ''
                if not self.getter_exist:
                    new_code = '\n\t// new getter method\n\t'
                    new_code += 'public ' + ctx.typeType().getText() + \
                                ' get' + str.capitalize(self.field_identifier)
                    new_code += '() { \n\t\treturn this.' + self.field_identifier \
                                + ';' + '\n\t}\n'

                # Mutator body
                if not self.setter_exist:
                    new_code += '\n\t// new setter method\n\t'
                    new_code += 'public void set' + str.capitalize(
                        self.field_identifier)
                    new_code += '(' + ctx.typeType().getText() + ' ' \
                                + self.field_identifier + ') { \n\t\t'
                    new_code += 'this.' + self.field_identifier + ' = ' \
                                + self.field_identifier + ';' + '\n\t}\n'
                self.token_stream_rewriter.insertAfter(ctx.stop.tokenIndex,
                                                       new_code)

                hidden = self.token_stream.getHiddenTokensToRight(
                    ctx.stop.tokenIndex)
                # self.token_stream_rewriter.replaceRange(from_idx=hidden[0].tokenIndex,
                #                                         to_idx=hidden[-1].tokenIndex,
                #                                         text='\n\t/*End of accessor and mutator methods!*/\n\n')

    def exitExpression21(self, ctx: JavaParserLabeled.Expression21Context):
        if self.in_source_class and self.in_selected_package:
            if ctx.expression(0).getText() == self.field_identifier or \
                    ctx.expression(0).getText() == 'this.' + self.field_identifier:
                expr_code = self.token_stream_rewriter.getText(
                    program_name=self.token_stream_rewriter.
                    DEFAULT_PROGRAM_NAME,
                    start=ctx.expression(1).start.tokenIndex,
                    stop=ctx.expression(1).stop.tokenIndex)
                new_code = 'this.set' + str.capitalize(
                    self.field_identifier) + '(' + expr_code + ')'
                self.token_stream_rewriter.replaceRange(
                    ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code)

    def exitExpression0(self, ctx: JavaParserLabeled.Expression0Context):
        if self.in_source_class and self.in_selected_package:
            try:
                if ctx.parentCtx.getChild(1).getText() in ('=', '+=', '-=',
                                                           '*=', '/=', '&=',
                                                           '|=', '^=', '>>=',
                                                           '>>>=', '<<=', '%=') and \
                        ctx.parentCtx.getChild(0) == ctx:
                    return
            except:
                pass
            if ctx.getText() == self.field_identifier:
                new_code = 'this.get' + str.capitalize(
                    self.field_identifier) + '()'
                self.token_stream_rewriter.replaceRange(
                    ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code)

    def exitExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        if self.in_source_class and self.in_selected_package:
            try:
                if ctx.parentCtx.getChild(1).getText() in ('=', '+=', '-=',
                                                           '*=', '/=', '&=',
                                                           '|=', '^=', '>>=',
                                                           '>>>=', '<<=', '%=') and \
                        ctx.parentCtx.getChild(0) == ctx:
                    return
            except:
                pass
            if ctx.getText() == 'this.' + self.field_identifier:
                new_code = 'this.get' + str.capitalize(
                    self.field_identifier) + '()'
                self.token_stream_rewriter.replaceRange(
                    ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code)

    def exitCompilationUnit(self,
                            ctx: JavaParserLabeled.CompilationUnitContext):
        try:
            hidden = self.token_stream.getHiddenTokensToLeft(
                ctx.start.tokenIndex)
            self.token_stream_rewriter.replaceRange(
                from_idx=hidden[0].tokenIndex,
                to_idx=hidden[-1].tokenIndex,
                text='/*After refactoring (Refactored version)*/\n')
        except:
            pass
示例#36
0
    def do_refactor(self):
        program = get_program(self.source_filenames)
        static = 0

        if self.class_name not in program.packages[self.package_name].classes or self.target_class_name not in \
                program.packages[
                    self.target_package_name].classes or self.method_key not in \
                program.packages[self.package_name].classes[
                    self.class_name].methods:
            return False

        _sourceclass = program.packages[self.package_name].classes[
            self.class_name]
        _targetclass = program.packages[self.target_package_name].classes[
            self.target_class_name]
        _method = program.packages[self.package_name].classes[
            self.class_name].methods[self.method_key]

        if _method.is_constructor:
            return False

        Rewriter_ = Rewriter(program, lambda x: x)
        tokens_info = TokensInfo(
            _method.parser_context)  # tokens of ctx method
        param_tokens_info = TokensInfo(_method.formalparam_context)
        method_declaration_info = TokensInfo(
            _method.method_declaration_context)
        exp = [
        ]  # برای نگه داری متغیرهایی که داخل کلاس تعریف شدند و در بدنه متد استفاده شدند
        exps = tokens_info.get_token_index(tokens_info.token_stream.tokens,
                                           tokens_info.start, tokens_info.stop)

        # check that method is static or not
        for modifier in _method.modifiers:
            if modifier == "static":
                static = 1

        for token in exps:
            if token.text in _sourceclass.fields:
                exp.append(token.tokenIndex)
            # check that where this method is call
        for package_names in program.packages:
            package = program.packages[package_names]
            for class_ in package.classes:
                _class = package.classes[class_]
                for method_ in _class.methods:
                    __method = _class.methods[method_]
                    for inv in __method.body_method_invocations:
                        invc = __method.body_method_invocations[inv]
                        method_name = self.method_key[:self.method_key.find('('
                                                                            )]
                        if invc[0] == method_name:
                            inv_tokens_info = TokensInfo(inv)
                            if static == 0:
                                class_token_info = TokensInfo(
                                    _class.body_context)
                                Rewriter_.insert_after_start(
                                    class_token_info,
                                    self.target_class_name + " " +
                                    str.lower(self.target_class_name) + "=" +
                                    "new " + self.target_class_name + "();")
                                Rewriter_.apply()
                            Rewriter_.insert_before_start(
                                class_token_info,
                                "import " + self.target_package_name + "." +
                                self.target_class_name + ";")
                            Rewriter_.replace(inv_tokens_info,
                                              self.target_class_name)
                        Rewriter_.apply()

        class_tokens_info = TokensInfo(_targetclass.parser_context)
        package_tokens_info = TokensInfo(
            program.packages[self.target_package_name].package_ctx)
        singlefileelement = SingleFileElement(_method.parser_context,
                                              _method.filename)
        token_stream_rewriter = TokenStreamRewriter(
            singlefileelement.get_token_stream())

        # insert name of source.java class befor param that define in body of classe (that use in method)
        for index in exp:
            token_stream_rewriter.insertBeforeIndex(
                index=index, text=str.lower(self.class_name) + ".")

        for inv in _method.body_method_invocations:
            if inv.getText() == self.target_class_name:
                inv_tokens_info_target = TokensInfo(inv)
                token_stream_rewriter.replaceRange(
                    from_idx=inv_tokens_info_target.start,
                    to_idx=inv_tokens_info_target.stop + 1,
                    text=" ")

            # insert source.java class befor methods of sourcr class that used in method
        for i in _method.body_method_invocations_without_typename:
            if i.getText() == self.class_name:
                ii = _method.body_method_invocations_without_typename[i]
                i_tokens = TokensInfo(ii[0])
                token_stream_rewriter.insertBeforeIndex(
                    index=i_tokens.start,
                    text=str.lower(self.class_name) + ".")

        # pass object of source.java class to method
        if param_tokens_info.start is not None:
            token_stream_rewriter.insertBeforeIndex(
                param_tokens_info.start,
                text=self.class_name + " " + str.lower(self.class_name) + ",")
        else:
            token_stream_rewriter.insertBeforeIndex(
                method_declaration_info.stop,
                text=self.class_name + " " + str.lower(self.class_name))

        strofmethod = token_stream_rewriter.getText(
            program_name=token_stream_rewriter.DEFAULT_PROGRAM_NAME,
            start=tokens_info.start,
            stop=tokens_info.stop)

        Rewriter_.insert_before(tokens_info=class_tokens_info,
                                text=strofmethod)
        Rewriter_.insert_after(
            package_tokens_info, "import " + self.target_package_name + "." +
            self.target_class_name + ";")
        Rewriter_.replace(tokens_info, "")
        Rewriter_.apply()
        return True
示例#37
0
class ReplaceConstructorWithFactoryFunctionRefactoringListener(JavaParserLabeledListener):

    def __init__(self, common_token_stream: CommonTokenStream = None,
                 target_class: str = None):

        if common_token_stream is None:
            raise ValueError('common_token_stream is None')
        else:
            self.codeRewrite = TokenStreamRewriter(common_token_stream)

        if target_class is None:
            raise ValueError("source_class is None")
        else:
            self.target_class = target_class

        self.is_target_class = False
        self.have_constructor = False
        self.new_factory_function = False
        self.new_parameters = []
        self.new_parameters_names = []

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        # self.target_class = ctx.IDENTIFIER().getText()
        # have_constructor = False
        # if ctx.IDENTIFIER().getText() ==
        class_identifier = ctx.IDENTIFIER().getText()
        if class_identifier == self.target_class:
            self.is_target_class = True
            # print("class name " + ctx.IDENTIFIER().getText())
        else:
            self.is_target_class = False

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        if self.is_target_class:
            self.is_target_class = False

    def enterConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        if self.is_target_class:
            # print("constructor name " + ctx.IDENTIFIER().getText())
            # parameters = ctx.formalParameters().getText()
            # print(len(ctx.formalParameters().formalParameterList().formalParameter()))
            grandParentCtx = ctx.parentCtx.parentCtx
            if ctx.IDENTIFIER().getText() == self.target_class:
                self.have_constructor = True
                # do refactor
                """
                Declare the constructor private.
                """
                if grandParentCtx.modifier():
                    if 'public' == grandParentCtx.modifier(0).getText():
                        self.codeRewrite.replaceRange(
                            from_idx=grandParentCtx.modifier(0).start.tokenIndex,
                            to_idx=grandParentCtx.modifier(0).stop.tokenIndex,
                            text='private')
                else:
                    self.codeRewrite.insertBeforeIndex(
                        index=ctx.start.tokenIndex,
                        text="private "
                    )

    def exitConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        """
        Create a factory method. Make its body a call to the current constructor.
        """
        if self.is_target_class:
            grandParentCtx = ctx.parentCtx.parentCtx
            self.codeRewrite.insertAfter(
                index=grandParentCtx.stop.tokenIndex,
                text="\n    public static " + ctx.IDENTIFIER().getText() + " Create( " + ", ".join(
                    self.new_parameters) +
                     "){\n       return new " + ctx.IDENTIFIER().getText() + "(" + ", ".join(
                    self.new_parameters_names) + ");\n}"
            )
        self.new_parameters = []
        self.new_parameters_names = []

    def enterFormalParameterList0(self, ctx: JavaParserLabeled.FormalParameterList0Context):
        # print(len(ctx.formalParameter()))
        pass

    def exitFormalParameterList0(self, ctx: JavaParserLabeled.FormalParameterList0Context):
        pass

    def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext):
        # print(ctx.typeType().getText())
        # print(ctx.variableDeclaratorId().getText())
        constructorName = ctx.parentCtx.parentCtx.parentCtx.IDENTIFIER().getText()
        if self.target_class == constructorName:
            text = ctx.typeType().getText() + " " + ctx.variableDeclaratorId().getText()
            self.new_parameters.append(text)
            self.new_parameters_names.append(ctx.variableDeclaratorId().getText())

    def exitFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext):
        pass

    def enterExpression4(self, ctx: JavaParserLabeled.Expression4Context):
        """
            Replace all constructor calls with calls to the factory method.
        """
        # currentMethodOrClassCtx=ctx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx
        # print(ctx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.getText())
        if ctx.creator().createdName().getText() == self.target_class:
            self.codeRewrite.replaceRange(
                from_idx=ctx.start.tokenIndex,
                to_idx=ctx.stop.tokenIndex,
                text=self.target_class + "." + "Create" + ctx.creator().classCreatorRest().getText())
示例#38
0
class MakeFieldStaticRefactoringListener(JavaParserLabeledListener):
    """
    To implement the encapsulate filed refactored
    Encapsulate field: Make a public field private and provide accessors
    """
    def __init__(self,
                 common_token_stream: CommonTokenStream = None,
                 field_identifier: str = None,
                 class_identifier: str = None,
                 package_identifier: str = None):
        """
        :param common_token_stream:
        """
        self.token_stream = common_token_stream
        self.field_identifier = field_identifier
        self.class_identifier = class_identifier
        self.package_identifier = package_identifier
        self.declared_objects_names = []

        self.is_package_imported = False
        self.in_selected_package = False
        self.in_selected_class = False
        self.in_some_package = False

        # Move all the tokens in the source code in a buffer, token_stream_rewriter.
        if common_token_stream is not None:
            self.token_stream_rewriter = TokenStreamRewriter(
                common_token_stream)
        else:
            raise TypeError('common_token_stream is None')

    def enterPackageDeclaration(
            self, ctx: JavaParserLabeled.PackageDeclarationContext):
        self.in_some_package = True
        if self.package_identifier is not None:
            if self.package_identifier == ctx.qualifiedName().getText():
                self.in_selected_package = True
                print("Package Found")

    def enterClassDeclaration(self,
                              ctx: JavaParserLabeled.ClassDeclarationContext):

        if self.package_identifier is None and not self.in_some_package or\
                self.package_identifier is not None and self.in_selected_package:
            if ctx.IDENTIFIER().getText() == self.class_identifier:
                print("Class Found")
                self.in_selected_class = True

    def enterImportDeclaration(
            self, ctx: JavaParserLabeled.ImportDeclarationContext):
        if self.package_identifier is not None:
            if ctx.getText() == "import" + self.package_identifier + "." + self.class_identifier + ";" \
                    or ctx.getText() == "import" + self.package_identifier + ".*" + ";" \
                    or ctx.getText() == "import" + self.package_identifier + ";":
                self.is_package_imported = True

    def exitFieldDeclaration(self,
                             ctx: JavaParserLabeled.FieldDeclarationContext):
        if self.package_identifier is None and not self.in_some_package\
                or self.package_identifier is not None and self.in_selected_package:
            if self.in_selected_class:
                if ctx.variableDeclarators().variableDeclarator(0)\
                        .variableDeclaratorId().getText() == self.field_identifier:

                    grand_parent_ctx = ctx.parentCtx.parentCtx
                    if len(grand_parent_ctx.modifier()) == 0:
                        self.token_stream_rewriter.insertBeforeIndex(
                            index=ctx.parentCtx.start.tokenIndex,
                            text=' static ')
                    else:
                        is_static = False
                        for modifier in grand_parent_ctx.modifier():
                            if modifier.getText() == "static":
                                is_static = True
                                break
                        if not is_static:
                            self.token_stream_rewriter.insertAfter(
                                index=grand_parent_ctx.start.tokenIndex +
                                len(grand_parent_ctx.modifier()),
                                text=' static ')

        if self.package_identifier is None or self.package_identifier is not None and self.is_package_imported:
            if ctx.typeType().classOrInterfaceType() is not None:
                if ctx.typeType().classOrInterfaceType().getText(
                ) == self.class_identifier:
                    self.declared_objects_names.append(
                        ctx.variableDeclarators().variableDeclarator(
                            0).variableDeclaratorId().getText())
                    print("Object " + ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().getText()\
                          + " of type " + self.class_identifier + " found.")

    def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        if self.is_package_imported or self.package_identifier is None or self.in_selected_package:
            for object_name in self.declared_objects_names:
                if ctx.getText() == object_name + "." + self.field_identifier:
                    self.token_stream_rewriter.replaceIndex(
                        index=ctx.start.tokenIndex, text=self.class_identifier)
示例#39
0
class MakeMethodNonStaticRefactoringListener(JavaParserLabeledListener):
    """

    To implement Make Method None-Static refactoring based on its actors.

    """
    def __init__(self,
                 common_token_stream: CommonTokenStream = None,
                 target_class: str = None,
                 target_methods: list = None):
        """


        """

        if common_token_stream is None:
            raise ValueError('common_token_stream is None')
        else:
            self.token_stream_rewriter = TokenStreamRewriter(
                common_token_stream)

        if target_class is None:
            raise ValueError("source_class is None")
        else:
            self.target_class = target_class
        if target_methods is None or len(target_methods) == 0:
            raise ValueError("target method must have one method name")
        else:
            self.target_methods = target_methods

        self.target_class_data = None
        self.is_target_class = False
        self.detected_field = None
        self.detected_method = None
        self.TAB = "\t"
        self.NEW_LINE = "\n"
        self.code = ""

    def enterClassDeclaration(self,
                              ctx: JavaParserLabeled.ClassDeclarationContext):
        class_identifier = ctx.IDENTIFIER().getText()
        if class_identifier == self.target_class:
            self.is_target_class = True
            self.target_class_data = {'constructors': []}
        else:
            self.is_target_class = False

    def exitClassDeclaration(self,
                             ctx: JavaParserLabeled.ClassDeclarationContext):
        if self.is_target_class:
            have_default_constructor = False
            for constructor in self.target_class_data['constructor']:
                if len(constructor.parameters) == 0:
                    have_default_constructor = True
                    break
            if not have_default_constructor:
                self.token_stream_rewriter.insertBeforeIndex(
                    index=ctx.stop.tokenIndex - 1,
                    text=
                    f'\n\t public {self.target_class_data["constructors"][0]} ()\n\t{{}}\n'
                )
            self.is_target_class = False

    def enterMethodDeclaration(
            self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if self.is_target_class:
            if ctx.IDENTIFIER().getText() in self.target_methods:
                grand_parent_ctx = ctx.parentCtx.parentCtx
                if grand_parent_ctx.modifier():
                    if len(grand_parent_ctx.modifier()) == 2:
                        self.token_stream_rewriter.delete(
                            program_name=self.token_stream_rewriter.
                            DEFAULT_PROGRAM_NAME,
                            from_idx=grand_parent_ctx.modifier(
                                1).start.tokenIndex - 1,
                            to_idx=grand_parent_ctx.modifier(
                                1).stop.tokenIndex)
                    else:
                        if grand_parent_ctx.modifier(0).getText() == 'static':
                            self.token_stream_rewriter.delete(
                                program_name=self.token_stream_rewriter.
                                DEFAULT_PROGRAM_NAME,
                                from_idx=grand_parent_ctx.modifier(
                                    0).start.tokenIndex - 1,
                                to_idx=grand_parent_ctx.modifier(
                                    0).stop.tokenIndex)
                else:
                    return None

    def enterConstructorDeclaration(
            self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        if self.is_target_class:
            if ctx.formalParameters().formalParameterList():
                constructor_parameters = [
                    ctx.formalParameters().formalParameterList().children[i]
                    for i in range(
                        len(ctx.formalParameters().formalParameterList().
                            children)) if i % 2 == 0
                ]
            else:
                constructor_parameters = []
            constructor_text = ''
            for modifier in ctx.parentCtx.parentCtx.modifier():
                constructor_text += modifier.getText() + ' '
                constructor_text += ctx.IDENTIFIER().getText()
            constructor_text += ' ( '
            for parameter in constructor_parameters:
                constructor_text += parameter.typeType().getText() + ' '
                constructor_text += parameter.variableDeclaratorId().getText(
                ) + ', '
            if constructor_parameters:
                constructor_text = constructor_text[:len(constructor_text) - 2]
            constructor_text += ')\n\t{'
            constructor_text += self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=ctx.block().start.tokenIndex + 1,
                stop=ctx.block().stop.tokenIndex - 1)
            constructor_text += '}\n'
            self.target_class_data['constructors'].append(
                ConstructorOrMethod(
                    name=self.target_class,
                    parameters=[
                        Parameter(parameterType=p.typeType().getText(),
                                  name=p.variableDeclaratorId().IDENTIFIER().
                                  getText()) for p in constructor_parameters
                    ],
                    text=constructor_text))
示例#40
0
class MakeMethodStaticRefactoringListener(JavaParserLabeledListener):
    """
    To implement extract class refactoring based on its actors.
    Creates a new class and move fields and methods from the old class to the new one
    """
    def __init__(self,
                 common_token_stream: CommonTokenStream = None,
                 target_class: str = None,
                 target_methods: list = None):

        if common_token_stream is None:
            raise ValueError('common_token_stream is None')
        else:
            self.token_stream_rewriter = TokenStreamRewriter(
                common_token_stream)

        if target_class is None:
            raise ValueError("source_class is None")
        else:
            self.target_class = target_class
        if target_methods is None or len(target_methods) == 0:
            raise ValueError("target method must have one method name")
        else:
            self.target_methods = target_methods

        self.is_target_class = False
        self.detected_instance_of_target_class = []
        self.TAB = "\t"
        self.NEW_LINE = "\n"
        self.code = ""

    def enterClassDeclaration(self,
                              ctx: JavaParserLabeled.ClassDeclarationContext):
        class_identifier = ctx.IDENTIFIER().getText()
        if class_identifier == self.target_class:
            self.is_target_class = True
        else:
            self.is_target_class = False

    def exitClassDeclaration(self,
                             ctx: JavaParserLabeled.ClassDeclarationContext):
        if self.is_target_class:
            self.is_target_class = False

    def enterMethodDeclaration(
            self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if self.is_target_class:
            if ctx.IDENTIFIER().getText() in self.target_methods:
                if 'this.' in ctx.getText():
                    raise ValueError("this method can not refactor")
                grand_parent_ctx = ctx.parentCtx.parentCtx
                if grand_parent_ctx.modifier():
                    if len(grand_parent_ctx.modifier()) == 2:
                        return None
                    else:
                        self.token_stream_rewriter.insertAfter(
                            index=grand_parent_ctx.modifier(0).stop.tokenIndex,
                            program_name=self.token_stream_rewriter.
                            DEFAULT_PROGRAM_NAME,
                            text=" static")
                else:
                    self.token_stream_rewriter.insertBeforeIndex(
                        index=ctx.start.tokenIndex, text="static ")

    def enterLocalVariableDeclaration(
            self, ctx: JavaParserLabeled.LocalVariableDeclarationContext):
        if ctx.typeType().getText() == self.target_class:
            self.detected_instance_of_target_class.append(
                ctx.variableDeclarators().variableDeclarator(
                    0).variableDeclaratorId().IDENTIFIER().getText())
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.start.tokenIndex,
                to_idx=ctx.stop.tokenIndex + 1)

    def enterMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context):
        if ctx.IDENTIFIER().getText() in self.target_methods:
            if ctx.parentCtx.expression().getText(
            ) in self.detected_instance_of_target_class:
                self.token_stream_rewriter.replace(
                    program_name=self.token_stream_rewriter.
                    DEFAULT_PROGRAM_NAME,
                    from_idx=ctx.parentCtx.expression().start.tokenIndex,
                    to_idx=ctx.parentCtx.expression().stop.tokenIndex,
                    text=self.target_class)
示例#41
0
class InlineClassRefactoringListener(JavaParserLabeledListener):
    """

    To implement inline class refactoring based on its actors.

    Creates a new class and move fields and methods from two old class to the new one, then delete the two class

    """

    def __init__(
            self, common_token_stream: CommonTokenStream = None,
            source_class: str = None, source_class_data: dict = None,
            target_class: str = None, target_class_data: dict = None, is_complete: bool = False):
        """


        """

        if common_token_stream is None:
            raise ValueError('common_token_stream is None')
        else:
            self.token_stream_rewriter = TokenStreamRewriter(common_token_stream)

        if source_class is None:
            raise ValueError("source_class is None")
        else:
            self.source_class = source_class
        if target_class is None:
            raise ValueError("new_class is None")
        else:
            self.target_class = target_class
        if target_class:
            self.target_class = target_class
        if source_class_data:
            self.source_class_data = source_class_data
        else:
            self.source_class_data = {'fields': [], 'methods': [], 'constructors': []}
        if target_class_data:
            self.target_class_data = target_class_data
        else:
            self.target_class_data = {'fields': [], 'methods': [], 'constructors': []}

        self.field_that_has_source = []
        self.has_source_new = False

        self.is_complete = is_complete
        self.is_target_class = False
        self.is_source_class = False
        self.detected_field = None
        self.detected_method = None
        self.TAB = "\t"
        self.NEW_LINE = "\n"
        self.code = ""

    def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        class_identifier = ctx.IDENTIFIER().getText()
        if class_identifier == self.source_class:
            self.is_source_class = True
            self.is_target_class = False
        elif class_identifier == self.target_class:
            self.is_target_class = True
            self.is_source_class = False
        else:
            self.is_target_class = False
            self.is_source_class = False

    def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext):
        if self.is_target_class and (self.source_class_data['fields'] or
                                     self.source_class_data['constructors'] or
                                     self.source_class_data['methods']):
            if not self.is_complete:
                final_fields = merge_fields(self.source_class_data['fields'], self.target_class_data['fields'],
                                            self.target_class)
                final_constructors = merge_constructors(self.source_class_data['constructors'],
                                                        self.target_class_data['constructors'])
                final_methods = merge_methods(self.source_class_data['methods'], self.target_class_data['methods'])
                text = '\t'
                for field in final_fields:
                    text += field.text + '\n'
                for constructor in final_constructors:
                    text += constructor.text + '\n'
                for method in final_methods:
                    text += method.text + '\n'
                self.token_stream_rewriter.insertBeforeIndex(
                    index=ctx.stop.tokenIndex,
                    text=text
                )
                self.is_complete = True
            else:
                self.is_target_class = False
        elif self.is_source_class:
            if ctx.parentCtx.classOrInterfaceModifier(0) is None:
                return
            self.is_source_class = False
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.parentCtx.classOrInterfaceModifier(0).start.tokenIndex,
                to_idx=ctx.stop.tokenIndex
            )

    def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext):
        if self.is_source_class:
            self.code += self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=ctx.start.tokenIndex + 1,
                stop=ctx.stop.tokenIndex - 1
            )
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.parentCtx.start.tokenIndex,
                to_idx=ctx.parentCtx.stop.tokenIndex
            )
        else:
            return None

    def enterFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        if self.is_source_class or self.is_target_class:
            field_text = ''
            for child in ctx.children:
                if child.getText() == ';':
                    field_text = field_text[:len(field_text) - 1] + ';'
                    break
                field_text += child.getText() + ' '

            name = ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().IDENTIFIER().getText()

            if ctx.typeType().classOrInterfaceType() is not None and \
                    ctx.typeType().classOrInterfaceType().getText() == self.source_class:
                self.field_that_has_source.append(name)
                return

            modifier_text = ''
            for modifier in ctx.parentCtx.parentCtx.modifier():
                modifier_text += modifier.getText() + ' '
            field_text = modifier_text + field_text
            if self.is_source_class:
                self.source_class_data['fields'].append(Field(name=name, text=field_text))
            else:
                self.target_class_data['fields'].append(Field(name=name, text=field_text))

    def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext):
        if self.is_target_class:
            if ctx.typeType().classOrInterfaceType().getText() == self.source_class:
                grand_parent_ctx = ctx.parentCtx.parentCtx
                self.token_stream_rewriter.delete(
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                    from_idx=grand_parent_ctx.start.tokenIndex,
                    to_idx=grand_parent_ctx.stop.tokenIndex)

    def enterConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext):
        if self.is_source_class or self.is_target_class:
            if ctx.formalParameters().formalParameterList():
                constructor_parameters = [ctx.formalParameters().formalParameterList().children[i] for i in
                                          range(len(ctx.formalParameters().formalParameterList().children)) if
                                          i % 2 == 0]
            else:
                constructor_parameters = []
            constructor_text = ''
            for modifier in ctx.parentCtx.parentCtx.modifier():
                constructor_text += modifier.getText() + ' '

            if self.is_source_class:
                constructor_text += self.target_class
            else:
                constructor_text += ctx.IDENTIFIER().getText()
            constructor_text += ' ( '
            for parameter in constructor_parameters:
                constructor_text += parameter.typeType().getText() + ' '
                constructor_text += parameter.variableDeclaratorId().getText() + ', '
            if constructor_parameters:
                constructor_text = constructor_text[:len(constructor_text) - 2]
            constructor_text += ')\n\t{'
            constructor_text += self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=ctx.block().start.tokenIndex + 1,
                stop=ctx.block().stop.tokenIndex - 1
            )
            constructor_text += '}\n'
            if self.is_source_class:
                self.source_class_data['constructors'].append(ConstructorOrMethod(
                    name=self.target_class, parameters=[Parameter(parameter_type=p.typeType().getText(),
                                                                  name=p.variableDeclaratorId().IDENTIFIER().getText())
                                                        for p in constructor_parameters],
                    text=constructor_text, constructor_body=self.token_stream_rewriter.getText(
                        program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                        start=ctx.block().start.tokenIndex + 1,
                        stop=ctx.block().stop.tokenIndex - 1
                    )))
            else:
                self.target_class_data['constructors'].append(ConstructorOrMethod(
                    name=self.target_class, parameters=[Parameter(parameter_type=p.typeType().getText(),
                                                                  name=p.variableDeclaratorId().IDENTIFIER().getText())
                                                        for p in constructor_parameters],
                    text=constructor_text, constructor_body=self.token_stream_rewriter.getText(
                        program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                        start=ctx.block().start.tokenIndex + 1,
                        stop=ctx.block().stop.tokenIndex - 1
                    )))
                proper_constructor = get_proper_constructor(self.target_class_data['constructors'][-1],
                                                            self.source_class_data['constructors'])

                if proper_constructor is None:
                    return

                self.token_stream_rewriter.insertBeforeIndex(
                    index=ctx.stop.tokenIndex,
                    text=proper_constructor.constructorBody
                )

    def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext):
        if self.is_source_class or self.is_target_class:
            if ctx.formalParameters().formalParameterList():
                method_parameters = [ctx.formalParameters().formalParameterList().children[i] for i in
                                     range(len(ctx.formalParameters().formalParameterList().children)) if i % 2 == 0]
            else:
                method_parameters = []
            method_text = ''
            for modifier in ctx.parentCtx.parentCtx.modifier():
                method_text += modifier.getText() + ' '

            type_text = ctx.typeTypeOrVoid().getText()

            if type_text == self.source_class:
                type_text = self.target_class

                if self.is_target_class:
                    self.token_stream_rewriter.replace(
                        program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                        from_idx=ctx.typeTypeOrVoid().start.tokenIndex,
                        to_idx=ctx.typeTypeOrVoid().stop.tokenIndex,
                        text=type_text
                    )
            method_text += type_text + ' ' + ctx.IDENTIFIER().getText()
            method_text += ' ( '
            for parameter in method_parameters:
                method_text += parameter.typeType().getText() + ' '
                method_text += parameter.variableDeclaratorId().getText() + ', '
            if method_parameters:
                method_text = method_text[:len(method_text) - 2]
            method_text += ')\n\t{'
            method_text += self.token_stream_rewriter.getText(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                start=ctx.methodBody().start.tokenIndex + 1,
                stop=ctx.methodBody().stop.tokenIndex - 1
            )
            method_text += '}\n'
            if self.is_source_class:
                self.source_class_data['methods'].append(ConstructorOrMethod(
                    name=ctx.IDENTIFIER().getText(),
                    parameters=[Parameter(
                        parameter_type=p.typeType().getText(),
                        name=p.variableDeclaratorId().IDENTIFIER().getText())
                        for p in
                        method_parameters],
                    text=method_text))
            else:
                self.target_class_data['methods'].append(ConstructorOrMethod(
                    name=ctx.IDENTIFIER().getText(),
                    parameters=[Parameter(
                        parameter_type=p.typeType().getText(),
                        name=p.variableDeclaratorId().IDENTIFIER().getText())
                        for p in
                        method_parameters],
                    text=method_text))

    def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context):
        if ctx.IDENTIFIER() is None and ctx.IDENTIFIER().getText() in self.field_that_has_source:
            field_text = ctx.expression().getText()
            self.token_stream_rewriter.replace(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.start.tokenIndex,
                to_idx=ctx.stop.tokenIndex,
                text=field_text
            )

    def exitExpression21(self, ctx: JavaParserLabeled.Expression21Context):
        if self.has_source_new:
            self.has_source_new = False
            self.token_stream_rewriter.delete(
                program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                from_idx=ctx.start.tokenIndex,
                to_idx=ctx.stop.tokenIndex + 1
            )

    def enterExpression4(self, ctx: JavaParserLabeled.Expression4Context):
        if ctx.children[-1].children[0].getText() == self.source_class:
            self.has_source_new = True

    def enterCreatedName0(self, ctx: JavaParserLabeled.CreatedName0Context):
        if ctx.IDENTIFIER(0).getText() == self.source_class and self.target_class:
            self.token_stream_rewriter.replaceIndex(
                index=ctx.start.tokenIndex,
                text=self.target_class
            )

    def enterCreatedName1(self, ctx: JavaParserLabeled.CreatedName1Context):
        if ctx.getText() == self.source_class and self.target_class:
            self.token_stream_rewriter.replaceIndex(
                index=ctx.start.tokenIndex,
                text=self.target_class
            )

    def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext):
        class_type = ctx.typeType().classOrInterfaceType()
        if class_type:
            if class_type.IDENTIFIER(0).getText() == self.source_class and self.target_class:
                self.token_stream_rewriter.replaceIndex(
                    index=class_type.start.tokenIndex,
                    text=self.target_class
                )

    def enterQualifiedName(self, ctx: JavaParserLabeled.QualifiedNameContext):
        if ctx.IDENTIFIER(0).getText() == self.source_class and self.target_class:
            self.token_stream_rewriter.replaceIndex(
                index=ctx.start.tokenIndex,
                text=self.target_class
            )

    def exitExpression0(self, ctx: JavaParserLabeled.Expression0Context):
        if ctx.primary().getText() == self.source_class and self.target_class:
            self.token_stream_rewriter.replaceIndex(
                index=ctx.start.tokenIndex,
                text=self.target_class
            )

    def enterLocalVariableDeclaration(self, ctx: JavaParserLabeled.LocalVariableDeclarationContext):
        if ctx.typeType().classOrInterfaceType():
            if ctx.typeType().classOrInterfaceType().getText() == self.source_class and self.target_class:
                self.token_stream_rewriter.replace(
                    program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME,
                    from_idx=ctx.typeType().start.tokenIndex,
                    to_idx=ctx.typeType().stop.tokenIndex,
                    text=self.target_class
                )