class PropagateFieldUsageListener(JavaParserLabeledListener): def __init__(self, common_token_stream: CommonTokenStream, object_name: str, field_name: str): self.token_stream_rewriter = TokenStreamRewriter(common_token_stream) self.field_name = field_name self.object_name = object_name def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context): identifier = ctx.IDENTIFIER() if identifier is not None: if identifier.getText() == self.field_name: # Found field usage! self.token_stream_rewriter.insertBeforeToken( token=ctx.stop, text=self.object_name + ".", program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME )
class NewClassPropagation(JavaParserLabeledListener): def __init__(self, common_token_stream: CommonTokenStream, method_map: dict, source_class: str, moved_fields: list): self.token_stream_rewriter = TokenStreamRewriter(common_token_stream) self.method_map = method_map self.source_class = source_class self.moved_fields = moved_fields self.fields = None def enterMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext): self.fields = self.method_map.get(ctx.IDENTIFIER().getText()) if self.fields: if ctx.formalParameters().getText() == "()": text = f"{self.source_class} ref" else: text = f", {self.source_class} ref" self.token_stream_rewriter.insertBeforeToken( token=ctx.formalParameters().stop, text=text, program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME ) def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext): self.fields = None def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context): if self.fields and ctx.expression().getText() == "this": for field in self.fields: if field in ctx.getText() and field not in self.moved_fields: self.token_stream_rewriter.replaceSingleToken( token=ctx.expression().primary().start, text="ref" ) def enterPrimary4(self, ctx: JavaParserLabeled.Primary4Context): if self.fields: field_name = ctx.getText() if field_name in self.fields and field_name not in self.moved_fields: self.fields.remove(field_name) self.token_stream_rewriter.insertBeforeToken( token=ctx.start, text="ref." )
class ExtractClassRefactoringListener(JavaParserLabeledListener): """ To implement extract class refactoring based on its actors. Creates a new class and move fields and methods from the old class to the new one """ def __init__(self, common_token_stream: CommonTokenStream = None, source_class: str = None, new_class: str = None, moved_fields=None, moved_methods=None, method_map: dict = None): if method_map is None: self.method_map = {} else: self.method_map = method_map if moved_methods is None: self.moved_methods = [] else: self.moved_methods = moved_methods if moved_fields is None: self.moved_fields = [] else: self.moved_fields = moved_fields if common_token_stream is None: raise ValueError("common_token_stream is None") else: self.token_stream_rewriter = TokenStreamRewriter( common_token_stream) if source_class is None: raise ValueError("source_class is None") else: self.source_class = source_class if new_class is None: raise ValueError("new_class is None") else: self.new_class = new_class self.is_source_class = False self.detected_field = None self.detected_method = None self.TAB = "\t" self.NEW_LINE = "\n" self.code = "" self.package_name = "" self.parameters = [] self.object_name = self.new_class.replace( self.new_class, self.new_class[0].lower() + self.new_class[1:]) self.modifiers = "" self.do_increase_visibility = False temp = [] for method in moved_methods: temp.append(self.method_map.get(method)) self.fields_to_increase_visibility = set().union(*temp) def enterPackageDeclaration( self, ctx: JavaParserLabeled.PackageDeclarationContext): if ctx.qualifiedName() and not self.package_name: self.package_name = ctx.qualifiedName().getText() self.code += f"package {self.package_name};{self.NEW_LINE}" def enterImportDeclaration( self, ctx: JavaParserLabeled.ImportDeclarationContext): i = self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=ctx.start.tokenIndex, stop=ctx.stop.tokenIndex) self.code += f"\n{i}\n" def enterClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): class_identifier = str(ctx.children[1]) if class_identifier == self.source_class: self.is_source_class = True self.code += self.NEW_LINE * 2 self.code += f"// New class({self.new_class}) generated by CodART" + self.NEW_LINE self.code += f"class {self.new_class}{self.NEW_LINE}" + "{" + self.NEW_LINE else: self.is_source_class = False def enterClassBody(self, ctx: JavaParserLabeled.ClassBodyContext): if self.is_source_class: self.token_stream_rewriter.insertAfterToken( token=ctx.start, text="\n\t" + f"public {self.new_class} {self.object_name} = new {self.new_class}();", program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME) def exitClassDeclaration(self, ctx: JavaParserLabeled.ClassDeclarationContext): class_identifier = str(ctx.children[1]) if class_identifier == self.source_class: self.code += "}" self.is_source_class = False else: self.is_source_class = True def exitCompilationUnit(self, ctx: JavaParserLabeled.CompilationUnitContext): pass def enterVariableDeclaratorId( self, ctx: JavaParserLabeled.VariableDeclaratorIdContext): if not self.is_source_class: return None field_identifier = ctx.IDENTIFIER().getText() if field_identifier in self.moved_fields: self.detected_field = field_identifier def enterFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): field_names = ctx.variableDeclarators().getText().split(",") for field in field_names: if field in self.fields_to_increase_visibility: for modifier in ctx.parentCtx.parentCtx.modifier(): if modifier.getText() == "private": self.token_stream_rewriter.replaceSingleToken( token=modifier.start, text="public") def exitFieldDeclaration(self, ctx: JavaParserLabeled.FieldDeclarationContext): if not self.is_source_class: return None if not self.detected_field: return None field_names = ctx.variableDeclarators().getText() field_names = field_names.split(',') grand_parent_ctx = ctx.parentCtx.parentCtx if any([self.detected_field in i for i in field_names]): field_type = ctx.typeType().getText() if len(field_names) == 1: self.code += f"public {field_type} {field_names[0]};{self.NEW_LINE}" else: self.code += f"public {field_type} {self.detected_field};{self.NEW_LINE}" # delete field from source class for i in field_names: if self.detected_field in i: field_names.remove(i) if field_names: self.token_stream_rewriter.replaceRange( from_idx=grand_parent_ctx.start.tokenIndex, to_idx=grand_parent_ctx.stop.tokenIndex, text=f"public {field_type} {','.join(field_names)};\n") else: self.token_stream_rewriter.delete( program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME, from_idx=grand_parent_ctx.start.tokenIndex, to_idx=grand_parent_ctx.stop.tokenIndex) self.detected_field = None def enterMethodDeclaration( self, ctx: JavaParserLabeled.MethodDeclarationContext): if not self.is_source_class: return None method_identifier = ctx.IDENTIFIER().getText() if method_identifier in self.moved_methods: self.detected_method = method_identifier def enterFormalParameter(self, ctx: JavaParserLabeled.FormalParameterContext): if self.detected_method: self.parameters.append( ctx.variableDeclaratorId().IDENTIFIER().getText()) def exitMethodDeclaration(self, ctx: JavaParserLabeled.MethodDeclarationContext): if not self.is_source_class: return None method_identifier = ctx.IDENTIFIER().getText() if self.detected_method == method_identifier: start_index = ctx.start.tokenIndex stop_index = ctx.stop.tokenIndex method_text = self.token_stream_rewriter.getText( program_name=self.token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=start_index, stop=stop_index) self.code += self.NEW_LINE + ("public " + method_text + self.NEW_LINE) # delegate method body in source class if self.method_map.get(method_identifier): self.parameters.append("this") self.token_stream_rewriter.replaceRange( from_idx=ctx.methodBody().start.tokenIndex, to_idx=stop_index, text="{" + f"\nreturn this.{self.object_name}.{self.detected_method}(" + ",".join(self.parameters) + ");\n" + "}") self.parameters = [] self.detected_method = None def enterExpression1(self, ctx: JavaParserLabeled.Expression1Context): identifier = ctx.IDENTIFIER() if identifier is not None: if identifier.getText( ) in self.moved_fields and self.detected_method not in self.moved_methods: # Found field usage! self.token_stream_rewriter.insertBeforeToken( token=ctx.stop, text=self.object_name + ".", program_name=self.token_stream_rewriter. DEFAULT_PROGRAM_NAME)