def do_refactor(self): program = symbol_table.get_program(self.source_filenames, print_status=True) if self.package_name not in program.packages or any( class_name not in program.packages[self.package_name].classes for class_name in self.class_names) or \ any(method_key not in program.packages[self.package_name].classes[class_name].methods for class_name in self.class_names for method_key in self.method_keys): return False method_returntypes = {} method_parameters = {} method_names = [] for method_key in self.method_keys: method_names.append(method_key[:method_key.find('(')]) rewriter = symbol_table.Rewriter(program, self.filename_mapping) for class_name in self.class_names: c: symbol_table.Class = program.packages[self.package_name].classes[class_name] # Add implements to the class has_superinterface = False if c.parser_context.IMPLEMENTS() is not None: # old: c.parser_context.superinterfaces() t = symbol_table.TokensInfo( c.parser_context.typeList()) # old: c.parser_context.superinterfaces() has_superinterface = True elif c.parser_context.EXTENDS() is not None: # old: c.parser_context.superclass() t = symbol_table.TokensInfo(c.parser_context.typeType()) # old: c.parser_context.superclass() elif c.parser_context.typeParameters() is not None: t = symbol_table.TokensInfo(c.parser_context.typeParameters()) else: # old: TokensInfo(c.parser_context.identifier()) t = symbol_table.TokensInfo(c.parser_context) t.stop = c.parser_context.IDENTIFIER().getSymbol().tokenIndex rewriter.insert_after(t, (", " if has_superinterface else " implements ") + self.interface_name) for method_key in self.method_keys: m: symbol_table.Method = c.methods[method_key] # Check if the return types / parameter types are the same # Or add to dictionary if method_key in method_returntypes: if method_returntypes[method_key] != m.returntype: return False if len(method_parameters[method_key]) != len(m.parameters): return False for i in range(len(m.parameters)): if method_parameters[method_key][i][0] != m.parameters[i][0]: return False else: method_returntypes[method_key] = m.returntype method_parameters[method_key] = m.parameters # Manage method modifiers if len(m.modifiers_parser_contexts) > 0: t = symbol_table.TokensInfo(m.modifiers_parser_contexts[0]) else: t = m.get_tokens_info() rewriter.insert_before_start( t, # old: m.get_tokens_info() # without requiring t ("" if "@Override" in m.modifiers else "@Override\n ") + ("" if "public" in m.modifiers else "public ") ) for i in range(len(m.modifiers)): mm = m.modifiers[i] if mm == "private" or mm == "protected": t = symbol_table.TokensInfo( m.modifiers_parser_contexts[i]) # old: m.parser_context.methodModifier(i) rewriter.replace(t, "") # Change variable types to the interface if only interface methods are used. for package_name in program.packages: p: symbol_table.Package = program.packages[package_name] for class_name in p.classes: c: symbol_table.Class = p.classes[class_name] fields_of_interest = {} for fn in c.fields: f: symbol_table.Field = c.fields[fn] d = False for cn in self.class_names: if (f.datatype == cn and f.file_info.has_imported_class(package_name, cn)) \ or (package_name is not None and f.datatype == package_name + '.' + cn): d = True break if d and "private" in f.modifiers: fields_of_interest[f.name] = f for method_key in c.methods: m: symbol_table.Method = c.methods[method_key] vars_of_interest = {} for item in m.body_local_vars_and_expr_names: if isinstance(item, symbol_table.LocalVariable): for cn in self.class_names: if (item.datatype == cn and c.file_info.has_imported_class(package_name, cn)) \ or (package_name is not None and item.datatype == package_name + '.' + cn): vars_of_interest[item.identifier] = item break if isinstance(item, symbol_table.MethodInvocation): if len(item.dot_separated_identifiers) == 2 or \ (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[ 0] == "this"): if item.dot_separated_identifiers[-2] in vars_of_interest: if item.dot_separated_identifiers[-1] not in method_names: vars_of_interest.pop(item.dot_separated_identifiers[-2]) elif item.dot_separated_identifiers[-2] in fields_of_interest \ and item.dot_separated_identifiers[-1] not in method_names: fields_of_interest.pop(item.dot_separated_identifiers[-2]) for var_name in vars_of_interest: var = vars_of_interest[var_name] if m.file_info.has_imported_package(package_name): # old: var.parser_context.unannType() rewriter.replace(symbol_table.TokensInfo(var.parser_context.typeType()), self.interface_name) else: if package_name is None: break # old: var.parser_context.unannType() rewriter.replace(symbol_table.TokensInfo(var.parser_context.typeType()), package_name + '.' + self.interface_name) for field_name in fields_of_interest: f = fields_of_interest[field_name] if c.file_info.has_imported_package(package_name): typename = self.interface_name else: if package_name is None: break typename = package_name + '.' + self.interface_name if len(f.neighbor_names) == 0: rewriter.replace(symbol_table.TokensInfo(f.parser_context.typeType()), typename) # old: f.parser_context.unannType() else: if not any(nn in fields_of_interest for nn in f.neighbor_names): t = symbol_table.TokensInfo( f.all_variable_declarator_contexts[f.index_in_variable_declarators]) if f.index_in_variable_declarators == 0: t.stop = symbol_table.TokensInfo( f.all_variable_declarator_contexts[f.index_in_variable_declarators + 1]).start - 1 else: t.start = symbol_table.TokensInfo( f.all_variable_declarator_contexts[f.index_in_variable_declarators - 1]).start + 1 rewriter.replace(t, "") rewriter.insert_after( f.get_tokens_info(), "\n private " + typename + " " + f.name + ( " = " + f.initializer + ";" if f.initializer is not None else ";") ) # Create the interface interface_file_content = ( "package " + package_name + ";\n\n" + "public interface " + self.interface_name + "\n" + "{\n" ) for method_key in self.method_keys: method_name = method_key[:method_key.find('(')] interface_file_content += " " + method_returntypes[method_key] + " " + method_name + "(" if len(method_parameters[method_key]) > 0: interface_file_content += method_parameters[method_key][0][0] + " " + method_parameters[method_key][0][ 1] for i in range(1, len(method_parameters[method_key])): param = method_parameters[method_key][i] interface_file_content += ", " + param[0] + " " + param[1] interface_file_content += ");\n" interface_file_content += "}\n" if not os.path.exists(self.interface_filename[:self.interface_filename.rfind('/')]): os.makedirs(self.interface_filename[:self.interface_filename.rfind('/')]) file = open(self.interface_filename, "w+", encoding='utf8', errors='ignore') file.write(interface_file_content) file.close() rewriter.apply() return True
def do_refactor(self): program = symbol_table.get_program(self.source_filenames, print_status=False) superclass: symbol_table.Class = program.packages[ self.package_name].classes[self.superclass_name] if not self.pre_condition_check(program, superclass): print(f"Cannot push-down field from {superclass.name}") return False other_derived_classes = [] classes_to_add_to = [] for pn in program.packages: p: symbol_table.Package = program.packages[pn] for cn in p.classes: c: symbol_table.Class = p.classes[cn] if ((c.superclass_name == self.superclass_name and c.file_info.has_imported_class( self.package_name, self.superclass_name)) or (self.package_name is not None and c.superclass_name == self.package_name + '.' + self.superclass_name)): # all_derived_classes.append(c) if len(self.class_names) == 0 or cn in self.class_names: if self.field_name in c.fields: print("some classes have same variable") return False else: classes_to_add_to.append(c) else: other_derived_classes.append(c) # Check if the field is used from the superclass or other derived classes for pn in program.packages: p: symbol_table.Package = program.packages[pn] for cn in p.classes: c: symbol_table.Class = p.classes[cn] has_imported_superclass = c.file_info.has_imported_class( self.package_name, self.superclass_name) fields_of_superclass_type_or_others = [] for fn in c.fields: f: symbol_table.Field = c.fields[fn] if (f.name == self.field_name and has_imported_superclass) \ or (self.package_name is not None and f.name == ( self.package_name + '.' + self.superclass_name)): fields_of_superclass_type_or_others.append(f.name) if any((c.file_info.has_imported_class( o.package_name, o.name) and f.datatype == o.name) or f.datatype == (o.package_name + '.' + o.name) for o in other_derived_classes): fields_of_superclass_type_or_others.append(f.name) for mk in c.methods: m: symbol_table.Method = c.methods[mk] local_vars_of_superclass_type_or_others = [] for item in m.body_local_vars_and_expr_names: if isinstance(item, symbol_table.LocalVariable): if (item.datatype == self.superclass_name and has_imported_superclass) \ or item.datatype == (self.package_name + '.' + self.superclass_name): local_vars_of_superclass_type_or_others.append( item.identifier) if any((c.file_info.has_imported_class( o.package_name, o.name) and item.datatype == o.name) or item.datatype == ( o.package_name + '.' + o.name) for o in other_derived_classes): local_vars_of_superclass_type_or_others.append( item.identifier) elif isinstance(item, symbol_table.ExpressionName): if item.dot_separated_identifiers[-1] == self.field_name \ and ( (len(item.dot_separated_identifiers) == 2) or (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[ 0] == "this") ) and ( (item.dot_separated_identifiers[ -2] in local_vars_of_superclass_type_or_others and len( item.dot_separated_identifiers) == 2) or item.dot_separated_identifiers[-2] in fields_of_superclass_type_or_others ): return False rewriter = symbol_table.Rewriter(program, self.filename_mapping) field = superclass.fields[self.field_name] if len(field.neighbor_names) == 0: rewriter.replace(field.get_tokens_info(), "") # Have to remove the modifiers too, because of the new grammar. for mod_ctx in field.modifiers_parser_contexts: rewriter.replace(symbol_table.TokensInfo(mod_ctx), "") else: i = field.index_in_variable_declarators var_ctxs = field.all_variable_declarator_contexts if i == 0: to_remove = symbol_table.TokensInfo(var_ctxs[i]) to_remove.stop = symbol_table.TokensInfo( var_ctxs[i + 1]).start - 1 # Include the ',' after it rewriter.replace(to_remove, "") else: to_remove = symbol_table.TokensInfo(var_ctxs[i]) to_remove.start = symbol_table.TokensInfo( var_ctxs[i - 1]).stop + 1 # Include the ',' before it rewriter.replace(to_remove, "") is_public = "public" in field.modifiers is_protected = "protected" in field.modifiers modifier = ("public " if is_public else ("protected " if is_protected else "")) for c in classes_to_add_to: c_body_start = symbol_table.TokensInfo( c.parser_context.classBody()) c_body_start.stop = c_body_start.start # Start and stop both point to the '{' rewriter.insert_after( c_body_start, ("\n " + modifier + field.datatype + " " + self.field_name + ((" = " + field.initializer) if field.initializer is not None else "") + ";")) rewriter.apply() return True
def do_refactor(self): program = symbol_table.get_program(self.source_filenames, print_status=False) # print(program.packages) if ( self.package_name not in program.packages or self.class_name not in program.packages[self.package_name].classes or self.field_name not in program.packages[self.package_name].classes[self.class_name].fields ): logger.error("One or more inputs are not valid.") return False _class: symbol_table.Class = program.packages[self.package_name].classes[self.class_name] if _class.superclass_name is None: logger.error("Super class is none.") return False superclass_name = _class.superclass_name if not program.packages[self.package_name].classes.get(superclass_name): logger.error("Super class package is none!") return False superclass: symbol_table.Class = program.packages[self.package_name].classes[superclass_name] superclass_body_start = symbol_table.TokensInfo(superclass.parser_context.classBody()) superclass_body_start.stop = superclass_body_start.start # Start and stop both point to the '{' if self.field_name in superclass.fields: logger.error("Field is in superclass fields.") return False datatype = _class.fields[self.field_name].datatype fields_to_remove = [] for pn in program.packages: p: symbol_table.Package = program.packages[pn] for cn in p.classes: c: symbol_table.Class = p.classes[cn] if ( ( ( c.superclass_name == superclass_name and c.file_info.has_imported_class(self.package_name, superclass_name) ) or (self.package_name is not None and c.superclass_name == superclass_name) ) and self.field_name in c.fields and c.fields[self.field_name].datatype == datatype ): fields_to_remove.append(c.fields[self.field_name]) if len(fields_to_remove) == 0: logger.error("No fields to remove.") return False is_public = False is_protected = True for field in fields_to_remove: field: symbol_table.Field = field is_public = is_public or "public" in field.modifiers is_protected = is_protected and ("protected" in field.modifiers or "private" in field.modifiers) rewriter = symbol_table.Rewriter(program, self.filename_mapping) rewriter.insert_after(superclass_body_start, "\n\t" + ( "public " if is_public else ( "protected " if is_protected else "")) + datatype + " " + self.field_name + ";") for field in fields_to_remove: if len(field.neighbor_names) == 0: rewriter.replace(field.get_tokens_info(), "") # Have to remove the modifiers too, because of the new grammar. for mod_ctx in field.modifiers_parser_contexts: rewriter.replace(symbol_table.TokensInfo(mod_ctx), "") else: i = field.index_in_variable_declarators var_ctxs = field.all_variable_declarator_contexts if i == 0: to_remove = symbol_table.TokensInfo(var_ctxs[i]) to_remove.stop = symbol_table.TokensInfo( var_ctxs[i + 1]).start - 1 # Include the ',' after it rewriter.replace(to_remove, "") else: to_remove = symbol_table.TokensInfo(var_ctxs[i]) to_remove.start = symbol_table.TokensInfo( var_ctxs[i - 1]).stop + 1 # Include the ',' before it rewriter.replace(to_remove, "") # Add initializer to class constructor if initializer exists in field declaration if field.initializer is not None: _class: symbol_table.Class = program.packages[field.package_name].classes[field.class_name] initializer_statement = (field.name + " = " + ("new " + field.datatype + " " if field.initializer.startswith('{') else "") + field.initializer + ";") # Todo: Requires better handling if 'new' in initializer_statement and '()' in initializer_statement: initializer_statement = initializer_statement.replace('new', 'new ') has_contructor = False for class_body_decl in _class.parser_context.classBody().getChildren(): if class_body_decl.getText() in ['{', '}']: continue member_decl = class_body_decl.memberDeclaration() if member_decl is not None: constructor = member_decl.constructorDeclaration() if constructor is not None: body = constructor.constructorBody # Start token = '{' body_start = symbol_table.TokensInfo(body) body_start.stop = body_start.start # Start and stop both point to the '{' rewriter.insert_after(body_start, "\n\t" + initializer_statement) has_contructor = True if not has_contructor: body = _class.parser_context.classBody() body_start = symbol_table.TokensInfo(body) body_start.stop = body_start.start # Start and stop both point to the '{' rewriter.insert_after(body_start, "\n\t" + _class.modifiers[ 0] + " " + _class.name + "() { " + initializer_statement + " }" ) rewriter.apply() # Todo: check for multilevel inheritance recursively. # if _class.superclass_name is not None: # PullUpFieldRefactoring(self.source_filenames, self.package_name, _class.superclass_name, "id").do_refactor() return True