def test_seriexpr(): def _test(ie): assert ie == bin2iexpr(iexpr2bin(ie)) assert ie == jstr2iexpr(iexpr2jstr(ie)) _test(IVal(33)) _test(IVal("foo")) _test(IAp(ILam("pat", intrin(MethodName('neg'),[('a',ref('pat'))])), IVal(42))) _test(IError("the message"))
def interp(expr: Expr, lib: Lib, m: IMem) -> Tuple[IExpr, IMem]: if isinstance(expr, Val): if isinstance(expr.val, Ref): return (m[expr.val], m) elif isinstance(expr.val, Const): return (IVal(expr.val.const), m) else: raise ValueError(f"Invalid value {expr}") elif isinstance(expr, Let): val, m2 = interp(expr.expr, lib, m) return interp(expr.body, lib, m2.set(expr.ref, val)) elif isinstance(expr, Ap): func, m2 = interp(expr.func, lib, m) arg, m3 = interp(expr.arg, lib, m2) if isinstance(func, ILam): return interp(func.body, lib, m3.set(Ref(func.name), arg)) else: raise ValueError(f"Invalid callable {func}") elif isinstance(expr, Lam): return (ILam(expr.name, expr.body), m) elif isinstance(expr, Intrin): libentry = lib[expr.name] iargs = {} for aname, aexpr in expr.args.items(): a, _ = interp(aexpr, lib, m) iargs.update({aname: a}) return (libentry.impl(iargs), m) else: raise ValueError(f"Invalid expression {expr}")
def _test_example_pickle(i:int)->Example: e1=Example( inp=mkmap({}), expr=EXPRS[i % len(EXPRS)], out=IVal(i)) e2=bin2ex(ex2bin(e1)) check_expr(e2.expr) return e2
def bin2iexpr(j: BIN) -> IExpr: if j.tag == Tag.ival: if j.value.HasField('int64'): return IVal(int(j.value.int64)) elif j.value.HasField('string'): return IVal(str(j.value.string)) else: raise ValueError(f"Invalid value expression {j}") elif j.tag == Tag.ilam: return ILam(str(j.value.tuple.v1.string), bin2expr(j.value.tuple.v2.node)) elif j.tag == Tag.iap: return IAp(bin2iexpr(j.value.tuple.v1.node), bin2iexpr(j.value.tuple.v2.node)) elif j.tag == Tag.ierror: return IError(str(j.value.string)) else: raise ValueError(f"Invalid expression {_flat(j)}")
def json2iexpr(j: dict) -> IExpr: typ = j['t'] if typ == 'ival': vtyp = j['ival']['t'] if vtyp == 'int': return IVal(int(j['ival']['val'])) elif vtyp == 'str': return IVal(str(j['ival']['val'])) else: raise ValueError(f"Invalid value expression {j}") elif typ == 'ilam': return ILam(j['name'], json2expr(j['body'])) elif typ == 'iap': return IAp(json2iexpr(j['func']), json2iexpr(j['arg'])) elif typ == 'ierror': return IError(j['msg']) else: raise ValueError(f"Invalid expression {j}")
def test_gengather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) imem = IMem({ Ref('i'): IVal(32), Ref('a'): IVal(33), Ref('b'): IVal(-32), Ref('c'): IVal(1) }) exprs = gengather(Ref('c'), mem) for e in exprs: inps = IMem({k:v for k,v in imem.items() if k in extrefs(e)}) iexpr,_ = interp(e, lib_arith, inps) print(print_expr(e), iexpr, inps.dict) assert iexpr == imem[Ref('c')]
def _make(b:Build): build_setoutpaths(b, 1) builtin_inputs:dict={ # Ref("i0"):IVal(0), # Ref("i1"):IVal(1) } IMEMs:List[IMem] = [mkmap(_union(builtin_inputs, {Ref(f"i{i+len(builtin_inputs.keys())}"): IVal(randint(range_min,range_max)) for i in range(num_inputs)})) for _ in range(batch_size)] writejson(mklens(b).out_inputs.syspath, [imem2json(M) for M in IMEMs])
def test_gather()->None: mn = MethodName mem = Mem({ Ref('a'): num(33), Ref('b'): intrin(mn('neg'),[('a',ref('i'))]), Ref('c'): intrin(mn('add'),[('a',ref('a')),('b',ref('b'))]) }) expr = gather(Ref('c'), mem) assert extrefs(expr)==set([Ref('i')]) iexpr,_ = interp(expr, lib_arith, mkmap({Ref('i'):ival(32)})) assert iexpr == IVal(1)
def test_example(): e=Example(inp=mkmap({Ref('a'):ival(0)}), expr=ap(ref('a'),ref('b')), out=IVal(33)) with open('/tmp/binfile','wb') as f: add=examples2fd(f) add(e) add(e) with open('/tmp/binfile','rb') as f: _next=fd2examples(f) n=_next() assert n==e n=_next() assert n==e
def _mul(a, b) -> IExpr: return IVal(a * b)
def _add(a, b) -> IExpr: return IVal(a + b)
def _prec(a) -> IExpr: return IVal(a - 1)
def _succ(a) -> IExpr: return IVal(a + 1)
def _sqrt(a) -> IExpr: if a >= 0: return IVal(int(sqrt(a))) else: return IError('Sqrt from negative')
def _sqr(a) -> IExpr: return IVal(a * a)
def _neg(a) -> IExpr: return IVal(-a)
def _div(a, b) -> IExpr: try: return IVal(a // b) except ZeroDivisionError: return IError('Division by zero')