def analyze(self, ptr_from: ReferencePointer, ptr_to: ReferencePointer, # noqa: D data_service: DataService, changes: Iterable[Change]) -> [Comment]: self._log.info("analyze %s %s", ptr_from.commit, ptr_to.commit) comments = [] parser = TokenParser(stem_threshold=100, single_shot=True) words = autocorrect.word.KNOWN_WORDS.copy() try: for name in self.model.names: if len(name) >= 3: autocorrect.word.KNOWN_WORDS.add(name) for change in changes: suggestions = defaultdict(list) new_lines = set(find_new_lines(change.base, change.head)) for node in bblfsh.filter(change.head.uast, "//*[@roleIdentifier]"): if node.start_position is not None and node.start_position.line in new_lines: for part in parser.split(node.token): if part not in self.model.names: fixed = autocorrect.spell(part) if fixed != part: suggestions[node.start_position.line].append( (node.token, part, fixed)) for line, s in suggestions.items(): comment = Comment() comment.file = change.head.path comment.text = "\n".join("`%s`: %s > %s" % fix for fix in s) comment.line = line comment.confidence = 100 comments.append(comment) finally: autocorrect.word.KNOWN_WORDS = words return comments
def remove_non_typos(dataset: str, filtered_dataset: str) -> None: """ Remove non-typo-ed identifiers from the dataset. 1. Remove examples, where token splits of the wrong and the correct identifiers are equal \ (they differ in non-alpha chars or casing). 2. Remove examples, where wrong and correct identifiers are equal on lemmas level. :param dataset: Path to the dataset. :param filtered_dataset: Path to save the filtered dataset to. """ data = pandas.read_csv(dataset, header=0, usecols=[0, 1], names=["wrong", "correct"], keep_default_na=False) # Filter examples with equal splits tp = TokenParser(min_split_length=1, stem_threshold=400, single_shot=True, max_token_length=400, attach_upper=True) data["wrong_split"] = data["wrong"].apply(lambda x: " ".join(tp.split(x))) data["correct_split"] = data["correct"].apply(lambda x: " ".join(tp.split(x))) data = data[data["wrong_split"] != data["correct_split"]] os.system("python3 -m spacy download en") nlp = spacy.load("en", disable=["parser", "ner"]) # Filter examples with equal lemmas def _lemmatize(token): lemm = nlp(token) if len(lemm) > 1 or lemm[0].lemma_ == "-PRON-" or ( token[-2:] == "ss" and lemm[0].lemma_ == token[:-1]): return token return lemm[0].lemma_ data["wrong_lem"] = data["wrong_split"].apply( lambda x: " ".join(_lemmatize(token) for token in x.split())) data["correct_lem"] = data["correct_split"].apply( lambda x: " ".join(_lemmatize(token) for token in x.split())) data = data[(data["wrong_lem"] != data["correct_lem"]) & (data["wrong_lem"] != data["correct_split"]) & (data["correct_lem"] != data["wrong_split"])] # Save new dataset whole_data = pandas.read_csv(dataset, header=0, keep_default_na=False) whole_data = whole_data.loc[data.index] whole_data.to_csv(filtered_dataset, compression="xz", index=False)
def reconstruct_identifier(tokenizer: TokenParser, pred_tokens: List[str], identifier: str) \ -> str: """ Reconstruct identifier given predicted tokens and initial identifier. :param tokenizer: tokenizer - instance of TokenParser. :param pred_tokens: list of predicted tokens. :param identifier: identifier. :return: reconstructed identifier based on predicted tokens. """ identifier_l = identifier.lower() # check required parameters assert tokenizer._single_shot, "TokenParser should be initialized with " \ "`single_shot=True` for IdTyposAnalyzer" # sanity checking initial_tokens = list(tokenizer.split(identifier)) err = "Number of predicted tokens (%s) not equal to the number of tokens in the " \ "identifier (%s) for identifier '%s', predicted_tokens '%s', tokens in identifier " \ "'%s'" assert len(initial_tokens) == len(pred_tokens), \ err % (len(initial_tokens), len(pred_tokens), identifier, pred_tokens, initial_tokens) # reconstruction res = [] prev_end = 0 for token, pred_token in zip(initial_tokens, pred_tokens): curr = identifier_l.find(token, prev_end) assert curr != -1, "TokenParser is broken, the subtoken `%s` was not found in the " \ "identifier `%s`" % (token, identifier) if curr != prev_end: # delimiter found res.append(identifier[prev_end:curr]) if identifier[curr:curr + len(token)].isupper(): # upper case res.append(pred_token.upper()) elif identifier[curr:curr + len(token)][0].isupper(): # capitalized res.append(pred_token[0].upper() + pred_token[1:]) else: res.append(pred_token) prev_end = curr + len(token) if prev_end != len(identifier): # suffix res.append(identifier[prev_end:]) return "".join(res)
class TokenParserTests(unittest.TestCase): def setUp(self): self.tp = TokenParser(stem_threshold=4, max_token_length=20) self.tp._single_shot = False def test_process_token(self): self.tp.max_token_length = 100 tokens = [ ("UpperCamelCase", ["upper", "camel", "case"]), ("camelCase", ["camel", "case"]), ("FRAPScase", ["frap", "case"]), ("SQLThing", ["sqlt", "hing"]), ("_Astra", ["astra"]), ("CAPS_CONST", ["caps", "const"]), ("_something_SILLY_", ["someth", "silli"]), ("blink182", ["blink"]), ("FooBar100500Bingo", ["foo", "bar", "bingo"]), ("Man45var", ["man", "var"]), ("method_name", ["method", "name"]), ("Method_Name", ["method", "name"]), ("101dalms", ["dalm"]), ("101_dalms", ["dalm"]), ("101_DalmsBug", ["dalm", "bug"]), ("101_Dalms45Bug7", ["dalm", "bug"]), ("wdSize", ["wd", "size", "wdsize"]), ("Glint", ["glint"]), ("foo_BAR", ["foo", "bar"]), ("sourced.ml.algorithms.uast_ids_to_bag", [ "sourc", "sourcedml", "algorithm", "mlalgorithm", "uast", "ids", "idsto", "bag", "tobag" ]), ("WORSTnameYOUcanIMAGINE", ["worst", "name", "you", "can", "imagin"]), # Another bad example. Parser failed to parse it correctly ("SmallIdsToFoOo", ["small", "ids", "idsto", "fo", "oo"]), ("SmallIdFooo", ["small", "smallid", "fooo", "idfooo"]), ("ONE_M0re_.__badId.example", [ "one", "onem", "re", "bad", "rebad", "badid", "exampl", "idexampl" ]), ("never_use_Such__varsableNames", ["never", "use", "such", "varsabl", "name"]), ("a.b.c.d", ["a", "b", "c", "d"]), ("A.b.Cd.E", ["a", "b", "cd", "e"]), ("looong_sh_loooong_sh", ["looong", "looongsh", "loooong", "shloooong", "loooongsh"]), ("sh_sh_sh_sh", ["sh", "sh", "sh", "sh"]), ("loooong_loooong_loooong", ["loooong", "loooong", "loooong"]) ] for token, correct in tokens: res = list(self.tp.process_token(token)) self.assertEqual(res, correct) def test_process_token_single_shot(self): self.tp.max_token_length = 100 self.tp._single_shot = True self.tp.min_split_length = 1 tokens = [ ("UpperCamelCase", ["upper", "camel", "case"]), ("camelCase", ["camel", "case"]), ("FRAPScase", ["frap", "case"]), ("SQLThing", ["sqlt", "hing"]), ("_Astra", ["astra"]), ("CAPS_CONST", ["caps", "const"]), ("_something_SILLY_", ["someth", "silli"]), ("blink182", ["blink"]), ("FooBar100500Bingo", ["foo", "bar", "bingo"]), ("Man45var", ["man", "var"]), ("method_name", ["method", "name"]), ("Method_Name", ["method", "name"]), ("101dalms", ["dalm"]), ("101_dalms", ["dalm"]), ("101_DalmsBug", ["dalm", "bug"]), ("101_Dalms45Bug7", ["dalm", "bug"]), ("wdSize", ["wd", "size"]), ("Glint", ["glint"]), ("foo_BAR", ["foo", "bar"]), ("sourced.ml.algorithms.uast_ids_to_bag", ["sourc", "ml", "algorithm", "uast", "ids", "to", "bag"]), ("WORSTnameYOUcanIMAGINE", ["worst", "name", "you", "can", "imagin"]), # Another bad example. Parser failed to parse it correctly ("SmallIdsToFoOo", ["small", "ids", "to", "fo", "oo"]), ("SmallIdFooo", ["small", "id", "fooo"]), ("ONE_M0re_.__badId.example", ["one", "m", "re", "bad", "id", "exampl"]), ("never_use_Such__varsableNames", ["never", "use", "such", "varsabl", "name"]), ("a.b.c.d", ["a", "b", "c", "d"]), ("A.b.Cd.E", ["a", "b", "cd", "e"]), ("looong_sh_loooong_sh", ["looong", "sh", "loooong", "sh"]), ("sh_sh_sh_sh", ["sh", "sh", "sh", "sh"]), ("loooong_loooong_loooong", ["loooong", "loooong", "loooong"]) ] for token, correct in tokens: res = list(self.tp.process_token(token)) self.assertEqual(res, correct) min_split_length = 3 self.tp.min_split_length = min_split_length for token, correct in tokens: res = list(self.tp.process_token(token)) self.assertEqual( res, [c for c in correct if len(c) >= min_split_length]) def test_split(self): self.assertEqual(list(self.tp.split("set for")), ["set", "for"]) self.assertEqual(list(self.tp.split("set /for.")), ["set", "for"]) self.assertEqual(list(self.tp.split("NeverHav")), ["never", "hav"]) self.assertEqual(list(self.tp.split("PrintAll")), ["print", "all"]) self.assertEqual(list(self.tp.split("PrintAllExcept")), ["print", "all", "except"]) self.assertEqual( list(self.tp.split("print really long line")), # 'longli' is expected artifact due to edge effects ["print", "really", "long", "longli"]) self.assertEqual(list(self.tp.split("set /for. *&PrintAll")), ["set", "for", "print", "all"]) self.assertEqual(list(self.tp.split("JumpDown not Here")), ["jump", "down", "not", "here"]) self.assertEqual(list(self.tp.split("a b c d")), ["a", "b", "c", "d"]) self.assertEqual(list(self.tp.split("a b long c d")), ["a", "b", "long", "blong", "longc", "d"]) self.assertEqual(list(self.tp.split("AbCd")), ["ab", "cd"]) def test_split_single_shot(self): self.tp._single_shot = True self.tp.min_split_length = 1 self.assertEqual( list(self.tp.split("print really long line")), # 'longli' is expected artifact due to edge effects ["print", "really", "long", "li"]) self.assertEqual(list(self.tp.split("a b c d")), ["a", "b", "c", "d"]) self.assertEqual(list(self.tp.split("a b long c d")), ["a", "b", "long", "c", "d"]) self.assertEqual(list(self.tp.split("AbCd")), ["ab", "cd"]) def test_stem(self): self.assertEqual(self.tp.stem("lol"), "lol") self.assertEqual(self.tp.stem("apple"), "appl") self.assertEqual(self.tp.stem("orange"), "orang") self.assertEqual(self.tp.stem("embedding"), "embed") self.assertEqual(self.tp.stem("Alfred"), "Alfred") self.assertEqual(self.tp.stem("Pluto"), "Pluto") def test_pickle(self): tp = pickle.loads(pickle.dumps(self.tp)) self.assertEqual(tp.stem("embedding"), "embed")
class IdTyposAnalyzer(Analyzer): """ Identifier typos analyzer. """ log = logging.getLogger("IdTyposAnalyzer") model_type = None version = 1 description = "Corrector of typos in source code identifiers." corrector_manager = TyposCorrectorManager() DEFAULT_LINE_LENGTH_LIMIT = 500 DEFAULT_N_CANDIDATES = 3 DEFAULT_CONFIDENCE_THRESHOLD = 0.1 INDEX_COLUMN = "index" def __init__(self, model: AnalyzerModel, url: str, config: Mapping[str, Any]): """ Initialize a new instance of IdTyposAnalyzer. :param model: The instance of the model loaded from the repository or freshly trained. :param url: The analyzed project's Git remote. :param config: Configuration of the analyzer of unspecified structure. """ super().__init__(model, url, config) self.model = self.corrector_manager.get(config.get("model")) self.n_candidates = config.get("n_candidates", self.DEFAULT_N_CANDIDATES) self.confidence_threshold = config.get( "confidence_threshold", self.DEFAULT_CONFIDENCE_THRESHOLD) self.parser = TokenParser(stem_threshold=40, single_shot=True) @with_changed_uasts_and_contents def analyze(self, ptr_from: ReferencePointer, ptr_to: ReferencePointer, data_service: DataService, **data) -> [Comment]: """ Return the list of `Comment`-s - found typo corrections. :param ptr_from: The Git revision of the fork point. Exists in both the original and \ the forked repositories. :param ptr_to: The Git revision to analyze. Exists only in the forked repository. :param data_service: The channel to the data service in Lookout server to query for \ UASTs, file contents, etc. :param data: Extra data passed into the method. Used by the decorators to simplify \ the data retrieval. :return: List of found review suggestions. Refer to \ lookout/core/server/sdk/service_analyzer.proto. """ log = self.log comments = [] changes = list(data["changes"]) base_files_by_lang = files_by_language(c.base for c in changes) head_files_by_lang = files_by_language(c.head for c in changes) line_length = self.config.get("line_length_limit", self.DEFAULT_LINE_LENGTH_LIMIT) for lang, head_files in head_files_by_lang.items(): for file in filter_files(head_files, line_length, log): try: prev_file = base_files_by_lang[lang][file.path] except KeyError: lines = [] old_identifiers = set() else: lines = find_new_lines(prev_file, file) old_identifiers = { node.token for node in uast2sequence(prev_file.uast) if bblfsh.role_id("IDENTIFIER") in node.roles and bblfsh.role_id("IMPORT") not in node.roles and node.token } changed_nodes = extract_changed_nodes(file.uast, lines) new_identifiers = [ node for node in changed_nodes if bblfsh.role_id("IDENTIFIER") in node.roles and bblfsh.role_id("IMPORT") not in node.roles and node.token and node.token not in old_identifiers ] if not new_identifiers: continue suggestions = self.check_identifiers( [n.token for n in new_identifiers]) for index in suggestions.keys(): corrections = suggestions[index] for token in corrections.keys(): comment = Comment() comment.file = file.path corrections_line = " " + ", ".join( "%s (%d%%)" % (candidate[0], int(candidate[1] * 100)) for candidate in corrections[token]) comment.text = """ Possible typo in \"%s\". Suggestions: """.strip( ) % new_identifiers[index].token + corrections_line comment.line = new_identifiers[ index].start_position.line comment.confidence = int(corrections[token][0][1] * 100) comments.append(comment) return comments @classmethod @with_uasts_and_contents def train(cls, ptr: ReferencePointer, config: dict, data_service: DataService, **data) -> AnalyzerModel: """ Generate a new model on top of the specified source code. :param ptr: Git repository state pointer. :param config: Configuration of the training of unspecified structure. :param data_service: The channel to the data service in Lookout server to query for \ UASTs, file contents, etc. :param data: Extra data passed into the method. Used by the decorators to simplify \ the data retrieval. :return: Instance of `AnalyzerModel` (`model_type`, to be precise). """ return DummyAnalyzerModel() def check_identifiers( self, identifiers: List[str], ) -> Dict[int, Dict[str, List[Tuple[str, float]]]]: """ Check tokens from identifiers for typos. :param identifiers: List of identifiers to check. :return: Dictionary of corrections grouped by ids of corresponding identifier \ in 'identifiers' and typoed tokens which have correction suggestions. """ df = pandas.DataFrame(columns=[self.INDEX_COLUMN, SPLIT_COLUMN]) df[self.INDEX_COLUMN] = range(len(identifiers)) df[SPLIT_COLUMN] = [ " ".join(self.parser.split(i)) for i in identifiers ] df = flatten_data(df, new_column_name=TYPO_COLUMN) suggestions = self.model.suggest(df, n_candidates=self.n_candidates, return_all=False) suggestions = self.filter_suggestions(df, suggestions) grouped_suggestions = defaultdict(dict) for index, row in df.iterrows(): if index in suggestions.keys(): grouped_suggestions[row[self.INDEX_COLUMN]][ row[TYPO_COLUMN]] = suggestions[index] return grouped_suggestions def filter_suggestions( self, test_df: pandas.DataFrame, suggestions: Dict[int, List[Tuple[str, float]]], ) -> Dict[int, List[Tuple[str, float]]]: """ Filter suggestions based on the repo specifics and confidence threshold. :param test_df: DataFrame with info about tested tokens. :param suggestions: Dictionary of correction suggestions grouped by \ typoed token index in test_df. :return: Dictionary of filtered suggestions grouped by typoed token index in test_df. """ filtered_suggestions = {} tokens = test_df.typo for index, candidates in suggestions.items(): filtered_candidates = [] for candidate in candidates: if candidate[0] == tokens[ index] or candidate[1] < self.confidence_threshold: break filtered_candidates.append(candidate) if filtered_candidates: filtered_suggestions[index] = filtered_candidates return filtered_suggestions
class TokenParserTests(unittest.TestCase): def setUp(self): self.tp = TokenParser(stem_threshold=4, max_token_length=20) def test_process_token(self): _max_token_length = self.tp.max_token_length self.tp.max_token_length = 100 tokens = [ ("sourced.ml.algorithms.uast_ids_to_bag", [ "sourc", "sourcedml", "algorithm", "mlalgorithm", "uast", "ids", "idsto", "bag", "tobag" ]), ("WORSTnameYOUcanIMAGINE", ['worst', 'name', 'you', 'can', 'imagin']), # Another bad example. Parser failed to parse it correctly ("SmallIdsToFoOo", ["small", "ids", 'idsto', 'fo', 'oo']), ("SmallIdFooo", ["small", "smallid", 'fooo', 'idfooo']), ("ONE_M0re_.__badId.example", [ 'one', 'onem', 're', 'bad', 'rebad', 'badid', 'exampl', 'idexampl' ]), ("never_use_Such__varsableNames", ['never', 'use', 'such', 'varsabl', 'name']), ("a.b.c.d", ["a", "b", "c", "d"]), ("A.b.Cd.E", ['a', 'b', 'cd', 'e']), ("looong_sh_loooong_sh", ['looong', 'looongsh', 'loooong', 'shloooong', 'loooongsh']), ("sh_sh_sh_sh", ['sh', 'sh', 'sh', 'sh']), ("loooong_loooong_loooong", ['loooong', 'loooong', 'loooong']) ] for token, correct in tokens: res = list(self.tp.process_token(token)) self.assertEqual(res, correct) self.tp.max_token_length = _max_token_length def test_split(self): self.assertEqual(list(self.tp.split("set for")), ["set", "for"]) self.assertEqual(list(self.tp.split("set /for.")), ["set", "for"]) self.assertEqual(list(self.tp.split("NeverHav")), ["never", "hav"]) self.assertEqual(list(self.tp.split("PrintAll")), ["print", "all"]) self.assertEqual(list(self.tp.split("PrintAllExcept")), ["print", "all", "except"]) self.assertEqual( list(self.tp.split("print really long line")), # 'longli' is expected artifact due to edge effects ["print", "really", "long", "longli"]) self.assertEqual(list(self.tp.split("set /for. *&PrintAll")), ["set", "for", "print", "all"]) self.assertEqual(list(self.tp.split("JumpDown not Here")), ["jump", "down", "not", "here"]) def test_stem(self): self.assertEqual(self.tp.stem("lol"), "lol") self.assertEqual(self.tp.stem("apple"), "appl") self.assertEqual(self.tp.stem("orange"), "orang") self.assertEqual(self.tp.stem("embedding"), "embed") self.assertEqual(self.tp.stem("Alfred"), "Alfred") self.assertEqual(self.tp.stem("Pluto"), "Pluto") def test_pickle(self): tp = pickle.loads(pickle.dumps(self.tp)) self.assertEqual(tp.stem("embedding"), "embed")
class FunctionNameAnalyzer(Analyzer): log = logging.getLogger("FunctionNameAnalyzer") model_type = FunctionNameModel version = "1" description = "Analyzer that suggests function names." def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.token_parser = TokenParser(min_split_length=1, single_shot=True) @with_changed_uasts_and_contents def analyze(self, ptr_from: ReferencePointer, ptr_to: ReferencePointer, data_request_stub: DataStub, **data) -> [Comment]: comments = [] changes = list(data["changes"]) base_files = files_by_language(c.base for c in changes) head_files = files_by_language(c.head for c in changes) for lang, lang_head_files in head_files.items(): if lang.lower() != "java": continue self.log.info("Working on %d java files", len(lang_head_files)) for i, (path, file) in enumerate(lang_head_files.items(), start=1): try: prev_file = base_files[lang][path] except KeyError: lines = None else: lines = find_new_lines(prev_file, file) names_file, tokens_file, line_numbers = self._extract_features( file, lines) for prediction, target, score, line_number, type_hint in self.translate( tokens_file, names_file, line_numbers): comment = Comment() comment.line = line_number comment.file = path comment.confidence = int(round(score * 100)) comments.append(comment) if type_hint == FunctionNameAnalyzer.TranslationTypes.LESS_DETAILED: comment.text = "Consider a more generic name: %s instead of %s" % ( prediction, target) else: comment.text = "Consider a more specific name: %s instead of %s" % ( prediction, target) comments.append(comment) self.log.info("Processed %d files", i) return comments @classmethod @with_uasts_and_contents def train(cls, ptr: ReferencePointer, config: Dict[str, Any], data_request_stub: DataStub, **data) -> FunctionNameModel: """ Dummy train. :param ptr: Git repository state pointer. :param config: configuration dict. :param data: contains "files" - the list of files in the pointed state. :param data_request_stub: connection to the Lookout data retrieval service, not used. :return: FunctionNameModel dummy model. """ return FunctionNameModel().construct(cls, ptr) def process_node(self, node, last_position): if IDENTIFIER in node.roles and node.token and FUNCTION not in node.roles: for x in self.token_parser(node.token): yield x, last_position def process_uast(self, uast): stack = [(uast, [0, 0])] while stack: node, last_position = stack.pop() if node.start_position.line != 0: # A lot of Nodes do not have position # It is good heuristic to take the last Node in tree with a position. last_position[0] = node.start_position.line last_position[1] = 0 if node.start_position.col != 0: last_position[1] = node.start_position.col yield from self.process_node(node, last_position) stack.extend([(child, list(last_position)) for child in node.children]) def extract_functions_from_uast(self, uast: bblfsh.Node): for node in uast2sequence(uast): if node.internal_type != "MethodDeclaration": continue for subnode in node.children: if FUNCTION not in subnode.roles and NAME not in subnode.roles: continue name = subnode.token break tokens = list(self.process_uast(node)) if len(tokens) < 5: continue yield (name, node.start_position.line, node.end_position.line, [ token for token, pos in sorted(tokens, key=lambda x: x[1]) ]) def get_affected_functions(self, uast, lines: Optional[Sequence[int]]): functions_info = list(self.extract_functions_from_uast(uast)) i = 0 res = [] for line in sorted(lines): while i < len(functions_info): if functions_info[i][1] <= line <= functions_info[i][2]: res.append(functions_info[i]) i += 1 break elif line < functions_info[i][1]: break elif line > functions_info[i][2]: i += 1 return res @staticmethod def to_nmt_files(functions_info): func_start = [] with tempfile.NamedTemporaryFile(delete=False, mode="w") as func_names: with tempfile.NamedTemporaryFile(delete=False, mode="w") as func_tokens: for name, start_line, end_line, tokens in functions_info: func_names.write(" ".join(list(name)) + "\n") func_tokens.write(" ".join(tokens) + "\n") func_start.append(start_line) return func_names.name, func_tokens.name, func_start def _extract_features(self, file, lines: Optional[Sequence[int]]): if file.language.lower() != "java": raise ValueError("Only java language is supported now") if lines: affected_functions = self.get_affected_functions(file.uast, lines) else: # all function are affected because the file is new affected_functions = self.extract_functions_from_uast(file.uast) return self.to_nmt_files(affected_functions) class TranslationTypes(Enum): NOOP = 0 MORE_DETAILED = 1 LESS_DETAILED = 2 CLASS_TO_FUNCTION = 3 FUNCTION_TO_CLASS = 4 CLASS_TO_CLASS = 5 OTHER = 6 def classify_translation(self, prediction, target): split_prediction = set(self.token_parser.split(prediction)) split_target = set(self.token_parser.split(target)) if prediction[0].isupper(): if target[0].isupper(): return FunctionNameAnalyzer.TranslationTypes.CLASS_TO_CLASS return FunctionNameAnalyzer.TranslationTypes.FUNCTION_TO_CLASS elif target[0].isupper(): return FunctionNameAnalyzer.TranslationTypes.CLASS_TO_FUNCTION if split_prediction == split_target: return FunctionNameAnalyzer.TranslationTypes.NOOP elif split_prediction > split_target: return FunctionNameAnalyzer.TranslationTypes.MORE_DETAILED elif split_prediction < split_target: return FunctionNameAnalyzer.TranslationTypes.LESS_DETAILED return FunctionNameAnalyzer.TranslationTypes.OTHER def translate(self, source_file, target_file, line_numbers): model = str(Path(__file__).parent.parent / "models" / "model.pt") with open(target_file) as fh: targets = [ "".join(line.strip().split(" ")) for line in fh.readlines() ] command = "translate.py -model %s -src %s -tgt %s" % ( model, source_file, target_file) with patch("sys.argv", command.split(" ")): scores, gold_scores, translations = onmt.infer.main() for [score], gold_score, [translation], target, line_number \ in zip(scores, gold_scores, translations, targets, line_numbers): prediction = "".join(translation.split(" ")) gold_score = gold_score / len(target) pred_score = score / len(prediction) score = 1 / (1 + math.exp(-pred_score - gold_score)) hint_type = self.classify_translation(prediction, target) if hint_type in [ FunctionNameAnalyzer.TranslationTypes.MORE_DETAILED, FunctionNameAnalyzer.TranslationTypes.LESS_DETAILED ]: yield prediction, target, score, line_number, hint_type