def visit_ImportFrom(self, imp: ast.ImportFrom): if imp.lineno not in self.act or imp.names[0].name == '*': return imp if imp.module is None: from_mod = ast.Name(".from_mod") @quote def mk_from_mod(from_mod_n, level): from_mod_n = '.'.join(__name__.split('.')[:-level]) mk_from_mod = mk_from_mod(from_mod, ast.Constant(imp.level)) else: mk_from_mod = [] from_mod = ast.Constant(imp.module) names_var = ast.Constant(tuple((name.name, name.asname) for name in imp.names)) name = ast.Name(".name") asname = ast.Name(".asname") @quote def f(lazy_module, from_mod, name, asname, names_var): for name, asname, in names_var: lazy_module(globals(), name, asname, from_mod) stmts = f( ast.Name(runtime_lazy_mod, ast.Load()), from_mod, name, asname, names_var ) mk_from_mod.extend(stmts) return stmts
def visit_Import(self, imp: ast.Import): if imp.lineno not in self.act: return imp @quote def f(lazy_module, name, asname, names_var): for name, asname, in names_var: lazy_module(globals(), name, asname) names_var = ast.Constant(tuple((name.name, name.asname) for name in imp.names)) name = ast.Name(".name") asname = ast.Name(".asname") return f(ast.Name(runtime_lazy_mod, ast.Load()), name, asname, names_var)
def visit_BinOp(self, n: ast.BinOp): if n.lineno in self.activation: name = n.op.__class__.__name__ pair = self.pair if name == pair[0]: fn = ast.Name(pair[1], ast.Load()) return ast.Call( fn, [self.visit(n.left), self.visit(n.right)], [], lineno=n.lineno, col_offset=n.col_offset ) return self.generic_visit(n)
def visit_With(self, node: ast.With): if node.lineno not in self.activation: return self.generic_visit(node) if not len(node.items): return self.generic_visit(node) item = node.items[0].context_expr if not isinstance(item, ast.Call): return self.generic_visit(node) fn = item.func if not isinstance(fn, ast.Name) or fn.id != self.token: return self.generic_visit(node) assert not item.keywords assert all(isinstance(stmt, ast.If) for stmt in node.body) if len(item.args) is not 1: val_to_match = ast.Tuple(item.args, ast.Load()) else: val_to_match = item.args[0] cached = Symbol(self.next_id, node.lineno, node.col_offset).to_name() ifs = node.body # type: t.List[ast.If] for if_ in ifs: assert not if_.orelse case_comp = CaseCompilation() spb = SyntacticPatternBinding(case_comp) pairs = [] for if_ in ifs: case = spb.visit(if_.test) stmts = Stmts([self.visit(each) for each in if_.body]) pairs.append((case, stmts))
def visit_FunctionDef(self, fn: ast.FunctionDef): if fn.lineno not in self.activation: return self.generic_visit(fn) if len(fn.decorator_list) is not 1: return self.generic_visit(fn) deco = fn.decorator_list[0] if not (isinstance(deco, ast.Name) and deco.id == self.token): return self.generic_visit(fn) assert fn.body fn.decorator_list.pop() arg_collector = ArgumentCollector() arg_collector.visit(fn) args = arg_collector.args splicing = Splicing(args) body_head = fn.body[-1] lineno, col_offset = body_head.lineno, body_head.col_offset new_body = [] for stmt in fn.body: stmt = splicing.visit(stmt) new_body.append(stmt) mod = ast.parse(repr( ast_to_literal_without_locations(new_body))) # type: ast.Module ast.fix_missing_locations(mod) expr = mod.body[0] # type: ast.Expr value = expr.value fn.body = [ ast.Return(ast.Call(ast.Name(runtime_ast_build, ast.Load()), [value], []), lineno=lineno, col_offset=col_offset) ] return fn
def visit_Call(self, n: ast.Call): if n.keywords: raise NotImplementedError(n) if isinstance(n.func, ast.Name): if n.func.id == 'pin' and len(n.args) == 1: return self.case_comp.pin(Expr(n.args[0])) if n.func.id == 'isinstance': if len(n.args) == 1: expr = Expr(n.args[0]) else: expr = Expr(ast.Tuple(n.args, ast.Load())) return self.case_comp.instance_of(expr) if n.func.id == 'when': if len(n.args) == 1: expr = Expr(n.args[0]) else: expr = Expr(ast.BoolOp(op=ast.And(), values=n.args)) return self.case_comp.guard(expr) return self.case_comp.recog2(Expr(n.func), [self.visit(elt) for elt in n.args])
def type_as(self, ty): if isinstance(ty, type): ty = ty.__name__ ty = ast.Name(ty, ast.Load()) def then(pattern): # noinspection PyStatementEffect PyUnusedLocal @quote def quote_tychk(ret, tag, ty, stmts): if isinstance(tag, ty): stmts else: ret = None @dyn_check def pat(target: Expr, remain: Stmts): remain = pattern.apply(target, remain) stmts = quote_tychk(self.ret, target.value, ty, remain.suite) return Stmts(stmts) return Pattern(pat) return then
def __init__(self, ret_sym: str = '.RET'): self.ret = ast.Name(ret_sym, ast.Load())
def to_name(self): return ast.Name( self.name, ast.Load(), lineno=self.lineno, col_offset=self.col_offset)
def gen(self) -> ast.Name: i = self.names[self.base] self.names[self.base] += 1 return ast.Name('{}.{}'.format(self.base, i), ast.Load())
# moshmosh? # +template-python import typing as t from moshmosh.ast_compat import ast, get_constant from moshmosh.extensions.pattern_matching.runtime import NotExhaustive from toolz import compose T = t.TypeVar('T') G = t.TypeVar('G') H = t.TypeVar('H') not_exhaustive_err_type = ast.Name(NotExhaustive.__name__, ast.Load()) def quote(_): raise NotImplemented class Names(dict): def __missing__(self, key): v = self[key] = len(self) return v class Gensym: names = Names() def __init__(self, base_name): self.base = base_name def gen(self) -> ast.Name: i = self.names[self.base]
def visit_Name(self, n: ast.Name): if not hasattr(n, 'ctx') or n.ctx is None: n.ctx = ast.Load()