Exemple #1
0
 def check_wf(e, ctx, pool):
     with task("checking well-formedness", size=e.size()):
         try:
             exp_wf(e,
                    pool=pool,
                    context=ctx,
                    assumptions=self.assumptions,
                    solver=self.wf_solver)
         except ExpIsNotWf as exc:
             return No("at {}: {}".format(
                 pprint(exc.offending_subexpression), exc.reason))
         for (sub, sub_ctx, sub_pool) in shred(e, ctx, pool):
             res = good_idea(self.wf_solver,
                             sub,
                             sub_ctx,
                             sub_pool,
                             assumptions=self.assumptions)
             if not res:
                 return res
         if pool == RUNTIME_POOL and self.cost_model.compare(
                 e, self.targets[0], ctx, pool) == Order.GT:
             # from cozy.cost_model import debug_comparison
             # debug_comparison(self.cost_model, e, self.target, ctx)
             return No("too expensive")
         # if isinstance(e.type, TBag):
         #     c = self.cost_model.cardinality(e)
         #     if all(cc < c for cc in cards):
         #         # print("too big: {}".format(pprint(e)))
         #         return No("too big")
         return True
Exemple #2
0
 def test_let(self):
     e1 = ELet(ZERO, ELambda(x, x))
     root_ctx = RootCtx(args=(), state_vars=())
     assert retypecheck(e1)
     n = 0
     for ee, ctx, pool in shred(e1, root_ctx, RUNTIME_POOL):
         if ee == x:
             e2 = replace(e1, root_ctx, RUNTIME_POOL, x, ctx, pool, ZERO)
             assert e2 == ELet(ZERO, ELambda(x, ZERO))
             n += 1
     assert n == 1
Exemple #3
0
    def test_pool_affects_alpha_equivalence(self):
        e = EMap(EEmptyList().with_type(INT_BAG), ELambda(x, ONE))
        root_ctx = RootCtx(args=(), state_vars=())
        assert retypecheck(e)

        c1 = []
        for ee, ctx, pool in shred(e, root_ctx, RUNTIME_POOL):
            if ee == ONE:
                c1.append(ctx)
        assert len(c1) == 1
        c1 = c1[0]

        c2 = []
        for ee, ctx, pool in shred(e, root_ctx, STATE_POOL):
            if ee == ONE:
                c2.append(ctx)
        assert len(c2) == 1
        c2 = c2[0]

        assert c1 != c2
        assert not c1.alpha_equivalent(c2)
Exemple #4
0
def exp_wf(e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = T, solver = None):
    """
    Returns True or throws exception indicating why `e` is not well-formed.
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in shred(e, context, pool):
        try:
            exp_wf_nonrecursive(solver, x, ctx, p, assumptions=ctx.adapt(assumptions, context))
        except ExpIsNotWf as exc:
            raise ExpIsNotWf(e, x, exc.reason)
    return True
Exemple #5
0
    def test_estatevar_ctx(self):
        xs = EVar("xs").with_type(INT_BAG)
        x = EVar("x").with_type(INT)
        y = EVar("y").with_type(BOOL)
        e = EMap(xs, ELambda(x, EStateVar(y)))
        ctx = RootCtx(args=(xs, ), state_vars=(y, ))
        assert retypecheck(e)
        for ee, ctx, pool in shred(e, ctx):
            if ee == y:
                assert isinstance(ctx, RootCtx)

        e = replace(e, ctx, RUNTIME_POOL, y, ctx, STATE_POOL, T)

        assert e == EMap(xs, ELambda(x, EStateVar(T))), pprint(e)
Exemple #6
0
def good_idea_recursive(solver,
                        e: Exp,
                        context: Context,
                        pool=RUNTIME_POOL,
                        assumptions: Exp = T,
                        ops: [Op] = ()) -> bool:
    for (sub, sub_ctx, sub_pool) in shred(e, context, pool):
        res = good_idea(solver,
                        sub,
                        sub_ctx,
                        sub_pool,
                        assumptions=assumptions,
                        ops=ops)
        if not res:
            return res
    return True
Exemple #7
0
def fold_into_map(e, context):
    fvs = free_vars(e)
    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    for subexp, subcontext, subpool in shred(e, context, RUNTIME_POOL):
        if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar):
            map = subexp.map.e
            key = subexp.key
            key_type = key.type
            value_type = subexp.type
            # e is of the form `... EStateVar(map)[key] ...`
            arg = fresh_var(subexp.type, omit=fvs)
            func = ELambda(
                arg,
                replace(e, context, RUNTIME_POOL, subexp, subcontext, subpool,
                        arg))
            if not all(v in state_vars for v in free_vars(func)):
                continue
            func = strip_EStateVar(func)
            new_map = map_values(map, func.apply_to)
            yield EMapGet(EStateVar(new_map).with_type(new_map.type),
                          key).with_type(e.type)
Exemple #8
0
def exploration_order(targets: [Exp],
                      context: Context,
                      pool: Pool = RUNTIME_POOL):
    """
    What order should subexpressions of the given targets be explored for
    possible improvements?

    Yields (target, subexpression, subcontext, subpool) tuples.
    """

    # current policy (earlier requirements have priority):
    #  - visit runtime expressions first
    #  - visit low-complexity contexts first
    #  - visit small expressions first
    def sort_key(tup):
        e, ctx, p = tup
        return (0 if p == RUNTIME_POOL else 1, ctx.complexity(), e.size())

    for target in targets:
        for e, ctx, p in sorted(unique(shred(target, context, pool=pool)),
                                key=sort_key):
            yield (target, e, ctx, p)
Exemple #9
0
def exp_wf(e: Exp,
           context: Context,
           pool=RUNTIME_POOL,
           assumptions: Exp = T,
           solver=None):
    """Check the well-formedess of `e`.

    Returns True or an instance of ExpIsNotWf that indicates why `e` is not
    well-formed.

    Parameters:
        e - an expression to check
        context - a context describing e's variables
        pool - what pool e lives in
        assumptions - facts that are true whenever e begins executing
            (NOTE: this does NOT need to include the path conditions from the
            context, but it is fine if it does.)
        solver - a ModelCachingSolver to use for solving formulas

    This function requires that:
     - all free variables in `e` are used in the correct pool
     - EStateVar only occurs in runtime expressions
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in shred(e, context, pool):
        is_wf = exp_wf_nonrecursive(solver,
                                    x,
                                    ctx,
                                    p,
                                    assumptions=ctx.adapt(
                                        assumptions, context))
        if not is_wf:
            if isinstance(is_wf, No):
                return ExpIsNotWf(e, x, is_wf.msg)
            return is_wf
    return True
Exemple #10
0
 def test_shred_minheap(self):
     f = ELambda(x, x)
     e = EMakeMinHeap(EEmptyList().with_type(INT_BAG),
                      f).with_type(TMinHeap(INT, f))
     ctx = RootCtx(args=(), state_vars=())
     list(shred(e, ctx))
Exemple #11
0
    def next(self):
        class No(object):
            def __init__(self, msg):
                self.msg = msg

            def __bool__(self):
                return False

            def __str__(self):
                return "no: {}".format(self.msg)

        # with task("pre-computing cardinalities"):
        #     cards = [self.cost_model.cardinality(ctx.e) for ctx in enumerate_fragments(self.target) if is_collection(ctx.e.type)]

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                try:
                    exp_wf(e,
                           pool=pool,
                           context=ctx,
                           assumptions=self.assumptions,
                           solver=self.wf_solver)
                except ExpIsNotWf as exc:
                    return No("at {}: {}".format(
                        pprint(exc.offending_subexpression), exc.reason))
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    # from cozy.cost_model import debug_comparison
                    # debug_comparison(self.cost_model, e, self.target, ctx)
                    return No("too expensive")
                # if isinstance(e.type, TBag):
                #     c = self.cost_model.cardinality(e)
                #     if all(cc < c for cc in cards):
                #         # print("too big: {}".format(pprint(e)))
                #         return No("too big")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback)

        size = 0
        # target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        target_fp = fingerprint(self.targets[0], self.examples)

        if not hasattr(self, "blacklist"):
            self.blacklist = set()

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0
            for target, e, ctx, pool in exploration_order(
                    self.targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if self.stop_callback():
                                raise StopException()
                            if info.e.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(info.e.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(info.e, e):
                                event("no change")
                                continue

                            k = (e, ctx, pool, info.e)
                            if k in self.blacklist:
                                event("blacklisted")
                                continue

                            n += 1
                            ee = freshen_binders(
                                replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                        pool, info.e), root_ctx)
                            if any(
                                    alpha_equivalent(t, ee)
                                    for t in self.targets):
                                event("already seen")
                                continue
                            if not self.matches(fingerprint(ee, self.examples),
                                                target_fp):
                                event("incorrect")
                                self.blacklist.add(k)
                                continue
                            wf = check_wf(ee, root_ctx, RUNTIME_POOL)
                            if not wf:
                                event("not well-formed [wf={}]".format(wf))
                                # if "expensive" in str(wf):
                                #     print(repr(self.cost_model.examples))
                                #     print(repr(ee))
                                self.blacklist.add(k)
                                continue
                            if self.cost_model.compare(
                                    ee, target, root_ctx,
                                    RUNTIME_POOL) not in (Order.LT,
                                                          Order.AMBIGUOUS):
                                event("not an improvement")
                                self.blacklist.add(k)
                                continue
                            print(
                                "FOUND A GUESS AFTER {} CONSIDERED".format(n))
                            yield ee

            print("CONSIDERED {}".format(n))
            size += 1

        raise NoMoreImprovements()
Exemple #12
0
    def search(self):

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                is_wf = exp_wf(e,
                               pool=pool,
                               context=ctx,
                               solver=self.wf_solver)
                if not is_wf:
                    return is_wf
                res = good_idea_recursive(self.wf_solver,
                                          e,
                                          ctx,
                                          pool,
                                          ops=self.ops)
                if not res:
                    return res
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    return No("too expensive")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback,
                          do_eviction=enable_eviction.value)

        size = 0
        target_fp = fingerprint(self.targets[0], self.examples)

        watches = OrderedDict()
        for target in self.targets:
            for e, ctx, pool in unique(
                    shred(target, context=root_ctx, pool=RUNTIME_POOL)):
                exs = ctx.instantiate_examples(self.examples)
                fp = fingerprint(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))
        watched_ctxs = list(
            unique((ctx, pool) for fp, ctx, pool in watches.keys()))

        def consider_new_target(old_target, e, ctx, pool, replacement):
            nonlocal n
            n += 1
            k = (e, ctx, pool, replacement)
            if enable_blacklist.value and k in self.blacklist:
                event("blacklisted")
                print("skipping blacklisted substitution: {} ---> {} ({})".
                      format(pprint(e), pprint(replacement),
                             self.blacklist[k]))
                return
            new_target = freshen_binders(
                replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool,
                        replacement), root_ctx)
            if any(alpha_equivalent(t, new_target) for t in self.targets):
                event("already seen")
                return
            wf = check_wf(new_target, root_ctx, RUNTIME_POOL)
            if not wf:
                msg = "not well-formed [wf={}]".format(wf)
                event(msg)
                self.blacklist[k] = msg
                return
            if not fingerprints_match(fingerprint(new_target, self.examples),
                                      target_fp):
                msg = "not correct"
                event(msg)
                self.blacklist[k] = msg
                return
            if self.cost_model.compare(new_target, target, root_ctx,
                                       RUNTIME_POOL) not in (Order.LT,
                                                             Order.AMBIGUOUS):
                msg = "not an improvement"
                event(msg)
                self.blacklist[k] = msg
                return
            print("FOUND A GUESS AFTER {} CONSIDERED".format(n))
            print(" * in {}".format(pprint(old_target), pprint(e),
                                    pprint(replacement)))
            print(" * replacing {}".format(pprint(e)))
            print(" * with {}".format(pprint(replacement)))
            yield new_target

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0

            for ctx, pool in watched_ctxs:
                with task("searching for obvious substitutions",
                          ctx=ctx,
                          pool=pool_name(pool)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("searching for obvious substitution",
                                  expression=pprint(info.e)):
                            fp = info.fingerprint
                            for ((fpx, cc, pp), reses) in watches.items():
                                if cc != ctx or pp != pool:
                                    continue

                                if not fingerprints_match(fpx, fp):
                                    continue

                                for target, watched_e in reses:
                                    replacement = info.e
                                    event("possible substitution: {} ---> {}".
                                          format(pprint(watched_e),
                                                 pprint(replacement)))
                                    event("replacement locations: {}".format(
                                        pprint(
                                            replace(target, root_ctx,
                                                    RUNTIME_POOL, watched_e,
                                                    ctx, pool, EVar("___")))))

                                    if alpha_equivalent(
                                            watched_e, replacement):
                                        event("no change")
                                        continue

                                    yield from consider_new_target(
                                        target, watched_e, ctx, pool,
                                        replacement)

            if check_all_substitutions.value:
                print("Guessing at substitutions...")
                for target, e, ctx, pool in exploration_order(
                        self.targets, root_ctx):
                    with task("checking substitutions",
                              target=pprint(
                                  replace(target, root_ctx, RUNTIME_POOL, e,
                                          ctx, pool, EVar("___"))),
                              e=pprint(e)):
                        for info in enum.enumerate_with_info(size=size,
                                                             context=ctx,
                                                             pool=pool):
                            with task("checking substitution",
                                      expression=pprint(info.e)):
                                if self.stop_callback():
                                    raise StopException()
                                replacement = info.e
                                if replacement.type != e.type:
                                    event("wrong type (is {}, need {})".format(
                                        pprint(replacement.type),
                                        pprint(e.type)))
                                    continue
                                if alpha_equivalent(replacement, e):
                                    event("no change")
                                    continue
                                should_consider = should_consider_replacement(
                                    target, root_ctx, e, ctx, pool,
                                    fingerprint(
                                        e,
                                        ctx.instantiate_examples(
                                            self.examples)), info.e,
                                    info.fingerprint)
                                if not should_consider:
                                    event(
                                        "skipped; `should_consider_replacement` returned {}"
                                        .format(should_consider))
                                    continue

                                yield from consider_new_target(
                                    target, e, ctx, pool, replacement)

            print("CONSIDERED {}".format(n))
            size += 1