Пример #1
0
 def construct_concrete(self, t: Type, e: Exp, out: Exp):
     """
     Construct a value of type `t` from the expression `e` and store it in
     lvalue `out`.
     """
     if hasattr(t, "construct_concrete"):
         return t.construct_concrete(e, out)
     elif isinstance(t, TBag) or isinstance(t, TList):
         assert out not in free_vars(e)
         x = self.fv(t.t, "x")
         return SSeq(self.initialize_native_list(out),
                     SForEach(x, e, SCall(out, "add", [x])))
     elif isinstance(t, TSet):
         if isinstance(e, EUnaryOp) and e.op == UOp.Distinct:
             return self.construct_concrete(t, e.e, out)
         x = self.fv(t.t, "x")
         return SSeq(self.initialize_native_set(out),
                     SForEach(x, e, SCall(out, "add", [x])))
     elif isinstance(t, TMap):
         return SSeq(self.initialize_native_map(out),
                     self.construct_map(t, e, out))
     elif isinstance(t, THandle):
         return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"],
                        [out, e])
     elif is_scalar(t):
         return SEscape("{indent}{lhs} = {rhs};\n", ["lhs", "rhs"],
                        [out, e])
     else:
         h = extension_handler(type(t))
         if h is not None:
             return h.codegen(e, self.state_exps, out=out)
         raise NotImplementedError(t, e, out)
Пример #2
0
 def _eq(self, e1, e2, indent):
     if isinstance(e1.type, THandle):
         return self.visit(
             EEscape("({e1} == {e2})", ["e1", "e2"],
                     [self.addr_of(e1), self.addr_of(e2)]).with_type(BOOL),
             indent)
     if (is_scalar(e1.type) or (isinstance(e1.type, library.TNativeMap)
                                and isinstance(e2.type, library.TNativeMap))
             or (isinstance(e1.type, library.TNativeSet)
                 and isinstance(e2.type, library.TNativeSet))
             or (isinstance(e1.type, library.TNativeList)
                 and isinstance(e2.type, library.TNativeList))):
         return self.visit(
             EEscape("({e1} == {e2})", ["e1", "e2"],
                     [e1, e2]).with_type(BOOL), indent)
     elif isinstance(e1.type, TSet) and isinstance(e2.type, TSet):
         raise NotImplementedError("set equality")
     elif isinstance(e1.type, TBag) or isinstance(e2.type, TBag):
         setup1, v1 = self.histogram(e1, indent)
         setup2, v2 = self.histogram(e2, indent)
         setup3, res = self._eq(v1, v2, indent)
         return (setup1 + setup2 + setup3, res)
     elif isinstance(e1.type, TMap) or isinstance(e2.type, TMap):
         raise NotImplementedError("map equality")
     else:
         raise NotImplementedError((e1.type, e2.type))
Пример #3
0
 def visit_SAssign(self, s):
     if is_scalar(s.rhs.type):
         self.write_stmt(self.visit(s.lhs), " = ", self.visit(s.rhs), ";")
     else:
         v = self.fv(s.lhs.type)
         self.declare(v, s.rhs)
         self.write_stmt(self.visit(s.lhs), " = ",
                         self.visit(EMove(v).with_type(v.type)), ";")
Пример #4
0
 def declare(self, v: EVar, initial_value: Exp = None):
     if initial_value is not None and is_scalar(v.type):
         iv = self.visit(initial_value)
         self.write_stmt(self.visit(v.type, v.id), " = ", iv, ";")
     else:
         self.write_stmt(self.visit(v.type, v.id), ";")
         if initial_value is not None:
             self.visit(self.construct_concrete(v.type, initial_value, v))
Пример #5
0
 def _eq(self, e1, e2, indent):
     if not self.boxed and self.is_primitive(e1.type):
         return self.visit(
             EEscape("({e1} == {e2})", ("e1", "e2"),
                     (e1, e2)).with_type(BOOL), indent)
     if (is_scalar(e1.type) or (isinstance(e1.type, library.TNativeMap)
                                and isinstance(e2.type, library.TNativeMap))
             or (isinstance(e1.type, library.TNativeSet)
                 and isinstance(e2.type, library.TNativeSet))
             or (isinstance(e1.type, library.TNativeList)
                 and isinstance(e2.type, library.TNativeList))):
         return self.visit(
             EEscape("java.util.Objects.equals({e1}, {e2})", ["e1", "e2"],
                     [e1, e2]).with_type(BOOL), indent)
     return super()._eq(e1, e2, indent)
Пример #6
0
    def build(self, cache, size):
        # print("Cache:")
        # for (e, sz, pool) in cache:
        #     from cozy.syntax_tools import pprint
        #     print("    @size={}, pool={}\t:\t{}".format(sz, pool, pprint(e)))
        binders_by_type = group_by(self.binders, lambda b: b.type)

        for pool in ALL_POOLS:
            if size == 1:
                yield self.check(T, pool)
                yield self.check(F, pool)
                yield self.check(ZERO, pool)
                yield self.check(ONE, pool)
                for b in self.binders:
                    yield self.check(b, pool)
                if pool == STATE_POOL:
                    for v in self.state_vars:
                        yield self.check(v, pool)
                elif pool == RUNTIME_POOL:
                    for v in self.args:
                        yield self.check(v, pool)

            if not build_exprs.value:
                return

            for e in cache.find(pool=STATE_POOL, size=size - 1):
                if all(v in self.state_vars for v in free_vars(e)):
                    yield self.check(
                        EStateVar(e).with_type(e.type), RUNTIME_POOL)

            for e in cache.find(pool=pool, size=size - 1):
                t = TBag(e.type)
                yield self.check(EEmptyList().with_type(t), pool)
                yield self.check(ESingleton(e).with_type(t), pool)

            for e in cache.find(pool=pool, type=TRecord, size=size - 1):
                for (f, t) in e.type.fields:
                    yield self.check(EGetField(e, f).with_type(t), pool)
            for e in cache.find_collections(pool=pool, size=size - 1):
                if is_numeric(e.type.t):
                    yield self.check(
                        EUnaryOp(UOp.Sum, e).with_type(e.type.t), pool)
            for e in cache.find(pool=pool, type=THandle, size=size - 1):
                yield self.check(
                    EGetField(e, "val").with_type(e.type.value_type), pool)
            for e in cache.find(pool=pool, type=TTuple, size=size - 1):
                for n in range(len(e.type.ts)):
                    yield self.check(
                        ETupleGet(e, n).with_type(e.type.ts[n]), pool)
            for e in cache.find(pool=pool, type=BOOL, size=size - 1):
                yield self.check(EUnaryOp(UOp.Not, e).with_type(BOOL), pool)
            for e in cache.find(pool=pool, type=INT, size=size - 1):
                yield self.check(EUnaryOp("-", e).with_type(INT), pool)

            for m in cache.find(pool=pool, type=TMap, size=size - 1):
                yield self.check(EMapKeys(m).with_type(TBag(m.type.k)), pool)

            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for a1 in cache.find(pool=pool, size=sz1):
                    if not is_numeric(a1.type):
                        continue
                    for a2 in cache.find(pool=pool, type=a1.type, size=sz2):
                        yield self.check(
                            EBinOp(a1, "+", a2).with_type(INT), pool)
                        yield self.check(
                            EBinOp(a1, "-", a2).with_type(INT), pool)
                        yield self.check(
                            EBinOp(a1, ">", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, "<", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, ">=", a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, "<=", a2).with_type(BOOL), pool)
                for a1 in cache.find_collections(pool=pool, size=sz1):
                    for a2 in cache.find(pool=pool, type=a1.type, size=sz2):
                        yield self.check(
                            EBinOp(a1, "+", a2).with_type(a1.type), pool)
                        yield self.check(
                            EBinOp(a1, "-", a2).with_type(a1.type), pool)
                    for a2 in cache.find(pool=pool, type=a1.type.t, size=sz2):
                        yield self.check(
                            EBinOp(a2, BOp.In, a1).with_type(BOOL), pool)
                for a1 in cache.find(pool=pool, type=BOOL, size=sz1):
                    for a2 in cache.find(pool=pool, type=BOOL, size=sz2):
                        yield self.check(
                            EBinOp(a1, BOp.And, a2).with_type(BOOL), pool)
                        yield self.check(
                            EBinOp(a1, BOp.Or, a2).with_type(BOOL), pool)
                for a1 in cache.find(pool=pool, size=sz1):
                    if not isinstance(a1.type, TMap):
                        for a2 in cache.find(pool=pool, type=a1.type,
                                             size=sz2):
                            yield self.check(EEq(a1, a2), pool)
                            yield self.check(
                                EBinOp(a1, "!=", a2).with_type(BOOL), pool)
                for m in cache.find(pool=pool, type=TMap, size=sz1):
                    for k in cache.find(pool=pool, type=m.type.k, size=sz2):
                        yield self.check(
                            EMapGet(m, k).with_type(m.type.v), pool)
                        yield self.check(EHasKey(m, k).with_type(BOOL), pool)

            for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
                for cond in cache.find(pool=pool, type=BOOL, size=sz1):
                    for then_branch in cache.find(pool=pool, size=sz2):
                        for else_branch in cache.find(pool=pool,
                                                      size=sz3,
                                                      type=then_branch.type):
                            yield self.check(
                                ECond(cond, then_branch,
                                      else_branch).with_type(then_branch.type),
                                pool)

            for bag in cache.find_collections(pool=pool, size=size - 1):
                # len of bag
                count = EUnaryOp(UOp.Length, bag).with_type(INT)
                yield self.check(count, pool)
                # empty?
                yield self.check(
                    EUnaryOp(UOp.Empty, bag).with_type(BOOL), pool)
                # exists?
                yield self.check(
                    EUnaryOp(UOp.Exists, bag).with_type(BOOL), pool)
                # singleton?
                yield self.check(EEq(count, ONE), pool)

                yield self.check(
                    EUnaryOp(UOp.The, bag).with_type(bag.type.t), pool)
                yield self.check(
                    EUnaryOp(UOp.Distinct, bag).with_type(bag.type), pool)
                yield self.check(
                    EUnaryOp(UOp.AreUnique, bag).with_type(BOOL), pool)

                if bag.type.t == BOOL:
                    yield self.check(
                        EUnaryOp(UOp.Any, bag).with_type(BOOL), pool)
                    yield self.check(
                        EUnaryOp(UOp.All, bag).with_type(BOOL), pool)

            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for bag in cache.find_collections(pool=pool, size=sz1):
                    for binder in binders_by_type[bag.type.t]:
                        for body in itertools.chain(
                                cache.find(pool=pool, size=sz2), (binder, )):
                            yield self.check(
                                EMap(bag,
                                     ELambda(binder,
                                             body)).with_type(TBag(body.type)),
                                pool)
                            if body.type == BOOL:
                                yield self.check(
                                    EFilter(bag,
                                            ELambda(binder,
                                                    body)).with_type(bag.type),
                                    pool)
                            if body.type == INT:
                                yield self.check(
                                    EArgMin(bag, ELambda(
                                        binder, body)).with_type(bag.type.t),
                                    pool)
                                yield self.check(
                                    EArgMax(bag, ELambda(
                                        binder, body)).with_type(bag.type.t),
                                    pool)
                            if pool == RUNTIME_POOL and isinstance(
                                    body.type, TBag):
                                yield self.check(
                                    EFlatMap(bag,
                                             ELambda(binder, body)).with_type(
                                                 TBag(body.type.t)), pool)

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for bag in cache.find_collections(pool=STATE_POOL, size=sz1):
                if not is_scalar(bag.type.t):
                    continue
                for b in binders_by_type[bag.type.t]:
                    for val in cache.find(pool=STATE_POOL, size=sz2):
                        t = TMap(bag.type.t, val.type)
                        m = EMakeMap2(bag, ELambda(b, val)).with_type(t)
                        yield self.check(m, STATE_POOL)
Пример #7
0
 def _is_concrete(self, e):
     if is_scalar(e.type):
         return True
     elif type(e.type) in [TMap, TSet, TBag]:
         return False
     return True
Пример #8
0
    def run(self):
        print("STARTING IMPROVEMENT JOB {} (|examples|={})".format(
            self.q.name, len(self.examples or ())))
        os.makedirs(log_dir.value, exist_ok=True)
        with open(os.path.join(log_dir.value, "{}.log".format(self.q.name)),
                  "w",
                  buffering=LINE_BUFFER_MODE) as f:
            sys.stdout = f
            print("STARTING IMPROVEMENT JOB {} (|examples|={})".format(
                self.q.name, len(self.examples or ())))
            print(pprint(self.q))

            if nice_children.value:
                os.nice(20)

            all_types = self.ctx.all_types
            n_binders = 1
            done = False
            expr = ETuple(
                (EAll(self.assumptions),
                 self.q.ret)).with_type(TTuple((BOOL, self.q.ret.type)))
            while not done:
                binders = []
                for t in all_types:
                    # if isinstance(t, TBag):
                    #     binders += [fresh_var(t.t) for i in range(n_binders)]
                    for i in range(n_binders):
                        b = fresh_var(t)
                        binders.append(b)
                try:
                    core.fixup_binders(expr, binders, throw=True)
                    done = True
                except:
                    pass
                n_binders += 1

            binders = [
                fresh_var(t) for t in all_types if is_scalar(t)
                for i in range(n_binders)
            ]
            print("Using {} binders".format(n_binders))
            relevant_state_vars = [
                v for v in self.state if v in free_vars(EAll(self.assumptions))
                | free_vars(self.q.ret)
            ]
            used_vars = free_vars(self.q.ret)
            for a in self.q.assumptions:
                used_vars |= free_vars(a)
            args = [EVar(v).with_type(t) for (v, t) in self.q.args]
            args = [a for a in args if a in used_vars]
            b = BinderBuilder(binders, relevant_state_vars, args)
            if accelerate.value:
                b = AcceleratedBuilder(b, binders, relevant_state_vars, args)

            try:
                for expr in itertools.chain(
                    (self.q.ret, ),
                        core.improve(
                            target=self.q.ret,
                            assumptions=EAll(self.assumptions),
                            hints=self.hints,
                            examples=self.examples,
                            binders=binders,
                            state_vars=relevant_state_vars,
                            args=args,
                            cost_model=CompositeCostModel(),
                            builder=b,
                            stop_callback=lambda: self.stop_requested)):

                    new_rep, new_ret = tease_apart(expr)
                    self.k(new_rep, new_ret)
                print("PROVED OPTIMALITY FOR {}".format(self.q.name))
            except core.StopException:
                print("stopping synthesis of {}".format(self.q.name))
                return
Пример #9
0
def improve_implementation(impl: Implementation,
                           timeout: datetime.timedelta = datetime.timedelta(
                               seconds=60),
                           progress_callback=None) -> Implementation:

    start_time = datetime.datetime.now()

    # we statefully modify `impl`, so let's make a defensive copy
    impl = Implementation(impl.spec, list(impl.concrete_state),
                          list(impl.query_specs),
                          OrderedDict(impl.query_impls),
                          defaultdict(SNoOp, impl.updates),
                          defaultdict(SNoOp, impl.handle_updates))

    # gather root types
    types = list(all_types(impl.spec))
    basic_types = set(t for t in types if is_scalar(t))
    basic_types |= {BOOL, INT}
    print("basic types:")
    for t in basic_types:
        print("  --> {}".format(pprint(t)))
    basic_types = list(basic_types)
    ctx = SynthCtx(all_types=types, basic_types=basic_types)

    # the actual worker threads
    improvement_jobs = []

    with jobs.SafeQueue() as solutions_q:

        def stop_jobs(js):
            js = list(js)
            jobs.stop_jobs(js)
            for j in js:
                improvement_jobs.remove(j)

        def reconcile_jobs():
            # figure out what new jobs we need
            job_query_names = set(j.q.name for j in improvement_jobs)
            new = []
            for q in impl.query_specs:
                if q.name not in job_query_names:
                    new.append(
                        ImproveQueryJob(
                            ctx,
                            impl.abstract_state,
                            list(impl.spec.assumptions) + list(q.assumptions),
                            q,
                            k=(lambda q: lambda new_rep, new_ret: solutions_q.
                               put((q, new_rep, new_ret)))(q),
                            hints=[
                                EStateVar(c).with_type(c.type) for c in
                                impl.concretization_functions.values()
                            ]))

            # figure out what old jobs we can stop
            impl_query_names = set(q.name for q in impl.query_specs)
            old = [
                j for j in improvement_jobs if j.q.name not in impl_query_names
            ]

            # make it so
            stop_jobs(old)
            for j in new:
                j.start()
            improvement_jobs.extend(new)

        # start jobs
        reconcile_jobs()

        # wait for results
        timeout = Timeout(timeout)
        done = False
        while not done and not timeout.is_timed_out():
            for j in improvement_jobs:
                if j.done:
                    if j.successful:
                        j.join()
                    else:
                        print("failed job: {}".format(j), file=sys.stderr)
                        # raise Exception("failed job: {}".format(j))

            done = all(j.done for j in improvement_jobs)

            try:
                # list of (Query, new_rep, new_ret) objects
                results = solutions_q.drain(block=True, timeout=0.5)
            except Empty:
                continue

            # group by query name, favoring later (i.e. better) solutions
            print("updating with {} new solutions".format(len(results)))
            improved_queries_by_name = OrderedDict()
            killed = 0
            for r in results:
                q, new_rep, new_ret = r
                if q.name in improved_queries_by_name:
                    killed += 1
                improved_queries_by_name[q.name] = r
            if killed:
                print(" --> dropped {} worse solutions".format(killed))

            improvements = list(improved_queries_by_name.values())

            def index_of(l, p):
                if not isinstance(l, list):
                    l = list(l)
                for i in range(len(l)):
                    if p(l[i]):
                        return i
                return -1

            improvements.sort(key=lambda i: index_of(
                impl.query_specs, lambda qq: qq.name == i[0].name))
            print("update order:")
            for (q, _, _) in improvements:
                print("  --> {}".format(q.name))

            # update query implementations
            i = 1
            for (q, new_rep, new_ret) in improvements:
                print("considering update {}/{}...".format(
                    i, len(improvements)))
                i += 1
                # this guard might be false if a better solution was
                # enqueued but the job has already been cleaned up
                if q.name in [qq.name for qq in impl.query_specs]:
                    elapsed = datetime.datetime.now() - start_time
                    print("SOLUTION FOR {} AT {} [size={}]".format(
                        q.name, elapsed,
                        new_ret.size() + sum(proj.size()
                                             for (v, proj) in new_rep)))
                    print("-" * 40)
                    for (sv, proj) in new_rep:
                        print("  {} : {} = {}".format(sv.id, pprint(sv.type),
                                                      pprint(proj)))
                    print("  return {}".format(pprint(new_ret)))
                    print("-" * 40)
                    impl.set_impl(q, new_rep, new_ret)

                    # clean up
                    impl.cleanup()
                    if progress_callback is not None:
                        progress_callback(
                            (impl, impl.code, impl.concretization_functions))
                    reconcile_jobs()

        # stop jobs
        print("Stopping jobs")
        stop_jobs(list(improvement_jobs))
        return impl
Пример #10
0
 def _eq(self, e1, e2):
     if not self.boxed and self.is_primitive(e1.type):
         return self.visit(EEscape("({e1} == {e2})", ("e1", "e2"), (e1, e2)).with_type(BOOL))
     if is_scalar(e1.type):
         return self.visit(EEscape("java.util.Objects.equals({e1}, {e2})", ["e1", "e2"], [e1, e2]).with_type(BOOL))
     return super()._eq(e1, e2)