def test_mergemap()->None: def _mm(a:TMap[int,int], b:TMap[int,int], f)->TMap[int,int]: return mergemap(a,b,f) assert _mm(mkmap({1:1,2:2}), mkmap({1:10,3:30}), lambda a,b:a+b)==mkmap({1:11,2:2,3:30}) assert _mm(mkmap({1:1,3:3}),mkmap({}), lambda a,b:int(a+b))==mkmap({1:1,3:3}) assert _mm(mkmap({}),mkmap({1:1,3:3}), lambda a,b:int(a+b))==mkmap({1:1,3:3})
def test_serimem(): def _test(ie): assert ie == bin2imem(imem2bin(ie)) assert ie == jstr2imem(imem2jstr(ie)) _test(mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}))
def test_let2()->None: e = let(num(33), lambda a: let(num(42), lambda b: intrin(MethodName("add"), [('a',a),('b',b)]))) v,_ = interp(e, lib_arith, mkmap()) assert isinstance(v, IVal) assert v.val==33+42
def _test_example_pickle(i:int)->Example: e1=Example( inp=mkmap({}), expr=EXPRS[i % len(EXPRS)], out=IVal(i)) e2=bin2ex(ex2bin(e1)) check_expr(e2.expr) return e2
def freqs(e: Expr) -> TMap[MethodName, int]: """ returns the intrinciq usage statistics """ def mm(a, b): return mergemap(a, b, lambda x, y: x + y) if isinstance(e, Val): return mkmap({}) elif isinstance(e, Ap): return mm(freqs(e.func), freqs(e.arg)) elif isinstance(e, Lam): return freqs(e.body) elif isinstance(e, Let): return mm(freqs(e.expr), freqs(e.body)) elif isinstance(e, Intrin): return mkmap({e.name: 1}) else: raise ValueError(f"Invalid expression '{e}'")
def test_gather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) expr = gather(Ref('c'), mem) assert extrefs(expr)==set([Ref('i')]) iexpr,_ = interp(expr, lib_arith, mkmap({Ref('i'):ival(32)})) assert iexpr == IVal(1)
def _make(b:Build): build_setoutpaths(b, 1) builtin_inputs:dict={ # Ref("i0"):IVal(0), # Ref("i1"):IVal(1) } IMEMs:List[IMem] = [mkmap(_union(builtin_inputs, {Ref(f"i{i+len(builtin_inputs.keys())}"): IVal(randint(range_min,range_max)) for i in range(num_inputs)})) for _ in range(batch_size)] writejson(mklens(b).out_inputs.syspath, [imem2json(M) for M in IMEMs])
def test_serexpr2(): wlib = mkwlib(lib_arith, 5) imem:IMem = mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}) g = genexpr(wlib, [imem]) for i in range(1000): ref,mem,vals,w = next(g) expr1 = gather(ref,mem) expr2 = jstr2expr(expr2jstr(expr1)) assert expr1==expr2 expr2 = bin2expr(expr2bin(expr1)) assert expr1==expr2
def test_genexpr()->None: wlib = mkwlib(lib_arith, 5) imem:IMem = mkmap({Ref('a'):ival(0), Ref('b'):ival(1), Ref('c'):ival(2)}) g = genexpr(wlib, [imem]) for i in range(1000): ref,mem,vals,exprw = next(g) expr = gather(ref,mem) iexpr,_ = interp(expr, lib_arith, imem) assert len(vals)==1 assert iexpr==vals[0][ref] assert len(extrefs(expr))>0 assert extrefs(expr).issubset(set([Ref('a'),Ref('b'),Ref('c')])) print(print_expr(expr), iexpr)
def test_example(): e=Example(inp=mkmap({Ref('a'):ival(0)}), expr=ap(ref('a'),ref('b')), out=IVal(33)) with open('/tmp/binfile','wb') as f: add=examples2fd(f) add(e) add(e) with open('/tmp/binfile','rb') as f: _next=fd2examples(f) n=_next() assert n==e n=_next() assert n==e
def test_let1()->None: e = let(num(33), lambda x: x) v,_ = interp(e, lib_arith, mkmap()) assert isinstance(v, IVal) assert v.val==33
def ival(n:int)->IVal: r,_ = interp(num(n), lib_arith, mkmap()) assert isinstance(r,IVal) return r
def labels(e:Example, l:Lib=lib_arith)->Tensor: s=mergemap(mkmap({mn:0 for mn in l.keys()}),freqs(e.expr),lambda x,y:x+y) f={mn:n/sum(s.values()) for mn,n in s.items()} # print({mn.val:i for mn,i in f.items()}) return Tensor(list(f.values()))