def __init__(self, parsed, par_names): self.def_use = beniget.DefUseChains() self.def_use.visit(parsed) self.use_def = beniget.UseDefChains(self.def_use) self.params = {k: v for (k, v) in self.use_def.chains.items() if hasattr(k, 'id') and k.id in par_names and v} self.graph = pydot.Dot(graph_type='digraph') self.node_to_dot_node = {} self.src_to_target = set() self.already_visited = set() for param_node in self.params.keys(): self.visit(param_node)
def check_unbound_identifier_message(self, code, expected_messages, filename=None): node = ast.parse(code) c = beniget.DefUseChains(filename) with captured_output() as (out, err): c.visit(node) produced_messages = out.getvalue().strip().split("\n") self.assertEqual(len(expected_messages), len(produced_messages)) for expected, produced in zip(expected_messages, produced_messages): self.assertIn(expected, produced, "actual message contains expected message")
def visit_Module(self, node): duc = beniget.DefUseChains() duc.visit(node) self.result = duc
def cells_to_stmts( cells: Sequence[str] ) -> Tuple[Sequence[ast.stmt], Sequence[ast.FunctionDef], ast.Module]: # cell_reads[i] holds every variable cells[i] uses. cell_reads: List[Set[str]] = [set() for _ in cells] # cell_writes[i] holds every variable cells[i] defs. cell_writes: List[Set[str]] = [set() for _ in cells] cell_linecounts = [cell.count("\n") for cell in cells] def line_to_cell(lineno: int) -> int: accumulator = 1 for cell_no, cell_linecount in enumerate(cell_linecounts): accumulator += cell_linecount if lineno < accumulator: return cell_no else: return len(cells) - 1 cells_gast = gast.parse("".join(cells)) duc = beniget.DefUseChains() duc.visit(cells_gast) for var in duc.locals[cells_gast]: if var.node.lineno is not None: def_cell = line_to_cell(var.node.lineno) for user in var.users(): use_cell = line_to_cell(user.node.lineno) if use_cell != def_cell: cell_writes[def_cell].add(var.name()) cell_reads[use_cell].add(var.name()) imports: List[Union[ast.Import, ast.ImportFrom]] = [] function_defs: List[ast.FunctionDef] = [] main_function = cast(ast.FunctionDef, ast.parse("def main(): ...").body[0]) main_function.body = [] for i, (cell, reads, writes) in enumerate(zip(cells, cell_reads, cell_writes)): name = f"cell{i}" reads_str = ",".join(reads) func_def = cast(ast.FunctionDef, ast.parse(f"def {name}({reads_str}): ...").body[0]) inner_stmts: List[ast.stmt] = [] for stmt, is_last in last_sentinel(ast.parse(cell).body): if isinstance(stmt, (ast.Import, ast.ImportFrom)): imports.append(stmt) elif isinstance(stmt, ast.FunctionDef): function_defs.append(stmt) elif is_last and isinstance(stmt, ast.Expr): inner_stmts.append( cast(ast.stmt, ast.parse(f"cell{i}_output = {ast.unparse(stmt)}"))) inner_stmts.append( cast(ast.stmt, ast.parse(f"print(cell{i}_output)"))) else: inner_stmts.append(stmt) if inner_stmts: func_def.body = inner_stmts function_defs.append(func_def) if writes: writes_str = ",".join(writes) return_stmt = ast.parse( f"return {writes_str}").body[0] if writes else None if return_stmt: func_def.body.append(return_stmt) main_function.body.append( ast.parse(f"{writes_str} = {name}({reads_str})").body[0]) else: main_function.body.append( ast.parse(f"{name}({reads_str})").body[0]) main_call = ast.parse('if __name__ == "__main__":\n main()\n') return imports, (*function_defs, main_function), main_call
import gast as ast import beniget mod = ast.parse(""" T = int def func() -> T: return 1 """) fdef = mod.body[1] node = fdef.returns du = beniget.DefUseChains() du.visit(mod) du.chains[node] ud = beniget.UseDefChains(du) ud.chains[node]
def __init__(self, module_node): self.chains = beniget.DefUseChains() self.chains.visit(module_node) self.attributes = set() self.users = set()