def construct_concrete(self, t: Type, e: Exp, out: Exp): """ Construct a value of type `t` from the expression `e` and store it in lvalue `out`. """ if hasattr(t, "construct_concrete"): return t.construct_concrete(e, out) elif isinstance(t, TBag) or isinstance(t, TList): assert out not in free_vars(e) x = self.fv(t.t, "x") return SSeq(self.initialize_native_list(out), SForEach(x, e, SCall(out, "add", [x]))) elif isinstance(t, TSet): if isinstance(e, EUnaryOp) and e.op == UOp.Distinct: return self.construct_concrete(t, e.e, out) x = self.fv(t.t, "x") return SSeq(self.initialize_native_set(out), SForEach(x, e, SCall(out, "add", [x]))) elif isinstance(t, TMap): return SSeq(self.initialize_native_map(out), self.construct_map(t, e, out)) elif isinstance(t, THandle): return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"], [out, e]) elif is_scalar(t): return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"], [out, e]) else: h = extension_handler(type(t)) if h is not None: return h.codegen(e, self.state_exps, out=out) raise NotImplementedError(t, e, out)
def _eq(self, e1, e2, indent): if isinstance(e1.type, THandle): return self.visit( EEscape("({e1} == {e2})", ["e1", "e2"], [self.addr_of(e1), self.addr_of(e2)]).with_type(BOOL), indent) if (is_scalar(e1.type) or (isinstance(e1.type, library.TNativeMap) and isinstance(e2.type, library.TNativeMap)) or (isinstance(e1.type, library.TNativeSet) and isinstance(e2.type, library.TNativeSet)) or (isinstance(e1.type, library.TNativeList) and isinstance(e2.type, library.TNativeList))): return self.visit( EEscape("({e1} == {e2})", ["e1", "e2"], [e1, e2]).with_type(BOOL), indent) elif isinstance(e1.type, TSet) and isinstance(e2.type, TSet): raise NotImplementedError("set equality") elif isinstance(e1.type, TBag) or isinstance(e2.type, TBag): setup1, v1 = self.histogram(e1, indent) setup2, v2 = self.histogram(e2, indent) setup3, res = self._eq(v1, v2, indent) return (setup1 + setup2 + setup3, res) elif isinstance(e1.type, TMap) or isinstance(e2.type, TMap): raise NotImplementedError("map equality") else: raise NotImplementedError((e1.type, e2.type))
def visit_SAssign(self, s): if is_scalar(s.rhs.type): self.write_stmt(self.visit(s.lhs), " = ", self.visit(s.rhs), ";") else: v = self.fv(s.lhs.type) self.declare(v, s.rhs) self.write_stmt(self.visit(s.lhs), " = ", self.visit(EMove(v).with_type(v.type)), ";")
def declare(self, v: EVar, initial_value: Exp = None): if initial_value is not None and is_scalar(v.type): iv = self.visit(initial_value) self.write_stmt(self.visit(v.type, v.id), " = ", iv, ";") else: self.write_stmt(self.visit(v.type, v.id), ";") if initial_value is not None: self.visit(self.construct_concrete(v.type, initial_value, v))
def _eq(self, e1, e2, indent): if not self.boxed and self.is_primitive(e1.type): return self.visit( EEscape("({e1} == {e2})", ("e1", "e2"), (e1, e2)).with_type(BOOL), indent) if (is_scalar(e1.type) or (isinstance(e1.type, library.TNativeMap) and isinstance(e2.type, library.TNativeMap)) or (isinstance(e1.type, library.TNativeSet) and isinstance(e2.type, library.TNativeSet)) or (isinstance(e1.type, library.TNativeList) and isinstance(e2.type, library.TNativeList))): return self.visit( EEscape("java.util.Objects.equals({e1}, {e2})", ["e1", "e2"], [e1, e2]).with_type(BOOL), indent) return super()._eq(e1, e2, indent)
def build(self, cache, size): # print("Cache:") # for (e, sz, pool) in cache: # from cozy.syntax_tools import pprint # print(" @size={}, pool={}\t:\t{}".format(sz, pool, pprint(e))) binders_by_type = group_by(self.binders, lambda b: b.type) for pool in ALL_POOLS: if size == 1: yield self.check(T, pool) yield self.check(F, pool) yield self.check(ZERO, pool) yield self.check(ONE, pool) for b in self.binders: yield self.check(b, pool) if pool == STATE_POOL: for v in self.state_vars: yield self.check(v, pool) elif pool == RUNTIME_POOL: for v in self.args: yield self.check(v, pool) if not build_exprs.value: return for e in cache.find(pool=STATE_POOL, size=size - 1): if all(v in self.state_vars for v in free_vars(e)): yield self.check( EStateVar(e).with_type(e.type), RUNTIME_POOL) for e in cache.find(pool=pool, size=size - 1): t = TBag(e.type) yield self.check(EEmptyList().with_type(t), pool) yield self.check(ESingleton(e).with_type(t), pool) for e in cache.find(pool=pool, type=TRecord, size=size - 1): for (f, t) in e.type.fields: yield self.check(EGetField(e, f).with_type(t), pool) for e in cache.find_collections(pool=pool, size=size - 1): if is_numeric(e.type.t): yield self.check( EUnaryOp(UOp.Sum, e).with_type(e.type.t), pool) for e in cache.find(pool=pool, type=THandle, size=size - 1): yield self.check( EGetField(e, "val").with_type(e.type.value_type), pool) for e in cache.find(pool=pool, type=TTuple, size=size - 1): for n in range(len(e.type.ts)): yield self.check( ETupleGet(e, n).with_type(e.type.ts[n]), pool) for e in cache.find(pool=pool, type=BOOL, size=size - 1): yield self.check(EUnaryOp(UOp.Not, e).with_type(BOOL), pool) for e in cache.find(pool=pool, type=INT, size=size - 1): yield self.check(EUnaryOp("-", e).with_type(INT), pool) for m in cache.find(pool=pool, type=TMap, size=size - 1): yield self.check(EMapKeys(m).with_type(TBag(m.type.k)), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for a1 in cache.find(pool=pool, size=sz1): if not is_numeric(a1.type): continue for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check( EBinOp(a1, "+", a2).with_type(INT), pool) yield self.check( EBinOp(a1, "-", a2).with_type(INT), pool) yield self.check( EBinOp(a1, ">", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, "<", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, ">=", a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, "<=", a2).with_type(BOOL), pool) for a1 in cache.find_collections(pool=pool, size=sz1): for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check( EBinOp(a1, "+", a2).with_type(a1.type), pool) yield self.check( EBinOp(a1, "-", a2).with_type(a1.type), pool) for a2 in cache.find(pool=pool, type=a1.type.t, size=sz2): yield self.check( EBinOp(a2, BOp.In, a1).with_type(BOOL), pool) for a1 in cache.find(pool=pool, type=BOOL, size=sz1): for a2 in cache.find(pool=pool, type=BOOL, size=sz2): yield self.check( EBinOp(a1, BOp.And, a2).with_type(BOOL), pool) yield self.check( EBinOp(a1, BOp.Or, a2).with_type(BOOL), pool) for a1 in cache.find(pool=pool, size=sz1): if not isinstance(a1.type, TMap): for a2 in cache.find(pool=pool, type=a1.type, size=sz2): yield self.check(EEq(a1, a2), pool) yield self.check( EBinOp(a1, "!=", a2).with_type(BOOL), pool) for m in cache.find(pool=pool, type=TMap, size=sz1): for k in cache.find(pool=pool, type=m.type.k, size=sz2): yield self.check( EMapGet(m, k).with_type(m.type.v), pool) yield self.check(EHasKey(m, k).with_type(BOOL), pool) for (sz1, sz2, sz3) in pick_to_sum(3, size - 1): for cond in cache.find(pool=pool, type=BOOL, size=sz1): for then_branch in cache.find(pool=pool, size=sz2): for else_branch in cache.find(pool=pool, size=sz3, type=then_branch.type): yield self.check( ECond(cond, then_branch, else_branch).with_type(then_branch.type), pool) for bag in cache.find_collections(pool=pool, size=size - 1): # len of bag count = EUnaryOp(UOp.Length, bag).with_type(INT) yield self.check(count, pool) # empty? yield self.check( EUnaryOp(UOp.Empty, bag).with_type(BOOL), pool) # exists? yield self.check( EUnaryOp(UOp.Exists, bag).with_type(BOOL), pool) # singleton? yield self.check(EEq(count, ONE), pool) yield self.check( EUnaryOp(UOp.The, bag).with_type(bag.type.t), pool) yield self.check( EUnaryOp(UOp.Distinct, bag).with_type(bag.type), pool) yield self.check( EUnaryOp(UOp.AreUnique, bag).with_type(BOOL), pool) if bag.type.t == BOOL: yield self.check( EUnaryOp(UOp.Any, bag).with_type(BOOL), pool) yield self.check( EUnaryOp(UOp.All, bag).with_type(BOOL), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in cache.find_collections(pool=pool, size=sz1): for binder in binders_by_type[bag.type.t]: for body in itertools.chain( cache.find(pool=pool, size=sz2), (binder, )): yield self.check( EMap(bag, ELambda(binder, body)).with_type(TBag(body.type)), pool) if body.type == BOOL: yield self.check( EFilter(bag, ELambda(binder, body)).with_type(bag.type), pool) if body.type == INT: yield self.check( EArgMin(bag, ELambda( binder, body)).with_type(bag.type.t), pool) yield self.check( EArgMax(bag, ELambda( binder, body)).with_type(bag.type.t), pool) if pool == RUNTIME_POOL and isinstance( body.type, TBag): yield self.check( EFlatMap(bag, ELambda(binder, body)).with_type( TBag(body.type.t)), pool) for (sz1, sz2) in pick_to_sum(2, size - 1): for bag in cache.find_collections(pool=STATE_POOL, size=sz1): if not is_scalar(bag.type.t): continue for b in binders_by_type[bag.type.t]: for val in cache.find(pool=STATE_POOL, size=sz2): t = TMap(bag.type.t, val.type) m = EMakeMap2(bag, ELambda(b, val)).with_type(t) yield self.check(m, STATE_POOL)
def _is_concrete(self, e): if is_scalar(e.type): return True elif type(e.type) in [TMap, TSet, TBag]: return False return True
def run(self): print("STARTING IMPROVEMENT JOB {} (|examples|={})".format( self.q.name, len(self.examples or ()))) os.makedirs(log_dir.value, exist_ok=True) with open(os.path.join(log_dir.value, "{}.log".format(self.q.name)), "w", buffering=LINE_BUFFER_MODE) as f: sys.stdout = f print("STARTING IMPROVEMENT JOB {} (|examples|={})".format( self.q.name, len(self.examples or ()))) print(pprint(self.q)) if nice_children.value: os.nice(20) all_types = self.ctx.all_types n_binders = 1 done = False expr = ETuple( (EAll(self.assumptions), self.q.ret)).with_type(TTuple((BOOL, self.q.ret.type))) while not done: binders = [] for t in all_types: # if isinstance(t, TBag): # binders += [fresh_var(t.t) for i in range(n_binders)] for i in range(n_binders): b = fresh_var(t) binders.append(b) try: core.fixup_binders(expr, binders, throw=True) done = True except: pass n_binders += 1 binders = [ fresh_var(t) for t in all_types if is_scalar(t) for i in range(n_binders) ] print("Using {} binders".format(n_binders)) relevant_state_vars = [ v for v in self.state if v in free_vars(EAll(self.assumptions)) | free_vars(self.q.ret) ] used_vars = free_vars(self.q.ret) for a in self.q.assumptions: used_vars |= free_vars(a) args = [EVar(v).with_type(t) for (v, t) in self.q.args] args = [a for a in args if a in used_vars] b = BinderBuilder(binders, relevant_state_vars, args) if accelerate.value: b = AcceleratedBuilder(b, binders, relevant_state_vars, args) try: for expr in itertools.chain( (self.q.ret, ), core.improve( target=self.q.ret, assumptions=EAll(self.assumptions), hints=self.hints, examples=self.examples, binders=binders, state_vars=relevant_state_vars, args=args, cost_model=CompositeCostModel(), builder=b, stop_callback=lambda: self.stop_requested)): new_rep, new_ret = tease_apart(expr) self.k(new_rep, new_ret) print("PROVED OPTIMALITY FOR {}".format(self.q.name)) except core.StopException: print("stopping synthesis of {}".format(self.q.name)) return
def improve_implementation(impl: Implementation, timeout: datetime.timedelta = datetime.timedelta( seconds=60), progress_callback=None) -> Implementation: start_time = datetime.datetime.now() # we statefully modify `impl`, so let's make a defensive copy impl = Implementation(impl.spec, list(impl.concrete_state), list(impl.query_specs), OrderedDict(impl.query_impls), defaultdict(SNoOp, impl.updates), defaultdict(SNoOp, impl.handle_updates)) # gather root types types = list(all_types(impl.spec)) basic_types = set(t for t in types if is_scalar(t)) basic_types |= {BOOL, INT} print("basic types:") for t in basic_types: print(" --> {}".format(pprint(t))) basic_types = list(basic_types) ctx = SynthCtx(all_types=types, basic_types=basic_types) # the actual worker threads improvement_jobs = [] with jobs.SafeQueue() as solutions_q: def stop_jobs(js): js = list(js) jobs.stop_jobs(js) for j in js: improvement_jobs.remove(j) def reconcile_jobs(): # figure out what new jobs we need job_query_names = set(j.q.name for j in improvement_jobs) new = [] for q in impl.query_specs: if q.name not in job_query_names: new.append( ImproveQueryJob( ctx, impl.abstract_state, list(impl.spec.assumptions) + list(q.assumptions), q, k=(lambda q: lambda new_rep, new_ret: solutions_q. put((q, new_rep, new_ret)))(q), hints=[ EStateVar(c).with_type(c.type) for c in impl.concretization_functions.values() ])) # figure out what old jobs we can stop impl_query_names = set(q.name for q in impl.query_specs) old = [ j for j in improvement_jobs if j.q.name not in impl_query_names ] # make it so stop_jobs(old) for j in new: j.start() improvement_jobs.extend(new) # start jobs reconcile_jobs() # wait for results timeout = Timeout(timeout) done = False while not done and not timeout.is_timed_out(): for j in improvement_jobs: if j.done: if j.successful: j.join() else: print("failed job: {}".format(j), file=sys.stderr) # raise Exception("failed job: {}".format(j)) done = all(j.done for j in improvement_jobs) try: # list of (Query, new_rep, new_ret) objects results = solutions_q.drain(block=True, timeout=0.5) except Empty: continue # group by query name, favoring later (i.e. better) solutions print("updating with {} new solutions".format(len(results))) improved_queries_by_name = OrderedDict() killed = 0 for r in results: q, new_rep, new_ret = r if q.name in improved_queries_by_name: killed += 1 improved_queries_by_name[q.name] = r if killed: print(" --> dropped {} worse solutions".format(killed)) improvements = list(improved_queries_by_name.values()) def index_of(l, p): if not isinstance(l, list): l = list(l) for i in range(len(l)): if p(l[i]): return i return -1 improvements.sort(key=lambda i: index_of( impl.query_specs, lambda qq: qq.name == i[0].name)) print("update order:") for (q, _, _) in improvements: print(" --> {}".format(q.name)) # update query implementations i = 1 for (q, new_rep, new_ret) in improvements: print("considering update {}/{}...".format( i, len(improvements))) i += 1 # this guard might be false if a better solution was # enqueued but the job has already been cleaned up if q.name in [qq.name for qq in impl.query_specs]: elapsed = datetime.datetime.now() - start_time print("SOLUTION FOR {} AT {} [size={}]".format( q.name, elapsed, new_ret.size() + sum(proj.size() for (v, proj) in new_rep))) print("-" * 40) for (sv, proj) in new_rep: print(" {} : {} = {}".format(sv.id, pprint(sv.type), pprint(proj))) print(" return {}".format(pprint(new_ret))) print("-" * 40) impl.set_impl(q, new_rep, new_ret) # clean up impl.cleanup() if progress_callback is not None: progress_callback( (impl, impl.code, impl.concretization_functions)) reconcile_jobs() # stop jobs print("Stopping jobs") stop_jobs(list(improvement_jobs)) return impl
def _eq(self, e1, e2): if not self.boxed and self.is_primitive(e1.type): return self.visit(EEscape("({e1} == {e2})", ("e1", "e2"), (e1, e2)).with_type(BOOL)) if is_scalar(e1.type): return self.visit(EEscape("java.util.Objects.equals({e1}, {e2})", ["e1", "e2"], [e1, e2]).with_type(BOOL)) return super()._eq(e1, e2)