def _get_async_expr_replacement( self, node: cst.CSTNode) -> Optional[cst.CSTNode]: if m.matches(node, m.Call()): node = cast(cst.Call, node) return self._get_async_call_replacement(node) elif m.matches(node, m.Attribute()): node = cast(cst.Attribute, node) return self._get_async_attr_replacement(node) elif m.matches(node, m.UnaryOperation(operator=m.Not())): node = cast(cst.UnaryOperation, node) replacement_expression = self._get_async_expr_replacement( node.expression) if replacement_expression is not None: return node.with_changes(expression=replacement_expression) elif m.matches(node, m.BooleanOperation()): node = cast(cst.BooleanOperation, node) maybe_left = self._get_async_expr_replacement(node.left) maybe_right = self._get_async_expr_replacement(node.right) if maybe_left is not None or maybe_right is not None: left_replacement = maybe_left if maybe_left is not None else node.left right_replacement = (maybe_right if maybe_right is not None else node.right) return node.with_changes(left=left_replacement, right=right_replacement) return None
def __assert_codegen( self, node: cst.CSTNode, expected: str, expected_position: Optional[CodeRange] = None, ) -> None: """ Verifies that the given node's `_codegen` method is correct. """ module = cst.Module([]) self.assertEqual(module.code_for_node(node), expected) if expected_position is not None: # This is using some internal APIs, because we only want to compute # position for the node being tested, not a whole module. # # Normally, this is a nonsense operation (how can a node have a position if # its not in a module?), which is why it's not supported, but it makes # sense in the context of these node tests. provider = PositionProvider() state = PositionProvidingCodegenState( default_indent=module.default_indent, default_newline=module.default_newline, provider=provider, ) node._codegen(state) self.assertEqual(provider._computed[node], expected_position)
def _replace_or_remove( parent: cst.CSTNode, original_node: cst.CSTNode, replacement_node: Union[cst.CSTNode, cst.RemovalSentinel], ) -> cst.CSTNode: if isinstance(replacement_node, cst.RemovalSentinel): return cst.ensure_type(parent.deep_remove(original_node), cst.CSTNode) else: return parent.deep_replace(original_node, replacement_node)
def basic_parenthesize( node: libcst.CSTNode, whitespace: Optional[libcst.BaseParenthesizableWhitespace] = None, ) -> libcst.CSTNode: if not hasattr(node, "lpar"): return node if whitespace: return node.with_changes( lpar=[libcst.LeftParen(whitespace_after=whitespace)], rpar=[libcst.RightParen()], ) return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()])
def remove_unused_import_by_node(context: CodemodContext, node: cst.CSTNode) -> None: """ Schedule any imports referenced by ``node`` or one of its children to be removed in a future invocation of this class by updating the ``context`` to include the ``module``, ``obj`` and ``alias`` for each import in question. When subclassing from :class:`~libcst.codemod.CodemodCommand`, this will be performed for you after your transform finishes executing. If you are subclassing from a :class:`~libcst.codemod.Codemod` instead, you will need to call the :meth:`~libcst.codemod.Codemod.transform_module` method on the module under modification with an instance of this class after performing your transform. Note that all imports that are referenced by this ``node`` or its children will only be removed if they are not in use at the time of exeucting :meth:`~libcst.codemod.Codemod.transform_module` on an instance of :class:`~libcst.codemod.visitors.AddImportsVisitor` in order to avoid removing an in-use import. """ # Special case both Import and ImportFrom so they can be # directly removed here. if isinstance(node, cst.Import): for import_alias in node.names: RemoveImportsVisitor.remove_unused_import( context, import_alias.evaluated_name, asname=import_alias.evaluated_alias, ) elif isinstance(node, cst.ImportFrom): names = node.names if isinstance(names, cst.ImportStar): # We don't handle removing this, so ignore it. return module_name = get_absolute_module_for_import( context.full_module_name, node) if module_name is None: raise Exception( "Cannot look up absolute module from relative import!") for import_alias in names: RemoveImportsVisitor.remove_unused_import( context, module_name, obj=import_alias.evaluated_name, asname=import_alias.evaluated_alias, ) else: # Look up all children that could have been imported. Any that # we find will be scheduled for removal. node.visit(RemovedNodeVisitor(context))
def rewrite(self, tree: cst.CSTNode, env: SymbolTable, metadata: tp.MutableMapping) -> PASS_ARGS_T: if self.replace_and: visitor = AndTransformer() tree = tree.visit(visitor) if self.replace_or: visitor = OrTransformer() tree = tree.visit(visitor) if self.replace_not: visitor = NotTransformer() tree = tree.visit(visitor) return tree, env, metadata
def assert_node_equal(self, first: libcst.CSTNode, second: libcst.CSTNode) -> None: """Check that two libcst nodes are equal.""" if not first.deep_equals(second): self.fail( f"Nodes are not equal\nGot: {utils.to_string(first)}\nExpected: {utils.to_string(second)}" )
def _maybe_autofix_node(self, node: cst.CSTNode, attribute_name: str) -> None: replacement_value = self._get_async_expr_replacement( getattr(node, attribute_name)) if replacement_value is not None: replacement = node.with_changes( **{attribute_name: replacement_value}) self.report(node, replacement=replacement)
def _cst_node_equality_func( a: cst.CSTNode, b: cst.CSTNode, msg: Optional[str] = None ) -> None: """ For use with addTypeEqualityFunc. """ if not a.deep_equals(b): suffix = "" if msg is None else f"\n{msg}" raise AssertionError(f"\n{a!r}\nis not deeply equal to \n{b!r}{suffix}")
def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None: """ When visit is called with a visitor that acts as a no-op, the visit method should return the same node it started with. """ # TODO: We're only checking equality right now, because visit currently clones # the node, since that was easier to implement. We should fix that behavior in a # later version and tighten this check. self.assertEqual(node, node.visit(_NOOPVisitor()))
def _replace_names( node: libcst.CSTNode, wrapper: libcst.metadata.MetadataWrapper, replacements: Dict[str, libcst.CSTNode], ) -> libcst.CSTNode: replacer = _ReplaceTransformer(replacements) with replacer.resolve(wrapper): # The result of node.visit can never be a RemovalSentinel. return cast(libcst.CSTNode, node.visit(replacer))
def _leave(self, original: cst.CSTNode, updated: cst.CSTNode) -> cst.CSTNode: # TODO: if global scope query create a module tree from scratch? # NOTE: in the future can cache lib cst node comparisons for speed match = notFound if first_ref_index is not None: match = tryFind(lambda m: original.deep_equals( m.path[first_ref_index].node), matches) if match is not notFound: from_assert = first(astNodeFromAssertion(transform, match)) return from_assert elif original in transform.references: # TODO: replace references to anything destroyed by the transform pass elif find_attempt is notFound and isinstance(updated, cst.Module): module_match = Match( [CaptureExpr().contextualize(node=updated)]) return updated.with_changes(body=(*updated.body, *astNodeFromAssertion(transform, module_match))) else: return updated
def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int(cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), )
def rewrite(self, tree: cst.CSTNode, env: SymbolTable, metadata: tp.MutableMapping) -> PASS_ARGS_T: if not isinstance(self.phi, str): phi_name = gen_free_name(tree, env, self.phi_name_prefix) env.locals[phi_name] = self.phi else: phi_name = self.phi visitor = IfExpTransformer(phi_name) tree = tree.visit(visitor) return tree, env, metadata
def assert_(py_ast: cst.CSTNode, transform: Transform) -> cst.CSTNode: """ TODO: in programming `assert` has a context of being passive, not fixing if it finds that it's incorrect, perhaps a more active word should be chosen. Maybe *ensure*? """ matches = transform.matches first_ref_index = None find_attempt = tryFirst(transform.capture_reference_indices) if find_attempt is not notFound: _, (first_ref_index, _) = first(transform.capture_reference_indices) @ unified_visit class Transformer(cst.CSTTransformer): def _leave(self, original: cst.CSTNode, updated: cst.CSTNode) -> cst.CSTNode: # TODO: if global scope query create a module tree from scratch? # NOTE: in the future can cache lib cst node comparisons for speed match = notFound if first_ref_index is not None: match = tryFind(lambda m: original.deep_equals( m.path[first_ref_index].node), matches) if match is not notFound: from_assert = first(astNodeFromAssertion(transform, match)) return from_assert elif original in transform.references: # TODO: replace references to anything destroyed by the transform pass elif find_attempt is notFound and isinstance(updated, cst.Module): module_match = Match( [CaptureExpr().contextualize(node=updated)]) return updated.with_changes(body=(*updated.body, *astNodeFromAssertion(transform, module_match))) else: return updated transformed_tree = py_ast.visit(Transformer()) return transformed_tree
def render_node(node: cst.CSTNode, module: Optional[cst.Module] = None) -> str: if module is None: module = cst.Module(body=[]) code = module.code_for_node(node.with_changes(leading_lines=())) return code
def rewrite(self, tree: cst.CSTNode, env: SymbolTable, metadata: tp.MutableMapping) -> PASS_ARGS_T: visitor = AssertRemover() tree = tree.visit(visitor) return tree, env, metadata
def _handle_suppression_comment( self, parent_node: cst.CSTNode, parent_attribute_name: str, index: int, local_supp_comment: SuppressionComment, comment_physical_line: int, ) -> None: ignored_rules = local_supp_comment.ignored_rules # Just a check for the type-checker - it should never actually be AllRulesType # because we don't support that in lint-fixme/lint-ignore comments. # TODO: We can remove the AllRulesType check once we deprecate noqa. if isinstance(ignored_rules, AllRulesType): return # First find the suppressed rules that were included in the lint run. If a rule was not included # in this run, we CANNOT know for sure that this lint suppression is unused. ignored_rules_that_ran = {ig for ig in ignored_rules if ig in self.rule_names} if not ignored_rules_that_ran: return lines_span = len(local_supp_comment.tokens) unneeded_codes = _get_unused_codes_in_comment( local_supp_comment, ignored_rules_that_ran ) if unneeded_codes: new_comment_lines = _compose_new_comment( local_supp_comment, unneeded_codes, comment_physical_line ) if not new_comment_lines: # If we're here, all of the codes in this suppression refer to rules that ran, and # none of comment's codes suppress anything, so we report this comment and offer to remove it. new_parent_attribute_value = _modify_parent_attribute( parent_node, parent_attribute_name, index, index + lines_span, [] ) self.report( parent_node, message=UNUSED_SUPPRESSION_COMMENT_MESSAGE, replacement=parent_node.with_changes( **{parent_attribute_name: new_parent_attribute_value} ), ) else: node_to_replace = getattr(parent_node, parent_attribute_name)[index] replacement_emptyline_nodes = [ cst.EmptyLine( indent=node_to_replace.indent, whitespace=node_to_replace.whitespace, comment=cst.Comment(line), ) for line in new_comment_lines ] new_parent_attribute_value: List[ cst.EmptyLine ] = _modify_parent_attribute( parent_node, parent_attribute_name, index, index + lines_span, replacement_emptyline_nodes, ) self.report( parent_node, message=UNUSED_SUPPRESSION_CODES_IN_COMMENT_MESSAGE.format( lint_codes="`, `".join(uc for uc in unneeded_codes) ), replacement=parent_node.with_changes( **{parent_attribute_name: new_parent_attribute_value} ), )
def unmangle_nodes( tree: cst.CSTNode, template_replacements: Mapping[str, ValidReplacementType], ) -> cst.CSTNode: unmangler = TemplateTransformer(template_replacements) return ensure_type(tree.visit(unmangler), cst.CSTNode)
def unroll_for_loops(tree: cst.CSTNode, env: tp.Mapping[str, tp.Any]) -> cst.CSTNode: return tree.visit(Unroller(env))
def _normalize(node: cst.CSTNode): node = node.visit(StripParens()) node = node.visit(WhiteSpaceNormalizer()) node.validate_types_deep() return node
def test_deep_equals_fails(self, a: cst.CSTNode, b: cst.CSTNode) -> None: self.assertFalse(a.deep_equals(b))
def test_deep_equals_success(self, a: cst.CSTNode, b: cst.CSTNode) -> None: self.assertTrue(a.deep_equals(b))
def renamed(self, old_node: cst.CSTNode): return old_node.with_changes(name=self.get_new_cst_name(old_node.name))
def obf_universal(self, node: cst.CSTNode, *types): if m.matches(node, m.Name()): types = ('a', 'ca', 'v', 'cv') if not types else types node = cst.ensure_type(node, cst.Name) if self.can_rename(node.value, *types): node = self.get_new_cst_name(node) elif m.matches(node, m.NameItem()): node = cst.ensure_type(node, cst.NameItem) node = node.with_changes(name=self.obf_universal(node.name)) elif m.matches(node, m.Call()): node = cst.ensure_type(node, cst.Call) if self.change_methods or self.change_functions: node = self.new_obf_function_name(node) if self.change_arguments or self.change_method_arguments: node = self.obf_function_args(node) elif m.matches(node, m.Attribute()): node = cst.ensure_type(node, cst.Attribute) value = node.value attr = node.attr self.obf_universal(value) self.obf_universal(attr) elif m.matches(node, m.AssignTarget()): node = cst.ensure_type(node, cst.AssignTarget) node = node.with_changes(target=self.obf_universal(node.target)) elif m.matches(node, m.List() | m.Tuple()): node = cst.ensure_type(node, cst.List) if m.matches( node, m.List()) else cst.ensure_type(node, cst.Tuple) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.Subscript()): node = cst.ensure_type(node, cst.Subscript) new_slice = [] for el in node.slice: new_slice.append( el.with_changes(slice=self.obf_slice(el.slice))) node = node.with_changes(slice=new_slice) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Element()): node = cst.ensure_type(node, cst.Element) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Dict()): node = cst.ensure_type(node, cst.Dict) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.DictElement()): node = cst.ensure_type(node, cst.DictElement) new_key = self.obf_universal(node.key) new_val = self.obf_universal(node.value) node = node.with_changes(key=new_key, value=new_val) elif m.matches(node, m.StarredDictElement()): node = cst.ensure_type(node, cst.StarredDictElement) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.If() | m.While()): node = cst.ensure_type(node, cst.IfExp) if m.matches( node, cst.If | cst.IfExp) else cst.ensure_type(node, cst.While) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.IfExp()): node = cst.ensure_type(node, cst.IfExp) node = node.with_changes(body=self.obf_universal(node.body)) node = node.with_changes(test=self.obf_universal(node.test)) node = node.with_changes(orelse=self.obf_universal(node.orelse)) elif m.matches(node, m.Comparison()): node = cst.ensure_type(node, cst.Comparison) new_compars = [] for target in node.comparisons: new_compars.append(self.obf_universal(target)) node = node.with_changes(left=self.obf_universal(node.left)) node = node.with_changes(comparisons=new_compars) elif m.matches(node, m.ComparisonTarget()): node = cst.ensure_type(node, cst.ComparisonTarget) node = node.with_changes( comparator=self.obf_universal(node.comparator)) elif m.matches(node, m.FormattedString()): node = cst.ensure_type(node, cst.FormattedString) new_parts = [] for part in node.parts: new_parts.append(self.obf_universal(part)) node = node.with_changes(parts=new_parts) elif m.matches(node, m.FormattedStringExpression()): node = cst.ensure_type(node, cst.FormattedStringExpression) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.BinaryOperation() | m.BooleanOperation()): node = cst.ensure_type(node, cst.BinaryOperation) if m.matches( node, m.BinaryOperation()) else cst.ensure_type( node, cst.BooleanOperation) node = node.with_changes(left=self.obf_universal(node.left), right=self.obf_universal(node.right)) elif m.matches(node, m.UnaryOperation()): node = cst.ensure_type(node, cst.UnaryOperation) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.ListComp()): node = cst.ensure_type(node, cst.ListComp) node = node.with_changes(elt=self.obf_universal(node.elt)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.DictComp()): node = cst.ensure_type(node, cst.DictComp) node = node.with_changes(key=self.obf_universal(node.key)) node = node.with_changes(value=self.obf_universal(node.value)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.CompFor()): node = cst.ensure_type(node, cst.CompFor) new_ifs = [] node = node.with_changes(target=self.obf_universal(node.target)) node = node.with_changes(iter=self.obf_universal(node.iter)) for el in node.ifs: new_ifs.append(self.obf_universal(el)) node = node.with_changes(ifs=new_ifs) elif m.matches(node, m.CompIf()): node = cst.ensure_type(node, cst.CompIf) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.Integer() | m.Float() | m.SimpleString()): pass else: pass # print(node) return node