class MethodUsageListener(UtilsListener): def __init__(self, filename: str, methods: str, target_class: str): super(MethodUsageListener, self).__init__(filename) self.methods = methods self.method_names = set(map(lambda m: m.name, methods)) self.rewriter = None self.target_class = target_class def enterCompilationUnit(self, ctx: JavaParser.CompilationUnitContext): super().enterCompilationUnit(ctx) self.rewriter = TokenStreamRewriter(ctx.parser.getTokenStream()) def enterClassCreatorRest(self, ctx: JavaParser.ClassCreatorRestContext): if type(ctx.parentCtx) is JavaParser.CreatorContext: if ctx.parentCtx.createdName().IDENTIFIER()[0].getText( ) not in self.method_names: return text = f"new {self.target_class}()" if ctx.arguments().expressionList( ) is None else f", new {self.target_class}()" index = ctx.arguments().RPAREN().symbol.tokenIndex self.rewriter.insertBeforeIndex(index, text) def exitMethodCall(self, ctx: JavaParser.MethodCallContext): super().exitMethodCall(ctx) if ctx.THIS() is not None: return if ctx.IDENTIFIER().getText() in self.method_names: text = f"new {self.target_class}()" if ctx.expressionList( ) is None else f", new {self.target_class}()" self.rewriter.insertBeforeIndex(ctx.RPAREN().symbol.tokenIndex, text) def exitClassBody(self, ctx: JavaParser.ClassBodyContext): super().exitClassBody(ctx) save(self.rewriter, self.filename)
def testInsertBeforeIndexZero(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '0') self.assertEqual(rewriter.getDefaultText(), '0abc')
def testInsertBeforeIndexZero(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '0') self.assertEquals(rewriter.getDefaultText(), '0abc')
def testDropPrevCoveredInsert(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'foo') rewriter.replaceRange(1, 2, 'foo') self.assertEqual('afoofoo', rewriter.getDefaultText())
def testCombineInserts(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, 'x') rewriter.insertBeforeIndex(0, 'y') self.assertEqual('yxabc', rewriter.getDefaultText())
def testReplaceThenInsertBeforeLastIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceIndex(2, 'x') rewriter.insertBeforeIndex(2, 'y') self.assertEqual('abyx', rewriter.getDefaultText())
def testCombineInsertOnLeftWithDelete(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.delete('default', 0, 2) rewriter.insertBeforeIndex(0, 'z') self.assertEquals('z', rewriter.getDefaultText())
def testLeaveAloneDisjointInsert2(self): input = InputStream('abcc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 3, 'foo') rewriter.insertBeforeIndex(1, 'x') self.assertEquals('axbfoo', rewriter.getDefaultText())
def testReplaceThenInsertBeforeLastIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceIndex(2, 'x') rewriter.insertBeforeIndex(2, 'y') self.assertEquals('abyx', rewriter.getDefaultText())
def testInsertBeforeTokenThenDeleteThatToken(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'foo') rewriter.replaceRange(1, 2, 'foo') self.assertEquals('afoofoo', rewriter.getDefaultText())
def testReplaceRangeThenInsertAtLeftEdge(self): input = InputStream('abcccba') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 4, 'x') rewriter.insertBeforeIndex(2, 'y') self.assertEqual('abyxba', rewriter.getDefaultText())
def testInsertThenReplaceSameIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '0') rewriter.replaceIndex(0, 'x') self.assertEquals('0xbc', rewriter.getDefaultText())
def test2InsertBeforeAfterMiddleIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'x') rewriter.insertAfter(1, 'x') self.assertEqual(rewriter.getDefaultText(), 'axbxc')
def testCombineInsertOnLeftWithDelete(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.delete('default', 0, 2) rewriter.insertBeforeIndex(0, 'z') self.assertEqual('z', rewriter.getDefaultText())
def testReplaceRangeThenInsertAtLeftEdge(self): input = InputStream('abcccba') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 4, 'x') rewriter.insertBeforeIndex(2, 'y') self.assertEquals('abyxba', rewriter.getDefaultText())
def testInsertBeforeTokenThenDeleteThatToken(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'foo') rewriter.replaceRange(1, 2, 'foo') self.assertEqual('afoofoo', rewriter.getDefaultText())
def testCombineInserts(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, 'x') rewriter.insertBeforeIndex(0, 'y') self.assertEquals('yxabc', rewriter.getDefaultText())
def testLeaveAloneDisjointInsert2(self): input = InputStream('abcc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 3, 'foo') rewriter.insertBeforeIndex(1, 'x') self.assertEqual('axbfoo', rewriter.getDefaultText())
def testDropPrevCoveredInsert(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'foo') rewriter.replaceRange(1, 2, 'foo') self.assertEquals('afoofoo', rewriter.getDefaultText())
def testCombineInsertOnLeftWithReplace(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(0, 2, 'foo') rewriter.insertBeforeIndex(0, 'z') self.assertEqual('zfoo', rewriter.getDefaultText())
def test2InsertBeforeAfterMiddleIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'x') rewriter.insertAfter(1, 'x') self.assertEquals(rewriter.getDefaultText(), 'axbxc')
def testInsertThenReplaceSameIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '0') rewriter.replaceIndex(0, 'x') self.assertEqual('0xbc', rewriter.getDefaultText())
def testCombineInsertOnLeftWithReplace(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(0, 2, 'foo') rewriter.insertBeforeIndex(0, 'z') self.assertEquals('zfoo', rewriter.getDefaultText())
def testDisjointInserts(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'x') rewriter.insertBeforeIndex(2, 'y') rewriter.insertBeforeIndex(0, 'z') self.assertEquals('zaxbyc', rewriter.getDefaultText())
def test2ReplaceMiddleIndex1InsertBefore(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, "_") rewriter.replaceIndex(1, 'x') rewriter.replaceIndex(1, 'y') self.assertEquals('_ayc', rewriter.getDefaultText())
def testDisjointInserts(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(1, 'x') rewriter.insertBeforeIndex(2, 'y') rewriter.insertBeforeIndex(0, 'z') self.assertEqual('zaxbyc', rewriter.getDefaultText())
def test2ReplaceMiddleIndex1InsertBefore(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, "_") rewriter.replaceIndex(1, 'x') rewriter.replaceIndex(1, 'y') self.assertEqual('_ayc', rewriter.getDefaultText())
def testReplaceThenDeleteMiddleIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(0, 2, 'x') rewriter.insertBeforeIndex(1, '0') with self.assertRaises(ValueError) as ctx: rewriter.getDefaultText() self.assertEquals( 'insert op <InsertBeforeOp@[@1,1:1=\'b\',<2>,1:1]:"0"> within boundaries of previous <ReplaceOp@[@0,0:0=\'a\',<1>,1:0]..[@2,2:2=\'c\',<3>,1:2]:"x">', ctx.exception.message)
def testPreservesOrderOfContiguousInserts(self): """ Test for fix for: https://github.com/antlr/antlr4/issues/550 """ input = InputStream('aa') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '<b>') rewriter.insertAfter(0, '</b>') rewriter.insertBeforeIndex(1, '<b>') rewriter.insertAfter(1, '</b>') self.assertEqual('<b>a</b><b>a</b>', rewriter.getDefaultText())
def testReplaceThenDeleteMiddleIndex(self): input = InputStream('abc') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(0, 2, 'x') rewriter.insertBeforeIndex(1, '0') with self.assertRaises(ValueError) as ctx: rewriter.getDefaultText() self.assertEqual( 'insert op <InsertBeforeOp@[@1,1:1=\'b\',<2>,1:1]:"0"> within boundaries of previous <ReplaceOp@[@0,0:0=\'a\',<1>,1:0]..[@2,2:2=\'c\',<3>,1:2]:"x">', str(ctx.exception) )
def testPreservesOrderOfContiguousInserts(self): """ Test for fix for: https://github.com/antlr/antlr4/issues/550 """ input = InputStream('aa') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.insertBeforeIndex(0, '<b>') rewriter.insertAfter(0, '</b>') rewriter.insertBeforeIndex(1, '<b>') rewriter.insertAfter(1, '</b>') self.assertEquals('<b>a</b><b>a</b>', rewriter.getDefaultText())
def testReplaceRangeThenInsertAtRightEdge(self): input = InputStream('abcccba') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 4, 'x') rewriter.insertBeforeIndex(4, 'y') with self.assertRaises(ValueError) as ctx: rewriter.getDefaultText() msg = ctx.exception.message self.assertEquals( "insert op <InsertBeforeOp@[@4,4:4='c',<3>,1:4]:\"y\"> within boundaries of previous <ReplaceOp@[@2,2:2='c',<3>,1:2]..[@4,4:4='c',<3>,1:4]:\"x\">", msg)
def testReplaceRangeThenInsertAtRightEdge(self): input = InputStream('abcccba') lexer = TestLexer(input) stream = CommonTokenStream(lexer=lexer) stream.fill() rewriter = TokenStreamRewriter(tokens=stream) rewriter.replaceRange(2, 4, 'x') rewriter.insertBeforeIndex(4, 'y') with self.assertRaises(ValueError) as ctx: rewriter.getDefaultText() msg = str(ctx.exception) self.assertEqual( "insert op <InsertBeforeOp@[@4,4:4='c',<3>,1:4]:\"y\"> within boundaries of previous <ReplaceOp@[@2,2:2='c',<3>,1:2]..[@4,4:4='c',<3>,1:4]:\"x\">", msg )
class FieldUsageListener(UtilsListener): """ FieldUsageListener finds all the usage of an specified field f, from a class c in package pkg. """ def __init__(self, filename: str, source_class: str, source_package: str, target_class: str, target_package: str, field_name: str, field_candidates: set, field_tobe_moved: Field): super(FieldUsageListener, self).__init__(filename) self.source_class = source_class self.source_package = source_package self.target_class = target_class self.target_package = target_package self.field_name = field_name self.has_imported_source = False self.has_imported_target = False self.usages = [] # current class name is the public class in each file. self.current_class_name = "" self.field_candidates = field_candidates self.rewriter = None # this represents the text to be added in target i.e. public int a; self.field_tobe_moved = field_tobe_moved self.methods_tobe_updated = [] def enterCompilationUnit(self, ctx: JavaParser.CompilationUnitContext): super().enterCompilationUnit(ctx) self.rewriter = TokenStreamRewriter(ctx.parser.getTokenStream()) def enterClassDeclaration(self, ctx: JavaParser.ClassDeclarationContext): super().enterClassDeclaration(ctx) if ctx.parentCtx.classOrInterfaceModifier()[0].getText() == "public": self.current_class_name = ctx.IDENTIFIER().getText() else: return self.has_imported_source = self.file_info.has_imported_package(self.package.name) or \ self.file_info.has_imported_class(self.package.name, self.source_class) # import target if we're not in Target and have not imported before if self.current_class_name != self.target_class: self.rewriter.insertBeforeIndex( ctx.parentCtx.start.tokenIndex, f"import {self.target_package}.{self.target_class};\n") def enterClassBody(self, ctx: JavaParser.ClassBodyContext): super().exitClassBody(ctx) if self.current_class_name == self.target_class: replacement_text = "" if self.field_tobe_moved.name == self.field_name: for mod in self.field_tobe_moved.modifiers: replacement_text += f"{mod} " replacement_text += f"{self.field_tobe_moved.datatype} {self.field_tobe_moved.name};" self.rewriter.insertAfter(ctx.start.tokenIndex, f"\n\t{replacement_text}\n") # add getter and setter name = self.field_tobe_moved.name method_name = self.field_tobe_moved.name.upper( ) + self.field_tobe_moved.name[1:-1] type = self.field_tobe_moved.datatype getter = f"\tpublic {type} get{method_name}() {{ return this.{name}; }}\n" setter = f"\tpublic void set{method_name}({type} {name}) {{ this.{name} = {name}; }}\n" self.rewriter.insertBeforeIndex(ctx.stop.tokenIndex, getter) self.rewriter.insertBeforeIndex(ctx.stop.tokenIndex, setter) def exitFieldDeclaration(self, ctx: JavaParser.FieldDeclarationContext): super().exitFieldDeclaration(ctx) if self.current_class_name != self.source_class: return if self.field_tobe_moved is None: field = self.package.classes[self.current_class_name].fields[ ctx.variableDeclarators().children[0].children[0].IDENTIFIER( ).getText()] if field.name == self.field_name: self.field_tobe_moved = field def exitClassBody(self, ctx: JavaParser.ClassBodyContext): super().exitClassBody(ctx) save(self.rewriter, self.filename) def exitMethodDeclaration(self, ctx: JavaParser.MethodDeclarationContext): super().exitMethodDeclaration(ctx) # we will remove getter and setter from source # and add it to target so there is no need to # find usages there if self.current_class_name == self.source_class and \ self.is_method_getter_or_setter(ctx.IDENTIFIER().getText()): self.rewriter.replaceRange( ctx.parentCtx.parentCtx.start.tokenIndex, ctx.parentCtx.parentCtx.stop.tokenIndex, "") def exitConstructorDeclaration( self, ctx: JavaParser.ConstructorDeclarationContext): self.current_method.name = ctx.IDENTIFIER().getText() self.current_method.returntype = self.current_method.class_name self.handleMethodUsage(ctx, True) super().exitConstructorDeclaration(ctx) def exitMethodBody(self, ctx: JavaParser.MethodBodyContext): super().exitMethodBody(ctx) self.handleMethodUsage(ctx, False) def handleMethodUsage(self, ctx, is_constructor: bool): method_identifier = ctx.IDENTIFIER().getText( ) if is_constructor else ctx.parentCtx.IDENTIFIER().getText() formal_params = ctx.formalParameters( ) if is_constructor else ctx.parentCtx.formalParameters() target_added = False target_param_name = "$$target" target_param = f"Target {target_param_name}" if \ len(self.current_method.parameters) == 0 \ else f", Target {target_param_name}" # if we have not imported source package or # Source class just ignore this if not self.has_imported_source: return local_candidates = set() if self.current_class_name == self.source_class: # we will remove getter and setter from source # and add it to target so there is no need to # find usages there if self.is_method_getter_or_setter(method_identifier): self.rewriter.replaceRange(ctx.start.tokenIndex, ctx.stop.tokenIndex, "") return local_candidates.add("this") # find parameters with type Source for t, identifier in self.current_method.parameters: if t == self.source_class: local_candidates.add(identifier) # find all local variables with type Source for var_or_exprs in self.current_method.body_local_vars_and_expr_names: if type(var_or_exprs) is LocalVariable: if var_or_exprs.datatype == self.source_class: local_candidates.add(var_or_exprs.identifier) should_ignore = False for var_or_exprs in self.current_method.body_local_vars_and_expr_names: if type(var_or_exprs) is ExpressionName: # we're going to find source.field try: local_ctx = var_or_exprs.parser_context.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx creator = local_ctx.expression()[0].getText() if creator.__contains__( f"new{self.source_class}" ) and local_ctx.IDENTIFIER().getText() == self.field_name: self.propagate_field(local_ctx, target_param_name) except: pass if len(var_or_exprs.dot_separated_identifiers) < 2: continue if (var_or_exprs.dot_separated_identifiers[0] in local_candidates or var_or_exprs.dot_separated_identifiers[0] in self.field_candidates) and \ var_or_exprs.dot_separated_identifiers[1] == self.field_name: if not target_added: # add target to param self.rewriter.insertBeforeIndex( formal_params.stop.tokenIndex, target_param) self.methods_tobe_updated.append(self.current_method) target_added = True self.usages.append(var_or_exprs.parser_context) self.propagate_field(var_or_exprs.parser_context, target_param_name) elif type(var_or_exprs) is MethodInvocation: # we are going to find getter or setters # if len(var_or_exprs.dot_separated_identifiers) < 2: # continue if var_or_exprs.dot_separated_identifiers[ 0] == f"new{self.source_class}": if var_or_exprs.parser_context.methodCall() is not None and \ self.is_method_getter_or_setter( var_or_exprs.parser_context.methodCall().IDENTIFIER().getText()): self.propagate_getter_setter( var_or_exprs.parser_context, target_param_name) elif self.is_method_getter_or_setter( var_or_exprs.dot_separated_identifiers[0]): if not target_added: # add target to param self.rewriter.insertBeforeIndex( formal_params.stop.tokenIndex, target_param) self.methods_tobe_updated.append(self.current_method) target_added = True if not should_ignore and var_or_exprs.parser_context is not None and type( var_or_exprs.parser_context ) is not JavaParser.ExpressionContext: continue self.usages.append(var_or_exprs.parser_context) self.propagate_getter_setter_form2( var_or_exprs.parser_context, target_param_name) elif len(var_or_exprs.dot_separated_identifiers ) > 1 and self.is_getter_or_setter( var_or_exprs.dot_separated_identifiers[0], var_or_exprs.dot_separated_identifiers[1], local_candidates): if not target_added: # add target to param self.rewriter.insertBeforeIndex( formal_params.stop.tokenIndex, target_param) self.methods_tobe_updated.append(self.current_method) target_added = True self.usages.append(var_or_exprs.parser_context) self.propagate_getter_setter(var_or_exprs.parser_context, target_param_name) def is_getter_or_setter(self, first_id: str, second_id: str, local_candidates: set): return ( first_id in local_candidates or first_id in self.field_candidates ) and (second_id == f"set{self.field_name[0].upper() + self.field_name[1:-1]}" or second_id == f"get{self.field_name[0].upper() + self.field_name[1:-1]}" or second_id == f"has{self.field_name[0].upper() + self.field_name[1:-1]}" or second_id == f"is{self.field_name[0].upper() + self.field_name[1:-1]}") def is_method_getter_or_setter(self, method: str): return (method == f"set{self.field_name[0].upper() + self.field_name[1:-1]}" or method == f"get{self.field_name[0].upper() + self.field_name[1:-1]}" or method == f"has{self.field_name[0].upper() + self.field_name[1:-1]}" or method == f"is{self.field_name[0].upper() + self.field_name[1:-1]}") def propagate_getter_setter(self, ctx: JavaParser.ExpressionContext, target_name: str): index = ctx.DOT().symbol.tokenIndex self.rewriter.replaceRange(ctx.start.tokenIndex, index - 1, target_name) def propagate_getter_setter_form2(self, ctx: JavaParser.ExpressionContext, target_name: str): """ form 2 is getA() setA()... """ self.rewriter.insertBeforeIndex(ctx.start.tokenIndex, f"{target_name}.") def propagate_field(self, ctx: JavaParser.ExpressionContext, target_name: str): index = ctx.DOT().symbol.tokenIndex self.rewriter.replaceRange(ctx.start.tokenIndex, index - 1, target_name)
class EncapsulateFiledRefactoringListener(JavaParserLabeledListener): """ To implement encapsulate field refactoring. Makes a public field private and provide accessors and mutator methods. """ def __init__(self, common_token_stream: CommonTokenStream = None, package_name: str = None, source_class_name: str = None, field_identifier: str = None): """ Args: common_token_stream (CommonTokenStream): contains the program tokens package_name (str): The enclosing package of the field source_class_name (str): The enclosing class of the field field_identifier (str): The field name to be encapsulated Returns: object (DecreaseMethodVisibilityListener): An instance of EncapsulateFiledRefactoringListener """ self.token_stream = common_token_stream if package_name is None: self.package_name = '' else: self.package_name = package_name self.source_class_name = source_class_name self.field_identifier = field_identifier self.getter_exist = False self.setter_exist = False self.in_source_class = False self.in_selected_package = True if self.package_name == '' else False # Move all the tokens in the source code in a buffer, token_stream_rewriter. if common_token_stream is not None: self.token_stream_rewriter = \ TokenStreamRewriter(common_token_stream) else: raise TypeError('common_token_stream is None') def enterPackageDeclaration( self, ctx: JavaParserLabeled.PackageDeclarationContext): if self.package_name == ctx.qualifiedName().getText(): self.in_selected_package = True else: self.in_selected_package = False def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if ctx.IDENTIFIER().getText() == self.source_class_name: self.in_source_class = True def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): self.in_source_class = False def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): if self.in_source_class and self.in_selected_package: if ctx.variableDeclarators().variableDeclarator( 0).variableDeclaratorId().getText( ) == self.field_identifier: if not ctx.parentCtx.parentCtx.modifier(0): self.token_stream_rewriter.insertBeforeIndex( index=ctx.typeType().stop.tokenIndex, text='private ') elif ctx.parentCtx.parentCtx.modifier(0).getText() == 'public': self.token_stream_rewriter.replaceRange( from_idx=ctx.parentCtx.parentCtx.modifier( 0).start.tokenIndex, to_idx=ctx.parentCtx.parentCtx.modifier( 0).stop.tokenIndex, text='private') else: return for c in ctx.parentCtx.parentCtx.parentCtx.classBodyDeclaration( ): try: print('method name: ' + c.memberDeclaration(). methodDeclaration().IDENTIFIER().getText()) if c.memberDeclaration().methodDeclaration().IDENTIFIER() \ .getText() == 'get' + str.capitalize( self.field_identifier): self.getter_exist = True if c.memberDeclaration().methodDeclaration().IDENTIFIER() \ .getText() == 'set' + str.capitalize( self.field_identifier): self.setter_exist = True except: logger.error("not method !!!") logger.debug("setter find: " + str(self.setter_exist)) logger.debug("getter find: " + str(self.getter_exist)) # generate accessor and mutator methods # Accessor body new_code = '' if not self.getter_exist: new_code = '\n\t// new getter method\n\t' new_code += 'public ' + ctx.typeType().getText() + \ ' get' + str.capitalize(self.field_identifier) new_code += '() { \n\t\treturn this.' + self.field_identifier \ + ';' + '\n\t}\n' # Mutator body if not self.setter_exist: new_code += '\n\t// new setter method\n\t' new_code += 'public void set' + str.capitalize( self.field_identifier) new_code += '(' + ctx.typeType().getText() + ' ' \ + self.field_identifier + ') { \n\t\t' new_code += 'this.' + self.field_identifier + ' = ' \ + self.field_identifier + ';' + '\n\t}\n' self.token_stream_rewriter.insertAfter(ctx.stop.tokenIndex, new_code) hidden = self.token_stream.getHiddenTokensToRight( ctx.stop.tokenIndex) # self.token_stream_rewriter.replaceRange(from_idx=hidden[0].tokenIndex, # to_idx=hidden[-1].tokenIndex, # text='\n\t/*End of accessor and mutator methods!*/\n\n') def exitExpression21(self, ctx: JavaParserLabeled.Expression21Context): if self.in_source_class and self.in_selected_package: if ctx.expression(0).getText() == self.field_identifier or \ ctx.expression(0).getText() == 'this.' + self.field_identifier: expr_code = self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, start=ctx.expression(1).start.tokenIndex, stop=ctx.expression(1).stop.tokenIndex) new_code = 'this.set' + str.capitalize( self.field_identifier) + '(' + expr_code + ')' self.token_stream_rewriter.replaceRange( ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code) def exitExpression0(self, ctx: JavaParserLabeled.Expression0Context): if self.in_source_class and self.in_selected_package: try: if ctx.parentCtx.getChild(1).getText() in ('=', '+=', '-=', '*=', '/=', '&=', '|=', '^=', '>>=', '>>>=', '<<=', '%=') and \ ctx.parentCtx.getChild(0) == ctx: return except: pass if ctx.getText() == self.field_identifier: new_code = 'this.get' + str.capitalize( self.field_identifier) + '()' self.token_stream_rewriter.replaceRange( ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code) def exitExpression1(self, ctx: JavaParserLabeled.Expression1Context): if self.in_source_class and self.in_selected_package: try: if ctx.parentCtx.getChild(1).getText() in ('=', '+=', '-=', '*=', '/=', '&=', '|=', '^=', '>>=', '>>>=', '<<=', '%=') and \ ctx.parentCtx.getChild(0) == ctx: return except: pass if ctx.getText() == 'this.' + self.field_identifier: new_code = 'this.get' + str.capitalize( self.field_identifier) + '()' self.token_stream_rewriter.replaceRange( ctx.start.tokenIndex, ctx.stop.tokenIndex, new_code) def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext): try: hidden = self.token_stream.getHiddenTokensToLeft( ctx.start.tokenIndex) self.token_stream_rewriter.replaceRange( from_idx=hidden[0].tokenIndex, to_idx=hidden[-1].tokenIndex, text='/*After refactoring (Refactored version)*/\n') except: pass
def do_refactor(self): program = get_program(self.source_filenames) static = 0 if self.class_name not in program.packages[self.package_name].classes or self.target_class_name not in \ program.packages[ self.target_package_name].classes or self.method_key not in \ program.packages[self.package_name].classes[ self.class_name].methods: return False _sourceclass = program.packages[self.package_name].classes[ self.class_name] _targetclass = program.packages[self.target_package_name].classes[ self.target_class_name] _method = program.packages[self.package_name].classes[ self.class_name].methods[self.method_key] if _method.is_constructor: return False Rewriter_ = Rewriter(program, lambda x: x) tokens_info = TokensInfo( _method.parser_context) # tokens of ctx method param_tokens_info = TokensInfo(_method.formalparam_context) method_declaration_info = TokensInfo( _method.method_declaration_context) exp = [ ] # برای نگه داری متغیرهایی که داخل کلاس تعریف شدند و در بدنه متد استفاده شدند exps = tokens_info.get_token_index(tokens_info.token_stream.tokens, tokens_info.start, tokens_info.stop) # check that method is static or not for modifier in _method.modifiers: if modifier == "static": static = 1 for token in exps: if token.text in _sourceclass.fields: exp.append(token.tokenIndex) # check that where this method is call for package_names in program.packages: package = program.packages[package_names] for class_ in package.classes: _class = package.classes[class_] for method_ in _class.methods: __method = _class.methods[method_] for inv in __method.body_method_invocations: invc = __method.body_method_invocations[inv] method_name = self.method_key[:self.method_key.find('(' )] if invc[0] == method_name: inv_tokens_info = TokensInfo(inv) if static == 0: class_token_info = TokensInfo( _class.body_context) Rewriter_.insert_after_start( class_token_info, self.target_class_name + " " + str.lower(self.target_class_name) + "=" + "new " + self.target_class_name + "();") Rewriter_.apply() Rewriter_.insert_before_start( class_token_info, "import " + self.target_package_name + "." + self.target_class_name + ";") Rewriter_.replace(inv_tokens_info, self.target_class_name) Rewriter_.apply() class_tokens_info = TokensInfo(_targetclass.parser_context) package_tokens_info = TokensInfo( program.packages[self.target_package_name].package_ctx) singlefileelement = SingleFileElement(_method.parser_context, _method.filename) token_stream_rewriter = TokenStreamRewriter( singlefileelement.get_token_stream()) # insert name of source.java class befor param that define in body of classe (that use in method) for index in exp: token_stream_rewriter.insertBeforeIndex( index=index, text=str.lower(self.class_name) + ".") for inv in _method.body_method_invocations: if inv.getText() == self.target_class_name: inv_tokens_info_target = TokensInfo(inv) token_stream_rewriter.replaceRange( from_idx=inv_tokens_info_target.start, to_idx=inv_tokens_info_target.stop + 1, text=" ") # insert source.java class befor methods of sourcr class that used in method for i in _method.body_method_invocations_without_typename: if i.getText() == self.class_name: ii = _method.body_method_invocations_without_typename[i] i_tokens = TokensInfo(ii[0]) token_stream_rewriter.insertBeforeIndex( index=i_tokens.start, text=str.lower(self.class_name) + ".") # pass object of source.java class to method if param_tokens_info.start is not None: token_stream_rewriter.insertBeforeIndex( param_tokens_info.start, text=self.class_name + " " + str.lower(self.class_name) + ",") else: token_stream_rewriter.insertBeforeIndex( method_declaration_info.stop, text=self.class_name + " " + str.lower(self.class_name)) strofmethod = token_stream_rewriter.getText( program_name=token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=tokens_info.start, stop=tokens_info.stop) Rewriter_.insert_before(tokens_info=class_tokens_info, text=strofmethod) Rewriter_.insert_after( package_tokens_info, "import " + self.target_package_name + "." + self.target_class_name + ";") Rewriter_.replace(tokens_info, "") Rewriter_.apply() return True
class ReplaceConstructorWithFactoryFunctionRefactoringListener(JavaParserLabeledListener): def __init__(self, common_token_stream: CommonTokenStream = None, target_class: str = None): if common_token_stream is None: raise ValueError('common_token_stream is None') else: self.codeRewrite = TokenStreamRewriter(common_token_stream) if target_class is None: raise ValueError("source_class is None") else: self.target_class = target_class self.is_target_class = False self.have_constructor = False self.new_factory_function = False self.new_parameters = [] self.new_parameters_names = [] def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): # self.target_class = ctx.IDENTIFIER().getText() # have_constructor = False # if ctx.IDENTIFIER().getText() == class_identifier = ctx.IDENTIFIER().getText() if class_identifier == self.target_class: self.is_target_class = True # print("class name " + ctx.IDENTIFIER().getText()) else: self.is_target_class = False def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if self.is_target_class: self.is_target_class = False def enterConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext): if self.is_target_class: # print("constructor name " + ctx.IDENTIFIER().getText()) # parameters = ctx.formalParameters().getText() # print(len(ctx.formalParameters().formalParameterList().formalParameter())) grandParentCtx = ctx.parentCtx.parentCtx if ctx.IDENTIFIER().getText() == self.target_class: self.have_constructor = True # do refactor """ Declare the constructor private. """ if grandParentCtx.modifier(): if 'public' == grandParentCtx.modifier(0).getText(): self.codeRewrite.replaceRange( from_idx=grandParentCtx.modifier(0).start.tokenIndex, to_idx=grandParentCtx.modifier(0).stop.tokenIndex, text='private') else: self.codeRewrite.insertBeforeIndex( index=ctx.start.tokenIndex, text="private " ) def exitConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext): """ Create a factory method. Make its body a call to the current constructor. """ if self.is_target_class: grandParentCtx = ctx.parentCtx.parentCtx self.codeRewrite.insertAfter( index=grandParentCtx.stop.tokenIndex, text="\n public static " + ctx.IDENTIFIER().getText() + " Create( " + ", ".join( self.new_parameters) + "){\n return new " + ctx.IDENTIFIER().getText() + "(" + ", ".join( self.new_parameters_names) + ");\n}" ) self.new_parameters = [] self.new_parameters_names = [] def enterFormalParameterList0(self, ctx: JavaParserLabeled.FormalParameterList0Context): # print(len(ctx.formalParameter())) pass def exitFormalParameterList0(self, ctx: JavaParserLabeled.FormalParameterList0Context): pass def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext): # print(ctx.typeType().getText()) # print(ctx.variableDeclaratorId().getText()) constructorName = ctx.parentCtx.parentCtx.parentCtx.IDENTIFIER().getText() if self.target_class == constructorName: text = ctx.typeType().getText() + " " + ctx.variableDeclaratorId().getText() self.new_parameters.append(text) self.new_parameters_names.append(ctx.variableDeclaratorId().getText()) def exitFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext): pass def enterExpression4(self, ctx: JavaParserLabeled.Expression4Context): """ Replace all constructor calls with calls to the factory method. """ # currentMethodOrClassCtx=ctx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx # print(ctx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.parentCtx.getText()) if ctx.creator().createdName().getText() == self.target_class: self.codeRewrite.replaceRange( from_idx=ctx.start.tokenIndex, to_idx=ctx.stop.tokenIndex, text=self.target_class + "." + "Create" + ctx.creator().classCreatorRest().getText())
class MakeFieldStaticRefactoringListener(JavaParserLabeledListener): """ To implement the encapsulate filed refactored Encapsulate field: Make a public field private and provide accessors """ def __init__(self, common_token_stream: CommonTokenStream = None, field_identifier: str = None, class_identifier: str = None, package_identifier: str = None): """ :param common_token_stream: """ self.token_stream = common_token_stream self.field_identifier = field_identifier self.class_identifier = class_identifier self.package_identifier = package_identifier self.declared_objects_names = [] self.is_package_imported = False self.in_selected_package = False self.in_selected_class = False self.in_some_package = False # Move all the tokens in the source code in a buffer, token_stream_rewriter. if common_token_stream is not None: self.token_stream_rewriter = TokenStreamRewriter( common_token_stream) else: raise TypeError('common_token_stream is None') def enterPackageDeclaration( self, ctx: JavaParserLabeled.PackageDeclarationContext): self.in_some_package = True if self.package_identifier is not None: if self.package_identifier == ctx.qualifiedName().getText(): self.in_selected_package = True print("Package Found") def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if self.package_identifier is None and not self.in_some_package or\ self.package_identifier is not None and self.in_selected_package: if ctx.IDENTIFIER().getText() == self.class_identifier: print("Class Found") self.in_selected_class = True def enterImportDeclaration( self, ctx: JavaParserLabeled.ImportDeclarationContext): if self.package_identifier is not None: if ctx.getText() == "import" + self.package_identifier + "." + self.class_identifier + ";" \ or ctx.getText() == "import" + self.package_identifier + ".*" + ";" \ or ctx.getText() == "import" + self.package_identifier + ";": self.is_package_imported = True def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): if self.package_identifier is None and not self.in_some_package\ or self.package_identifier is not None and self.in_selected_package: if self.in_selected_class: if ctx.variableDeclarators().variableDeclarator(0)\ .variableDeclaratorId().getText() == self.field_identifier: grand_parent_ctx = ctx.parentCtx.parentCtx if len(grand_parent_ctx.modifier()) == 0: self.token_stream_rewriter.insertBeforeIndex( index=ctx.parentCtx.start.tokenIndex, text=' static ') else: is_static = False for modifier in grand_parent_ctx.modifier(): if modifier.getText() == "static": is_static = True break if not is_static: self.token_stream_rewriter.insertAfter( index=grand_parent_ctx.start.tokenIndex + len(grand_parent_ctx.modifier()), text=' static ') if self.package_identifier is None or self.package_identifier is not None and self.is_package_imported: if ctx.typeType().classOrInterfaceType() is not None: if ctx.typeType().classOrInterfaceType().getText( ) == self.class_identifier: self.declared_objects_names.append( ctx.variableDeclarators().variableDeclarator( 0).variableDeclaratorId().getText()) print("Object " + ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().getText()\ + " of type " + self.class_identifier + " found.") def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context): if self.is_package_imported or self.package_identifier is None or self.in_selected_package: for object_name in self.declared_objects_names: if ctx.getText() == object_name + "." + self.field_identifier: self.token_stream_rewriter.replaceIndex( index=ctx.start.tokenIndex, text=self.class_identifier)
class MakeMethodNonStaticRefactoringListener(JavaParserLabeledListener): """ To implement Make Method None-Static refactoring based on its actors. """ def __init__(self, common_token_stream: CommonTokenStream = None, target_class: str = None, target_methods: list = None): """ """ if common_token_stream is None: raise ValueError('common_token_stream is None') else: self.token_stream_rewriter = TokenStreamRewriter( common_token_stream) if target_class is None: raise ValueError("source_class is None") else: self.target_class = target_class if target_methods is None or len(target_methods) == 0: raise ValueError("target method must have one method name") else: self.target_methods = target_methods self.target_class_data = None self.is_target_class = False self.detected_field = None self.detected_method = None self.TAB = "\t" self.NEW_LINE = "\n" self.code = "" def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): class_identifier = ctx.IDENTIFIER().getText() if class_identifier == self.target_class: self.is_target_class = True self.target_class_data = {'constructors': []} else: self.is_target_class = False def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if self.is_target_class: have_default_constructor = False for constructor in self.target_class_data['constructor']: if len(constructor.parameters) == 0: have_default_constructor = True break if not have_default_constructor: self.token_stream_rewriter.insertBeforeIndex( index=ctx.stop.tokenIndex - 1, text= f'\n\t public {self.target_class_data["constructors"][0]} ()\n\t{{}}\n' ) self.is_target_class = False def enterMethodDeclaration( self, ctx: JavaParserLabeled.MethodDeclarationContext): if self.is_target_class: if ctx.IDENTIFIER().getText() in self.target_methods: grand_parent_ctx = ctx.parentCtx.parentCtx if grand_parent_ctx.modifier(): if len(grand_parent_ctx.modifier()) == 2: self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, from_idx=grand_parent_ctx.modifier( 1).start.tokenIndex - 1, to_idx=grand_parent_ctx.modifier( 1).stop.tokenIndex) else: if grand_parent_ctx.modifier(0).getText() == 'static': self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, from_idx=grand_parent_ctx.modifier( 0).start.tokenIndex - 1, to_idx=grand_parent_ctx.modifier( 0).stop.tokenIndex) else: return None def enterConstructorDeclaration( self, ctx: JavaParserLabeled.ConstructorDeclarationContext): if self.is_target_class: if ctx.formalParameters().formalParameterList(): constructor_parameters = [ ctx.formalParameters().formalParameterList().children[i] for i in range( len(ctx.formalParameters().formalParameterList(). children)) if i % 2 == 0 ] else: constructor_parameters = [] constructor_text = '' for modifier in ctx.parentCtx.parentCtx.modifier(): constructor_text += modifier.getText() + ' ' constructor_text += ctx.IDENTIFIER().getText() constructor_text += ' ( ' for parameter in constructor_parameters: constructor_text += parameter.typeType().getText() + ' ' constructor_text += parameter.variableDeclaratorId().getText( ) + ', ' if constructor_parameters: constructor_text = constructor_text[:len(constructor_text) - 2] constructor_text += ')\n\t{' constructor_text += self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.block().start.tokenIndex + 1, stop=ctx.block().stop.tokenIndex - 1) constructor_text += '}\n' self.target_class_data['constructors'].append( ConstructorOrMethod( name=self.target_class, parameters=[ Parameter(parameterType=p.typeType().getText(), name=p.variableDeclaratorId().IDENTIFIER(). getText()) for p in constructor_parameters ], text=constructor_text))
class MakeMethodStaticRefactoringListener(JavaParserLabeledListener): """ To implement extract class refactoring based on its actors. Creates a new class and move fields and methods from the old class to the new one """ def __init__(self, common_token_stream: CommonTokenStream = None, target_class: str = None, target_methods: list = None): if common_token_stream is None: raise ValueError('common_token_stream is None') else: self.token_stream_rewriter = TokenStreamRewriter( common_token_stream) if target_class is None: raise ValueError("source_class is None") else: self.target_class = target_class if target_methods is None or len(target_methods) == 0: raise ValueError("target method must have one method name") else: self.target_methods = target_methods self.is_target_class = False self.detected_instance_of_target_class = [] self.TAB = "\t" self.NEW_LINE = "\n" self.code = "" def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): class_identifier = ctx.IDENTIFIER().getText() if class_identifier == self.target_class: self.is_target_class = True else: self.is_target_class = False def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if self.is_target_class: self.is_target_class = False def enterMethodDeclaration( self, ctx: JavaParserLabeled.MethodDeclarationContext): if self.is_target_class: if ctx.IDENTIFIER().getText() in self.target_methods: if 'this.' in ctx.getText(): raise ValueError("this method can not refactor") grand_parent_ctx = ctx.parentCtx.parentCtx if grand_parent_ctx.modifier(): if len(grand_parent_ctx.modifier()) == 2: return None else: self.token_stream_rewriter.insertAfter( index=grand_parent_ctx.modifier(0).stop.tokenIndex, program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, text=" static") else: self.token_stream_rewriter.insertBeforeIndex( index=ctx.start.tokenIndex, text="static ") def enterLocalVariableDeclaration( self, ctx: JavaParserLabeled.LocalVariableDeclarationContext): if ctx.typeType().getText() == self.target_class: self.detected_instance_of_target_class.append( ctx.variableDeclarators().variableDeclarator( 0).variableDeclaratorId().IDENTIFIER().getText()) self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.start.tokenIndex, to_idx=ctx.stop.tokenIndex + 1) def enterMethodCall0(self, ctx: JavaParserLabeled.MethodCall0Context): if ctx.IDENTIFIER().getText() in self.target_methods: if ctx.parentCtx.expression().getText( ) in self.detected_instance_of_target_class: self.token_stream_rewriter.replace( program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, from_idx=ctx.parentCtx.expression().start.tokenIndex, to_idx=ctx.parentCtx.expression().stop.tokenIndex, text=self.target_class)
class InlineClassRefactoringListener(JavaParserLabeledListener): """ To implement inline class refactoring based on its actors. Creates a new class and move fields and methods from two old class to the new one, then delete the two class """ def __init__( self, common_token_stream: CommonTokenStream = None, source_class: str = None, source_class_data: dict = None, target_class: str = None, target_class_data: dict = None, is_complete: bool = False): """ """ if common_token_stream is None: raise ValueError('common_token_stream is None') else: self.token_stream_rewriter = TokenStreamRewriter(common_token_stream) if source_class is None: raise ValueError("source_class is None") else: self.source_class = source_class if target_class is None: raise ValueError("new_class is None") else: self.target_class = target_class if target_class: self.target_class = target_class if source_class_data: self.source_class_data = source_class_data else: self.source_class_data = {'fields': [], 'methods': [], 'constructors': []} if target_class_data: self.target_class_data = target_class_data else: self.target_class_data = {'fields': [], 'methods': [], 'constructors': []} self.field_that_has_source = [] self.has_source_new = False self.is_complete = is_complete self.is_target_class = False self.is_source_class = False self.detected_field = None self.detected_method = None self.TAB = "\t" self.NEW_LINE = "\n" self.code = "" def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): class_identifier = ctx.IDENTIFIER().getText() if class_identifier == self.source_class: self.is_source_class = True self.is_target_class = False elif class_identifier == self.target_class: self.is_target_class = True self.is_source_class = False else: self.is_target_class = False self.is_source_class = False def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): if self.is_target_class and (self.source_class_data['fields'] or self.source_class_data['constructors'] or self.source_class_data['methods']): if not self.is_complete: final_fields = merge_fields(self.source_class_data['fields'], self.target_class_data['fields'], self.target_class) final_constructors = merge_constructors(self.source_class_data['constructors'], self.target_class_data['constructors']) final_methods = merge_methods(self.source_class_data['methods'], self.target_class_data['methods']) text = '\t' for field in final_fields: text += field.text + '\n' for constructor in final_constructors: text += constructor.text + '\n' for method in final_methods: text += method.text + '\n' self.token_stream_rewriter.insertBeforeIndex( index=ctx.stop.tokenIndex, text=text ) self.is_complete = True else: self.is_target_class = False elif self.is_source_class: if ctx.parentCtx.classOrInterfaceModifier(0) is None: return self.is_source_class = False self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.parentCtx.classOrInterfaceModifier(0).start.tokenIndex, to_idx=ctx.stop.tokenIndex ) def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext): if self.is_source_class: self.code += self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.start.tokenIndex + 1, stop=ctx.stop.tokenIndex - 1 ) self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.parentCtx.start.tokenIndex, to_idx=ctx.parentCtx.stop.tokenIndex ) else: return None def enterFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): if self.is_source_class or self.is_target_class: field_text = '' for child in ctx.children: if child.getText() == ';': field_text = field_text[:len(field_text) - 1] + ';' break field_text += child.getText() + ' ' name = ctx.variableDeclarators().variableDeclarator(0).variableDeclaratorId().IDENTIFIER().getText() if ctx.typeType().classOrInterfaceType() is not None and \ ctx.typeType().classOrInterfaceType().getText() == self.source_class: self.field_that_has_source.append(name) return modifier_text = '' for modifier in ctx.parentCtx.parentCtx.modifier(): modifier_text += modifier.getText() + ' ' field_text = modifier_text + field_text if self.is_source_class: self.source_class_data['fields'].append(Field(name=name, text=field_text)) else: self.target_class_data['fields'].append(Field(name=name, text=field_text)) def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): if self.is_target_class: if ctx.typeType().classOrInterfaceType().getText() == self.source_class: grand_parent_ctx = ctx.parentCtx.parentCtx self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=grand_parent_ctx.start.tokenIndex, to_idx=grand_parent_ctx.stop.tokenIndex) def enterConstructorDeclaration(self, ctx: JavaParserLabeled.ConstructorDeclarationContext): if self.is_source_class or self.is_target_class: if ctx.formalParameters().formalParameterList(): constructor_parameters = [ctx.formalParameters().formalParameterList().children[i] for i in range(len(ctx.formalParameters().formalParameterList().children)) if i % 2 == 0] else: constructor_parameters = [] constructor_text = '' for modifier in ctx.parentCtx.parentCtx.modifier(): constructor_text += modifier.getText() + ' ' if self.is_source_class: constructor_text += self.target_class else: constructor_text += ctx.IDENTIFIER().getText() constructor_text += ' ( ' for parameter in constructor_parameters: constructor_text += parameter.typeType().getText() + ' ' constructor_text += parameter.variableDeclaratorId().getText() + ', ' if constructor_parameters: constructor_text = constructor_text[:len(constructor_text) - 2] constructor_text += ')\n\t{' constructor_text += self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.block().start.tokenIndex + 1, stop=ctx.block().stop.tokenIndex - 1 ) constructor_text += '}\n' if self.is_source_class: self.source_class_data['constructors'].append(ConstructorOrMethod( name=self.target_class, parameters=[Parameter(parameter_type=p.typeType().getText(), name=p.variableDeclaratorId().IDENTIFIER().getText()) for p in constructor_parameters], text=constructor_text, constructor_body=self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.block().start.tokenIndex + 1, stop=ctx.block().stop.tokenIndex - 1 ))) else: self.target_class_data['constructors'].append(ConstructorOrMethod( name=self.target_class, parameters=[Parameter(parameter_type=p.typeType().getText(), name=p.variableDeclaratorId().IDENTIFIER().getText()) for p in constructor_parameters], text=constructor_text, constructor_body=self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.block().start.tokenIndex + 1, stop=ctx.block().stop.tokenIndex - 1 ))) proper_constructor = get_proper_constructor(self.target_class_data['constructors'][-1], self.source_class_data['constructors']) if proper_constructor is None: return self.token_stream_rewriter.insertBeforeIndex( index=ctx.stop.tokenIndex, text=proper_constructor.constructorBody ) def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext): if self.is_source_class or self.is_target_class: if ctx.formalParameters().formalParameterList(): method_parameters = [ctx.formalParameters().formalParameterList().children[i] for i in range(len(ctx.formalParameters().formalParameterList().children)) if i % 2 == 0] else: method_parameters = [] method_text = '' for modifier in ctx.parentCtx.parentCtx.modifier(): method_text += modifier.getText() + ' ' type_text = ctx.typeTypeOrVoid().getText() if type_text == self.source_class: type_text = self.target_class if self.is_target_class: self.token_stream_rewriter.replace( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.typeTypeOrVoid().start.tokenIndex, to_idx=ctx.typeTypeOrVoid().stop.tokenIndex, text=type_text ) method_text += type_text + ' ' + ctx.IDENTIFIER().getText() method_text += ' ( ' for parameter in method_parameters: method_text += parameter.typeType().getText() + ' ' method_text += parameter.variableDeclaratorId().getText() + ', ' if method_parameters: method_text = method_text[:len(method_text) - 2] method_text += ')\n\t{' method_text += self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.methodBody().start.tokenIndex + 1, stop=ctx.methodBody().stop.tokenIndex - 1 ) method_text += '}\n' if self.is_source_class: self.source_class_data['methods'].append(ConstructorOrMethod( name=ctx.IDENTIFIER().getText(), parameters=[Parameter( parameter_type=p.typeType().getText(), name=p.variableDeclaratorId().IDENTIFIER().getText()) for p in method_parameters], text=method_text)) else: self.target_class_data['methods'].append(ConstructorOrMethod( name=ctx.IDENTIFIER().getText(), parameters=[Parameter( parameter_type=p.typeType().getText(), name=p.variableDeclaratorId().IDENTIFIER().getText()) for p in method_parameters], text=method_text)) def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context): if ctx.IDENTIFIER() is None and ctx.IDENTIFIER().getText() in self.field_that_has_source: field_text = ctx.expression().getText() self.token_stream_rewriter.replace( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.start.tokenIndex, to_idx=ctx.stop.tokenIndex, text=field_text ) def exitExpression21(self, ctx: JavaParserLabeled.Expression21Context): if self.has_source_new: self.has_source_new = False self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.start.tokenIndex, to_idx=ctx.stop.tokenIndex + 1 ) def enterExpression4(self, ctx: JavaParserLabeled.Expression4Context): if ctx.children[-1].children[0].getText() == self.source_class: self.has_source_new = True def enterCreatedName0(self, ctx: JavaParserLabeled.CreatedName0Context): if ctx.IDENTIFIER(0).getText() == self.source_class and self.target_class: self.token_stream_rewriter.replaceIndex( index=ctx.start.tokenIndex, text=self.target_class ) def enterCreatedName1(self, ctx: JavaParserLabeled.CreatedName1Context): if ctx.getText() == self.source_class and self.target_class: self.token_stream_rewriter.replaceIndex( index=ctx.start.tokenIndex, text=self.target_class ) def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext): class_type = ctx.typeType().classOrInterfaceType() if class_type: if class_type.IDENTIFIER(0).getText() == self.source_class and self.target_class: self.token_stream_rewriter.replaceIndex( index=class_type.start.tokenIndex, text=self.target_class ) def enterQualifiedName(self, ctx: JavaParserLabeled.QualifiedNameContext): if ctx.IDENTIFIER(0).getText() == self.source_class and self.target_class: self.token_stream_rewriter.replaceIndex( index=ctx.start.tokenIndex, text=self.target_class ) def exitExpression0(self, ctx: JavaParserLabeled.Expression0Context): if ctx.primary().getText() == self.source_class and self.target_class: self.token_stream_rewriter.replaceIndex( index=ctx.start.tokenIndex, text=self.target_class ) def enterLocalVariableDeclaration(self, ctx: JavaParserLabeled.LocalVariableDeclarationContext): if ctx.typeType().classOrInterfaceType(): if ctx.typeType().classOrInterfaceType().getText() == self.source_class and self.target_class: self.token_stream_rewriter.replace( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, from_idx=ctx.typeType().start.tokenIndex, to_idx=ctx.typeType().stop.tokenIndex, text=self.target_class )