Ejemplo n.º 1
0
 def watch(self, new_target):
     print("watching new target...")
     self.backlog_counter = 0
     self.target = new_target
     self.roots = OrderedSet()
     types = OrderedSet()
     for e in itertools.chain(all_exps(new_target), *[all_exps(h) for h in self.hints]):
         if isinstance(e, ELambda):
             continue
         for pool in ALL_POOLS:
             exp = e
             if pool == STATE_POOL:
                 exp = strip_EStateVar(e)
             fvs = free_vars(exp)
             if all(v in self.legal_free_vars for v in fvs) and self.is_legal_in_pool(exp, pool):
                 _on_exp(exp, "new root", pool_name(pool))
                 exp._root = True
                 self.roots.add((exp, pool))
                 if pool == STATE_POOL and all(v in self.state_vars for v in fvs):
                     self.roots.add((EStateVar(exp).with_type(exp.type), RUNTIME_POOL))
                 types.add(exp.type)
             else:
                 _on_exp(exp, "rejected root", pool_name(pool))
     for b in self.binders:
         types.add(b.type)
     for t in types:
         self.roots.add((construct_value(t), RUNTIME_POOL))
     self.roots = list(self.roots)
     self.roots.sort(key = lambda tup: tup[0].size())
     self._watches = group_by(
         enumerate_fragments2(new_target),
         k=lambda ctx: (ctx.pool, ctx.e.type),
         v=lambda ctxs: sorted(ctxs, key=lambda ctx: -ctx.e.size()))
     print("done!")
Ejemplo n.º 2
0
def _consider(e, size, context, pool):
    if _interesting(e, size, context, pool) and not verbose.value:
        print("considering {} @ size={} in {}/{}".format(
            pprint(e), size, context, pool_name(pool)))
    task_begin("considering expression",
               expression=pprint(e),
               size=size,
               context=context,
               pool=pool_name(pool),
               interesting=_interesting(e, size, context, pool))
Ejemplo n.º 3
0
def _consider(e, size, context, pool):
    """Called when an Enumerator sees an expression for the first time."""
    if _interesting(e, size, context, pool) and not verbose.value:
        print("considering {} @ size={} in {}/{}".format(
            pprint(e), size, context, pool_name(pool)))
    task_begin("considering expression",
               expression=pprint(e),
               size=size,
               context=context,
               pool=pool_name(pool),
               interesting=_interesting(e, size, context, pool))
Ejemplo n.º 4
0
def _accept(e, size, context, pool, fingerprint):
    """Called when an Enumerator "accepts" an expression and adds it to the cache."""
    if _interesting(e, size, context, pool) and not verbose.value:
        print("accepting [fp={}]".format(fingerprint))
    event("accepting {} @ {} in {}/{}".format(pprint(e), size, context,
                                              pool_name(pool)))
    task_end()
Ejemplo n.º 5
0
    def _possible_replacements(self, e, pool, cost):
        """
        Yields watched expressions that appear as worse versions of the given
        expression. There may be more than one.
        """
        # return
        free_binders = OrderedSet(v for v in free_vars(e) if v in self.binders)
        for ctx in self._watched_contexts(pool, e.type):
            watched_e = ctx.e
            p = ctx.pool
            r = ctx.replace_e_with

            assert e.type == watched_e.type
            assert p == pool
            _on_exp(e, "considering replacement of", watched_e)
            # if e.type != watched_e.type:
            #     # _on_exp(e, "wrong type")
            #     continue
            # if p != pool:
            #     # _on_exp(e, "wrong pool")
            #     continue
            if e == watched_e:
                # _on_exp(e, "no change")
                continue
            unbound_binders = [b for b in free_binders if b not in ctx.bound_vars]
            if unbound_binders:
                _on_exp(e, "skipped exp with free binders", ", ".join(b.id for b in unbound_binders))
                continue
            if CHECK_SUBST_COST:
                watched_cost = self.cost_model.cost(watched_e, pool=pool)
                ordering = self.compare_costs(cost, watched_cost)
                if ordering == Cost.WORSE:
                    _on_exp(e, "skipped worse replacement", pool_name(pool), watched_e)
                    continue
                if ordering == Cost.UNORDERED:
                    _on_exp(e, "skipped equivalent replacement", pool_name(pool), watched_e)
                    # print("    e1 = {!r}".format(e))
                    # print("    e2 = {!r}".format(watched_e))
                    continue
            # assert all(eval_bulk(self.assumptions, self.all_examples))
            if all(eval_bulk(EEq(self.target, r(e)), self.all_examples)):
                yield (watched_e, e, ctx.facts, r)
            else:
                _on_exp(e, "visited pointless replacement", watched_e)
Ejemplo n.º 6
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        if context.parent() is not None:
            yield from self.enumerate_with_info(context.parent(), size, pool)

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            # print("[[{} cached @ size={}]]".format(len(res), size))
            for e in res:
                yield e
        else:
            # print("ENTER {}".format(k))
            examples = context.instantiate_examples(self.examples)
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression,
                    # decide which can be evicted
                    should_keep = True
                    # cost = self.cost_model.cost(e, pool)
                    # print("prev={}".format(prev))
                    # print("seen={}".format(self.seen))
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            # prev_cost = self.cost_model.cost(prev_exp, pool)
                            # ordering = cost.compare_to(prev_cost)
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                            # if ordering == Order.LT:
                            #     pass
                            # elif ordering == Order.GT:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "worse than {}".format(pprint(prev_exp)))
                            #     should_keep = False
                            #     break
                            # else:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "{} to cached {}".format(
                            #         "equal" if ordering == Order.EQUAL else "similar",
                            #         pprint(prev_exp)))
                            #     assert ordering in (Order.EQUAL, Order.AMBIGUOUS)
                            #     should_keep = False
                            #     break

                if should_keep:

                    with task("evicting"):
                        to_evict = []
                        for (key, exps) in self.cache.items():
                            (p, s, c) = key
                            if p == pool and c in itertools.chain(
                                [context], parent_contexts(context)):
                                for ee in exps:
                                    if ee.fingerprint == fp:  # and cost_model.compare(e, ee.e, context, pool) == Order.LT:
                                        # to_evict.append((key, ee))
                                        to_keep = eviction_policy(
                                            e, context, ee.e, c, pool,
                                            cost_model)
                                        if ee.e not in to_keep:
                                            to_evict.append((key, ee))
                        for key, ee in to_evict:
                            (p, s, c) = key
                            # self.blacklist.add((ee.e, c, pool))
                            _evict(ee.e, c, pool, e)
                            self.cache[key].remove(ee)
                            self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp, cost=None)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            # print("trying {} accelerations".format(len(to_try)))
                            queue = itertools.chain(to_try, queue)

            # print("EXIT {}".format(k))
            self.in_progress.remove(k)
Ejemplo n.º 7
0
def _accept(e, size, context, pool):
    if _interesting(e, size, context, pool) and not verbose.value:
        print("accepting")
    event("accepting {} @ {} in {}/{}".format(pprint(e), size, context,
                                              pool_name(pool)))
    task_end()
Ejemplo n.º 8
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        examples = context.instantiate_examples(self.examples)
        if context.parent() is not None:
            for info in self.enumerate_with_info(context.parent(), size, pool):
                e = info.e
                yield EnumeratedExp(e=e, fingerprint=fingerprint(e, examples))

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            for e in res:
                yield e
        else:
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, size, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, size, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, size, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression
                    should_keep = True
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                if should_keep:

                    if self.do_eviction:
                        with task("evicting"):
                            to_evict = []
                            for (key, exps) in self.cache.items():
                                (p, s, c) = key
                                if p == pool and c == context:
                                    for ee in exps:
                                        if ee.fingerprint == fp:
                                            event("considering eviction of {}".
                                                  format(pprint(ee.e)))
                                            to_keep = eviction_policy(
                                                e, context, ee.e, c, pool,
                                                cost_model)
                                            if ee.e not in to_keep:
                                                to_evict.append((key, ee))
                            for key, ee in to_evict:
                                (p, s, c) = key
                                _evict(ee.e, s, c, pool, e)
                                self.cache[key].remove(ee)
                                self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, size, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)

            self.in_progress.remove(k)
Ejemplo n.º 9
0
    def next(self):
        target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        self.ncount += 1
        while True:
            if self.backlog is not None:
                if self.stop_callback():
                    raise StopException()
                (e, pool, cost) = self.backlog
                improvements = list(self._possible_replacements(e, pool, cost))
                if self.backlog_counter < len(improvements):
                    i = improvements[self.backlog_counter]
                    self.backlog_counter += 1
                    return i
                else:
                    self.backlog = None
                    self.backlog_counter = 0
            for (e, pool) in self.builder_iter:
                self._on_exp(e, pool)
                if self.stop_callback():
                    raise StopException()

                # # Stopgap measure... long story --Calvin
                # bad = False
                # for x in all_exps(e):
                #     if isinstance(x, EStateVar):
                #         if any(v not in self.state_vars for v in free_vars(x.e)):
                #             bad = True
                #             _on_exp(e, "skipping due to illegal free vars under EStateVar")
                # if bad:
                #     continue

                new_e = self.pre_optimize(e, pool) if preopt.value else e
                if new_e is not e:
                    _on_exp(e, "preoptimized", new_e)
                    e = new_e

                cost = self.cost_model.cost(e, pool)

                if pool == RUNTIME_POOL and (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and self.compare_costs(cost, target_cost) == Cost.WORSE:
                    _on_exp(e, "too expensive", cost, target_cost)
                    continue

                fp = self._fingerprint(e)
                prev = list(self.seen.find_all(pool, fp))
                should_add = True
                if not prev:
                    _on_exp(e, "new", pool_name(pool))
                elif any(alpha_equivalent(e, ee) for (ee, _, _) in prev):
                    _on_exp(e, "duplicate")
                    should_add = False
                else:
                    better_than = None
                    worse_than = None
                    for prev_exp, prev_size, prev_cost in prev:
                        self._on_cost_cmp()
                        ordering = self.compare_costs(cost, prev_cost)
                        assert ordering in (Cost.WORSE, Cost.BETTER, Cost.UNORDERED)
                        if enforce_strong_progress.value and ordering != Cost.WORSE:
                            bad = find_one(all_exps(e), lambda ee: alpha_equivalent(ee, prev_exp))
                            if bad:
                                _on_exp(e, "failed strong progress requirement", bad)
                                should_add = False
                                break
                        _on_exp(e, ordering, pool_name(pool), prev_exp)
                        if ordering == Cost.UNORDERED:
                            continue
                        elif ordering == Cost.BETTER:
                            better_than = (prev_exp, prev_size, prev_cost)
                            _on_exp(prev_exp, "found better alternative", e)
                            self.cache.evict(prev_exp, size=prev_size, pool=pool)
                            self.seen.remove(prev_exp, pool, fp)
                            if (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and hyperaggressive_eviction.value:
                                for (cached_e, size, p) in list(self.cache):
                                    if p != pool:
                                        continue
                                    if prev_exp in all_exps(cached_e):
                                        _on_exp(cached_e, "evicted since it contains", prev_exp)
                                        self.cache.evict(cached_e, size=size, pool=pool)
                        else:
                            should_add = False
                            worse_than = (prev_exp, prev_size, prev_cost)
                            # break
                    if worse_than and better_than:
                        print("Uh-oh! Strange cost relationship between")
                        print("  (1) this exp: {}".format(pprint(e)))
                        print("  (2) prev. A:  {}".format(pprint(worse_than[0])))
                        print("  (2) prev. B:  {}".format(pprint(better_than[0])))
                        print("e1 = {}".format(repr(e)))
                        print("e2 = {}".format(repr(worse_than[0])))
                        print("e3 = {}".format(repr(better_than[0])))
                        print("(1) vs (2): {}".format(cost.compare_to(worse_than[2], self.assumptions)))
                        print("(2) vs (3): {}".format(worse_than[2].compare_to(better_than[2], self.assumptions)))
                        print("(3) vs (1): {}".format(better_than[2].compare_to(cost, self.assumptions)))
                        # raise Exception("insane cost model behavior")

                if should_add:
                    self.cache.add(e, pool=pool, size=self.current_size)
                    self.seen.add(e, pool, fp, self.current_size, cost)
                    self.last_progress = self.current_size
                else:
                    continue

                for pr in self._possible_replacements(e, pool, cost):
                    self.backlog = (e, pool, cost)
                    self.backlog_counter = 1
                    return pr

            if self.last_progress < (self.current_size+1) // 2:
                raise NoMoreImprovements("hit termination condition")

            self.current_size += 1
            self.builder_iter = self.builder.build(self.cache, self.current_size)
            if self.current_size == 0:
                self.builder_iter = itertools.chain(self.builder_iter, list(self.roots))
            for f, ct in sorted(_fates.items(), key=lambda x: x[1], reverse=True):
                print("  {:6} | {}".format(ct, f))
            _fates.clear()
            self._start_minor_it()
Ejemplo n.º 10
0
def search_for_improvements(targets: [Exp], wf_solver: ModelCachingSolver,
                            context: Context, examples: [{
                                str: object
                            }], cost_model: CostModel,
                            stop_callback: Callable[[], bool], hints: [Exp],
                            ops: [Op], blacklist: {
                                (Exp, Context, Pool, Exp): str
                            }):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context

    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(
                    e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(
            unique(
                itertools.chain(
                    *[
                        all_subexpressions_with_context_information(
                            t, root_ctx) for t in targets
                    ], *[
                        all_subexpressions_with_context_information(
                            h, root_ctx) for h in hints
                    ])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=examples,
                          cost_model=cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=stop_callback,
                          do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(
                    all_subexpressions_with_context_information(
                        target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(
            unique(
                (ctx, pool)
                for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(context=root_ctx,
                             targets=targets,
                             target_fingerprint=target_fp,
                             examples=examples,
                             check_wf=check_wf,
                             cost_model=cost_model,
                             blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(
            size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions",
                      ctx=ctx,
                      pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size,
                                                     context=ctx,
                                                     pool=pool):
                    with task("searching for obvious substitution",
                              expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event(
                                    "possible substitution: {} ---> {}".format(
                                        pprint(watched_e),
                                        pprint(replacement)))
                                event("replacement locations: {}".format(
                                    pprint(
                                        replace(target, root_ctx, RUNTIME_POOL,
                                                watched_e, ctx, pool,
                                                EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(
                                    target, watched_e, ctx, pool, replacement,
                                    search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx, e, ctx, pool,
                                Fingerprint.of(
                                    e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event(
                                    "skipped; `should_consider_replacement` returned {}"
                                    .format(should_consider))
                                continue

                            yield from _consider_replacement(
                                target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL,
                                                     size):
            raise StopException(
                "no more expressions can exist above size={}".format(size))

        size += 1
Ejemplo n.º 11
0
    def search(self):

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                is_wf = exp_wf(e,
                               pool=pool,
                               context=ctx,
                               solver=self.wf_solver)
                if not is_wf:
                    return is_wf
                res = good_idea_recursive(self.wf_solver,
                                          e,
                                          ctx,
                                          pool,
                                          ops=self.ops)
                if not res:
                    return res
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    return No("too expensive")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback,
                          do_eviction=enable_eviction.value)

        size = 0
        target_fp = fingerprint(self.targets[0], self.examples)

        watches = OrderedDict()
        for target in self.targets:
            for e, ctx, pool in unique(
                    shred(target, context=root_ctx, pool=RUNTIME_POOL)):
                exs = ctx.instantiate_examples(self.examples)
                fp = fingerprint(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))
        watched_ctxs = list(
            unique((ctx, pool) for fp, ctx, pool in watches.keys()))

        def consider_new_target(old_target, e, ctx, pool, replacement):
            nonlocal n
            n += 1
            k = (e, ctx, pool, replacement)
            if enable_blacklist.value and k in self.blacklist:
                event("blacklisted")
                print("skipping blacklisted substitution: {} ---> {} ({})".
                      format(pprint(e), pprint(replacement),
                             self.blacklist[k]))
                return
            new_target = freshen_binders(
                replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool,
                        replacement), root_ctx)
            if any(alpha_equivalent(t, new_target) for t in self.targets):
                event("already seen")
                return
            wf = check_wf(new_target, root_ctx, RUNTIME_POOL)
            if not wf:
                msg = "not well-formed [wf={}]".format(wf)
                event(msg)
                self.blacklist[k] = msg
                return
            if not fingerprints_match(fingerprint(new_target, self.examples),
                                      target_fp):
                msg = "not correct"
                event(msg)
                self.blacklist[k] = msg
                return
            if self.cost_model.compare(new_target, target, root_ctx,
                                       RUNTIME_POOL) not in (Order.LT,
                                                             Order.AMBIGUOUS):
                msg = "not an improvement"
                event(msg)
                self.blacklist[k] = msg
                return
            print("FOUND A GUESS AFTER {} CONSIDERED".format(n))
            print(" * in {}".format(pprint(old_target), pprint(e),
                                    pprint(replacement)))
            print(" * replacing {}".format(pprint(e)))
            print(" * with {}".format(pprint(replacement)))
            yield new_target

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0

            for ctx, pool in watched_ctxs:
                with task("searching for obvious substitutions",
                          ctx=ctx,
                          pool=pool_name(pool)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("searching for obvious substitution",
                                  expression=pprint(info.e)):
                            fp = info.fingerprint
                            for ((fpx, cc, pp), reses) in watches.items():
                                if cc != ctx or pp != pool:
                                    continue

                                if not fingerprints_match(fpx, fp):
                                    continue

                                for target, watched_e in reses:
                                    replacement = info.e
                                    event("possible substitution: {} ---> {}".
                                          format(pprint(watched_e),
                                                 pprint(replacement)))
                                    event("replacement locations: {}".format(
                                        pprint(
                                            replace(target, root_ctx,
                                                    RUNTIME_POOL, watched_e,
                                                    ctx, pool, EVar("___")))))

                                    if alpha_equivalent(
                                            watched_e, replacement):
                                        event("no change")
                                        continue

                                    yield from consider_new_target(
                                        target, watched_e, ctx, pool,
                                        replacement)

            if check_all_substitutions.value:
                print("Guessing at substitutions...")
                for target, e, ctx, pool in exploration_order(
                        self.targets, root_ctx):
                    with task("checking substitutions",
                              target=pprint(
                                  replace(target, root_ctx, RUNTIME_POOL, e,
                                          ctx, pool, EVar("___"))),
                              e=pprint(e)):
                        for info in enum.enumerate_with_info(size=size,
                                                             context=ctx,
                                                             pool=pool):
                            with task("checking substitution",
                                      expression=pprint(info.e)):
                                if self.stop_callback():
                                    raise StopException()
                                replacement = info.e
                                if replacement.type != e.type:
                                    event("wrong type (is {}, need {})".format(
                                        pprint(replacement.type),
                                        pprint(e.type)))
                                    continue
                                if alpha_equivalent(replacement, e):
                                    event("no change")
                                    continue
                                should_consider = should_consider_replacement(
                                    target, root_ctx, e, ctx, pool,
                                    fingerprint(
                                        e,
                                        ctx.instantiate_examples(
                                            self.examples)), info.e,
                                    info.fingerprint)
                                if not should_consider:
                                    event(
                                        "skipped; `should_consider_replacement` returned {}"
                                        .format(should_consider))
                                    continue

                                yield from consider_new_target(
                                    target, e, ctx, pool, replacement)

            print("CONSIDERED {}".format(n))
            size += 1
Ejemplo n.º 12
0
Archivo: core.py Proyecto: uwplse/cozy
def search_for_improvements(
        targets       : [Exp],
        wf_solver     : ModelCachingSolver,
        context       : Context,
        examples      : [{str:object}],
        cost_model    : CostModel,
        stop_callback : Callable[[], bool],
        hints         : [Exp],
        ops           : [Op],
        blacklist     : {(Exp, Context, Pool, Exp) : str}):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context
    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(unique(itertools.chain(
            *[all_subexpressions_with_context_information(t, root_ctx) for t in targets],
            *[all_subexpressions_with_context_information(h, root_ctx) for h in hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(
            examples=examples,
            cost_model=cost_model,
            check_wf=check_wf,
            hints=frags,
            heuristics=try_optimize,
            stop_callback=stop_callback,
            do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(all_subexpressions_with_context_information(target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(unique((ctx, pool) for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(
        context=root_ctx,
        targets=targets,
        target_fingerprint=target_fp,
        examples=examples,
        check_wf=check_wf,
        cost_model=cost_model,
        blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                    with task("searching for obvious substitution", expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event("possible substitution: {} ---> {}".format(pprint(watched_e), pprint(replacement)))
                                event("replacement locations: {}".format(pprint(replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(target, watched_e, ctx, pool, replacement, search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                        target=pprint(replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))),
                        e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                        with task("checking substitution", expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx,
                                e, ctx, pool, Fingerprint.of(e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event("skipped; `should_consider_replacement` returned {}".format(should_consider))
                                continue

                            yield from _consider_replacement(target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL, size):
            raise StopException("no more expressions can exist above size={}".format(size))

        size += 1