def test_revert_indentation_change(self): cases = [ ("\n ", (cls.CLS_NEWLINE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC), "\n "), ("\n ", (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC), "\n "), ("\n\t ", (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_SPACE_INC), "\n"), ("\n ", (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_TAB_INC), InapplicableIndentation), (" ", (cls.CLS_SPACE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC), ValueError), ] for value, y, result in cases: vnode = VirtualNode(value, Position(0, 1, 1), Position(len(value), 1, len(value) + 1), y=tuple(cls.CLASS_INDEX[i] for i in y)) if isinstance(result, str): self.assertEqual( CodeGenerator.revert_indentation_change(vnode), result) else: with self.assertRaises(result): CodeGenerator.revert_indentation_change(vnode)
def _test(self): y_cur = deepcopy(self.y) for i, yi in zip(y_indexes, y_pred): y_cur[i] = yi code_generator = CodeGenerator(self.feature_extractor) pred_vnodes = code_generator.apply_predicted_y( self.vnodes, self.vnodes_y, list(range(len(self.vnodes_y))), FakeRules(y_cur)) generated_file = code_generator.generate(pred_vnodes, "local") self.assertEqual(generated_file, result_local)
def _generate_token_fixes( self, file: File, fe: FeatureExtractor, feature_extractor_output, bblfsh_stub: "bblfsh.aliases.ProtocolServiceStub", rules: Rules, ) -> Tuple[List[LineFix], List[VirtualNode], numpy.ndarray, numpy.ndarray]: X, y, (vnodes_y, vnodes, vnode_parents, node_parents) = feature_extractor_output y_pred_pure, rule_winners, new_rules, grouped_quote_predictions = rules.predict( X=X, vnodes_y=vnodes_y, vnodes=vnodes, feature_extractor=fe) y_pred = rules.fill_missing_predictions(y_pred_pure, y) if self.config["uast_break_check"]: y, y_pred, vnodes_y, rule_winners, safe_preds = filter_uast_breaking_preds( y=y, y_pred=y_pred, vnodes_y=vnodes_y, vnodes=vnodes, files={file.path: file}, feature_extractor=fe, stub=bblfsh_stub, vnode_parents=vnode_parents, node_parents=node_parents, rule_winners=rule_winners, grouped_quote_predictions=grouped_quote_predictions) y_pred_pure = y_pred_pure[safe_preds] assert len(y) == len(y_pred) assert len(y) == len(rule_winners) code_generator = CodeGenerator(fe, skip_errors=True) new_vnodes = code_generator.apply_predicted_y(vnodes, vnodes_y, rule_winners, new_rules) token_fixes = [] for line_number, line in self._group_line_nodes( y, y_pred, vnodes_y, new_vnodes, rule_winners): line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line new_code_line = code_generator.generate( new_line_vnodes, "local").lstrip("\n").splitlines()[0] confidence = self._get_comment_confidence(line_ys, line_ys_pred, line_winners, new_rules) fixed_vnodes = [ vnode for vnode in new_line_vnodes if hasattr(vnode, "y_old") and vnode.y_old != vnode.y ] token_fixes.append( LineFix( line_number=line_number, # line number for the comment suggested_code= new_code_line, # code line suggested by our model fixed_vnodes=fixed_vnodes, # VirtualNode-s with changed y confidence= confidence, # overall confidence in the prediction, 0-100 )) return token_fixes, new_vnodes, y_pred_pure, y
def __init__(self, feature_extractor: FeatureExtractor, debug: bool=False): """ Construct a UASTStabilityChecker. :param feature_extractor: Feature extraction class that was used to generate data for \ the check. :param debug: Logs code diff for unsafe predictions with debug level. """ self._feature_extractor = feature_extractor self._code_generator = CodeGenerator(self._feature_extractor, skip_errors=False) self._parsing_cache = {} # type: Dict[int, Optional[Tuple[bblfsh.Node, int, int]]] self._debug = debug
def _generate_token_fixes( self, file: File, fe: FeatureExtractor, feature_extractor_output, bblfsh_stub: "bblfsh.aliases.ProtocolServiceStub", rules: Rules, ) -> Tuple[List[LineFix], List[VirtualNode], numpy.ndarray, numpy.ndarray]: X, y, (vnodes_y, vnodes, vnode_parents, node_parents) = feature_extractor_output y_pred_pure, rule_winners, new_rules, grouped_quote_predictions = rules.predict( X=X, vnodes_y=vnodes_y, vnodes=vnodes, feature_extractor=fe) y_pred = rules.fill_missing_predictions(y_pred_pure, y) if self.analyze_config[file.language.lower()]["uast_break_check"]: checker = UASTStabilityChecker(fe) y, y_pred, vnodes_y, rule_winners, safe_preds = checker.check( y=y, y_pred=y_pred, vnodes_y=vnodes_y, vnodes=vnodes, files=[file], stub=bblfsh_stub, vnode_parents=vnode_parents, node_parents=node_parents, rule_winners=rule_winners, grouped_quote_predictions=grouped_quote_predictions) y_pred_pure = y_pred_pure[safe_preds] assert len(y) == len(y_pred) assert len(y) == len(rule_winners) code_generator = CodeGenerator(fe, skip_errors=True) new_vnodes = code_generator.apply_predicted_y(vnodes, vnodes_y, rule_winners, new_rules) token_fixes = [] newline_index = CLASS_INDEX[CLS_NEWLINE] for line_number, line in self._group_line_nodes( y, y_pred, vnodes_y, new_vnodes, rule_winners): line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line try: new_code_line = code_generator.generate_new_line(new_line_vnodes) except Exception: self._log.exception( "Failed to generate new line suggestion for line %d in %s. line vnodes:\n%s", line_number, file.path, "\n".join( "%s, y_old=%s" % (repr(vn), getattr(vn, "y_old", "N/A")) for vn in new_line_vnodes)) new_code_line = None if (new_line_vnodes and hasattr(new_line_vnodes[0], "y_old") and newline_index in new_line_vnodes[0].y_old): lines_num_diff = new_line_vnodes[0].y.count(newline_index) - \ new_line_vnodes[0].y_old.count(newline_index) if lines_num_diff < 0: # Some lines were removed. This means that several original lines should be # modified. GitHub Suggested Change feature cannot handle such cases right now. # To not confuse the user we do not provide any code suggestion. new_code_line = None confidence = self._get_comment_confidence(line_ys, line_ys_pred, line_winners, new_rules) fixed_vnodes = [vnode for vnode in new_line_vnodes if hasattr(vnode, "y_old") and vnode.y_old != vnode.y] token_fixes.append(LineFix( line_number=line_number, # line number for the comment suggested_code=new_code_line, # code line suggested by our model fixed_vnodes=fixed_vnodes, # VirtualNode-s with changed y confidence=confidence, # overall confidence in the prediction, 0-100 )) return token_fixes, new_vnodes, y_pred_pure, y
def _test(self): y_cur = deepcopy(self.y) for offset, yi in zip(offsets, y_pred): i = None for i, vnode in enumerate(vnodes_y): # noqa: B007 if offset == vnode.start.offset: break y_cur[i] = yi code_generator = CodeGenerator(self.feature_extractor) pred_vnodes = code_generator.apply_predicted_y( self.vnodes, self.vnodes_y, list(range(len(self.vnodes_y))), FakeRules(y_cur)) generated_file = code_generator.generate(pred_vnodes) self.assertEqual(generated_file, result)
def test_vnode_positions(self): code_generator = CodeGenerator(feature_extractor=self.extractor) lines = self.code.decode("utf-8", "replace").splitlines() lines.append("\r\n") ok = True for line_number, line in FormatAnalyzer._group_line_nodes( self.y, self.y - 1, self.vnodes_y, self.vnodes, repeat(0)): line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line new_code_line = code_generator.generate_new_line(new_line_vnodes) if lines[line_number - 1] != new_code_line: print("Lines %d are different" % line_number, file=sys.stderr) print(repr(lines[line_number - 1]), file=sys.stderr) print(repr(new_code_line), file=sys.stderr) print() ok = False self.assertTrue(ok, "Original and restored lines are different")
class UASTStabilityChecker: """ Check if predictions change the UAST structure of the processed files. See `check()` or `file_check()` for more info. """ _log = getLogger("UASTStabilityChecker") _check_return_type = Tuple[numpy.ndarray, numpy.ndarray, Sequence[VirtualNode], numpy.ndarray, numpy.ndarray] def __init__(self, feature_extractor: FeatureExtractor, debug: bool=False): """ Construct a UASTStabilityChecker. :param feature_extractor: Feature extraction class that was used to generate data for \ the check. :param debug: Logs code diff for unsafe predictions with debug level. """ self._feature_extractor = feature_extractor self._code_generator = CodeGenerator(self._feature_extractor, skip_errors=False) self._parsing_cache = {} # type: Dict[int, Optional[Tuple[bblfsh.Node, int, int]]] self._debug = debug def _parse_code(self, parent: bblfsh.Node, content: str, stub: "bblfsh.aliases.ProtocolServiceStub", node_parents: Mapping[int, bblfsh.Node], path: str, ) -> Optional[Tuple[bblfsh.Node, int, int]]: """ Find a parent node that Babelfish can parse and parse it. Iterates over the parents of the current virtual node until it is parsable and returns the parsed UAST or None if it reaches the root without finding a parsable parent. The cache will be used to avoid recomputations for parents that have already been considered. :param parent: First virtual node to try to parse. Will go up in the tree if it fails. :param content: Content of the file. :param stub: Babelfish GRPC service stub. :param node_parents: Parents mapping of the input UASTs. :param path: Path of the file being parsed. :return: tuple of the parsed UAST and the corresponding starting and ending offsets. \ None if Babelfish failed to parse the whole file. """ descendants = [] current_ancestor = parent while current_ancestor is not None: if id(current_ancestor) in self._parsing_cache: result = self._parsing_cache[id(current_ancestor)] break descendants.append(current_ancestor) start, end = (current_ancestor.start_position.offset, current_ancestor.end_position.offset) uast, errors = parse_uast(stub, content[start:end], filename="", language=self._feature_extractor.language) if not errors: result = uast, start, end break current_ancestor = node_parents.get(id(current_ancestor), None) else: result = None self._log.warning("skipped file %s, due to errors in parsing the whole content", path) for descendant in descendants: self._parsing_cache[id(descendant)] = result return result def _check_file( self, y: numpy.ndarray, y_pred: numpy.ndarray, vnodes_y: Sequence[VirtualNode], vnodes: Sequence[VirtualNode], file: File, stub: "bblfsh.aliases.ProtocolServiceStub", vnode_parents: Mapping[int, bblfsh.Node], node_parents: Mapping[int, bblfsh.Node], rule_winners: numpy.ndarray, grouped_quote_predictions: QuotedNodeTripleMapping, ) -> _check_return_type: # TODO(warenlg): Add current algorithm description. # TODO(vmarkovtsev): Apply ML to not parse all the parents. self._parsing_cache = {} unsafe_preds = [] file_content = file.content.decode("utf-8", "replace") vnodes_i = 0 changes = numpy.where((y_pred != -1) & (y != y_pred))[0] start_offset_to_vnodes = {} end_offset_to_vnodes = {} for i, vnode in enumerate(vnodes): if vnode.start.offset not in start_offset_to_vnodes: # NOOP always included start_offset_to_vnodes[vnode.start.offset] = i for i, vnode in enumerate(vnodes[::-1]): if vnode.end.offset not in end_offset_to_vnodes: # NOOP always included that is why we have reverse order in this loop end_offset_to_vnodes[vnode.end.offset] = len(vnodes) - i for i in changes: vnode_y = vnodes_y[i] while vnode_y is not vnodes[vnodes_i]: vnodes_i += 1 if vnodes_i >= len(vnodes): raise AssertionError("vnodes_y and vnodes are not consistent.") if id(vnode_y) in grouped_quote_predictions: # quote types are special case group = grouped_quote_predictions[id(vnode_y)] if group is None: # already handled with the previous vnode continue vnode1, vnode2, vnode3 = group content_before = file_content[vnode1.start.offset:vnode3.end.offset] content_after = (self._feature_extractor.label_to_str(y_pred[i]) + vnode2.value + self._feature_extractor.label_to_str(y_pred[i + 1])) parsed_before, errors = parse_uast(stub, content_before, filename="", language=self._feature_extractor.language) if not errors: parsed_after, errors = parse_uast(stub, content_after, filename="", language=self._feature_extractor.language) if not self.check_uasts_equivalent(parsed_before, parsed_after): unsafe_preds.append(i) unsafe_preds.append(i + 1) # Second quote continue parsed_before = self._parse_code(vnode_parents[id(vnode_y)], file_content, stub, node_parents, vnode_y.path) if parsed_before is None: continue parent_before, start, end = parsed_before vnode_start_index = start_offset_to_vnodes[start] vnode_end_index = end_offset_to_vnodes[end] assert vnode_start_index <= vnodes_i < vnode_end_index try: content_after = self._code_generator.generate_one_change( vnodes[vnode_start_index:vnode_end_index], vnodes_i - vnode_start_index, y_pred[i]) except CodeGenerationBaseError as e: self._log.debug("Code generator can't generate code: %s", repr(e.args)) unsafe_preds.append(i) continue parent_after, errors_after = parse_uast( stub, content_after, filename="", language=self._feature_extractor.language) if errors_after: unsafe_preds.append(i) continue if not self.check_uasts_equivalent(parent_before, parent_after): if self._debug: self._log.debug( "Bad prediction\nfile:%s\nDiff:\n%s\n\n", vnode_y.path, "\n".join(line for line in difflib.unified_diff( file_content[start:end].splitlines(), content_after.splitlines(), fromfile="original", tofile="suggested", lineterm=""))) unsafe_preds.append(i) self._log.info("%d filtered out of %d with changes", len(unsafe_preds), changes.shape[0]) unsafe_preds = frozenset(unsafe_preds) safe_preds = numpy.array([i for i in range(len(y)) if i not in unsafe_preds]) vnodes_y = [vn for i, vn in enumerate(list(vnodes_y)) if i not in unsafe_preds] return y[safe_preds], y_pred[safe_preds], vnodes_y, rule_winners[safe_preds], safe_preds def check( self, y: numpy.ndarray, y_pred: numpy.ndarray, vnodes_y: Sequence[VirtualNode], vnodes: Sequence[VirtualNode], files: Sequence[File], stub: "bblfsh.aliases.ProtocolServiceStub", vnode_parents: Mapping[int, bblfsh.Node], node_parents: Mapping[int, bblfsh.Node], rule_winners: numpy.ndarray, grouped_quote_predictions: QuotedNodeTripleMapping, ) -> _check_return_type: """ Filter the model's predictions that modify the UAST apart from changing Node positions. :param y: Numpy 1-dimensional array of labels. :param y_pred: Numpy 1-dimensional array of predicted labels by the model. :param vnodes_y: Sequence of the labeled `VirtualNode`-s corresponding to labeled samples. :param vnodes: Sequence of all the `VirtualNode`-s corresponding to the input. :param files: File or Sequence of File-s with content, uast and path. :param stub: Babelfish GRPC service stub. :param vnode_parents: `VirtualNode`-s' parents mapping as the LCA of the closest \ left and right babelfish nodes. :param node_parents: Parents mapping of the input UASTs. :param rule_winners: Numpy array with the indexes of the winning rules for each sample. :param grouped_quote_predictions: Quotes predictions (handled differenlty from the rest). :return: List of predictions indices that are considered valid i.e. that are not breaking \ the UAST. """ if len(files) == 1: return self._check_file(y, y_pred, vnodes_y, vnodes, files[0], stub, vnode_parents, node_parents, rule_winners, grouped_quote_predictions) # There is more then one file in data and data splitting is required. # The logic of the next code is about splitting mostly. current_path = vnodes_y[0].path file_vnodes_indexes = {current_path: [0, None]} for i, vnode in enumerate(vnodes): if vnode.path != current_path: file_vnodes_indexes[current_path][1] = i file_vnodes_indexes[vnode.path] = [i, None] current_path = vnode.path file_vnodes_indexes[current_path][1] = len(vnodes) files = {file.path: file for file in files} current_path = vnodes_y[0].path i_start = 0 check_result = [] for i, vnode_y in enumerate(vnodes_y): if vnode_y.path != current_path or i + 1 == len(vnodes_y): file_vnodes = vnodes[file_vnodes_indexes[current_path][0]: file_vnodes_indexes[current_path][1]] check_result.append(list(self._check_file( y[i_start:i], y_pred[i_start:i], vnodes_y[i_start:i], file_vnodes, files[current_path], stub, vnode_parents, node_parents, rule_winners[i_start:i], grouped_quote_predictions))) check_result[-1][-1] += i_start current_path = vnode_y.path i_start = i res = [] for res_i in zip(*check_result): if isinstance(res_i[0], list): res.append(list(chain(*res_i))) else: res.append(numpy.concatenate(res_i)) pass return tuple(res) @staticmethod def check_uasts_equivalent(uast1: bblfsh.Node, uast2: bblfsh.Node) -> bool: """ Check if 2 UAST nodes are identical regarding `roles`, `internal_type` and `token` of \ their subtree members. :param uast1: The bblfsh.Node of the first UAST to compare. :param uast2: The bblfsh.Node of the second UAST to compare. :return: True if the 2 input UASTs are identical and False otherwise. """ queue1 = [uast1] queue2 = [uast2] while queue1 or queue2: try: node1 = queue1.pop() node2 = queue2.pop() except IndexError: return False for child1, child2 in zip(node1.children, node2.children): if (child1.roles != child2.roles or child1.internal_type != child2.internal_type or child1.token != child2.token): return False queue1.extend(node1.children) queue2.extend(node2.children) return True
def test_reproduction(self): for indent in ("local", "global"): code_generator = CodeGenerator(self.feature_extractor) generated_file = code_generator.generate(self.vnodes, indent) self.assertEqual(generated_file, self.file.content.decode("utf-8"))
def calc_metrics(bad_style_code: str, correct_style_code: str, fe: FeatureExtractor, vnodes: Sequence[VirtualNode], url: str, commit: str) -> Dict[str, Any]: """ Calculate metrics for model output. Algorithm description: 1. For a given model predictions `y_pred` we generate a new file. Now we have 3 files we should compare: 1. `bad_style_code`. The file from head revision where style mistakes where applied. We inspect this file to find them. 2. `correct_style_code` The file from base revision. We use this file to train repo format model. In the ideal case, we should be able to restore this file. 3. `predicted_style`. The file we get as format model output. 2. We compare files on a character level. To do so we has to align them first. `align3` function is used for that. There is an example: >>> bad_style_code = "import abcd" >>> correct_style_code = "import abcd" >>> predicted_code = "import abcd," >>> print(align3(bad_style_code, correct_style_code, predicted_code)) >>> Out[1]: ("import abcd␣", >>> "import ␣␣abcd␣", >>> "import ␣abcd,") 4. Now we are able to compare sequences character by character. `calc_aligned_metrics` function is used for that. We can have 5 cases here. Let's consider them in the same example: ("import abcd␣", # aligned bad_style_code "import ␣␣abcd␣", # aligned correct_style_code "import ␣abcd,") # aligned predicted_code ^ ^^ ^ 1 23 4 1. All characters are equal. Everything is fine. 2. Characters in bad style and predicted code are equal, but it is different in correct code. So, style mistake is undetected. 3. Characters in correct style and predicted code are equal, but it is different in wrong file. So, style mistake is detected and correctly fixed. 4. Characters in wrong style and correct style code are equal, but it is different in predicted code. So, new style mistake is introduced. We call this situation misdetection and want to avoid it as much as possible. 5. All characters are different. There is no such case in the example, but this means that style mistake is detected but wrongly fixed. Thus, as output we have 4 numbers: 1. style mistake misdetection 2. undetected style mistake, 3. detected style mistake with the wrong fix 4. detected style mistake with the correct fix In scientific words: 1. False positive. 2 + 3. False negative. We have two types of false negatives. First one is when the error was missed and there is no fix. Second one is when the error was found but wrongly fixed. 4. True positive. :param bad_style_code: The file from head revision where style mistakes where applied. :param correct_style_code: The file from base revision. In ideal case, we should be able to \ restore it. :param fe: Feature extraction class that was used to generate corresponding data. Set a value \ to None if no changes were introduced for `bad_style_code`. :param vnodes: Sequence of all the `VirtualNode`-s corresponding to the input code file. \ Should be ordered by position. New y values should be applied. :param url: Repository url if applicable. Useful for more informative warning messages. :param commit: Commit hash if applicable. Useful for more informative warning messages. :return: A dictionary with losses and predicted code. """ predicted_code = CodeGenerator(fe, skip_errors=True, url=url, commit=commit).generate(vnodes) misdetection, undetected, detected_wrong_fix, detected_correct_fix = \ calc_aligned_metrics(*align3(bad_style_code, correct_style_code, predicted_code)) losses = { "misdetection": misdetection, "undetected": undetected, "detected_wrong_fix": detected_wrong_fix, "detected_correct_fix": detected_correct_fix, "predicted_file": predicted_code, } return losses
def test_reproduction(self): code_generator = CodeGenerator(self.feature_extractor) generated_file = code_generator.generate(self.vnodes) self.assertEqual(generated_file, self.file.content)
def test_apply_new_indentation(self): cases = [ ("\n ", ("\n", " "), (cls.CLS_NEWLINE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, ), ""), ("\n ", ("\n", " "), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC), (cls.CLS_NEWLINE, ), ""), ("\n\t ", ("\n", ""), (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, ), ""), ("\n ", InapplicableIndentation, (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_TAB_INC), (cls.CLS_NEWLINE, ), ""), ("\n ", ValueError, (cls.CLS_NEWLINE, cls.CLS_SPACE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, ), ""), ("\n\t ", InapplicableIndentation, (cls.CLS_NEWLINE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC), ""), ("\n\t ", ValueError, (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE, cls.CLS_SPACE_DEC), ""), ("\n\n ", ("\n", " "), (cls.CLS_NEWLINE, cls.CLS_NEWLINE, cls.CLS_SPACE_DEC), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC), ""), ("", ("\n", " "), (cls.CLS_NOOP, ), (cls.CLS_NEWLINE, ), " "), ("", ("\n\n", ""), (cls.CLS_NOOP, ), (cls.CLS_NEWLINE, cls.CLS_NEWLINE), ""), ] for value, result, y_old, y, last_ident in cases: vnode = VirtualNode(value, Position(0, 1, 1), Position(len(y), 1, len(y) + 1), y=tuple(cls.CLASS_INDEX[i] for i in y)) vnode.y_old = tuple(cls.CLASS_INDEX[i] for i in y_old) if isinstance(result, tuple): self.assertEqual( CodeGenerator.apply_new_indentation(vnode, last_ident), result) else: with self.assertRaises(result): CodeGenerator.apply_new_indentation(vnode, last_ident) msg = None def _warning(*args): nonlocal msg msg = args[0] try: backup_warning = CodeGenerator._log.warning CodeGenerator._log.warning = _warning vnode = VirtualNode( "\n ", Position(0, 1, 1), Position(3, 1, 4), y=tuple(cls.CLASS_INDEX[i] for i in (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC))) vnode.y_old = tuple(cls.CLASS_INDEX[i] for i in (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC)) CodeGenerator.apply_new_indentation(vnode, "") expected_msg = "There is no indentation characters left to decrease for vnode" self.assertEqual(msg[:len(expected_msg)], expected_msg) finally: CodeGenerator._log.warning = backup_warning
def test_generate_new_line(self): self.maxDiff = None expected_res = { "nothing changed": [], "remove new line in the end of 4th line": None, "indentation in the beginning": [" import { makeToast } from '../../common/app/Toasts/redux';"], "remove indentation in the 4th line till the end": [" return Object.keys(flash)", " }"], "new line between 6th and 7th regular code lines": ["\n return messages.map(message => ({"], "new line in the middle of the 7th code line with indentation increase": [" return messages\n .map(message => ({", " })"], "new line in the middle of the 7th code line with indentation decrease": [" return messages\n .map(message => ({", " })"], "new line in the middle of the 7th code line without indentation increase": [" return messages\n .map(message => ({"], "change quotes": ['import { makeToast } from "../../common/app/Toasts/redux";'], "remove indentation decrease 11th line": [" }));"], "change indentation decrease to indentation increase 11th line": [" }));"], "change indentation decrease to indentation increase 11th line but keep the rest": [" }));", "})"], } base = Path(__file__).parent # str() is needed for Python 3.5 with lzma.open(str(base / "benchmark_small.js.xz"), mode="rt") as fin: contents = fin.read() with lzma.open(str(base / "benchmark_small.js.uast.xz")) as fin: uast = bblfsh.Node.FromString(fin.read()) config = FormatAnalyzer._load_config(get_config()) fe_config = config["train"]["javascript"] for case in expected_res: offsets, y_pred, _ = cases[case] feature_extractor = FeatureExtractor( language="javascript", label_composites=label_composites, **fe_config["feature_extractor"]) file = UnicodeFile(content=contents, uast=uast, path="", language="") X, y, (vnodes_y, vnodes, vnode_parents, node_parents) = \ feature_extractor.extract_features([file]) y_cur = deepcopy(y) for offset, yi in zip(offsets, y_pred): i = None for i, vnode in enumerate(vnodes_y): # noqa: B007 if offset == vnode.start.offset: break y_cur[i] = yi code_generator = CodeGenerator(feature_extractor) pred_vnodes = code_generator.apply_predicted_y( vnodes, vnodes_y, list(range(len(vnodes_y))), FakeRules(y_cur)) res = [] for gln in FormatAnalyzer._group_line_nodes( y, y_cur, vnodes_y, pred_vnodes, [1] * len(y)): line, (line_y, line_y_pred, line_vnodes_y, line_vnodes, line_rule_winners) = gln new_code_line = code_generator.generate_new_line(line_vnodes) res.append(new_code_line) if expected_res[case] is not None: # None means that we delete some lines. We are not handle this properly now. self.assertEqual(res, expected_res[case], case)