def test_serimem(): def _test(ie): assert ie == bin2imem(imem2bin(ie)) assert ie == jstr2imem(imem2jstr(ie)) _test(mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}))
def test_serexpr2(): wlib = mkwlib(lib_arith, 5) imem:IMem = mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}) g = genexpr(wlib, [imem]) for i in range(1000): ref,mem,vals,w = next(g) expr1 = gather(ref,mem) expr2 = jstr2expr(expr2jstr(expr1)) assert expr1==expr2 expr2 = bin2expr(expr2bin(expr1)) assert expr1==expr2
def interp(expr: Expr, lib: Lib, m: IMem) -> Tuple[IExpr, IMem]: if isinstance(expr, Val): if isinstance(expr.val, Ref): return (m[expr.val], m) elif isinstance(expr.val, Const): return (IVal(expr.val.const), m) else: raise ValueError(f"Invalid value {expr}") elif isinstance(expr, Let): val, m2 = interp(expr.expr, lib, m) return interp(expr.body, lib, m2.set(expr.ref, val)) elif isinstance(expr, Ap): func, m2 = interp(expr.func, lib, m) arg, m3 = interp(expr.arg, lib, m2) if isinstance(func, ILam): return interp(func.body, lib, m3.set(Ref(func.name), arg)) else: raise ValueError(f"Invalid callable {func}") elif isinstance(expr, Lam): return (ILam(expr.name, expr.body), m) elif isinstance(expr, Intrin): libentry = lib[expr.name] iargs = {} for aname, aexpr in expr.args.items(): a, _ = interp(aexpr, lib, m) iargs.update({aname: a}) return (libentry.impl(iargs), m) else: raise ValueError(f"Invalid expression {expr}")
def test_refs()->None: e = let_('a', num(33), lambda a: let_('b', num(42), lambda b: intrin(MethodName("add"), [('a',a),('c',ref('c'))]))) assert refs_(e)==set([Ref('a'),Ref('b'),Ref('c')]) assert refs(e)==set([Ref('a'),Ref('c')]) assert decls(e)==set([Ref('a'),Ref('b')])
def _make(b:Build): build_setoutpaths(b, 1) builtin_inputs:dict={ # Ref("i0"):IVal(0), # Ref("i1"):IVal(1) } IMEMs:List[IMem] = [mkmap(_union(builtin_inputs, {Ref(f"i{i+len(builtin_inputs.keys())}"): IVal(randint(range_min,range_max)) for i in range(num_inputs)})) for _ in range(batch_size)] writejson(mklens(b).out_inputs.syspath, [imem2json(M) for M in IMEMs])
def test_gengather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) imem = IMem({ Ref('i'): IVal(32), Ref('a'): IVal(33), Ref('b'): IVal(-32), Ref('c'): IVal(1) }) exprs = gengather(Ref('c'), mem) for e in exprs: inps = IMem({k:v for k,v in imem.items() if k in extrefs(e)}) iexpr,_ = interp(e, lib_arith, inps) print(print_expr(e), iexpr, inps.dict) assert iexpr == imem[Ref('c')]
def test_example(): e=Example(inp=mkmap({Ref('a'):ival(0)}), expr=ap(ref('a'),ref('b')), out=IVal(33)) with open('/tmp/binfile','wb') as f: add=examples2fd(f) add(e) add(e) with open('/tmp/binfile','rb') as f: _next=fd2examples(f) n=_next() assert n==e n=_next() assert n==e
def test_gather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) expr = gather(Ref('c'), mem) assert extrefs(expr)==set([Ref('i')]) iexpr,_ = interp(expr, lib_arith, mkmap({Ref('i'):ival(32)})) assert iexpr == IVal(1)
def test_genexpr()->None: wlib = mkwlib(lib_arith, 5) imem:IMem = mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}) g = genexpr(wlib, [imem]) for i in range(1000): ref,mem,vals,exprw = next(g) expr = gather(ref,mem) iexpr,_ = interp(expr, lib_arith, imem) assert len(vals)==1 assert iexpr==vals[0][ref] assert len(extrefs(expr))>0 assert extrefs(expr).issubset(set([Ref('a'),Ref('b'),Ref('c')])) print(print_expr(expr), iexpr)
def let_(name:str, expr:Expr, body:Callable[[Expr], Expr])->Expr: return Let(Ref(str(name)), expr, body(ref(name)))
def ref(x:str)->Expr: return Val(Ref(str(x)))
def json2imem(d: dict) -> IMem: return TMap({Ref(k): json2iexpr(v) for k, v in d.items()})
def bin2imem(d: BIN) -> IMem: assert d.tag == Tag.imem, f"Unexpected tag {d.tag}" return TMap({ Ref(str(i.tuple.v1.string)): bin2iexpr(i.tuple.v2.node) for i in d.value.list.list })
def genexpr( wlib: WLib, inputs: List[IMem] ) -> Iterator[Tuple[Ref, TMap[Ref, Expr], List[IMem], Dict[Ref, int]]]: """ Iterate over space of lambda-expressions with `len(inputs[0])` input arguments. For every expression visited, provide results of it's evaluation on every input of the `intputs` list. Arguments: * `wlib`: Weighted library of primitive operations. See also `mkwlib`. * `inputs`: Collection of the inputs on which to grow the expression. All inputs in list should be of the same size and use same names. Yields a tuple of: * Top-level reference `Ref` * Map of `Ref -> Intrin`. Where `Intrin`s may contain more refs from the same map. * List of output expressions. Size of this list is equal to the size of list of `inputs`. * Current weight of the expression. """ # All inputs should provide the same input names assert all([i.keys() == inputs[0].keys() for i in inputs]) nbatch = len(inputs) lib = {k: wl[0] for k, wl in wlib.items()} libws = {k: wl[1] for k, wl in wlib.items()} exprcache: Dict[Ref, Expr] = {} exprw: Dict[Ref, int] = {k: 1 for k in inputs[0].keys()} valcache: List[Dict[Ref, IExpr]] = [OrderedDict(i.dict) for i in inputs] W = 0 while True: W += 1 for op in lib.values(): w = libws[op.name] nargs = len(op.argnames) vws: List[Tuple[Ref, int]] = list(exprw.items()) for valindices in permute(weights=[a[1] for a in vws], nargs=nargs, target_weight=W - w): argrefs: List[Ref] = [vws[i][0] for i in valindices] assert len(op.argnames) == len(argrefs) e2name = Ref(mkname('val')) e2expr = intrin(op.name, [(nm, Val(ai)) for nm, ai in zip(op.argnames, argrefs)]) # TODO: Make this block customizable via callbacks err = False acc: List[IExpr] = [] for b in range(nbatch): e2val, _ = interp(e2expr, TMap(lib), TMap(valcache[b])) if isinstance(e2val, IError): err = True break if isinstance(e2val, IVal) and isinstance(e2val.val, int): if abs(e2val.val) > 10000 or abs(e2val.val) < -10000: err = True break acc.append(e2val) if err: continue for b in range(nbatch): valcache[b][e2name] = acc[b] exprcache[e2name] = e2expr exprw[e2name] = W yield (e2name, TMap(exprcache), [TMap(fd) for fd in valcache], exprw)