Beispiel #1
0
 def test_extract_precedence_sequence(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=[
                     m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")),
                     m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")),
                 ])),
         ]),
     )
     extracted_node = (cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args[1].value)
     self.assertEqual(nodes, {"arg": extracted_node})
Beispiel #2
0
    def visit_With(self, node: cst.With):
        with_names = [
            n.value for n in match.findall(
                match.extract(
                    node,
                    match.With(items=match.SaveMatchedNode(
                        match.DoNotCare(), 'with_items')))['with_items'][0],
                match.Name(
                    value=match.SaveMatchedNode(match.DoNotCare(), 'name')))
        ]
        if len(self.stack) > 0:
            self.fn_may_args_var_use.append(with_names)

        if len(self.cls_stack) > 0:
            if self.cls_stack[0].name in with_names:
                self.cls_may_vars_use.append(with_names)

        self.__find_module_vars_use(with_names)
Beispiel #3
0
 def test_extract_optional_wildcard_present(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=[
                     m.DoNotCare(),
                     m.DoNotCare(),
                     m.ZeroOrOne(
                         m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))),
                 ])),
         ]),
     )
     extracted_node = (cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args[2].value)
     self.assertEqual(nodes, {"arg": extracted_node})
Beispiel #4
0
 def test_extract_sequence_multiple_wildcards(self) -> None:
     expression = cst.parse_expression("1, 2, 3, 4")
     nodes = m.extract(
         expression,
         m.Tuple(elements=(
             m.SaveMatchedNode(m.ZeroOrMore(), "head"),
             m.SaveMatchedNode(m.Element(value=m.Integer(
                 value="3")), "element"),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail"),
         )),
     )
     tuple_elements = cst.ensure_type(expression, cst.Tuple).elements
     self.assertEqual(
         nodes,
         {
             "head": tuple(tuple_elements[:2]),
             "element": tuple_elements[2],
             "tail": tuple(tuple_elements[3:]),
         },
     )
Beispiel #5
0
    def visit_ImportAlias(self, node: cst.ImportAlias):
        """
        Extracts imports.

        Even if an Import has no alias, the Import node will still have an ImportAlias
        with it's real name.
        """
        # Extract information from the Import Alias node.
        # Both the (real) import name and the alias are extracted.
        # Result is returned as a dictionary with a name:node KV pair,
        # and an alias:name KV pair if an alias is present.
        import_info = match.extract(
            node,
            match.ImportAlias(
                asname=match.AsName(  # Attempt to match alias
                    name=match.Name(  # Check for alias name
                        value=match.SaveMatchedNode(  # Save the alias name
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "alias"))) | ~match.AsName(),
                # If there is no AsName, we should still match. We negate AsName to and OR to form a tautology.
                name=match.SaveMatchedNode(  # Match & save name of import
                    match.DoNotCare(), "name")))

        # Add import if a name could be extracted.
        if "name" in import_info:
            # Append import name to imports list.
            # We convert the ImportAlias node to the code representation for simplified conversion
            # of multi-level imports (e.g. import.x.y.z)
            # TODO: This might be un-needed after implementing import type extraction change.
            # TODO: So far, no differentiation between import and from imports.
            import_name = self.__convert_node_to_code(import_info["name"])
            import_name = self.__clean_string_whitespace(import_name)
            self.imports.append(import_name)

        if "alias" in import_info:
            import_name = self.__clean_string_whitespace(import_info["alias"])
            self.imports.append(import_name)

        # No need to traverse children, as we already performed the matching.
        return False
Beispiel #6
0
 def test_extract_multiple(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.Element(
                 m.BinaryOperation(
                     left=m.SaveMatchedNode(m.Name(), "left"))),
             m.Element(m.Call(func=m.SaveMatchedNode(m.Name(), "func"))),
         ]),
     )
     extracted_node_left = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[0].value,
         cst.BinaryOperation,
     ).left
     extracted_node_func = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).func
     self.assertEqual(nodes, {
         "left": extracted_node_left,
         "func": extracted_node_func
     })
Beispiel #7
0
    def __extract_variable_name(self, node: cst.AssignTarget):
        extracted_var_names = match.extract(
            node,
            match.AssignTarget(  # Assignment operator
                target=match.OneOf(  # Two cases exist
                    match.Name(  # Single target
                        value=match.SaveMatchedNode(  # Save result
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "name")),
                    match.Tuple(  # Multi-target
                        elements=match.SaveMatchedNode(  # Save result
                            match.DoNotCare(),  # Type of list
                            "names")),
                    # This extracts variables inside __init__ without type annotation (e.g. self.x=2)
                    match.Attribute(
                        value=match.Name(value=match.SaveMatchedNode(
                            match.MatchRegex(r'(.)+'),
                            "obj_name"  # Object name
                        )),
                        attr=match.Name(
                            match.SaveMatchedNode(match.MatchRegex(r'(.)+'),
                                                  "name")),
                    ))))

        if extracted_var_names is not None:
            if "name" in extracted_var_names:
                t = self.__get_type_from_metadata(node.target)
                extracted_var_names['type'] = (t, INF_TYPE_ANNOT
                                               if t else UNK_TYPE_ANNOT)
                return extracted_var_names
            elif "names" in extracted_var_names:
                return {
                    'names':
                    self.__extract_names_multi_assign(
                        list(extracted_var_names['names']))
                }
        else:
            return extracted_var_names
Beispiel #8
0
 def _is_awaitable_callable(annotation: str) -> bool:
     if not (annotation.startswith("typing.Callable")
             or annotation.startswith("typing.ClassMethod")
             or annotation.startswith("StaticMethod")):
         # Exit early if this is not even a `typing.Callable` annotation.
         return False
     try:
         # Wrap this in a try-except since the type annotation may not be parse-able as a module.
         # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`.
         parsed_ann = cst.parse_module(annotation)
     except Exception:
         return False
     # If passed annotation does not match the expected annotation structure for a `typing.Callable` with
     # typing.Coroutine as the return type, matched_callable_ann will simply be `None`.
     # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]]
     matched_callable_ann: Optional[Dict[str, Union[
         Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract(
             parsed_ann,
             m.Module(body=[
                 m.SimpleStatementLine(body=[
                     m.Expr(value=m.Subscript(slice=[
                         m.SubscriptElement(),
                         m.SubscriptElement(slice=m.Index(value=m.Subscript(
                             value=m.SaveMatchedNode(
                                 m.Attribute(),
                                 "base_return_type",
                             )))),
                     ], ))
                 ]),
             ]),
         )
     if (matched_callable_ann is not None
             and "base_return_type" in matched_callable_ann):
         base_return_type = get_full_name_for_node(
             cst.ensure_type(matched_callable_ann["base_return_type"],
                             cst.CSTNode))
         return (base_return_type is not None
                 and base_return_type == "typing.Coroutine")
     return False
Beispiel #9
0
    def __extract_assign_newtype(self, node: cst.Assign):
        """
        Attempts extracting a NewType declaration from the provided Assign node.

        If the Assign node corresponds to a NewType assignment, the NewType name is
        added to the class definitions of the Visitor.
        """
        # Define matcher to extract NewType assignment
        matcher_newtype = match.Assign(
            targets=[  # Check the assign targets
                match.AssignTarget(  # There should only be one target
                    target=match.Name(  # Check target name
                        value=match.SaveMatchedNode(  # Save target name
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "type")))
            ],
            value=match.Call(  # We are examining a function call
                func=match.Name(  # Function must have a name
                    value="NewType"  # Name must be 'NewType'
                ),
                args=[
                    match.Arg(  # Check first argument
                        value=match.SimpleString(
                        )  # First argument must be the name for the type
                    ),
                    match.ZeroOrMore(
                    )  # We allow any number of arguments after by def. of NewType
                ]))

        extracted_type = match.extract(node, matcher_newtype)

        if extracted_type is not None:
            # Append the additional type to the list
            # TODO: Either rename class defs, or create new list for additional types
            self.class_defs.append(extracted_type["type"].strip("\'"))
Beispiel #10
0
    def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
        expr_key = "expr"
        extracts = m.extract(
            node,
            m.BinaryOperation(
                left=m.MatchIfTrue(_match_simple_string),
                operator=m.Modulo(),
                right=m.SaveMatchedNode(
                    m.MatchIfTrue(
                        _gen_match_simple_expression(
                            self.context.wrapper.module)),
                    expr_key,
                ),
            ),
        )

        if extracts:
            expr = extracts[expr_key]
            parts = []
            simple_string = cst.ensure_type(node.left, cst.SimpleString)
            innards = simple_string.raw_value.replace("{",
                                                      "{{").replace("}", "}}")
            tokens = innards.split("%s")
            token = tokens[0]
            if len(token) > 0:
                parts.append(cst.FormattedStringText(value=token))
            expressions = ([elm.value for elm in expr.elements] if isinstance(
                expr, cst.Tuple) else [expr])
            escape_transformer = EscapeStringQuote(simple_string.quote)
            i = 1
            while i < len(tokens):
                if i - 1 >= len(expressions):
                    # Only generate warning for cases where %-string not comes with same number of elements in tuple
                    self.report(node)
                    return
                try:
                    parts.append(
                        cst.FormattedStringExpression(expression=cast(
                            cst.BaseExpression,
                            expressions[i - 1].visit(escape_transformer),
                        )))
                except Exception:
                    self.report(node)
                    return
                token = tokens[i]
                if len(token) > 0:
                    parts.append(cst.FormattedStringText(value=token))
                i += 1
            start = f"f{simple_string.prefix}{simple_string.quote}"
            replacement = cst.FormattedString(parts=parts,
                                              start=start,
                                              end=simple_string.quote)
            self.report(node, replacement=replacement)
        elif m.matches(
                node,
                m.BinaryOperation(
                    left=m.SimpleString(),
                    operator=m.Modulo())) and isinstance(
                        cst.ensure_type(
                            node.left, cst.SimpleString).evaluated_value, str):
            self.report(node)
    def visit_Call(self, node: cst.Call) -> None:
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(value=m.Name("self"),
                                 attr=m.Name("assertTrue")),
                args=[
                    m.DoNotCare(),
                    m.Arg(value=m.SaveMatchedNode(
                        m.OneOf(
                            m.Integer(),
                            m.Float(),
                            m.Imaginary(),
                            m.Tuple(),
                            m.List(),
                            m.Set(),
                            m.Dict(),
                            m.Name("None"),
                            m.Name("True"),
                            m.Name("False"),
                        ),
                        "second",
                    )),
                ],
            ),
        )

        if result:
            second_arg = result["second"]
            if isinstance(second_arg, Sequence):
                second_arg = second_arg[0]

            if m.matches(second_arg, m.Name("True")):
                new_call = node.with_changes(args=[
                    node.args[0].with_changes(comma=cst.MaybeSentinel.DEFAULT)
                ], )
            elif m.matches(second_arg, m.Name("None")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertIsNone",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            elif m.matches(second_arg, m.Name("False")):
                new_call = node.with_changes(
                    func=node.func.with_deep_changes(
                        old_node=cst.ensure_type(node.func,
                                                 cst.Attribute).attr,
                        value="assertFalse",
                    ),
                    args=[
                        node.args[0].with_changes(
                            comma=cst.MaybeSentinel.DEFAULT)
                    ],
                )
            else:
                new_call = node.with_deep_changes(
                    old_node=cst.ensure_type(node.func, cst.Attribute).attr,
                    value="assertEqual",
                )

            self.report(node, replacement=new_call)
    def leave_BinaryOperation(
        self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
    ) -> cst.BaseExpression:
        expr_key = "expr"
        extracts = m.extract(
            original_node,
            m.BinaryOperation(
                # pyre-fixme[6]: Expected `Union[m._matcher_base.AllOf[typing.Union[m...
                left=m.MatchIfTrue(_match_simple_string),
                operator=m.Modulo(),
                # pyre-fixme[6]: Expected `Union[m._matcher_base.AllOf[typing.Union[m...
                right=m.SaveMatchedNode(
                    m.MatchIfTrue(_gen_match_simple_expression(self.module)),
                    expr_key,
                ),
            ),
        )

        if extracts:
            exprs = extracts[expr_key]
            exprs = (exprs,) if not isinstance(exprs, Sequence) else exprs
            parts = []
            simple_string = cst.ensure_type(original_node.left, cst.SimpleString)
            innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}")
            tokens = innards.split("%s")
            token = tokens[0]
            if len(token) > 0:
                parts.append(cst.FormattedStringText(value=token))
            expressions: List[cst.CSTNode] = list(
                *itertools.chain(
                    [elm.value for elm in expr.elements]
                    if isinstance(expr, cst.Tuple)
                    else [expr]
                    for expr in exprs
                )
            )
            escape_transformer = EscapeStringQuote(simple_string.quote)
            i = 1
            while i < len(tokens):
                if i - 1 >= len(expressions):
                    # the %-string doesn't come with same number of elements in tuple
                    return original_node
                try:
                    parts.append(
                        cst.FormattedStringExpression(
                            expression=cast(
                                cst.BaseExpression,
                                expressions[i - 1].visit(escape_transformer),
                            )
                        )
                    )
                except Exception:
                    return original_node
                token = tokens[i]
                if len(token) > 0:
                    parts.append(cst.FormattedStringText(value=token))
                i += 1
            start = f"f{simple_string.prefix}{simple_string.quote}"
            return cst.FormattedString(
                parts=parts, start=start, end=simple_string.quote
            )

        return original_node
    def visit_Call(self, node: cst.Call) -> None:
        match_compare_is_none = m.ComparisonTarget(
            m.SaveMatchedNode(
                m.OneOf(m.Is(), m.IsNot()),
                "comparison_type",
            ),
            comparator=m.Name("None"),
        )
        result = m.extract(
            node,
            m.Call(
                func=m.Attribute(
                    value=m.Name("self"),
                    attr=m.SaveMatchedNode(
                        m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")),
                        "assertion_name",
                    ),
                ),
                args=[
                    m.Arg(
                        m.SaveMatchedNode(
                            m.OneOf(
                                m.Comparison(
                                    comparisons=[match_compare_is_none]),
                                m.UnaryOperation(
                                    operator=m.Not(),
                                    expression=m.Comparison(
                                        comparisons=[match_compare_is_none]),
                                ),
                            ),
                            "argument",
                        ))
                ],
            ),
        )

        if result:
            assertion_name = result["assertion_name"]
            if isinstance(assertion_name, Sequence):
                assertion_name = assertion_name[0]

            argument = result["argument"]
            if isinstance(argument, Sequence):
                argument = argument[0]

            comparison_type = result["comparison_type"]
            if isinstance(comparison_type, Sequence):
                comparison_type = comparison_type[0]

            if m.matches(argument, m.Comparison()):
                assertion_argument = ensure_type(argument, cst.Comparison).left
            else:
                assertion_argument = ensure_type(
                    ensure_type(argument, cst.UnaryOperation).expression,
                    cst.Comparison).left

            negations_seen = 0
            if m.matches(assertion_name, m.Name("assertFalse")):
                negations_seen += 1
            if m.matches(argument, m.UnaryOperation()):
                negations_seen += 1
            if m.matches(comparison_type, m.IsNot()):
                negations_seen += 1

            new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone"
            new_call = node.with_changes(
                func=cst.Attribute(value=cst.Name("self"),
                                   attr=cst.Name(new_attr)),
                args=[cst.Arg(assertion_argument)],
            )

            if new_call is not node:
                self.report(node, replacement=new_call)