def visit_Module(self, node: cst.Module) -> None: if self.rule_disabled: return if not m.matches( node, m.Module(header=[*self.header_matcher, m.ZeroOrMore()])): self.report( node, replacement=node.with_changes( header=[*self.header_replacement, *node.header]), )
def _split_module( self, orig_module: libcst.Module, updated_module: libcst.Module ) -> Tuple[List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], List[Union[ libcst.SimpleStatementLine, libcst.BaseCompoundStatement]], ]: statement_before_import_location = 0 import_add_location = 0 # never insert an import before initial __strict__ flag if m.matches( orig_module, m.Module(body=[ m.SimpleStatementLine(body=[ m.Assign(targets=[ m.AssignTarget(target=m.Name("__strict__")) ]) ]), m.ZeroOrMore(), ]), ): statement_before_import_location = import_add_location = 1 # This works under the principle that while we might modify node contents, # we have yet to modify the number of statements. So we can match on the # original tree but break up the statements of the modified tree. If we # change this assumption in this visitor, we will have to change this code. for i, statement in enumerate(orig_module.body): if m.matches( statement, m.SimpleStatementLine( body=[m.Expr(value=m.SimpleString())])): statement_before_import_location = import_add_location = 1 elif isinstance(statement, libcst.SimpleStatementLine): for possible_import in statement.body: for last_import in self.all_imports: if possible_import is last_import: import_add_location = i + 1 break return ( list(updated_module.body[:statement_before_import_location]), list(updated_module. body[statement_before_import_location:import_add_location]), list(updated_module.body[import_add_location:]), )
def _has_testnode(node: cst.Module) -> bool: return m.matches( node, m.Module(body=[ # Sequence wildcard matchers matches LibCAST nodes in a row in a # sequence. It does not implicitly match on partial sequences. So, # when matching against a sequence we will need to provide a # complete pattern. This often means using helpers such as # ``ZeroOrMore()`` as the first and last element of the sequence. m.ZeroOrMore(), m.AtLeastN( n=1, matcher=m.OneOf( m.FunctionDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("test_")))), m.ClassDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("Test")))), ), ), m.ZeroOrMore(), ]), )
def _is_awaitable_callable(annotation: str) -> bool: if not (annotation.startswith("typing.Callable") or annotation.startswith("typing.ClassMethod") or annotation.startswith("StaticMethod")): # Exit early if this is not even a `typing.Callable` annotation. return False try: # Wrap this in a try-except since the type annotation may not be parse-able as a module. # If it is not parse-able, we know it's not what we are looking for anyway, so return `False`. parsed_ann = cst.parse_module(annotation) except Exception: return False # If passed annotation does not match the expected annotation structure for a `typing.Callable` with # typing.Coroutine as the return type, matched_callable_ann will simply be `None`. # The expected structure of an awaitable callable annotation from Pyre is: typing.Callable()[[...], typing.Coroutine[...]] matched_callable_ann: Optional[Dict[str, Union[ Sequence[cst.CSTNode], cst.CSTNode]]] = m.extract( parsed_ann, m.Module(body=[ m.SimpleStatementLine(body=[ m.Expr(value=m.Subscript(slice=[ m.SubscriptElement(), m.SubscriptElement(slice=m.Index(value=m.Subscript( value=m.SaveMatchedNode( m.Attribute(), "base_return_type", )))), ], )) ]), ]), ) if (matched_callable_ann is not None and "base_return_type" in matched_callable_ann): base_return_type = get_full_name_for_node( cst.ensure_type(matched_callable_ann["base_return_type"], cst.CSTNode)) return (base_return_type is not None and base_return_type == "typing.Coroutine") return False