Beispiel #1
0
def exploration_order(targets: [Exp],
                      context: Context,
                      pool: Pool = RUNTIME_POOL):
    """
    What order should subexpressions of the given targets be explored for
    possible improvements?

    Yields (target, subexpression, subcontext, subpool) tuples.
    """

    # current policy (earlier requirements have priority):
    #  - visit runtime expressions first
    #  - visit low-complexity contexts first
    #  - visit small expressions first
    def sort_key(tup):
        e, ctx, p = tup
        return (0 if p == RUNTIME_POOL else 1, ctx.complexity(), e.size())

    for target in targets:
        for e, ctx, p in sorted(unique(
                all_subexpressions_with_context_information(target,
                                                            context,
                                                            pool=pool)),
                                key=sort_key):
            yield (target, e, ctx, p)
Beispiel #2
0
Datei: wf.py Projekt: uwplse/cozy
def exp_wf(e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = ETRUE, solver = None):
    """Check the well-formedess of `e`.

    Returns True or an instance of ExpIsNotWf that indicates why `e` is not
    well-formed.

    Parameters:
        e - an expression to check
        context - a context describing e's variables
        pool - what pool e lives in
        assumptions - facts that are true whenever e begins executing
            (NOTE: this does NOT need to include the path conditions from the
            context, but it is fine if it does.)
        solver - a ModelCachingSolver to use for solving formulas

    This function requires that:
     - all free variables in `e` are used in the correct pool
     - EStateVar only occurs in runtime expressions
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in all_subexpressions_with_context_information(e, context, pool):
        is_wf = exp_wf_nonrecursive(solver, x, ctx, p, assumptions=ctx.adapt(assumptions, context))
        if not is_wf:
            if isinstance(is_wf, No):
                return ExpIsNotWf(e, x, is_wf.msg)
            return is_wf
    return True
Beispiel #3
0
def possibly_useful(solver, e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = ETRUE, ops : [Op] = ()) -> bool:
    """Ensure that every subexpression of `e` passes the `possibly_useful_nonrecursive` check."""
    for (sub, sub_ctx, sub_pool) in all_subexpressions_with_context_information(e, context, pool):
        res = possibly_useful_nonrecursive(solver, sub, sub_ctx, sub_pool, assumptions=assumptions, ops=ops)
        if not res:
            return res
    return True
Beispiel #4
0
    def test_pool_affects_alpha_equivalence(self):
        e = EMap(EEmptyList().with_type(INT_BAG), ELambda(x, ONE))
        root_ctx = RootCtx(args=(), state_vars=())
        assert retypecheck(e)

        c1 = []
        for ee, ctx, pool in all_subexpressions_with_context_information(e, root_ctx, RUNTIME_POOL):
            if ee == ONE:
                c1.append(ctx)
        assert len(c1) == 1
        c1 = c1[0]

        c2 = []
        for ee, ctx, pool in all_subexpressions_with_context_information(e, root_ctx, STATE_POOL):
            if ee == ONE:
                c2.append(ctx)
        assert len(c2) == 1
        c2 = c2[0]

        assert c1 != c2
        assert not c1.alpha_equivalent(c2)
Beispiel #5
0
 def test_let(self):
     e1 = ELet(ZERO, ELambda(x, x))
     root_ctx = RootCtx(args=(), state_vars=())
     assert retypecheck(e1)
     n = 0
     for ee, ctx, pool in all_subexpressions_with_context_information(e1, root_ctx, RUNTIME_POOL):
         if ee == x:
             e2 = replace(
                 e1, root_ctx, RUNTIME_POOL,
                 x, ctx, pool,
                 ZERO)
             assert e2 == ELet(ZERO, ELambda(x, ZERO))
             n += 1
     assert n == 1
Beispiel #6
0
    def test_estatevar_ctx(self):
        xs = EVar("xs").with_type(INT_BAG)
        x = EVar("x").with_type(INT)
        y = EVar("y").with_type(BOOL)
        e = EMap(xs, ELambda(x, EStateVar(y)))
        ctx = RootCtx(args=(xs,), state_vars=(y,))
        assert retypecheck(e)
        for ee, ctx, pool in all_subexpressions_with_context_information(e, ctx):
            if ee == y:
                assert isinstance(ctx, RootCtx)

        e = replace(
            e, ctx, RUNTIME_POOL,
            y, ctx, STATE_POOL,
            ETRUE)

        assert e == EMap(xs, ELambda(x, EStateVar(ETRUE))), pprint(e)
Beispiel #7
0
def possibly_useful(solver,
                    e: Exp,
                    context: Context,
                    pool=RUNTIME_POOL,
                    assumptions: Exp = ETRUE,
                    ops: [Op] = ()) -> bool:
    """Ensure that every subexpression of `e` passes the `possibly_useful_nonrecursive` check."""
    for (sub, sub_ctx,
         sub_pool) in all_subexpressions_with_context_information(
             e, context, pool):
        res = possibly_useful_nonrecursive(solver,
                                           sub,
                                           sub_ctx,
                                           sub_pool,
                                           assumptions=assumptions,
                                           ops=ops)
        if not res:
            return res
    return True
Beispiel #8
0
def exploration_order(targets : [Exp], context : Context, pool : Pool = RUNTIME_POOL):
    """
    What order should subexpressions of the given targets be explored for
    possible improvements?

    Yields (target, subexpression, subcontext, subpool) tuples.
    """

    # current policy (earlier requirements have priority):
    #  - visit runtime expressions first
    #  - visit low-complexity contexts first
    #  - visit small expressions first
    def sort_key(tup):
        e, ctx, p = tup
        return (0 if p == RUNTIME_POOL else 1, ctx.complexity(), e.size())

    for target in targets:
        for e, ctx, p in sorted(unique(all_subexpressions_with_context_information(target, context, pool=pool)), key=sort_key):
            yield (target, e, ctx, p)
Beispiel #9
0
def fold_into_map(e, context):
    fvs = free_vars(e)
    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL):
        if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar):
            map = subexp.map.e
            key = subexp.key
            key_type = key.type
            value_type = subexp.type
            # e is of the form `... EStateVar(map)[key] ...`
            arg = fresh_var(subexp.type, omit=fvs)
            func = ELambda(arg, replace(
                e, context, RUNTIME_POOL,
                subexp, subcontext, subpool,
                arg))
            if not all(v in state_vars for v in free_vars(func)):
                continue
            func = strip_EStateVar(func)
            new_map = map_values(map, func.apply_to)
            yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
Beispiel #10
0
def fold_into_map(e, context):
    fvs = free_vars(e)
    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL):
        if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar):
            map = subexp.map.e
            key = subexp.key
            key_type = key.type
            value_type = subexp.type
            # e is of the form `... EStateVar(map)[key] ...`
            arg = fresh_var(subexp.type, omit=fvs)
            func = ELambda(arg, replace(
                e, context, RUNTIME_POOL,
                subexp, subcontext, subpool,
                arg))
            if not all(v in state_vars for v in free_vars(func)):
                continue
            func = strip_EStateVar(func)
            new_map = map_values(map, func.apply_to)
            yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
Beispiel #11
0
def exp_wf(e: Exp,
           context: Context,
           pool=RUNTIME_POOL,
           assumptions: Exp = ETRUE,
           solver=None):
    """Check the well-formedess of `e`.

    Returns True or an instance of ExpIsNotWf that indicates why `e` is not
    well-formed.

    Parameters:
        e - an expression to check
        context - a context describing e's variables
        pool - what pool e lives in
        assumptions - facts that are true whenever e begins executing
            (NOTE: this does NOT need to include the path conditions from the
            context, but it is fine if it does.)
        solver - a ModelCachingSolver to use for solving formulas

    This function requires that:
     - all free variables in `e` are used in the correct pool
     - EStateVar only occurs in runtime expressions
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in all_subexpressions_with_context_information(
            e, context, pool):
        is_wf = exp_wf_nonrecursive(solver,
                                    x,
                                    ctx,
                                    p,
                                    assumptions=ctx.adapt(
                                        assumptions, context))
        if not is_wf:
            if isinstance(is_wf, No):
                return ExpIsNotWf(e, x, is_wf.msg)
            return is_wf
    return True
Beispiel #12
0
 def test_all_subexpressions_with_context_information_minheap(self):
     f = ELambda(x, x)
     e = EMakeMinHeap(EEmptyList().with_type(INT_BAG), f).with_type(TMinHeap(INT, f))
     ctx = RootCtx(args=(), state_vars=())
     list(all_subexpressions_with_context_information(e, ctx))
Beispiel #13
0
def search_for_improvements(targets: [Exp], wf_solver: ModelCachingSolver,
                            context: Context, examples: [{
                                str: object
                            }], cost_model: CostModel,
                            stop_callback: Callable[[], bool], hints: [Exp],
                            ops: [Op], blacklist: {
                                (Exp, Context, Pool, Exp): str
                            }):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context

    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(
                    e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(
            unique(
                itertools.chain(
                    *[
                        all_subexpressions_with_context_information(
                            t, root_ctx) for t in targets
                    ], *[
                        all_subexpressions_with_context_information(
                            h, root_ctx) for h in hints
                    ])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=examples,
                          cost_model=cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=stop_callback,
                          do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(
                    all_subexpressions_with_context_information(
                        target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(
            unique(
                (ctx, pool)
                for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(context=root_ctx,
                             targets=targets,
                             target_fingerprint=target_fp,
                             examples=examples,
                             check_wf=check_wf,
                             cost_model=cost_model,
                             blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(
            size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions",
                      ctx=ctx,
                      pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size,
                                                     context=ctx,
                                                     pool=pool):
                    with task("searching for obvious substitution",
                              expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event(
                                    "possible substitution: {} ---> {}".format(
                                        pprint(watched_e),
                                        pprint(replacement)))
                                event("replacement locations: {}".format(
                                    pprint(
                                        replace(target, root_ctx, RUNTIME_POOL,
                                                watched_e, ctx, pool,
                                                EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(
                                    target, watched_e, ctx, pool, replacement,
                                    search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx, e, ctx, pool,
                                Fingerprint.of(
                                    e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event(
                                    "skipped; `should_consider_replacement` returned {}"
                                    .format(should_consider))
                                continue

                            yield from _consider_replacement(
                                target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL,
                                                     size):
            raise StopException(
                "no more expressions can exist above size={}".format(size))

        size += 1
Beispiel #14
0
def search_for_improvements(
        targets       : [Exp],
        wf_solver     : ModelCachingSolver,
        context       : Context,
        examples      : [{str:object}],
        cost_model    : CostModel,
        stop_callback : Callable[[], bool],
        hints         : [Exp],
        ops           : [Op],
        blacklist     : {(Exp, Context, Pool, Exp) : str}):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context
    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(unique(itertools.chain(
            *[all_subexpressions_with_context_information(t, root_ctx) for t in targets],
            *[all_subexpressions_with_context_information(h, root_ctx) for h in hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(
            examples=examples,
            cost_model=cost_model,
            check_wf=check_wf,
            hints=frags,
            heuristics=try_optimize,
            stop_callback=stop_callback,
            do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(all_subexpressions_with_context_information(target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(unique((ctx, pool) for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(
        context=root_ctx,
        targets=targets,
        target_fingerprint=target_fp,
        examples=examples,
        check_wf=check_wf,
        cost_model=cost_model,
        blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                    with task("searching for obvious substitution", expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event("possible substitution: {} ---> {}".format(pprint(watched_e), pprint(replacement)))
                                event("replacement locations: {}".format(pprint(replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(target, watched_e, ctx, pool, replacement, search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                        target=pprint(replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))),
                        e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                        with task("checking substitution", expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx,
                                e, ctx, pool, Fingerprint.of(e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event("skipped; `should_consider_replacement` returned {}".format(should_consider))
                                continue

                            yield from _consider_replacement(target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL, size):
            raise StopException("no more expressions can exist above size={}".format(size))

        size += 1