def eval_expr(exp, env, out): return conde( ( eq([quote, out], exp), no_item(out, is_internal), missing(quote, env), ), fresh(lambda lst: conjp( eq([list_, lst, ...], exp), missing(list_, env), eval_list(lst, env, out), )), fresh( 5, lambda fn, arg, var, body, cenv: conj( eval_list(exp, env, [fn, arg]), conde( eval_builtin(fn, arg, out), ( eq(Closure(var, body, cenv), fn), eval_expr(body, Env(var, arg, cenv), out), ), ))), fresh( 2, lambda var, body: conjp( eq([lambda_, [var], body], exp), eq(Closure(var, body, env), out), eqt(var, Symbol), missing(lambda_, env), )), (eqt(exp, Symbol), lookup(exp, env, out)), (builtin(exp, out), missing(exp, env)), )
def calc(e, o): a, b = Var(), Var() ca, cb = Var(), Var() op = Var() return conde([ eq((op, a, b), e), calc(a, ca), calc(b, cb), conde( [eq("+", op), add(ca, cb, o)], [eq("-", op), sub(ca, cb, o)], [eq("*", op), mul(ca, cb, o)], [eq("/", op), div(ca, cb, o)], ) ], [eqt(e, int), eq(e, o)])
def eval_list(lst, env, out): return conde((eq([], lst), eq([], out)), fresh( 4, lambda h, t, oh, ot: conjp( eq([h, t, ...], lst), eq([oh, ot, ...], out), eval_expr(h, env, oh), eval_list(t, env, ot), )))
def jugs(states): return disj( eq([(0, 0, "")], states), fresh( 8, lambda big, small, act, prev_big, prev_small, tail, _, __: conjp( eq([(big, small, act), tail, ...], states), eq([(prev_big, prev_small, _), __, ...], tail), conde( [ conde( [ eq(small, prev_small), conde( [eq(big, BIG), eq(act, "fill big")], [eq(big, 0), eq(act, "empty big")])], [ fresh( lambda total: conjp( conde( [eq(big, BIG), eq(act, "to big")], [ eq(small, SMALL), eq(act, "to small")], [ eq(small, 0), neq(big, BIG), eq(act, "to big")], [ eq(big, 0), neq(small, SMALL), eq(act, "to small")]), add(big, small, total), add(prev_big, prev_small, total)))]), neq(big, prev_big)], [ eq(big, prev_big), conde( [eq(small, SMALL), eq(act, "fill small")], [eq(small, 0), eq(act, "empty small")]), neq(small, prev_small)]), gte(big, 0), lte(big, BIG), gte(small, 0), lte(small, SMALL), jugs(tail), )))
def lookup(o, v, x): return fresh( 3, lambda a, b, t: conj( eq(o, ((a, b), t)), conde( [eq(a, v), eq(b, TMono(x))], fresh( 3, lambda po, pb, pm: conjp( eq(a, v), eq(b, TPoly(po, pb)), infer(pb, po, x, pm), )), [neq(a, v), lookup(t, v, x)], )))
def infer(expr, env, typ, ann): return conde( fresh(lambda n: conjp( eq(expr, Sym(n)), lookup(env, n, typ), eq(ann, Ann(expr, typ)), )), fresh( 6, lambda f, ft, fa, a, at, aa: conjp( eq(expr, App(f, a)), infer(a, env, at, aa), infer(f, env, TFunc(at, typ), fa), eq(ann, App(fa, aa)), )), fresh( 5, lambda v, vt, b, bt, ba: conjp( eq(expr, Abs(Sym(v), b)), infer(b, ((v, TMono(vt)), env), bt, ba), eq(typ, TFunc(vt, bt)), eq(ann, Abs(Ann(Sym(v), vt), ba)), )), fresh( 4, lambda n, v, b, bm: conjp( eq(expr, Let(Sym(n), v, b)), infer(b, ((n, TPoly(env, v)), env), typ, bm), eq(ann, Let(Sym(n), v, bm)), )), fresh( 3, lambda n, b, bm: conjp( eq(expr, Rec(Sym(n), b)), infer(b, ((n, TMono(typ)), env), typ, bm), eq(ann, Rec(Ann(Sym(n), typ), bm)), )), fresh( 6, lambda c, cm, a, am, b, bm: conjp( eq(expr, If(c, a, b)), infer(c, env, TTerm("bool"), cm), infer(a, env, typ, am), infer(b, env, typ, bm), eq(ann, If(cm, am, bm)), )), [eqt(expr, int), eq(typ, TTerm("int")), eq(ann, expr)], [eqt(expr, bool), eq(typ, TTerm("bool")), eq(ann, expr)], )
def delay_goal(a, b): return conde( eq(b, (1, a)), (eq(a, 1), delay_goal(2, b)), (eq(a, 2), delay_goal(3, b), eq(a, b)) )
def zzz_goal(a, b): return conde( eq(b, (1, a)), (eq(a, 1), zzz(lambda: zzz_goal(2, b))), (eq(a, 2), zzz(lambda: zzz_goal(3, b)), eq(a, b)) )
GOALS_DATA = [ (conj(eq(a, 1), fresh(lambda x: conj(eq(a, x), eq(x, c)))), c, [1]), ( conj( eq(a, 1), fresh(2, lambda x, y: conj(eq((x, y), (a, 2)), eq(c, (x, y)))) ), c, [(1, 2)] ), (conjp(eq(a, 1), eq(b, 2), eq(c, (a, b))), c, [(1, 2)]), (conjp(disjp(eq(a, 1), eq(a, 2), eq(a, 3)), eq(c, a)), c, [1, 3, 2]), ( conde( (eq(a, 1), eq(b, 2), eq(c, a)), (eq(a, 2), eq(b, 3), eq(c, b)), eq(c, 5), ), c, [1, 5, 3] ), (zzz_goal(1, c), c, [(1, 1), (1, 2)]), (delay_goal(1, c), c, [(1, 1), (1, 2)]), ] @pytest.mark.parametrize("goal, query, expected", GOALS_DATA) def test_dsl(goal, query, expected): assert [reify(query, st) for st in unfold(goal(initial()))] == expected
def eval_builtin(fn, arg, out): return fresh(lambda t: conde( (eq(fn, car_fn), eq([out, t, ...], arg)), (eq(fn, cdr_fn), eq([t, out, ...], arg)), ))
def builtin(fn, out): return conde( (eq(fn, car), eq(out, car_fn)), (eq(fn, cdr), eq(out, cdr_fn)), )
def lookup(var, env, out): return fresh( 3, lambda rest, sym, val: conj( eq(Env(sym, val, rest), env), conde((eq(sym, var), eq(val, out)), (neq(sym, var), lookup(var, rest, out)))))