def test_generalization3(self): root = RootCtx(args=[x, int_bag], state_vars=[]) ctx1 = UnderBinder(root, y, int_bag, RUNTIME_POOL) ctx2 = UnderBinder(ctx1, z, ESingleton(y).with_type(TBag(y.type)), RUNTIME_POOL) gen = ctx2.generalize({z}) assert gen is ctx2
def test_generalization2(self): root = RootCtx(args=[x, int_bag], state_vars=[]) ctx1 = UnderBinder(root, y, int_bag, RUNTIME_POOL) ctx2 = UnderBinder(ctx1, z, int_bag, RUNTIME_POOL) gen = ctx2.generalize({z}) assert gen is not ctx2 assert gen == UnderBinder(root, z, int_bag, RUNTIME_POOL)
def test_hint_instantation(self): x = EVar("x").with_type(INT) y = EVar("y").with_type(INT) z = EVar("z").with_type(INT) hint = ECall("f", (x,)).with_type(INT) context = UnderBinder( RootCtx(args=[x]), v=y, bag=ESingleton(x).with_type(TBag(x.type)), bag_pool=RUNTIME_POOL) cost_model = CostModel() f = lambda a: a + 1 enumerator = Enumerator( examples=[{"x": 1, "f": f}, {"x": 100, "f": f}], hints=[(hint, context, RUNTIME_POOL)], cost_model=cost_model) results = [] for ctx in ( context, context.parent(), UnderBinder(context, v=z, bag=ESingleton(y).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=z, bag=ESingleton(x).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=y, bag=ESingleton(ONE).with_type(INT_BAG), bag_pool=RUNTIME_POOL)): print("-" * 30) found = False for e in enumerator.enumerate(ctx, 0, RUNTIME_POOL): print(" -> {}".format(pprint(e))) found = found or alpha_equivalent(e, hint) print("found? {}".format(found)) results.append(found) assert all(results)
def test_complicated_adapt(self): e = EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))) e_ctx = UnderBinder(parent=RootCtx(state_vars=OrderedSet([EVar('lineitem').with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString()))))), EVar('orders').with_type(TBag(TRecord((('orderkey', TInt()), ('custkey', TInt()), ('orderstatus', TNative('char')), ('totalprice', TFloat()), ('orderdate', TNative('uint64_t')), ('orderpriority', TString()), ('clerk', TString()), ('shippriority', TInt()), ('comment', TString()))))), EVar('part').with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), EVar('customer').with_type(TBag(TRecord((('custkey', TInt()), ('name', TString()), ('address', TString()), ('nationkey', TInt()), ('phone', TString()), ('acctbal', TFloat()), ('mktsegment', TString()), ('comment', TString()))))), EVar('supplier').with_type(TBag(TRecord((('suppkey', TInt()), ('name', TString()), ('address', TString()), ('nationkey', TInt()), ('phone', TString()), ('acctbal', TFloat()), ('comment', TString()))))), EVar('partsupp').with_type(TBag(TRecord((('partkey', TInt()), ('suppkey', TInt()), ('availqty', TInt()), ('supplycost', TFloat()), ('comment', TString()))))), EVar('nation').with_type(TBag(TRecord((('nationkey', TInt()), ('name', TString()), ('regionkey', TInt()), ('comment', TString()))))), EVar('region').with_type(TBag(TRecord((('regionkey', TInt()), ('name', TString()), ('comment', TString())))))]), args=OrderedSet([EVar('orderkey').with_type(TInt()), EVar('partkey').with_type(TInt()), EVar('suppkey').with_type(TInt()), EVar('linenumber').with_type(TInt()), EVar('quantity').with_type(TFloat()), EVar('extendedprice').with_type(TFloat()), EVar('discount').with_type(TFloat()), EVar('tax').with_type(TFloat()), EVar('returnflag').with_type(TNative('char')), EVar('linestatus').with_type(TNative('char')), EVar('shipdate').with_type(TNative('uint64_t')), EVar('commitdate').with_type(TNative('uint64_t')), EVar('receiptdate').with_type(TNative('uint64_t')), EVar('shipinstruct').with_type(TString()), EVar('shipmode').with_type(TString()), EVar('comment').with_type(TString())]), funcs=OrderedDict([('div', TFunc((TFloat(), TFloat()), TFloat())), ('int2float', TFunc((TInt(),), TFloat())), ('brand23', TFunc((), TString())), ('medbox', TFunc((), TString()))])), v=EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), bag=EStateVar(EFilter(EFilter(EVar('part').with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), ELambda(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), EBinOp(EGetField(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), 'brand').with_type(TString()), '==', ECall('brand23', ()).with_type(TString())).with_type(TBool()))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), ELambda(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), EBinOp(EGetField(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), 'container').with_type(TString()), '==', ECall('medbox', ()).with_type(TString())).with_type(TBool()))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), bag_pool=RUNTIME_POOL) dest_ctx = UnderBinder(parent=UnderBinder(parent=RootCtx(state_vars=OrderedSet([EVar('lineitem').with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString()))))), EVar('orders').with_type(TBag(TRecord((('orderkey', TInt()), ('custkey', TInt()), ('orderstatus', TNative('char')), ('totalprice', TFloat()), ('orderdate', TNative('uint64_t')), ('orderpriority', TString()), ('clerk', TString()), ('shippriority', TInt()), ('comment', TString()))))), EVar('part').with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), EVar('customer').with_type(TBag(TRecord((('custkey', TInt()), ('name', TString()), ('address', TString()), ('nationkey', TInt()), ('phone', TString()), ('acctbal', TFloat()), ('mktsegment', TString()), ('comment', TString()))))), EVar('supplier').with_type(TBag(TRecord((('suppkey', TInt()), ('name', TString()), ('address', TString()), ('nationkey', TInt()), ('phone', TString()), ('acctbal', TFloat()), ('comment', TString()))))), EVar('partsupp').with_type(TBag(TRecord((('partkey', TInt()), ('suppkey', TInt()), ('availqty', TInt()), ('supplycost', TFloat()), ('comment', TString()))))), EVar('nation').with_type(TBag(TRecord((('nationkey', TInt()), ('name', TString()), ('regionkey', TInt()), ('comment', TString()))))), EVar('region').with_type(TBag(TRecord((('regionkey', TInt()), ('name', TString()), ('comment', TString())))))]), args=OrderedSet([EVar('orderkey').with_type(TInt()), EVar('partkey').with_type(TInt()), EVar('suppkey').with_type(TInt()), EVar('linenumber').with_type(TInt()), EVar('quantity').with_type(TFloat()), EVar('extendedprice').with_type(TFloat()), EVar('discount').with_type(TFloat()), EVar('tax').with_type(TFloat()), EVar('returnflag').with_type(TNative('char')), EVar('linestatus').with_type(TNative('char')), EVar('shipdate').with_type(TNative('uint64_t')), EVar('commitdate').with_type(TNative('uint64_t')), EVar('receiptdate').with_type(TNative('uint64_t')), EVar('shipinstruct').with_type(TString()), EVar('shipmode').with_type(TString()), EVar('comment').with_type(TString())]), funcs=OrderedDict([('div', TFunc((TFloat(), TFloat()), TFloat())), ('int2float', TFunc((TInt(),), TFloat())), ('brand23', TFunc((), TString())), ('medbox', TFunc((), TString()))])), v=EVar('l').with_type(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString())))), bag=EBinOp(EStateVar(EVar('lineitem').with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString())))))).with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString()))))), '+', ESingleton(EMakeRecord((('orderkey', EVar('orderkey').with_type(TInt())), ('partkey', EVar('partkey').with_type(TInt())), ('suppkey', EVar('suppkey').with_type(TInt())), ('linenumber', EVar('linenumber').with_type(TInt())), ('quantity', EVar('quantity').with_type(TFloat())), ('extendedprice', EVar('extendedprice').with_type(TFloat())), ('discount', EVar('discount').with_type(TFloat())), ('tax', EVar('tax').with_type(TFloat())), ('returnflag', EVar('returnflag').with_type(TNative('char'))), ('linestatus', EVar('linestatus').with_type(TNative('char'))), ('shipdate', EVar('shipdate').with_type(TNative('uint64_t'))), ('commitdate', EVar('commitdate').with_type(TNative('uint64_t'))), ('receiptdate', EVar('receiptdate').with_type(TNative('uint64_t'))), ('shipinstruct', EVar('shipinstruct').with_type(TString())), ('shipmode', EVar('shipmode').with_type(TString())), ('comment', EVar('comment').with_type(TString())))).with_type(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString()))))).with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString())))))).with_type(TBag(TRecord((('orderkey', TInt()), ('partkey', TInt()), ('suppkey', TInt()), ('linenumber', TInt()), ('quantity', TFloat()), ('extendedprice', TFloat()), ('discount', TFloat()), ('tax', TFloat()), ('returnflag', TNative('char')), ('linestatus', TNative('char')), ('shipdate', TNative('uint64_t')), ('commitdate', TNative('uint64_t')), ('receiptdate', TNative('uint64_t')), ('shipinstruct', TString()), ('shipmode', TString()), ('comment', TString()))))), bag_pool=RUNTIME_POOL), v=EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), bag=EStateVar(EFilter(EFilter(EVar('part').with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), ELambda(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), EBinOp(EGetField(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), 'brand').with_type(TString()), '==', ECall('brand23', ()).with_type(TString())).with_type(TBool()))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), ELambda(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), EBinOp(EGetField(EVar('p').with_type(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))), 'container').with_type(TString()), '==', ECall('medbox', ()).with_type(TString())).with_type(TBool()))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString())))))).with_type(TBag(TRecord((('partkey', TInt()), ('name', TString()), ('mfgr', TString()), ('brand', TString()), ('part_type', TString()), ('size', TInt()), ('container', TString()), ('retailprice', TFloat()), ('comment', TString()))))), bag_pool=RUNTIME_POOL) print("adapting {}".format(pprint(e))) print("in {}".format(e_ctx)) print("to {}".format(dest_ctx)) e_prime = dest_ctx.adapt(e, e_ctx) print(" ---> {}".format(pprint(e_prime)))
def test_enumerator_fingerprints(self): """ The enumerator should always give us fingerprints in the context we asked for. """ x = EVar("x").with_type(INT) ctx = RootCtx(args=(x, ), state_vars=()) enumerator = Enumerator(examples=[{ "x": 0 }, { "x": 1 }], cost_model=CostModel()) inner_ctx = UnderBinder( ctx, EVar("y").with_type(INT), EBinOp( ESingleton(ZERO).with_type(INT_BAG), "+", ESingleton(ONE).with_type(INT_BAG)).with_type(INT_BAG), RUNTIME_POOL) fingerprint_lens = set() for info in enumerator.enumerate_with_info(inner_ctx, 0, RUNTIME_POOL): fingerprint_lens.add(len(info.fingerprint)) print(info) assert len(fingerprint_lens) == 1, fingerprint_lens
def test_hint_instantation(self): x = EVar("x").with_type(INT) y = EVar("y").with_type(INT) z = EVar("z").with_type(INT) hint = ECall("f", (x, )).with_type(INT) context = UnderBinder(RootCtx(args=[x]), v=y, bag=ESingleton(x).with_type(TBag(x.type)), bag_pool=RUNTIME_POOL) cost_model = CostModel() f = lambda a: a + 1 enumerator = Enumerator(examples=[{ "x": 1, "f": f }, { "x": 100, "f": f }], hints=[(hint, context, RUNTIME_POOL)], cost_model=cost_model) results = [] for ctx in (context, context.parent(), UnderBinder(context, v=z, bag=ESingleton(y).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=z, bag=ESingleton(x).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL), UnderBinder(context.parent(), v=y, bag=ESingleton(ONE).with_type(INT_BAG), bag_pool=RUNTIME_POOL)): print("-" * 30) found = False for e in enumerator.enumerate(ctx, 0, RUNTIME_POOL): print(" -> {}".format(pprint(e))) found = found or alpha_equivalent(e, hint) print("found? {}".format(found)) results.append(found) assert all(results)
def test_state_pool_boundary(self): """ When enumerating expressions, we shouldn't ever enumerate state expressions in a context where some binders are runtime variables. """ class TestEnumerator(Enumerator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.state_enumerations = 0 def _enumerate_core(self, context, size, pool): print("_enumerate_core({}, {}, {})".format( context, size, pool)) if pool == STATE_POOL: self.state_enumerations += 1 return super()._enumerate_core(context, size, pool) state_bag = EVar("state").with_type(INT_BAG) context = RootCtx(state_vars=[state_bag], args=[EVar("arg").with_type(INT)]) enumerator = TestEnumerator(examples=[{ "state": Bag([10]), "arg": 10 }, { "state": Bag([20]), "arg": 30 }], cost_model=CostModel()) for e in enumerator.enumerate(context, 1, RUNTIME_POOL): pass for e in enumerator.enumerate( UnderBinder(context, EVar("x").with_type(INT), EStateVar(state_bag).with_type(state_bag.type), RUNTIME_POOL), 1, RUNTIME_POOL): pass assert enumerator.state_enumerations == 1
def build_lambdas(bag, pool, body_size): v = fresh_var(bag.type.t, omit=set(v for v, p in context.vars())) inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool) for lam_body in self.enumerate(inner_context, body_size, pool): yield ELambda(v, lam_body)
def test_generalization1(self): root = RootCtx(args=[x, int_bag], state_vars=[]) ctx = UnderBinder(root, y, int_bag, RUNTIME_POOL) assert ctx.generalize({x}) is root