def test(): source_filenames = [ "tests/move_method_test/source.java", "tests/move_method_test/target.java" ] program = get_program(source_filenames) print(program.packages["tests.utils_test2"].classes["source"].methods["c"]. get_text_from_file())
def test_utils(): mylist = ["tests/utils_test.java"] movemethods = move_method(mylist) program = get_program(mylist) print(program) for package_name in program.packages: package = program.packages[package_name] print(package) for class_name in package.classes: _class = package.classes[class_name] print(_class) for method_name in _class.methods: method = _class.methods[method_name] print(_class.methods[method_name]) for _param in method.parameters: print(_param[0]) for key in movemethods: n = movemethods[key] print("......") print(key) print(n)
PATH = "./models/squeezenet1_0_MNIST_shuffled" batch_size = 16 train_loader, test_loader = get_mnist(batch_size) pretrained_model = torchvision.models.squeezenet1_0(pretrained=True).eval() input_size = 224 patch_size = 36 model = ProgrammingShuffledNetwork(pretrained_model, input_size, patch_size, blur_sigma=.5) optimizer = T.optim.Adam([model.p]) nb_epochs = 20 nb_freq = 10 model, loss_history = train(model, train_loader, nb_epochs, optimizer, C=.05, reg_fun=reg_l2, save_freq=nb_freq, save_path=PATH, test_loader=test_loader, device=DEVICE) program = get_program(model, PATH, imshow=True)
def pullup_method_refactoring( source_filenames: list, package_name: str, class_name: str, method_key: str, filename_mapping=lambda x: x + ".rewritten.java"): program = get_program(source_filenames) #گرفتن پکیج های برنامه _sourceclass = program.packages[package_name].classes[class_name] target_class_name = _sourceclass.superclass_name static = 0 removemethod = get_removemethods( program, package_name, target_class_name, method_key, class_name) #متد های مشابه در کلاس های دیگر _targetclass = program.packages[package_name].classes[target_class_name] _method_name = program.packages[package_name].classes[class_name].methods[ method_key] tokens_info = TokensInfo(_method_name.parser_context) exps = tokens_info.get_token_index( tokens_info.token_stream.tokens, tokens_info.start, tokens_info.stop ) #لیست متغیر های داخل بدنه کلاس که داخل متد استفاده شده اند if _method_name.is_constructor: return False #if method use param of class body return false for token in exps: if token.text in _sourceclass.fields: return False if bool(_method_name.body_method_invocations_without_typename) == True: return False Rewriter_ = Rewriter(program, filename_mapping) for remove in removemethod: _methodd = removemethod[remove] if _methodd != None: _methodds = _methodd[0] _method = program.packages[package_name].classes[remove].methods[ str(_methodds)] _method_token_info = TokensInfo(_method.parser_context) Rewriter_.replace(_method_token_info, " ") class_tokens_info = TokensInfo(_targetclass.parser_context) singlefileelement = SingleFileElement(_method_name.parser_context, _method_name.filename) token_stream_rewriter = TokenStreamRewriter( singlefileelement.get_token_stream()) strofmethod = token_stream_rewriter.getText( program_name=token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=tokens_info.start, stop=tokens_info.stop) Rewriter_.insert_before(tokens_info=class_tokens_info, text=strofmethod) Rewriter_.apply() #در کلاس های دیگر هر جا که از این متد استفاده شده باید اپدیت شود for package_names in program.packages: package = program.packages[package_names] for class_ in package.classes: _class = package.classes[class_] for method_ in _class.methods: __method = _class.methods[method_] for inv in __method.body_method_invocations: invc = __method.body_method_invocations[inv] method_name = method_key[:method_key.find('(')] if (invc[0] == method_name & package_names == package_name): inv_tokens_info = TokensInfo(inv) if (static == 0): class_token_info = TokensInfo(_class.body_context) Rewriter_.insert_after_start( class_token_info, target_class_name + " " + str.lower(target_class_name) + "=" + "new " + target_class_name + "();") Rewriter_.apply() Rewriter_.replace(inv_tokens_info, target_class_name) Rewriter_.apply() return True
def pullup_field(source_filenames: list, package_name: str, class_name: str, field_name: str, filename_mapping = lambda x: (x[:-5] if x.endswith(".java") else x) + ".re.java") -> bool: program = utils.get_program(source_filenames, print_status=True) if package_name not in program.packages \ or class_name not in program.packages[package_name].classes \ or field_name not in program.packages[package_name].classes[class_name].fields: return False _class: utils_listener_fast.Class = program.packages[package_name].classes[class_name] if _class.superclass_name is None: return False superclass_name = _class.superclass_name superclass: utils_listener_fast.Class = program.packages[package_name].classes[superclass_name] superclass_body_start = utils_listener_fast.TokensInfo(superclass.parser_context.classBody()) superclass_body_start.stop = superclass_body_start.start # Start and stop both point to the '{' if field_name in superclass.fields: return False datatype = _class.fields[field_name].datatype fields_to_remove = [] for pn in program.packages: p: utils_listener_fast.Package = program.packages[pn] for cn in p.classes: c: utils_listener_fast.Class = p.classes[cn] if ((c.superclass_name == superclass_name and c.file_info.has_imported_class(package_name, superclass_name)) \ or (package_name is not None and c.superclass_name == package_name + '.' + superclass_name)) \ and field_name in c.fields \ and c.fields[field_name].datatype == datatype: fields_to_remove.append(c.fields[field_name]) if len(fields_to_remove) == 0: return False is_public = False is_protected = True for field in fields_to_remove: field: utils_listener_fast.Field = field is_public = is_public or "public" in field.modifiers is_protected = is_protected and ("protected" in field.modifiers or "private" in field.modifiers) rewriter = utils.Rewriter(program, filename_mapping) rewriter.insert_after(superclass_body_start, "\n " + ("public " if is_public else ("protected " if is_protected else "")) + datatype + " " + field_name + ";") for field in fields_to_remove: if len(field.neighbor_names) == 0: rewriter.replace(field.get_tokens_info(), "") # Have to remove the modifiers too, because of the new grammar. for mod_ctx in field.modifiers_parser_contexts: rewriter.replace(utils_listener_fast.TokensInfo(mod_ctx), "") else: i = field.index_in_variable_declarators var_ctxs = field.all_variable_declarator_contexts if i == 0: to_remove = utils_listener_fast.TokensInfo(var_ctxs[i]) to_remove.stop = utils_listener_fast.TokensInfo(var_ctxs[i + 1]).start - 1 # Include the ',' after it rewriter.replace(to_remove, "") else: to_remove = utils_listener_fast.TokensInfo(var_ctxs[i]) to_remove.start = utils_listener_fast.TokensInfo(var_ctxs[i - 1]).stop + 1 # Include the ',' before it rewriter.replace(to_remove, "") # Add initializer to class constructor if initializer exists in field declaration if field.initializer is not None: _class: utils_listener_fast.Class = program.packages[field.package_name].classes[field.class_name] initializer_statement = (field.name + " = " + ("new " + field.datatype + " " if field.initializer.startswith('{') else "") + field.initializer + ";") has_contructor = False for class_body_decl in _class.parser_context.classBody().getChildren(): if class_body_decl.getText() in ['{', '}']: continue member_decl = class_body_decl.memberDeclaration() if member_decl is not None: constructor = member_decl.constructorDeclaration() if constructor is not None: body = constructor.constructorBody # Start token = '{' body_start = utils_listener_fast.TokensInfo(body) body_start.stop = body_start.start # Start and stop both point to the '{' rewriter.insert_after(body_start, "\n " + initializer_statement) has_contructor = True if not has_contructor: body = _class.parser_context.classBody() body_start = utils_listener_fast.TokensInfo(body) body_start.stop = body_start.start # Start and stop both point to the '{' rewriter.insert_after(body_start, "\n " + _class.name + "() { " + initializer_statement + " }" ) rewriter.apply() return True
def pushdown_field( source_filenames: list, package_name: str, superclass_name: str, field_name: str, class_names: list = [], filename_mapping=lambda x: (x[:-5] if x.endswith(".java") else x) + ".re.java" ) -> bool: program = utils.get_program(source_filenames, print_status=True) if package_name not in program.packages \ or superclass_name not in program.packages[package_name].classes \ or field_name not in program.packages[package_name].classes[superclass_name].fields: return False superclass: utils_listener_fast.Class = program.packages[ package_name].classes[superclass_name] for mk in superclass.methods: m: utils_listener_fast.Method = superclass.methods[mk] for item in m.body_local_vars_and_expr_names: if isinstance(item, utils_listener_fast.ExpressionName): if ((len(item.dot_separated_identifiers) == 1 and item.dot_separated_identifiers[0] == field_name) or (len(item.dot_separated_identifiers) == 2 and item.dot_separated_identifiers[0] == "this" and item.dot_separated_identifiers[1] == field_name)): return False #all_derived_classes = [] # Not needed other_derived_classes = [] classes_to_add_to = [] for pn in program.packages: p: utils_listener_fast.Package = program.packages[pn] for cn in p.classes: c: utils_listener_fast.Class = p.classes[cn] if ((c.superclass_name == superclass_name and c.file_info.has_imported_class(package_name, superclass_name)) \ or (package_name is not None and c.superclass_name == package_name + '.' + superclass_name)): #all_derived_classes.append(c) if len(class_names) == 0 or cn in class_names: if field_name in c.fields: return False else: classes_to_add_to.append(c) else: other_derived_classes.append(c) # Check if the field is used from the superclass or other derived classes for pn in program.packages: p: utils_listener_fast.Package = program.packages[pn] for cn in p.classes: c: utils_listener_fast.Class = p.classes[cn] has_imported_superclass = c.file_info.has_imported_class( package_name, superclass_name) fields_of_superclass_type_or_others = [] for fn in c.fields: f: utils_listener_fast.Field = c.fields[fn] if (f.datatype == superclass_name and has_imported_superclass) \ or (package_name is not None and f.datatype == (package_name + '.' + superclass_name)): fields_of_superclass_type_or_others.append(f.name) if any((c.file_info.has_imported_class(o.package_name, o.name) and f.datatype == o.name) or f.datatype == ( o.package_name + '.' + o.name) for o in other_derived_classes): fields_of_superclass_type_or_others.append(f.name) for mk in c.methods: m: utils_listener_fast.Method = c.methods[mk] local_vars_of_superclass_type_or_others = [] for item in m.body_local_vars_and_expr_names: if isinstance(item, utils_listener_fast.LocalVariable): if (item.datatype == superclass_name and has_imported_superclass) \ or item.datatype == (package_name + '.' + superclass_name): local_vars_of_superclass_type_or_others.append( item.identifier) if any((c.file_info.has_imported_class( o.package_name, o.name) and item.datatype == o.name) or item.datatype == (o.package_name + '.' + o.name) for o in other_derived_classes): local_vars_of_superclass_type_or_others.append( item.identifier) elif isinstance(item, utils_listener_fast.ExpressionName): if item.dot_separated_identifiers[-1] == field_name \ and ( (len(item.dot_separated_identifiers) == 2) or (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[0] == "this") ) and ( (item.dot_separated_identifiers[-2] in local_vars_of_superclass_type_or_others and len(item.dot_separated_identifiers) == 2) or item.dot_separated_identifiers[-2] in fields_of_superclass_type_or_others ): return False rewriter = utils.Rewriter(program, filename_mapping) field = superclass.fields[field_name] if len(field.neighbor_names) == 0: rewriter.replace(field.get_tokens_info(), "") # Have to remove the modifiers too, because of the new grammar. for mod_ctx in field.modifiers_parser_contexts: rewriter.replace(utils_listener_fast.TokensInfo(mod_ctx), "") else: i = field.index_in_variable_declarators var_ctxs = field.all_variable_declarator_contexts if i == 0: to_remove = utils_listener_fast.TokensInfo(var_ctxs[i]) to_remove.stop = utils_listener_fast.TokensInfo( var_ctxs[i + 1]).start - 1 # Include the ',' after it rewriter.replace(to_remove, "") else: to_remove = utils_listener_fast.TokensInfo(var_ctxs[i]) to_remove.start = utils_listener_fast.TokensInfo( var_ctxs[i - 1]).stop + 1 # Include the ',' before it rewriter.replace(to_remove, "") is_public = "public" in field.modifiers is_protected = "protected" in field.modifiers modifier = ("public " if is_public else ("protected " if is_protected else "")) for c in classes_to_add_to: c_body_start = utils_listener_fast.TokensInfo( c.parser_context.classBody()) c_body_start.stop = c_body_start.start # Start and stop both point to the '{' rewriter.insert_after(c_body_start, "\n " + modifier + field.datatype + " " + field_name \ + ((" = " + field.initializer) if field.initializer is not None else "") + ";") rewriter.apply() return True
def move_method_refactoring(source_filenames: list, package_name: str, class_name: str, method_key: str, target_class_name: str, target_package_name: str, filename_mapping=lambda x: x + ".rewritten.java"): program = get_program(source_filenames) static = 0 if class_name not in program.packages[ package_name].classes or target_class_name not in program.packages[ target_package_name].classes or method_key not in program.packages[ package_name].classes[class_name].methods: return False _sourceclass = program.packages[package_name].classes[class_name] _targetclass = program.packages[target_package_name].classes[ target_class_name] _method = program.packages[package_name].classes[class_name].methods[ method_key] if _method.is_constructor: return False Rewriter_ = Rewriter( program, lambda x: 'Real_Test_refactorings/move_method_test_resault' + x) tokens_info = TokensInfo(_method.parser_context) # tokens of ctx method param_tokens_info = TokensInfo(_method.formalparam_context) method_declaration_info = TokensInfo(_method.method_declaration_context) exp = [ ] # برای نگه داری متغیرهایی که داخل کلاس تعریف شدند و در بدنه متد استفاده شدند exps = tokens_info.get_token_index(tokens_info.token_stream.tokens, tokens_info.start, tokens_info.stop) # check that method is static or not for modifier in _method.modifiers: if (modifier == "static"): static = 1 for token in exps: if token.text in _sourceclass.fields: exp.append(token.tokenIndex) # check that where this method is call for package_names in program.packages: package = program.packages[package_names] for class_ in package.classes: _class = package.classes[class_] for method_ in _class.methods: __method = _class.methods[method_] for inv in __method.body_method_invocations: invc = __method.body_method_invocations[inv] method_name = method_key[:method_key.find('(')] if (invc[0] == method_name): inv_tokens_info = TokensInfo(inv) if (static == 0): class_token_info = TokensInfo(_class.body_context) Rewriter_.insert_after_start( class_token_info, target_class_name + " " + str.lower(target_class_name) + "=" + "new " + target_class_name + "();") Rewriter_.apply() Rewriter_.insert_before_start( class_token_info, "import " + target_package_name + "." + target_class_name + ";") Rewriter_.replace(inv_tokens_info, target_class_name) Rewriter_.apply() class_tokens_info = TokensInfo(_targetclass.parser_context) package_tokens_info = TokensInfo( program.packages[target_package_name].package_ctx) singlefileelement = SingleFileElement(_method.parser_context, _method.filename) token_stream_rewriter = TokenStreamRewriter( singlefileelement.get_token_stream()) # insert name of source.java class befor param that define in body of classe (that use in method) for index in exp: token_stream_rewriter.insertBeforeIndex(index=index, text=str.lower(class_name) + ".") for inv in _method.body_method_invocations: if (inv.getText() == target_class_name): inv_tokens_info_target = TokensInfo(inv) token_stream_rewriter.replaceRange( from_idx=inv_tokens_info_target.start, to_idx=inv_tokens_info_target.stop + 1, text=" ") # insert source.java class befor methods of sourcr class that used in method for i in _method.body_method_invocations_without_typename: if (i.getText() == class_name): ii = _method.body_method_invocations_without_typename[i] i_tokens = TokensInfo(ii[0]) token_stream_rewriter.insertBeforeIndex( index=i_tokens.start, text=str.lower(class_name) + ".") # pass object of source.java class to method if (param_tokens_info.start != None): token_stream_rewriter.insertBeforeIndex(param_tokens_info.start, text=class_name + " " + str.lower(class_name) + ",") else: token_stream_rewriter.insertBeforeIndex(method_declaration_info.stop, text=class_name + " " + str.lower(class_name)) strofmethod = token_stream_rewriter.getText( program_name=token_stream_rewriter.DEFAULT_PROGRAM_NAME, start=tokens_info.start, stop=tokens_info.stop) Rewriter_.insert_before(tokens_info=class_tokens_info, text=strofmethod) Rewriter_.insert_after( package_tokens_info, "import " + target_package_name + "." + target_class_name + ";") Rewriter_.replace(tokens_info, "") Rewriter_.apply() return True
def extract_interface( source_filenames: list, package_name: str, class_names: list, method_keys: list, interface_name: str, interface_filename: str, filename_mapping=lambda x: (x[:-5] if x.endswith(".java") else x) + ".re.java" ) -> bool: program = utils.get_program(source_filenames, print_status=True) if package_name not in program.packages \ or any( class_name not in program.packages[package_name].classes for class_name in class_names ) \ or any( method_key not in program.packages[package_name].classes[class_name].methods for class_name in class_names for method_key in method_keys ): return False method_returntypes = {} method_parameters = {} method_names = [] for method_key in method_keys: method_names.append(method_key[:method_key.find('(')]) rewriter = utils.Rewriter(program, filename_mapping) for class_name in class_names: c: utils_listener_fast.Class = program.packages[package_name].classes[ class_name] # Add implements to the class has_superinterface = False if c.parser_context.IMPLEMENTS( ) is not None: # old: c.parser_context.superinterfaces() t = utils_listener_fast.TokensInfo(c.parser_context.typeList( )) # old: c.parser_context.superinterfaces() has_superinterface = True elif c.parser_context.EXTENDS( ) is not None: # old: c.parser_context.superclass() t = utils_listener_fast.TokensInfo(c.parser_context.typeType( )) # old: c.parser_context.superclass() elif c.parser_context.typeParameters() is not None: t = utils_listener_fast.TokensInfo( c.parser_context.typeParameters()) else: # old: TokensInfo(c.parser_context.identifier()) t = utils_listener_fast.TokensInfo(c.parser_context) t.stop = c.parser_context.IDENTIFIER().getSymbol().tokenIndex rewriter.insert_after( t, (", " if has_superinterface else " implements ") + interface_name) for method_key in method_keys: m: utils_listener_fast.Method = c.methods[method_key] # Check if the return types / parameter types are the same # Or add to dictionary if method_key in method_returntypes: if method_returntypes[method_key] != m.returntype: return False if len(method_parameters[method_key]) != len(m.parameters): return False for i in range(len(m.parameters)): if method_parameters[method_key][i][0] != m.parameters[i][ 0]: return False else: method_returntypes[method_key] = m.returntype method_parameters[method_key] = m.parameters # Manage method modifiers if len(m.modifiers_parser_contexts) > 0: t = utils_listener_fast.TokensInfo( m.modifiers_parser_contexts[0]) else: t = m.get_tokens_info() rewriter.insert_before_start( t, # old: m.get_tokens_info() # without requiring t ("" if "@Override" in m.modifiers else "@Override\n ") + ("" if "public" in m.modifiers else "public ")) for i in range(len(m.modifiers)): mm = m.modifiers[i] if mm == "private" or mm == "protected": t = utils_listener_fast.TokensInfo( m.modifiers_parser_contexts[i] ) # old: m.parser_context.methodModifier(i) rewriter.replace(t, "") # Change variable types to the interface if only interface methods are used. for package_name in program.packages: p: utils_listener_fast.Package = program.packages[package_name] for class_name in p.classes: c: utils_listener_fast.Class = p.classes[class_name] fields_of_interest = {} for fn in c.fields: f: utils_listener_fast.Field = c.fields[fn] d = False for cn in class_names: if (f.datatype == cn and f.file_info.has_imported_class(package_name, cn)) \ or (package_name is not None and f.datatype == package_name + '.' + cn): d = True break if d and "private" in f.modifiers: fields_of_interest[f.name] = f for method_key in c.methods: m: utils_listener_fast.Method = c.methods[method_key] vars_of_interest = {} for item in m.body_local_vars_and_expr_names: if isinstance(item, utils_listener_fast.LocalVariable): for cn in class_names: if (item.datatype == cn and c.file_info.has_imported_class(package_name, cn)) \ or (package_name is not None and item.datatype == package_name + '.' + cn): vars_of_interest[item.identifier] = item break if isinstance(item, utils_listener_fast.MethodInvocation): if len(item.dot_separated_identifiers) == 2 or \ (len(item.dot_separated_identifiers) == 3 and item.dot_separated_identifiers[0] == "this"): if item.dot_separated_identifiers[ -2] in vars_of_interest: if item.dot_separated_identifiers[ -1] not in method_names: vars_of_interest.pop( item.dot_separated_identifiers[-2]) elif item.dot_separated_identifiers[-2] in fields_of_interest \ and item.dot_separated_identifiers[-1] not in method_names: fields_of_interest.pop( item.dot_separated_identifiers[-2]) for var_name in vars_of_interest: var = vars_of_interest[var_name] if m.file_info.has_imported_package(package_name): # old: var.parser_context.unannType() rewriter.replace( utils_listener_fast.TokensInfo( var.parser_context.typeType()), interface_name) else: if package_name is None: break # old: var.parser_context.unannType() rewriter.replace( utils_listener_fast.TokensInfo( var.parser_context.typeType()), package_name + '.' + interface_name) for field_name in fields_of_interest: f = fields_of_interest[field_name] if c.file_info.has_imported_package(package_name): typename = interface_name else: if package_name is None: break typename = package_name + '.' + interface_name if len(f.neighbor_names) == 0: rewriter.replace( utils_listener_fast.TokensInfo( f.parser_context.typeType()), typename) # old: f.parser_context.unannType() else: if not any(nn in fields_of_interest for nn in f.neighbor_names): t = utils_listener_fast.TokensInfo( f.all_variable_declarator_contexts[ f.index_in_variable_declarators]) if f.index_in_variable_declarators == 0: t.stop = utils_listener_fast.TokensInfo( f.all_variable_declarator_contexts[ f.index_in_variable_declarators + 1]).start - 1 else: t.start = utils_listener_fast.TokensInfo( f.all_variable_declarator_contexts[ f.index_in_variable_declarators - 1]).start + 1 rewriter.replace(t, "") rewriter.insert_after( f.get_tokens_info(), "\n private " + typename + " " + f.name + (" = " + f.initializer + ";" if f.initializer is not None else ";")) # Create the interface interface_file_content = ("package " + package_name + ";\n\n" + "public interface " + interface_name + "\n" + "{\n") for method_key in method_keys: method_name = method_key[:method_key.find('(')] interface_file_content += " " + method_returntypes[ method_key] + " " + method_name + "(" if len(method_parameters[method_key]) > 0: interface_file_content += method_parameters[method_key][0][ 0] + " " + method_parameters[method_key][0][1] for i in range(1, len(method_parameters[method_key])): param = method_parameters[method_key][i] interface_file_content += ", " + param[0] + " " + param[1] interface_file_content += ");\n" interface_file_content += "}\n" if not os.path.exists(interface_filename[:interface_filename.rfind('/')]): os.makedirs(interface_filename[:interface_filename.rfind('/')]) file = open(interface_filename, "w+") file.write(interface_file_content) file.close() rewriter.apply() return True