def test_eq()->None: assert let_('a', num(33), lambda x: x) == let_('a', num(33), lambda x: x) assert intrin(MethodName("add"), [('a',num(1)),('b',num(2))]) == \ intrin(MethodName("add"), [('a',num(1)),('b',num(2))]) assert ap(ref('a'),ref('b')) == ap(ref('a'),ref('b')) assert lam('a',lambda x:num(44)) == lam('a',lambda x:num(44)) assert lam('a',lambda x:num(44)) != lam('a',lambda x:num(0)) assert lam('a',lambda x:num(44)) != lam('b',lambda x:num(44))
def test_gather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) expr = gather(Ref('c'), mem) assert extrefs(expr)==set([Ref('i')]) iexpr,_ = interp(expr, lib_arith, mkmap({Ref('i'):ival(32)})) assert iexpr == IVal(1)
def test_refs()->None: e = let_('a', num(33), lambda a: let_('b', num(42), lambda b: intrin(MethodName("add"), [('a',a),('c',ref('c'))]))) assert refs_(e)==set([Ref('a'),Ref('b'),Ref('c')]) assert refs(e)==set([Ref('a'),Ref('c')]) assert decls(e)==set([Ref('a'),Ref('b')])
def bin2expr(j: BIN) -> Expr: typ = j.tag if typ == Tag.val: vtyp = j.value.node.tag if vtyp == Tag.const: return num(j.value.node.value.int64) elif vtyp == Tag.ref: return ref(j.value.node.value.string) else: raise ValueError(f"Invalid value expression {_flat(j)}") elif typ == Tag.lam: return lam(j.value.tuple.v1.string, lambda _: bin2expr(j.value.tuple.v2.node)) elif typ == Tag.let: return let_(j.value.tuple.v1.string, bin2expr(j.value.tuple.v2.tuple.v1.node), lambda _: bin2expr(j.value.tuple.v2.tuple.v2.node)) elif typ == Tag.ap: return ap(bin2expr(j.value.tuple.v1.node), bin2expr(j.value.tuple.v2.node)) elif typ == Tag.intrin: return intrin(MethodName(str(j.value.tuple.v1.string)), [(str(v.tuple.v1.string), bin2expr(v.tuple.v2.node)) for v in j.value.tuple.v2.list.list]) else: raise ValueError(f"Invalid expression {_flat(j)}")
def test_let2()->None: e = let(num(33), lambda a: let(num(42), lambda b: intrin(MethodName("add"), [('a',a),('b',b)]))) v,_ = interp(e, lib_arith, mkmap()) assert isinstance(v, IVal) assert v.val==33+42
def test_seriexpr(): def _test(ie): assert ie == bin2iexpr(iexpr2bin(ie)) assert ie == jstr2iexpr(iexpr2jstr(ie)) _test(IVal(33)) _test(IVal("foo")) _test(IAp(ILam("pat", intrin(MethodName('neg'),[('a',ref('pat'))])), IVal(42))) _test(IError("the message"))
def test_serexpr(): def _test(ie): assert ie == jstr2expr(expr2jstr(ie)) assert ie == bin2expr(expr2bin(ie)) _test(num(33)) _test(ref('a')) _test(lam('b',lambda x:num(44))) _test(let_('a', num(33), lambda x: x)) _test(ap(ref('a'),ref('b'))) _test(intrin(MethodName('add'),[('a',num(1)),('b',ref('x'))]))
def test_gengather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) imem = IMem({ Ref('i'): IVal(32), Ref('a'): IVal(33), Ref('b'): IVal(-32), Ref('c'): IVal(1) }) exprs = gengather(Ref('c'), mem) for e in exprs: inps = IMem({k:v for k,v in imem.items() if k in extrefs(e)}) iexpr,_ = interp(e, lib_arith, inps) print(print_expr(e), iexpr, inps.dict) assert iexpr == imem[Ref('c')]
def json2expr(j: dict) -> Expr: typ = j['t'] if typ == 'val': vtyp = j['val']['t'] if vtyp == 'const': return num(j['val']['val']) elif vtyp == 'ref': return ref(j['val']['val']) else: raise ValueError(f"Invalid value expression {j}") elif typ == 'lam': return lam(j['name'], lambda _: json2expr(j['body'])) elif typ == 'let': return let_(j['ref'], json2expr(j['expr']), lambda _: json2expr(j['body'])) elif typ == 'ap': return ap(json2expr(j['func']), json2expr(j['arg'])) elif typ == 'intrin': return intrin(MethodName(j['name']), [(k, json2expr(v)) for k, v in j['args']]) else: raise ValueError(f"Invalid expression {j}")
def test_freqs()->None: e = let(num(33), lambda a: let(num(42), lambda b: intrin(MethodName("add"), [('a',a),('b',b)]))) assert freqs(e)[MethodName("add")]==1 assert len(freqs(e))==1
def test_print()->None: assert print_expr(intrin(MethodName("add"), [('a',num(0)),('b',ref('1'))])) == "add(a=0,b=1)" assert print_expr(let_('a',num(33),lambda a: num(42))) == "let a = 33 in 42" assert print_expr(ap(lam('a',lambda a: num(42)), num(33))) == "((a -> 42) 33)"
accs=[] with Pool(2) as p: accs=p.map(_test_tmap_pickle, list(range(100)), chunksize=2) assert len(accs)==100 assert all([accs[i][i]==str(i) for i in range(100)]) def test_refs()->None: e = let_('a', num(33), lambda a: let_('b', num(42), lambda b: intrin(MethodName("add"), [('a',a),('c',ref('c'))]))) assert refs_(e)==set([Ref('a'),Ref('b'),Ref('c')]) assert refs(e)==set([Ref('a'),Ref('c')]) assert decls(e)==set([Ref('a'),Ref('b')]) EXPRS = [ intrin(MethodName("add"), [('a',num(0)),('b',ref('1'))]), let_('a',num(33),lambda a: num(42)), ap(lam('a',lambda a: num(42)), num(33)) ] def test_print()->None: assert print_expr(intrin(MethodName("add"), [('a',num(0)),('b',ref('1'))])) == "add(a=0,b=1)" assert print_expr(let_('a',num(33),lambda a: num(42))) == "let a = 33 in 42" assert print_expr(ap(lam('a',lambda a: num(42)), num(33))) == "((a -> 42) 33)" def test_gather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))])
def genexpr( wlib: WLib, inputs: List[IMem] ) -> Iterator[Tuple[Ref, TMap[Ref, Expr], List[IMem], Dict[Ref, int]]]: """ Iterate over space of lambda-expressions with `len(inputs[0])` input arguments. For every expression visited, provide results of it's evaluation on every input of the `intputs` list. Arguments: * `wlib`: Weighted library of primitive operations. See also `mkwlib`. * `inputs`: Collection of the inputs on which to grow the expression. All inputs in list should be of the same size and use same names. Yields a tuple of: * Top-level reference `Ref` * Map of `Ref -> Intrin`. Where `Intrin`s may contain more refs from the same map. * List of output expressions. Size of this list is equal to the size of list of `inputs`. * Current weight of the expression. """ # All inputs should provide the same input names assert all([i.keys() == inputs[0].keys() for i in inputs]) nbatch = len(inputs) lib = {k: wl[0] for k, wl in wlib.items()} libws = {k: wl[1] for k, wl in wlib.items()} exprcache: Dict[Ref, Expr] = {} exprw: Dict[Ref, int] = {k: 1 for k in inputs[0].keys()} valcache: List[Dict[Ref, IExpr]] = [OrderedDict(i.dict) for i in inputs] W = 0 while True: W += 1 for op in lib.values(): w = libws[op.name] nargs = len(op.argnames) vws: List[Tuple[Ref, int]] = list(exprw.items()) for valindices in permute(weights=[a[1] for a in vws], nargs=nargs, target_weight=W - w): argrefs: List[Ref] = [vws[i][0] for i in valindices] assert len(op.argnames) == len(argrefs) e2name = Ref(mkname('val')) e2expr = intrin(op.name, [(nm, Val(ai)) for nm, ai in zip(op.argnames, argrefs)]) # TODO: Make this block customizable via callbacks err = False acc: List[IExpr] = [] for b in range(nbatch): e2val, _ = interp(e2expr, TMap(lib), TMap(valcache[b])) if isinstance(e2val, IError): err = True break if isinstance(e2val, IVal) and isinstance(e2val.val, int): if abs(e2val.val) > 10000 or abs(e2val.val) < -10000: err = True break acc.append(e2val) if err: continue for b in range(nbatch): valcache[b][e2name] = acc[b] exprcache[e2name] = e2expr exprw[e2name] = W yield (e2name, TMap(exprcache), [TMap(fd) for fd in valcache], exprw)