Esempio n. 1
0
 def test_revert_indentation_change(self):
     cases = [
         ("\n    ", (cls.CLS_NEWLINE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC),
          "\n  "),
         ("\n    ", (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC),
          "\n      "),
         ("\n\t ", (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_SPACE_INC),
          "\n"),
         ("\n    ", (cls.CLS_NEWLINE, cls.CLS_TAB_INC, cls.CLS_TAB_INC),
          InapplicableIndentation),
         ("   ", (cls.CLS_SPACE, cls.CLS_SPACE_INC, cls.CLS_SPACE_INC),
          ValueError),
     ]
     for value, y, result in cases:
         vnode = VirtualNode(value,
                             Position(0, 1, 1),
                             Position(len(value), 1,
                                      len(value) + 1),
                             y=tuple(cls.CLASS_INDEX[i] for i in y))
         if isinstance(result, str):
             self.assertEqual(
                 CodeGenerator.revert_indentation_change(vnode), result)
         else:
             with self.assertRaises(result):
                 CodeGenerator.revert_indentation_change(vnode)
Esempio n. 2
0
 def _test(self):
     y_cur = deepcopy(self.y)
     for i, yi in zip(y_indexes, y_pred):
         y_cur[i] = yi
     code_generator = CodeGenerator(self.feature_extractor)
     pred_vnodes = code_generator.apply_predicted_y(
         self.vnodes, self.vnodes_y, list(range(len(self.vnodes_y))),
         FakeRules(y_cur))
     generated_file = code_generator.generate(pred_vnodes, "local")
     self.assertEqual(generated_file, result_local)
Esempio n. 3
0
 def _generate_token_fixes(
     self,
     file: File,
     fe: FeatureExtractor,
     feature_extractor_output,
     bblfsh_stub: "bblfsh.aliases.ProtocolServiceStub",
     rules: Rules,
 ) -> Tuple[List[LineFix], List[VirtualNode], numpy.ndarray, numpy.ndarray]:
     X, y, (vnodes_y, vnodes, vnode_parents,
            node_parents) = feature_extractor_output
     y_pred_pure, rule_winners, new_rules, grouped_quote_predictions = rules.predict(
         X=X, vnodes_y=vnodes_y, vnodes=vnodes, feature_extractor=fe)
     y_pred = rules.fill_missing_predictions(y_pred_pure, y)
     if self.config["uast_break_check"]:
         y, y_pred, vnodes_y, rule_winners, safe_preds = filter_uast_breaking_preds(
             y=y,
             y_pred=y_pred,
             vnodes_y=vnodes_y,
             vnodes=vnodes,
             files={file.path: file},
             feature_extractor=fe,
             stub=bblfsh_stub,
             vnode_parents=vnode_parents,
             node_parents=node_parents,
             rule_winners=rule_winners,
             grouped_quote_predictions=grouped_quote_predictions)
         y_pred_pure = y_pred_pure[safe_preds]
     assert len(y) == len(y_pred)
     assert len(y) == len(rule_winners)
     code_generator = CodeGenerator(fe, skip_errors=True)
     new_vnodes = code_generator.apply_predicted_y(vnodes, vnodes_y,
                                                   rule_winners, new_rules)
     token_fixes = []
     for line_number, line in self._group_line_nodes(
             y, y_pred, vnodes_y, new_vnodes, rule_winners):
         line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line
         new_code_line = code_generator.generate(
             new_line_vnodes, "local").lstrip("\n").splitlines()[0]
         confidence = self._get_comment_confidence(line_ys, line_ys_pred,
                                                   line_winners, new_rules)
         fixed_vnodes = [
             vnode for vnode in new_line_vnodes
             if hasattr(vnode, "y_old") and vnode.y_old != vnode.y
         ]
         token_fixes.append(
             LineFix(
                 line_number=line_number,  # line number for the comment
                 suggested_code=
                 new_code_line,  # code line suggested by our model
                 fixed_vnodes=fixed_vnodes,  # VirtualNode-s with changed y
                 confidence=
                 confidence,  # overall confidence in the prediction, 0-100
             ))
     return token_fixes, new_vnodes, y_pred_pure, y
Esempio n. 4
0
    def __init__(self, feature_extractor: FeatureExtractor, debug: bool=False):
        """
        Construct a UASTStabilityChecker.

        :param feature_extractor: Feature extraction class that was used to generate data for \
                                  the check.
        :param debug: Logs code diff for unsafe predictions with debug level.
        """
        self._feature_extractor = feature_extractor
        self._code_generator = CodeGenerator(self._feature_extractor, skip_errors=False)
        self._parsing_cache = {}  # type: Dict[int, Optional[Tuple[bblfsh.Node, int, int]]]
        self._debug = debug
Esempio n. 5
0
 def _generate_token_fixes(
         self, file: File, fe: FeatureExtractor, feature_extractor_output,
         bblfsh_stub: "bblfsh.aliases.ProtocolServiceStub", rules: Rules,
 ) -> Tuple[List[LineFix], List[VirtualNode], numpy.ndarray, numpy.ndarray]:
     X, y, (vnodes_y, vnodes, vnode_parents, node_parents) = feature_extractor_output
     y_pred_pure, rule_winners, new_rules, grouped_quote_predictions = rules.predict(
         X=X, vnodes_y=vnodes_y, vnodes=vnodes, feature_extractor=fe)
     y_pred = rules.fill_missing_predictions(y_pred_pure, y)
     if self.analyze_config[file.language.lower()]["uast_break_check"]:
         checker = UASTStabilityChecker(fe)
         y, y_pred, vnodes_y, rule_winners, safe_preds = checker.check(
             y=y, y_pred=y_pred, vnodes_y=vnodes_y, vnodes=vnodes, files=[file],
             stub=bblfsh_stub, vnode_parents=vnode_parents, node_parents=node_parents,
             rule_winners=rule_winners, grouped_quote_predictions=grouped_quote_predictions)
         y_pred_pure = y_pred_pure[safe_preds]
     assert len(y) == len(y_pred)
     assert len(y) == len(rule_winners)
     code_generator = CodeGenerator(fe, skip_errors=True)
     new_vnodes = code_generator.apply_predicted_y(vnodes, vnodes_y, rule_winners, new_rules)
     token_fixes = []
     newline_index = CLASS_INDEX[CLS_NEWLINE]
     for line_number, line in self._group_line_nodes(
             y, y_pred, vnodes_y, new_vnodes, rule_winners):
         line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line
         try:
             new_code_line = code_generator.generate_new_line(new_line_vnodes)
         except Exception:
             self._log.exception(
                 "Failed to generate new line suggestion for line %d in %s. line vnodes:\n%s",
                 line_number, file.path, "\n".join(
                     "%s, y_old=%s" % (repr(vn), getattr(vn, "y_old", "N/A"))
                     for vn in new_line_vnodes))
             new_code_line = None
         if (new_line_vnodes and hasattr(new_line_vnodes[0], "y_old") and newline_index in
                 new_line_vnodes[0].y_old):
             lines_num_diff = new_line_vnodes[0].y.count(newline_index) - \
                              new_line_vnodes[0].y_old.count(newline_index)
             if lines_num_diff < 0:
                 # Some lines were removed. This means that several original lines should be
                 # modified. GitHub Suggested Change feature cannot handle such cases right now.
                 # To not confuse the user we do not provide any code suggestion.
                 new_code_line = None
         confidence = self._get_comment_confidence(line_ys, line_ys_pred, line_winners,
                                                   new_rules)
         fixed_vnodes = [vnode for vnode in new_line_vnodes if
                         hasattr(vnode, "y_old") and vnode.y_old != vnode.y]
         token_fixes.append(LineFix(
             line_number=line_number,        # line number for the comment
             suggested_code=new_code_line,   # code line suggested by our model
             fixed_vnodes=fixed_vnodes,      # VirtualNode-s with changed y
             confidence=confidence,          # overall confidence in the prediction, 0-100
         ))
     return token_fixes, new_vnodes, y_pred_pure, y
Esempio n. 6
0
 def _test(self):
     y_cur = deepcopy(self.y)
     for offset, yi in zip(offsets, y_pred):
         i = None
         for i, vnode in enumerate(vnodes_y):  # noqa: B007
             if offset == vnode.start.offset:
                 break
         y_cur[i] = yi
     code_generator = CodeGenerator(self.feature_extractor)
     pred_vnodes = code_generator.apply_predicted_y(
         self.vnodes, self.vnodes_y, list(range(len(self.vnodes_y))),
         FakeRules(y_cur))
     generated_file = code_generator.generate(pred_vnodes)
     self.assertEqual(generated_file, result)
Esempio n. 7
0
 def test_vnode_positions(self):
     code_generator = CodeGenerator(feature_extractor=self.extractor)
     lines = self.code.decode("utf-8", "replace").splitlines()
     lines.append("\r\n")
     ok = True
     for line_number, line in FormatAnalyzer._group_line_nodes(
             self.y, self.y - 1, self.vnodes_y, self.vnodes, repeat(0)):
         line_ys, line_ys_pred, line_vnodes_y, new_line_vnodes, line_winners = line
         new_code_line = code_generator.generate_new_line(new_line_vnodes)
         if lines[line_number - 1] != new_code_line:
             print("Lines %d are different" % line_number, file=sys.stderr)
             print(repr(lines[line_number - 1]), file=sys.stderr)
             print(repr(new_code_line), file=sys.stderr)
             print()
             ok = False
     self.assertTrue(ok, "Original and restored lines are different")
Esempio n. 8
0
class UASTStabilityChecker:
    """
    Check if predictions change the UAST structure of the processed files.

    See `check()` or `file_check()` for more info.
    """

    _log = getLogger("UASTStabilityChecker")
    _check_return_type = Tuple[numpy.ndarray, numpy.ndarray, Sequence[VirtualNode], numpy.ndarray,
                               numpy.ndarray]

    def __init__(self, feature_extractor: FeatureExtractor, debug: bool=False):
        """
        Construct a UASTStabilityChecker.

        :param feature_extractor: Feature extraction class that was used to generate data for \
                                  the check.
        :param debug: Logs code diff for unsafe predictions with debug level.
        """
        self._feature_extractor = feature_extractor
        self._code_generator = CodeGenerator(self._feature_extractor, skip_errors=False)
        self._parsing_cache = {}  # type: Dict[int, Optional[Tuple[bblfsh.Node, int, int]]]
        self._debug = debug

    def _parse_code(self, parent: bblfsh.Node, content: str,
                    stub: "bblfsh.aliases.ProtocolServiceStub",
                    node_parents: Mapping[int, bblfsh.Node], path: str,
                    ) -> Optional[Tuple[bblfsh.Node, int, int]]:
        """
        Find a parent node that Babelfish can parse and parse it.

        Iterates over the parents of the current virtual node until it is parsable and returns the
        parsed UAST or None if it reaches the root without finding a parsable parent.

        The cache will be used to avoid recomputations for parents that have already been
        considered.

        :param parent: First virtual node to try to parse. Will go up in the tree if it fails.
        :param content: Content of the file.
        :param stub: Babelfish GRPC service stub.
        :param node_parents: Parents mapping of the input UASTs.
        :param path: Path of the file being parsed.
        :return: tuple of the parsed UAST and the corresponding starting and ending offsets. \
                 None if Babelfish failed to parse the whole file.
        """
        descendants = []
        current_ancestor = parent
        while current_ancestor is not None:
            if id(current_ancestor) in self._parsing_cache:
                result = self._parsing_cache[id(current_ancestor)]
                break
            descendants.append(current_ancestor)
            start, end = (current_ancestor.start_position.offset,
                          current_ancestor.end_position.offset)
            uast, errors = parse_uast(stub, content[start:end], filename="",
                                      language=self._feature_extractor.language)
            if not errors:
                result = uast, start, end
                break
            current_ancestor = node_parents.get(id(current_ancestor), None)
        else:
            result = None
            self._log.warning("skipped file %s, due to errors in parsing the whole content", path)
        for descendant in descendants:
            self._parsing_cache[id(descendant)] = result
        return result

    def _check_file(
            self, y: numpy.ndarray, y_pred: numpy.ndarray, vnodes_y: Sequence[VirtualNode],
            vnodes: Sequence[VirtualNode], file: File, stub: "bblfsh.aliases.ProtocolServiceStub",
            vnode_parents: Mapping[int, bblfsh.Node], node_parents: Mapping[int, bblfsh.Node],
            rule_winners: numpy.ndarray, grouped_quote_predictions: QuotedNodeTripleMapping,
    ) -> _check_return_type:
        # TODO(warenlg): Add current algorithm description.
        # TODO(vmarkovtsev): Apply ML to not parse all the parents.
        self._parsing_cache = {}
        unsafe_preds = []
        file_content = file.content.decode("utf-8", "replace")
        vnodes_i = 0
        changes = numpy.where((y_pred != -1) & (y != y_pred))[0]
        start_offset_to_vnodes = {}
        end_offset_to_vnodes = {}
        for i, vnode in enumerate(vnodes):
            if vnode.start.offset not in start_offset_to_vnodes:
                # NOOP always included
                start_offset_to_vnodes[vnode.start.offset] = i
        for i, vnode in enumerate(vnodes[::-1]):
            if vnode.end.offset not in end_offset_to_vnodes:
                # NOOP always included that is why we have reverse order in this loop
                end_offset_to_vnodes[vnode.end.offset] = len(vnodes) - i
        for i in changes:
            vnode_y = vnodes_y[i]
            while vnode_y is not vnodes[vnodes_i]:
                vnodes_i += 1
                if vnodes_i >= len(vnodes):
                    raise AssertionError("vnodes_y and vnodes are not consistent.")
            if id(vnode_y) in grouped_quote_predictions:
                # quote types are special case
                group = grouped_quote_predictions[id(vnode_y)]
                if group is None:
                    # already handled with the previous vnode
                    continue
                vnode1, vnode2, vnode3 = group
                content_before = file_content[vnode1.start.offset:vnode3.end.offset]
                content_after = (self._feature_extractor.label_to_str(y_pred[i]) + vnode2.value +
                                 self._feature_extractor.label_to_str(y_pred[i + 1]))
                parsed_before, errors = parse_uast(stub, content_before, filename="",
                                                   language=self._feature_extractor.language)
                if not errors:
                    parsed_after, errors = parse_uast(stub, content_after, filename="",
                                                      language=self._feature_extractor.language)
                    if not self.check_uasts_equivalent(parsed_before, parsed_after):
                        unsafe_preds.append(i)
                        unsafe_preds.append(i + 1)  # Second quote
                continue

            parsed_before = self._parse_code(vnode_parents[id(vnode_y)], file_content, stub,
                                             node_parents, vnode_y.path)
            if parsed_before is None:
                continue
            parent_before, start, end = parsed_before
            vnode_start_index = start_offset_to_vnodes[start]
            vnode_end_index = end_offset_to_vnodes[end]

            assert vnode_start_index <= vnodes_i < vnode_end_index
            try:
                content_after = self._code_generator.generate_one_change(
                    vnodes[vnode_start_index:vnode_end_index],
                    vnodes_i - vnode_start_index, y_pred[i])
            except CodeGenerationBaseError as e:
                self._log.debug("Code generator can't generate code: %s", repr(e.args))
                unsafe_preds.append(i)
                continue
            parent_after, errors_after = parse_uast(
                stub, content_after, filename="", language=self._feature_extractor.language)
            if errors_after:
                unsafe_preds.append(i)
                continue
            if not self.check_uasts_equivalent(parent_before, parent_after):
                if self._debug:
                    self._log.debug(
                        "Bad prediction\nfile:%s\nDiff:\n%s\n\n", vnode_y.path,
                        "\n".join(line for line in difflib.unified_diff(
                            file_content[start:end].splitlines(), content_after.splitlines(),
                            fromfile="original", tofile="suggested", lineterm="")))
                unsafe_preds.append(i)
        self._log.info("%d filtered out of %d with changes", len(unsafe_preds), changes.shape[0])
        unsafe_preds = frozenset(unsafe_preds)
        safe_preds = numpy.array([i for i in range(len(y)) if i not in unsafe_preds])
        vnodes_y = [vn for i, vn in enumerate(list(vnodes_y)) if i not in unsafe_preds]
        return y[safe_preds], y_pred[safe_preds], vnodes_y, rule_winners[safe_preds], safe_preds

    def check(
            self, y: numpy.ndarray, y_pred: numpy.ndarray, vnodes_y: Sequence[VirtualNode],
            vnodes: Sequence[VirtualNode], files: Sequence[File],
            stub: "bblfsh.aliases.ProtocolServiceStub", vnode_parents: Mapping[int, bblfsh.Node],
            node_parents: Mapping[int, bblfsh.Node], rule_winners: numpy.ndarray,
            grouped_quote_predictions: QuotedNodeTripleMapping,
    ) -> _check_return_type:
        """
        Filter the model's predictions that modify the UAST apart from changing Node positions.

        :param y: Numpy 1-dimensional array of labels.
        :param y_pred: Numpy 1-dimensional array of predicted labels by the model.
        :param vnodes_y: Sequence of the labeled `VirtualNode`-s corresponding to labeled samples.
        :param vnodes: Sequence of all the `VirtualNode`-s corresponding to the input.
        :param files: File or Sequence of File-s with content, uast and path.
        :param stub: Babelfish GRPC service stub.
        :param vnode_parents: `VirtualNode`-s' parents mapping as the LCA of the closest \
                               left and right babelfish nodes.
        :param node_parents: Parents mapping of the input UASTs.
        :param rule_winners: Numpy array with the indexes of the winning rules for each sample.
        :param grouped_quote_predictions: Quotes predictions (handled differenlty from the rest).
        :return: List of predictions indices that are considered valid i.e. that are not breaking \
                 the UAST.
        """
        if len(files) == 1:
            return self._check_file(y, y_pred, vnodes_y, vnodes, files[0], stub, vnode_parents,
                                    node_parents, rule_winners, grouped_quote_predictions)
        # There is more then one file in data and data splitting is required.
        # The logic of the next code is about splitting mostly.
        current_path = vnodes_y[0].path
        file_vnodes_indexes = {current_path: [0, None]}
        for i, vnode in enumerate(vnodes):
            if vnode.path != current_path:
                file_vnodes_indexes[current_path][1] = i
                file_vnodes_indexes[vnode.path] = [i, None]
                current_path = vnode.path
        file_vnodes_indexes[current_path][1] = len(vnodes)
        files = {file.path: file for file in files}
        current_path = vnodes_y[0].path
        i_start = 0
        check_result = []
        for i, vnode_y in enumerate(vnodes_y):
            if vnode_y.path != current_path or i + 1 == len(vnodes_y):
                file_vnodes = vnodes[file_vnodes_indexes[current_path][0]:
                                     file_vnodes_indexes[current_path][1]]
                check_result.append(list(self._check_file(
                    y[i_start:i], y_pred[i_start:i], vnodes_y[i_start:i], file_vnodes,
                    files[current_path], stub, vnode_parents, node_parents,
                    rule_winners[i_start:i], grouped_quote_predictions)))
                check_result[-1][-1] += i_start
                current_path = vnode_y.path
                i_start = i

        res = []
        for res_i in zip(*check_result):
            if isinstance(res_i[0], list):
                res.append(list(chain(*res_i)))
            else:
                res.append(numpy.concatenate(res_i))
            pass
        return tuple(res)

    @staticmethod
    def check_uasts_equivalent(uast1: bblfsh.Node, uast2: bblfsh.Node) -> bool:
        """
        Check if 2 UAST nodes are identical regarding `roles`, `internal_type` and `token` of \
        their subtree members.

        :param uast1: The bblfsh.Node of the first UAST to compare.
        :param uast2: The bblfsh.Node of the second UAST to compare.
        :return: True if the 2 input UASTs are identical and False otherwise.
        """
        queue1 = [uast1]
        queue2 = [uast2]
        while queue1 or queue2:
            try:
                node1 = queue1.pop()
                node2 = queue2.pop()
            except IndexError:
                return False
            for child1, child2 in zip(node1.children, node2.children):
                if (child1.roles != child2.roles or child1.internal_type != child2.internal_type
                        or child1.token != child2.token):
                    return False
            queue1.extend(node1.children)
            queue2.extend(node2.children)
        return True
Esempio n. 9
0
 def test_reproduction(self):
     for indent in ("local", "global"):
         code_generator = CodeGenerator(self.feature_extractor)
         generated_file = code_generator.generate(self.vnodes, indent)
         self.assertEqual(generated_file, self.file.content.decode("utf-8"))
Esempio n. 10
0
def calc_metrics(bad_style_code: str, correct_style_code: str,
                 fe: FeatureExtractor, vnodes: Sequence[VirtualNode], url: str,
                 commit: str) -> Dict[str, Any]:
    """
    Calculate metrics for model output.

    Algorithm description:
    1. For a given model predictions `y_pred` we generate a new file.
       Now we have 3 files we should compare:
       1. `bad_style_code`. The file from head revision where style mistakes where applied.
          We inspect this file to find them.
       2. `correct_style_code` The file from base revision. We use this file to train repo format
          model. In the ideal case, we should be able to restore this file.
       3. `predicted_style`. The file we get as format model output.
    2. We compare files on a character level. To do so we has to align them first.
       `align3` function is used for that. There is an example:
    >>> bad_style_code = "import   abcd"
    >>> correct_style_code = "import abcd"
    >>> predicted_code = "import  abcd,"
    >>> print(align3(bad_style_code, correct_style_code, predicted_code))
    >>> Out[1]: ("import   abcd␣",
    >>>          "import ␣␣abcd␣",
    >>>          "import  ␣abcd,")
    4. Now we are able to compare sequences character by character. `calc_aligned_metrics` function
       is used for that. We can have 5 cases here. Let's consider them in the same example:
       ("import   abcd␣",  # aligned bad_style_code
        "import ␣␣abcd␣",  # aligned correct_style_code
        "import  ␣abcd,")  # aligned predicted_code
         ^      ^^    ^
         1      23    4

         1. All characters are equal. Everything is fine.
         2. Characters in bad style and predicted code are equal, but it is different in correct
            code. So, style mistake is undetected.
         3. Characters in correct style and predicted code are equal, but it is different in wrong
            file. So, style mistake is detected and correctly fixed.
         4. Characters in wrong style and correct style code are equal, but it is different in
            predicted code. So, new style mistake is introduced. We call this situation
            misdetection and want to avoid it as much as possible.
         5. All characters are different. There is no such case in the example, but this means that
            style mistake is detected but wrongly fixed.

         Thus, as output we have 4 numbers:
         1. style mistake misdetection
         2. undetected style mistake,
         3. detected style mistake with the wrong fix
         4. detected style mistake with the correct fix

         In scientific words:
         1. False positive.
         2 + 3. False negative. We have two types of false negatives. First one is when the error
                was missed and there is no fix. Second one is when the error was found but wrongly
                fixed.
         4. True positive.

    :param bad_style_code: The file from head revision where style mistakes where applied.
    :param correct_style_code: The file from base revision. In ideal case, we should be able to \
                               restore it.
    :param fe: Feature extraction class that was used to generate corresponding data. Set a value \
            to None if no changes were introduced for `bad_style_code`.
    :param vnodes: Sequence of all the `VirtualNode`-s corresponding to the input code file. \
                   Should be ordered by position. New y values should be applied.
    :param url: Repository url if applicable. Useful for more informative warning messages.
    :param commit: Commit hash if applicable. Useful for more informative warning messages.

    :return: A dictionary with losses and predicted code.
    """
    predicted_code = CodeGenerator(fe,
                                   skip_errors=True,
                                   url=url,
                                   commit=commit).generate(vnodes)
    misdetection, undetected, detected_wrong_fix, detected_correct_fix = \
        calc_aligned_metrics(*align3(bad_style_code, correct_style_code, predicted_code))
    losses = {
        "misdetection": misdetection,
        "undetected": undetected,
        "detected_wrong_fix": detected_wrong_fix,
        "detected_correct_fix": detected_correct_fix,
        "predicted_file": predicted_code,
    }
    return losses
Esempio n. 11
0
 def test_reproduction(self):
     code_generator = CodeGenerator(self.feature_extractor)
     generated_file = code_generator.generate(self.vnodes)
     self.assertEqual(generated_file, self.file.content)
Esempio n. 12
0
    def test_apply_new_indentation(self):
        cases = [
            ("\n    ", ("\n", "  "), (cls.CLS_NEWLINE, cls.CLS_SPACE_INC,
                                      cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, ),
             ""),
            ("\n    ", ("\n", "      "), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC,
                                          cls.CLS_SPACE_DEC),
             (cls.CLS_NEWLINE, ), ""),
            ("\n\t ", ("\n", ""), (cls.CLS_NEWLINE, cls.CLS_TAB_INC,
                                   cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, ),
             ""),
            ("\n    ", InapplicableIndentation,
             (cls.CLS_NEWLINE, cls.CLS_TAB_INC,
              cls.CLS_TAB_INC), (cls.CLS_NEWLINE, ), ""),
            ("\n   ", ValueError, (cls.CLS_NEWLINE, cls.CLS_SPACE,
                                   cls.CLS_SPACE_INC, cls.CLS_SPACE_INC),
             (cls.CLS_NEWLINE, ), ""),
            ("\n\t  ", InapplicableIndentation,
             (cls.CLS_NEWLINE, cls.CLS_SPACE_INC,
              cls.CLS_SPACE_INC), (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC), ""),
            ("\n\t   ", ValueError, (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC),
             (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE,
              cls.CLS_SPACE_DEC), ""),
            ("\n\n    ", ("\n", "  "), (cls.CLS_NEWLINE, cls.CLS_NEWLINE,
                                        cls.CLS_SPACE_DEC),
             (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC,
              cls.CLS_SPACE_DEC), ""),
            ("", ("\n", "  "), (cls.CLS_NOOP, ), (cls.CLS_NEWLINE, ), "  "),
            ("", ("\n\n", ""), (cls.CLS_NOOP, ), (cls.CLS_NEWLINE,
                                                  cls.CLS_NEWLINE), ""),
        ]
        for value, result, y_old, y, last_ident in cases:
            vnode = VirtualNode(value,
                                Position(0, 1, 1),
                                Position(len(y), 1,
                                         len(y) + 1),
                                y=tuple(cls.CLASS_INDEX[i] for i in y))
            vnode.y_old = tuple(cls.CLASS_INDEX[i] for i in y_old)
            if isinstance(result, tuple):
                self.assertEqual(
                    CodeGenerator.apply_new_indentation(vnode, last_ident),
                    result)
            else:
                with self.assertRaises(result):
                    CodeGenerator.apply_new_indentation(vnode, last_ident)

        msg = None

        def _warning(*args):
            nonlocal msg
            msg = args[0]

        try:
            backup_warning = CodeGenerator._log.warning
            CodeGenerator._log.warning = _warning
            vnode = VirtualNode(
                "\n ",
                Position(0, 1, 1),
                Position(3, 1, 4),
                y=tuple(cls.CLASS_INDEX[i]
                        for i in (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC,
                                  cls.CLS_SPACE_DEC, cls.CLS_SPACE_DEC)))
            vnode.y_old = tuple(cls.CLASS_INDEX[i]
                                for i in (cls.CLS_NEWLINE, cls.CLS_SPACE_DEC))
            CodeGenerator.apply_new_indentation(vnode, "")
            expected_msg = "There is no indentation characters left to decrease for vnode"
            self.assertEqual(msg[:len(expected_msg)], expected_msg)
        finally:
            CodeGenerator._log.warning = backup_warning
Esempio n. 13
0
    def test_generate_new_line(self):
        self.maxDiff = None
        expected_res = {
            "nothing changed": [],
            "remove new line in the end of 4th line":
            None,
            "indentation in the beginning":
            [" import { makeToast } from '../../common/app/Toasts/redux';"],
            "remove indentation in the 4th line till the end":
            [" return Object.keys(flash)", " }"],
            "new line between 6th and 7th regular code lines":
            ["\n      return messages.map(message => ({"],
            "new line in the middle of the 7th code line with indentation increase":
            ["      return messages\n        .map(message => ({", "  })"],
            "new line in the middle of the 7th code line with indentation decrease":
            ["      return messages\n    .map(message => ({", "      })"],
            "new line in the middle of the 7th code line without indentation increase":
            ["      return messages\n      .map(message => ({"],
            "change quotes":
            ['import { makeToast } from "../../common/app/Toasts/redux";'],
            "remove indentation decrease 11th line": ["        }));"],
            "change indentation decrease to indentation increase 11th line":
            ["          }));"],
            "change indentation decrease to indentation increase 11th line but keep the rest":
            ["          }));", "})"],
        }

        base = Path(__file__).parent
        # str() is needed for Python 3.5
        with lzma.open(str(base / "benchmark_small.js.xz"), mode="rt") as fin:
            contents = fin.read()
        with lzma.open(str(base / "benchmark_small.js.uast.xz")) as fin:
            uast = bblfsh.Node.FromString(fin.read())
        config = FormatAnalyzer._load_config(get_config())
        fe_config = config["train"]["javascript"]

        for case in expected_res:
            offsets, y_pred, _ = cases[case]
            feature_extractor = FeatureExtractor(
                language="javascript",
                label_composites=label_composites,
                **fe_config["feature_extractor"])
            file = UnicodeFile(content=contents,
                               uast=uast,
                               path="",
                               language="")
            X, y, (vnodes_y, vnodes, vnode_parents, node_parents) = \
                feature_extractor.extract_features([file])
            y_cur = deepcopy(y)
            for offset, yi in zip(offsets, y_pred):
                i = None
                for i, vnode in enumerate(vnodes_y):  # noqa: B007
                    if offset == vnode.start.offset:
                        break
                y_cur[i] = yi
            code_generator = CodeGenerator(feature_extractor)
            pred_vnodes = code_generator.apply_predicted_y(
                vnodes, vnodes_y, list(range(len(vnodes_y))), FakeRules(y_cur))
            res = []
            for gln in FormatAnalyzer._group_line_nodes(
                    y, y_cur, vnodes_y, pred_vnodes, [1] * len(y)):
                line, (line_y, line_y_pred, line_vnodes_y, line_vnodes,
                       line_rule_winners) = gln
                new_code_line = code_generator.generate_new_line(line_vnodes)
                res.append(new_code_line)
            if expected_res[case] is not None:
                # None means that we delete some lines. We are not handle this properly now.
                self.assertEqual(res, expected_res[case], case)