Ejemplo n.º 1
0
 def watch(self, new_target):
     print("watching new target...")
     self.backlog_counter = 0
     self.target = new_target
     self.roots = OrderedSet()
     types = OrderedSet()
     for e in itertools.chain(all_exps(new_target), *[all_exps(h) for h in self.hints]):
         if isinstance(e, ELambda):
             continue
         for pool in ALL_POOLS:
             exp = e
             if pool == STATE_POOL:
                 exp = strip_EStateVar(e)
             fvs = free_vars(exp)
             if all(v in self.legal_free_vars for v in fvs) and self.is_legal_in_pool(exp, pool):
                 _on_exp(exp, "new root", pool_name(pool))
                 exp._root = True
                 self.roots.add((exp, pool))
                 if pool == STATE_POOL and all(v in self.state_vars for v in fvs):
                     self.roots.add((EStateVar(exp).with_type(exp.type), RUNTIME_POOL))
                 types.add(exp.type)
             else:
                 _on_exp(exp, "rejected root", pool_name(pool))
     for b in self.binders:
         types.add(b.type)
     for t in types:
         self.roots.add((construct_value(t), RUNTIME_POOL))
     self.roots = list(self.roots)
     self.roots.sort(key = lambda tup: tup[0].size())
     self._watches = group_by(
         enumerate_fragments2(new_target),
         k=lambda ctx: (ctx.pool, ctx.e.type),
         v=lambda ctxs: sorted(ctxs, key=lambda ctx: -ctx.e.size()))
     print("done!")
Ejemplo n.º 2
0
def optimized_sum(xs, args):
    elem_type = xs.type.elem_type
    if isinstance(xs, EStateVar):
        yield EStateVar(sum_of(strip_EStateVar(xs))).with_type(elem_type)
    if isinstance(xs, EBinOp) and xs.op == "+":
        for a in optimized_sum(xs.e1, args=args):
            for b in optimized_sum(xs.e2, args=args):
                yield EBinOp(a, "+", b).with_type(elem_type)
    if isinstance(xs, EBinOp) and xs.op == "-":
        arg = fresh_var(elem_type)
        for a in optimized_sum(xs.e1, args=args):
            for e2 in _simple_filter(xs.e2, ELambda(arg, optimized_in(arg, xs.e1)), args):
                for b in optimized_sum(e2, args=args):
                    yield EBinOp(a, "-", b).with_type(elem_type)
    x = excluded_element(xs, args)
    if x is not None:
        bag, x = x
        for s in optimized_sum(bag, args):
            yield EBinOp(s, "-", x).with_type(x.type)
    if isinstance(xs, ESingleton):
        yield xs.e
    if isinstance(xs, EFlatMap):
        f = xs.transform_function
        if isinstance(f.body, EBinOp) and f.body.op == "+":
            for e1 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e1), args):
                for e2 in optimized_flatmap(xs.e, ELambda(f.arg, f.body.e2), args):
                    for e in optimized_sum(EBinOp(e1, "+", e2).with_type(e1.type), args):
                        yield e

    yield sum_of(xs)
Ejemplo n.º 3
0
 def test_mutate_preserves_statevar(self):
     x = EVar("x").with_type(INT)
     e = EBinOp(EStateVar(x), "+", ONE)
     assert retypecheck(e)
     s = SAssign(x, EBinOp(x, "+", ONE).with_type(INT))
     e2 = strip_EStateVar(inc.mutate(e, s))
     e2 = repair_well_formedness(e2, context=RootCtx(state_vars=[x], args=[]))
     print(pprint(e))
     print(pprint(e2))
     assert e2 == EBinOp(EBinOp(EStateVar(x), "+", ONE), "+", ONE)
Ejemplo n.º 4
0
 def test_mutate_preserves_statevar(self):
     x = EVar("x").with_type(INT)
     e = EBinOp(EStateVar(x), "+", ONE)
     assert retypecheck(e)
     s = SAssign(x, EBinOp(x, "+", ONE).with_type(INT))
     e2 = strip_EStateVar(inc.mutate(e, s))
     e2 = repair_well_formedness(e2,
                                 context=RootCtx(state_vars=[x], args=[]))
     print(pprint(e))
     print(pprint(e2))
     assert e2 == EBinOp(EBinOp(EStateVar(x), "+", ONE), "+", ONE)
Ejemplo n.º 5
0
def repair_EStateVar(e: syntax.Exp,
                     available_state: [syntax.Exp]) -> syntax.Exp:
    class V(BottomUpRewriter):
        def visit_EStateVar(self, e):
            return e

        def visit_Exp(self, e):
            if any(alpha_equivalent(e, x) for x in available_state):
                return target_syntax.EStateVar(e).with_type(e.type)
            return super().visit_ADT(e)

    return V().visit(strip_EStateVar(e))
Ejemplo n.º 6
0
 def make_subgoal(e, a=[], docstring=None):
     e = strip_EStateVar(e)
     if skip_stateless_synthesis.value and not any(v in ctx
                                                   for v in free_vars(e)):
         return e
     query_name = fresh_name("query")
     query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                          assumptions + a, e, docstring)
     query_vars = [v for v in free_vars(query) if v not in ctx]
     query.args = [(arg.id, arg.type) for arg in query_vars]
     subgoals.append(query)
     return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)
Ejemplo n.º 7
0
Archivo: core.py Proyecto: uwplse/cozy
def can_elim_vars(spec : Exp, assumptions : Exp, vs : [EVar]):
    """Does any execution of `spec` actually depend on any of `vs`?

    It is possible for a variable to appear in an expression like `spec`
    without affecting its value.  This function uses the solver to
    determine whether any of the given variables can affect the output of
    `spec`.
    """
    spec = strip_EStateVar(spec)
    sub = { v.id : fresh_var(v.type) for v in vs }
    return valid(EImplies(
        EAll([assumptions, subst(assumptions, sub)]),
        EEq(spec, subst(spec, sub))))
Ejemplo n.º 8
0
def simplify_sum(e):
    parts = list(break_sum(e))
    t, f = partition(parts, lambda p: p[0])
    t = [x[1] for x in t]
    f = [x[1] for x in f]
    parts = []
    for x in t:
        opp = find_one(
            f,
            lambda y: alpha_equivalent(strip_EStateVar(x), strip_EStateVar(y)))
        if opp:
            f.remove(opp)
        else:
            parts.append(x)
    parts.extend(EUnaryOp("-", x).with_type(INT) for x in f)

    if not parts:
        return ZERO
    res = parts[0]
    for i in range(1, len(parts)):
        res = EBinOp(res, "+", parts[i]).with_type(INT)
    return res
Ejemplo n.º 9
0
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]):
    """Does any execution of `spec` actually depend on any of `vs`?

    It is possible for a variable to appear in an expression like `spec`
    without affecting its value.  This function uses the solver to
    determine whether any of the given variables can affect the output of
    `spec`.
    """
    spec = strip_EStateVar(spec)
    sub = {v.id: fresh_var(v.type) for v in vs}
    return valid(
        EImplies(EAll([assumptions, subst(assumptions, sub)]),
                 EEq(spec, subst(spec, sub))))
Ejemplo n.º 10
0
def extract_listcomp(e):
    """
    Extract list comprehension components from its desugared form
    :param e: list comprehension expression
    :return: list comprehension structure { "P": ..., "C": ..., "V": ... }
             (we use "lc" to refer to this type of structure).
             "P" is pulled expressions, "C" is condition, "V" is returned value.
             In the written form, it is { V(p0, ..., pn) | p0 <- P_0, ..., pn <- P_n, C(p0, ..., pn)}.
             Notice that all V and C already have free variables p0 to pn.
             If the structure doesn't follow our assumption, return None
    """
    if isinstance(e, EFlatMap):
        pulled = e.e
        f = e.transform_function
        var = f.arg
        ebody = f.body
        lc = extract_listcomp(ebody)
        if lc is not None:
            lc["P"][var] = strip_EStateVar(pulled)
            return lc
    elif isinstance(e, EMap):
        f = e.transform_function
        ebody = f.body
        lc = extract_listcomp(e.e)
        if lc is not None:
            lc["V"] = ebody
            return lc
    elif isinstance(e, EFilter):
        lc = {
            "C": e.predicate.body,
            "P": {
                e.predicate.arg: strip_EStateVar(e.e)
            }
        }
        return lc
    return None
Ejemplo n.º 11
0
def _simple_filter(xs, p, args):
    if p.body == T:
        return xs
    if p.body == F:
        return EEmptyList().with_type(xs.type)
    if isinstance(xs, EEmptyList):
        return xs
    if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)):
        return EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type(
            xs.type)).with_type(xs.type)
    if isinstance(xs, EMapGet) and isinstance(
            xs.map, EStateVar) and not any(v in args for v in free_vars(p)):
        m = map_values(xs.map.e, lambda ys: _simple_filter(ys, p, args))
        return EMapGet(EStateVar(m).with_type(m.type),
                       xs.key).with_type(xs.type)
    if isinstance(xs, EBinOp) and xs.op == "+":
        return EBinOp(_simple_filter(xs.e1, p, args), "+",
                      _simple_filter(xs.e2, p, args)).with_type(xs.type)
    if isinstance(p.body, EBinOp) and p.body.op == "==":
        fvs2 = free_vars(p.body.e2)
        if p.body.e1 == p.arg and p.arg not in fvs2:
            return optimized_cond(optimized_in(p.body.e2, xs),
                                  ESingleton(p.body.e2).with_type(xs.type),
                                  EEmptyList().with_type(xs.type)).with_type(
                                      xs.type)
        fvs1 = free_vars(p.body.e1)
        if p.arg in fvs1 and not any(
                a in fvs1 for a in args) and p.arg not in fvs2 and isinstance(
                    xs, EStateVar):
            k = fresh_var(p.body.e1.type)
            e = EMapGet(
                EStateVar(
                    EMakeMap2(
                        EMap(xs.e, ELambda(p.arg, p.body.e1)),
                        ELambda(
                            k, EFilter(xs.e, ELambda(p.arg, EEq(p.body.e1,
                                                                k)))))),
                p.body.e2)
            res = retypecheck(e)
            assert res
            return e
    return EFilter(xs, p).with_type(xs.type)
Ejemplo n.º 12
0
def fold_into_map(e, context):
    fvs = free_vars(e)
    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL):
        if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar):
            map = subexp.map.e
            key = subexp.key
            key_type = key.type
            value_type = subexp.type
            # e is of the form `... EStateVar(map)[key] ...`
            arg = fresh_var(subexp.type, omit=fvs)
            func = ELambda(arg, replace(
                e, context, RUNTIME_POOL,
                subexp, subcontext, subpool,
                arg))
            if not all(v in state_vars for v in free_vars(func)):
                continue
            func = strip_EStateVar(func)
            new_map = map_values(map, func.apply_to)
            yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
Ejemplo n.º 13
0
def fold_into_map(e, context):
    fvs = free_vars(e)
    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    for subexp, subcontext, subpool in all_subexpressions_with_context_information(e, context, RUNTIME_POOL):
        if isinstance(subexp, EMapGet) and isinstance(subexp.map, EStateVar):
            map = subexp.map.e
            key = subexp.key
            key_type = key.type
            value_type = subexp.type
            # e is of the form `... EStateVar(map)[key] ...`
            arg = fresh_var(subexp.type, omit=fvs)
            func = ELambda(arg, replace(
                e, context, RUNTIME_POOL,
                subexp, subcontext, subpool,
                arg))
            if not all(v in state_vars for v in free_vars(func)):
                continue
            func = strip_EStateVar(func)
            new_map = map_values(map, func.apply_to)
            yield EMapGet(EStateVar(new_map).with_type(new_map.type), key).with_type(e.type)
Ejemplo n.º 14
0
def repair_well_formedness(e: Exp,
                           context: Context,
                           extra_available_state: [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(
            itertools.chain((v for v, p in context.vars() if p == STATE_POOL),
                            extra_available_state)),
                                 key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e,
                            s,
                            EStateVar(s).with_type(s.type),
                            match=alpha_equivalent,
                            filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
Ejemplo n.º 15
0
def map_accelerate(e, state_vars, binders, args, cache, size):
    for (_, arg, f, bound) in enumerate_fragments(strip_EStateVar(e)):
        if any(v in state_vars for v in free_vars(arg)):
            continue
        for binder in (b for b in binders
                       if b.type == arg.type and b not in bound):
            value = f(binder)
            if any(v not in state_vars and v not in binders
                   for v in free_vars(value)):
                continue
            for bag in cache.find_collections(pool=STATE_POOL,
                                              size=size,
                                              of=arg.type):
                if isinstance(bag, EEmptyList):
                    continue
                m = EMakeMap2(bag,
                              ELambda(binder,
                                      value)).with_type(TMap(arg.type, e.type))
                assert not any(v in args for v in free_vars(m))
                if any(v in binders for v in free_vars(m)):
                    continue
                yield (m, STATE_POOL)
                yield (EMapGet(EStateVar(m).with_type(m.type),
                               arg).with_type(e.type), RUNTIME_POOL)
Ejemplo n.º 16
0
Archivo: wf.py Proyecto: uwplse/cozy
def repair_well_formedness(e : Exp, context : Context, extra_available_state : [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(itertools.chain(
            (v for v, p in context.vars() if p == STATE_POOL),
            extra_available_state)), key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e, s, EStateVar(s).with_type(s.type),
                    match=alpha_equivalent,
                    filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
Ejemplo n.º 17
0
def sketch_update(
        lval: syntax.Exp,
        old_value: syntax.Exp,
        new_value: syntax.Exp,
        ctx: [syntax.EVar],
        assumptions: [syntax.Exp] = [],
        invariants: [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(
            syntax.EImplies(
                syntax.EAll(itertools.chain(assumptions, invariants)),
                syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx
                                                      for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [],
                             assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(
            syntax.EBinOp(old_value, "-", new_value).with_type(t),
            docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.t)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))
        ])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-",
                                            old_value).with_type(t),
                              docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval,
                             syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.ts))
        ])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(
            t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i),
                    get(old_value, i),
                    get(new_value, i),
                    ctx,
                    assumptions,
                    invariants=invariants) for i in range(len(t.fields))
        ])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-",
                                     new_keys).with_type(key_bag)
        s1 = syntax.SForEach(
            k,
            make_subgoal(deleted_keys,
                         docstring="keys removed from {}".format(
                             pprint(lval))), target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(
            new_keys,
            syntax.ELambda(
                k,
                syntax.EAny([
                    syntax.ENot(syntax.EIn(k, old_keys)),
                    syntax.ENot(
                        syntax.EEq(value_at(old_value, k),
                                   value_at(new_value, k)))
                ]))).with_type(key_bag)
        update_value = recurse(v,
                               value_at(old_value, k),
                               value_at(new_value, k),
                               ctx=ctx,
                               assumptions=assumptions + [
                                   syntax.EIn(k, new_or_modified),
                                   syntax.EEq(v, value_at(old_value, k))
                               ],
                               invariants=invariants)
        s2 = syntax.SForEach(
            k,
            make_subgoal(new_or_modified,
                         docstring="new or modified keys from {}".format(
                             pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(
            lval,
            make_subgoal(new_value,
                         docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
Ejemplo n.º 18
0
    def build(self, cache, size):

        for e in cache.find(pool=RUNTIME_POOL, size=size - 1, type=INT):
            if not is_root(e):
                continue
            e2 = simplify_sum(e)
            if e != e2:
                yield self.check(e2, RUNTIME_POOL)

        # for e in cache.find(pool=RUNTIME_POOL, size=size-1):
        #     if isinstance(e, EMapGet) and isinstance(e.map, EMakeMap2):
        #         x = e.map.value.apply_to(e.key)
        #         x._tag = True
        #         yield self.check(x, RUNTIME_POOL)

        # [x] - ys
        for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if isinstance(e, EBinOp) and e.op == "-" and isinstance(
                    e.e1, ESingleton):
                x = e.e1.e
                y = e.e2
                x = ECond(
                    EBinOp(x, BOp.In, y).with_type(BOOL),
                    EEmptyList().with_type(e.type), e.e1).with_type(e.type)
                yield self.check(x, RUNTIME_POOL)
            elif isinstance(e, EUnaryOp) and e.op == UOp.Distinct:
                e = strip_EStateVar(e)
                m = EMakeMap2(e.e, mk_lambda(e.type.t, lambda x: T)).with_type(
                    TMap(e.type.t, BOOL))
                yield self.check(m, STATE_POOL)
                m = EStateVar(m).with_type(m.type)
                yield self.check(m, RUNTIME_POOL)
                x = EMapKeys(m).with_type(e.type)
                # x._tag = True
                yield self.check(x, RUNTIME_POOL)

        # # x in ys ----> (count x in ys) > 0
        # for e in cache.find(pool=RUNTIME_POOL, type=BOOL, size=size-1):
        #     if isinstance(e, EBinOp) and e.op == BOp.In:
        #         for b in self.binders:
        #             if b.type != e.e1.type:
        #                 continue
        #             x = EGt(
        #                 EUnaryOp(UOp.Length, EFilter(e.e2, ELambda(b, EEq(e.e1, b))).with_type(e.e2.type)).with_type(INT),
        #                 ZERO)
        #             x._tag = True
        #             yield self.check(x, RUNTIME_POOL)

        for e in cache.find(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if (isinstance(e, EArgMin) or isinstance(
                    e, EArgMax)) and isinstance(e.e, EBinOp) and e.e.op == "+":
                l = e.e.e1
                r = e.e.e2
                op = e.e.op
                f = lambda x: type(e)(x, e.f).with_type(e.type)
                ll = EStateVar(f(l.e)).with_type(e.type) if isinstance(
                    l, EStateVar) else f(l)
                rr = EStateVar(f(r.e)).with_type(e.type) if isinstance(
                    r, EStateVar) else f(r)
                x = ECond(
                    EUnaryOp(UOp.Exists, l).with_type(BOOL),
                    ECond(
                        EUnaryOp(UOp.Exists, r).with_type(BOOL),
                        f(
                            EBinOp(
                                ESingleton(ll).with_type(e.e.type), op,
                                ESingleton(rr).with_type(e.e.type)).with_type(
                                    e.e.type)), ll).with_type(e.type),
                    rr).with_type(e.type)
                # from cozy.solver import valid
                # assert valid(EEq(e, x), model_callback=print)
                x._tag = True
                yield self.check(x, RUNTIME_POOL)

        # is-last(x, l)
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for e1 in cache.find(pool=RUNTIME_POOL, size=sz1):
                if not is_root(e1):
                    continue
                for e2 in cache.find_collections(pool=STATE_POOL,
                                                 size=sz2,
                                                 of=e1.type):
                    if not is_root(e2):
                        continue
                    for b in self.binders:
                        if b.type != e1.type:
                            continue
                        m = EMakeMap2(
                            e2,
                            mk_lambda(
                                e2.type.t, lambda x: EUnaryOp(
                                    UOp.Length,
                                    EFilter(
                                        e2,
                                        mk_lambda(e2.type.t, lambda y: EEq(
                                            x, y))).with_type(e2.type)).
                                with_type(INT))).with_type(TMap(
                                    e2.type.t, INT))
                        # filt = EFilter(e2, ELambda(b, EEq(e1, b))).with_type(e2.type)
                        # x = EEq(
                        #     EUnaryOp(UOp.Length, filt).with_type(INT),
                        #     ONE)
                        x = EGt(
                            EMapGet(EStateVar(m).with_type(m.type),
                                    e1).with_type(INT), ONE)
                        # x._tag = True
                        yield self.check(x, RUNTIME_POOL)

        # histogram
        # for e in cache.find_collections(pool=STATE_POOL, size=size-1):
        #     m = EMakeMap2(e,
        #         mk_lambda(e.type.t, lambda x:
        #             EUnaryOp(UOp.Length, EFilter(e,
        #                 mk_lambda(e.type.t, lambda y: EEq(x, y))).with_type(e.type)).with_type(INT))).with_type(TMap(e.type.t, INT))
        #     m._tag = True
        #     yield self.check(m, STATE_POOL)

        # Fixup EFilter(\x -> ECond...)
        for e in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(e):
                continue
            if isinstance(e, EFilter):
                for (_, x, r, _) in enumerate_fragments(e.p.body):
                    if isinstance(x, ECond):
                        lhs = EFilter(
                            e.e,
                            ELambda(e.p.arg, EAll([x.cond,
                                                   r(x.then_branch)
                                                   ]))).with_type(e.type)
                        rhs = EFilter(
                            e.e,
                            ELambda(e.p.arg,
                                    EAll([ENot(x.cond),
                                          r(x.else_branch)
                                          ]))).with_type(e.type)
                        union = EBinOp(lhs, "+", rhs).with_type(e.type)
                        # yield self.check(lhs.p.body, RUNTIME_POOL)
                        # yield self.check(rhs.p.body, RUNTIME_POOL)
                        yield self.check(lhs, RUNTIME_POOL)
                        yield self.check(rhs, RUNTIME_POOL)
                        yield self.check(union, RUNTIME_POOL)

        # Try instantiating bound expressions
        for pool in (STATE_POOL, RUNTIME_POOL):
            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for e1 in cache.find(pool=pool, size=sz1):
                    if not is_root(e1):
                        continue
                    for v in free_vars(e1):
                        if pool == RUNTIME_POOL:
                            e1 = subst(
                                strip_EStateVar(e1), {
                                    sv.id: EStateVar(sv).with_type(sv.type)
                                    for sv in self.state_vars if sv != v
                                })
                        for e2 in cache.find(pool=pool, type=v.type, size=sz2):
                            yield self.check(subst(e1, {v.id: e2}), pool)

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for e in cache.find(pool=RUNTIME_POOL, size=sz1):
                if not is_root(e):
                    continue
                for x, pool in map_accelerate(e, self.state_vars, self.binders,
                                              self.args, cache, sz2):
                    yield self.check(x, pool)
                if isinstance(e, EFilter) and not any(v in self.binders
                                                      for v in free_vars(e)):
                    for x, pool in accelerate_filter(e.e, e.p, self.state_vars,
                                                     self.binders, self.args,
                                                     cache, sz2):
                        yield self.check(x, pool)

        for bag in cache.find_collections(pool=RUNTIME_POOL, size=size - 1):
            if not is_root(bag):
                continue
            for a in self.args:
                for v in self.state_vars:
                    if is_collection(v.type) and v.type == a.type:
                        v = EStateVar(v).with_type(v.type)
                        cond = EBinOp(a, BOp.In, v).with_type(BOOL)
                        yield self.check(
                            EFilter(bag, mk_lambda(bag.type.t,
                                                   lambda _: cond)).with_type(
                                                       bag.type), RUNTIME_POOL)
                        yield self.check(
                            EFilter(
                                bag,
                                mk_lambda(bag.type.t,
                                          lambda _: ENot(cond))).with_type(
                                              bag.type), RUNTIME_POOL)

            if isinstance(bag, EFilter):
                if any(v not in self.state_vars for v in free_vars(bag.e)):
                    continue

                # separate filter conds
                if isinstance(bag.p.body, EBinOp) and bag.p.body.op == BOp.And:
                    p1 = ELambda(bag.p.arg, bag.p.body.e1)
                    p2 = ELambda(bag.p.arg, bag.p.body.e2)
                    f1 = EFilter(bag.e, p1).with_type(bag.type)
                    f2 = EFilter(bag.e, p2).with_type(bag.type)
                    f3 = EFilter(f1, p2).with_type(bag.type)
                    f4 = EFilter(f2, p1).with_type(bag.type)
                    yield self.check(f1, RUNTIME_POOL)
                    yield self.check(f2, RUNTIME_POOL)
                    yield self.check(f3, RUNTIME_POOL)
                    yield self.check(f4, RUNTIME_POOL)

                # construct map lookups
                binder = bag.p.arg
                inf = infer_map_lookup(bag.p.body, binder,
                                       set(self.state_vars))
                if inf:
                    key_proj, key_lookup, remaining_filter = inf
                    bag_binder = find_one(
                        self.binders,
                        lambda b: b.type == key_proj.type and b != binder)
                    if bag_binder:
                        m = strip_EStateVar(
                            EMakeMap2(
                                EMap(bag.e,
                                     ELambda(binder, key_proj)).with_type(
                                         type(bag.type)(key_proj.type)),
                                ELambda(
                                    bag_binder,
                                    EFilter(
                                        bag.e,
                                        ELambda(binder,
                                                EEq(key_proj,
                                                    bag_binder))).with_type(
                                                        bag.type))).with_type(
                                                            TMap(
                                                                key_proj.type,
                                                                bag.type)))
                        assert not any(v in self.args for v in free_vars(m))
                        yield self.check(m, STATE_POOL)
                        m = EStateVar(m).with_type(m.type)
                        mg = EMapGet(m, key_lookup).with_type(bag.type)
                        yield self.check(mg, RUNTIME_POOL)
                        yield self.check(
                            EFilter(mg, ELambda(
                                binder, remaining_filter)).with_type(mg.type),
                            RUNTIME_POOL)

        # for e in cache.find(size=size-1):
        #     # F(xs +/- ys) ---> F(xs), F(ys)
        #     for z in break_plus_minus(e):
        #         if z != e:
        #             # print("broke {} --> {}".format(pprint(e), pprint(z)))
        #             yield z

        #     # try reordering operations
        #     for (_, e1, f) in enumerate_fragments(e):
        #         if e1.type == e.type and e1 != e:
        #             for (_, e2, g) in enumerate_fragments(e1):
        #                 if e2.type == e.type and e2 != e1:
        #                     # e == f(g(e2))
        #                     yield g(f(e2))

        yield from self.wrapped.build(cache, size)
Ejemplo n.º 19
0
def sketch_update(
        lval        : syntax.Exp,
        old_value   : syntax.Exp,
        new_value   : syntax.Exp,
        ctx         : [syntax.EVar],
        assumptions : [syntax.Exp] = [],
        invariants  : [syntax.Exp] = []) -> (syntax.Stm, [syntax.Query]):
    """
    Write code to update `lval` when it changes from `old_value` to `new_value`.
    Variables in `ctx` are assumed to be part of the data structure abstract
    state, and `assumptions` will be appended to all generated subgoals.

    This function returns a statement (code to update `lval`) and a list of
    subgoals (new queries that appear in the code).
    """

    if valid(syntax.EImplies(
            syntax.EAll(itertools.chain(assumptions, invariants)),
            syntax.EEq(old_value, new_value))):
        return (syntax.SNoOp(), [])

    subgoals = []
    new_value = strip_EStateVar(new_value)

    def make_subgoal(e, a=[], docstring=None):
        if skip_stateless_synthesis.value and not any(v in ctx for v in free_vars(e)):
            return e
        query_name = fresh_name("query")
        query = syntax.Query(query_name, syntax.Visibility.Internal, [], assumptions + a, e, docstring)
        query_vars = [v for v in free_vars(query) if v not in ctx]
        query.args = [(arg.id, arg.type) for arg in query_vars]
        subgoals.append(query)
        return syntax.ECall(query_name, tuple(query_vars)).with_type(e.type)

    def recurse(*args, **kwargs):
        (code, sgs) = sketch_update(*args, **kwargs)
        subgoals.extend(sgs)
        return code

    t = lval.type
    if isinstance(t, syntax.TBag) or isinstance(t, syntax.TSet):
        to_add = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="additions to {}".format(pprint(lval)))
        to_del = make_subgoal(syntax.EBinOp(old_value, "-", new_value).with_type(t), docstring="deletions from {}".format(pprint(lval)))
        v = fresh_var(t.elem_type)
        stm = syntax.seq([
            syntax.SForEach(v, to_del, syntax.SCall(lval, "remove", [v])),
            syntax.SForEach(v, to_add, syntax.SCall(lval, "add", [v]))])
    elif is_numeric(t) and update_numbers_with_deltas.value:
        change = make_subgoal(syntax.EBinOp(new_value, "-", old_value).with_type(t), docstring="delta for {}".format(pprint(lval)))
        stm = syntax.SAssign(lval, syntax.EBinOp(lval, "+", change).with_type(t))
    elif isinstance(t, syntax.TTuple):
        get = lambda val, i: syntax.ETupleGet(val, i).with_type(t.ts[i])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.ts))])
    elif isinstance(t, syntax.TRecord):
        get = lambda val, i: syntax.EGetField(val, t.fields[i][0]).with_type(t.fields[i][1])
        stm = syntax.seq([
            recurse(get(lval, i), get(old_value, i), get(new_value, i), ctx, assumptions,
                invariants=invariants)
            for i in range(len(t.fields))])
    elif isinstance(t, syntax.TMap):
        k = fresh_var(lval.type.k)
        v = fresh_var(lval.type.v)
        key_bag = syntax.TBag(lval.type.k)

        old_keys = target_syntax.EMapKeys(old_value).with_type(key_bag)
        new_keys = target_syntax.EMapKeys(new_value).with_type(key_bag)

        # (1) exit set
        deleted_keys = syntax.EBinOp(old_keys, "-", new_keys).with_type(key_bag)
        s1 = syntax.SForEach(k, make_subgoal(deleted_keys, docstring="keys removed from {}".format(pprint(lval))),
            target_syntax.SMapDel(lval, k))

        # (2) enter/mod set
        new_or_modified = target_syntax.EFilter(new_keys,
            syntax.ELambda(k, syntax.EAny([syntax.ENot(syntax.EIn(k, old_keys)), syntax.ENot(syntax.EEq(value_at(old_value, k), value_at(new_value, k)))]))).with_type(key_bag)
        update_value = recurse(
            v,
            value_at(old_value, k),
            value_at(new_value, k),
            ctx = ctx,
            assumptions = assumptions + [syntax.EIn(k, new_or_modified), syntax.EEq(v, value_at(old_value, k))],
            invariants = invariants)
        s2 = syntax.SForEach(k, make_subgoal(new_or_modified, docstring="new or modified keys from {}".format(pprint(lval))),
            target_syntax.SMapUpdate(lval, k, v, update_value))

        stm = syntax.SSeq(s1, s2)
    else:
        # Fallback rule: just compute a new value from scratch
        stm = syntax.SAssign(lval, make_subgoal(new_value, docstring="new value for {}".format(pprint(lval))))

    return (stm, subgoals)
Ejemplo n.º 20
0
    def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(
                        itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [
                    e for v, e in self._concretization_functions
                ]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret), self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(
                    simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(
                self.query_specs, lambda qq: dedup_queries.value and
                queries_equivalent(qq,
                                   sub_q,
                                   state_vars=state_vars,
                                   extern_funcs=funcs,
                                   assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(
                    sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a)
                               for (a, t) in qq.args]

                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)

                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
Ejemplo n.º 21
0
def _simple_filter(xs : Exp, p : ELambda, args : {EVar}):
    """Assumes the body of p is already in negation normal form"""
    if p.body == ETRUE:
        yield xs
        return
    if p.body == EFALSE:
        yield EEmptyList().with_type(xs.type)
        return
    if isinstance(xs, EEmptyList):
        yield xs
        return
    yielded = False
    if isinstance(xs, ESingleton):
        yielded = True
        yield optimized_cond(p.apply_to(xs.e), xs, EEmptyList().with_type(xs.type))
    if isinstance(p.body, EBinOp) and p.body.op == BOp.Or:
        for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]):
            for r1 in _simple_filter(xs, ELambda(p.arg, e1), args):
                for r2 in _simple_filter(xs, ELambda(p.arg, EAll([e2, ENot(e1)])), args):
                    yielded = True
                    yield EBinOp(r1, "+", r2).with_type(xs.type)
    if isinstance(p.body, EBinOp) and p.body.op == BOp.And:
        for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]):
            for r1 in _simple_filter(xs, ELambda(p.arg, e1), args):
                yielded = True
                yield from _simple_filter(r1, ELambda(p.arg, e2), args)
    if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)):
        yielded = True
        yield EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type(xs.type)).with_type(xs.type)
    if isinstance(xs, EMapGet) and isinstance(xs.map, EStateVar) and not any(v in args for v in free_vars(p)):
        for m in map_values_multi(xs.map.e, lambda ys: _simple_filter(ys, p, args)):
            yielded = True
            yield EMapGet(EStateVar(m).with_type(m.type), xs.key).with_type(xs.type)
    if isinstance(xs, EBinOp) and xs.op in ("+", "-"):
        for e1 in _simple_filter(xs.e1, p, args):
            for e2 in _simple_filter(xs.e2, p, args):
                yielded = True
                yield EBinOp(e1, xs.op, e2).with_type(xs.type)
    if isinstance(p.body, EBinOp) and p.body.op == "==":
        e1 = p.body.e1
        e2 = p.body.e2
        fvs2 = free_vars(e2)
        fvs1 = free_vars(e1)
        for (e1, fvs1), (e2, fvs2) in itertools.permutations([(e1, fvs1), (e2, fvs2)]):
            if p.arg in fvs1 and not any(a in fvs1 for a in args) and p.arg not in fvs2 and isinstance(xs, EStateVar):
                if e1 == p.arg:
                    yield optimized_cond(
                        optimized_in(e2, xs),
                        ESingleton(e2).with_type(xs.type),
                        EEmptyList().with_type(xs.type))

                k = fresh_var(e1.type)
                e = EMapGet(
                    EStateVar(
                        EMakeMap2(
                            EMap(xs.e, ELambda(p.arg, e1)),
                            ELambda(k, EFilter(xs.e, ELambda(p.arg, EEq(e1, k)))))),
                    e2)
                res = retypecheck(e)
                assert res
                yielded = True
                yield e
    if not yielded:
        yield EFilter(xs, p).with_type(xs.type)
Ejemplo n.º 22
0
def optimized_best(xs, keyfunc, op, args):
    argbest = EArgMin if op == "<" else EArgMax
    elem_type = xs.type.elem_type
    key_type = keyfunc.body.type
    if excluded_element(xs, args) is not None:
        bag, x = excluded_element(xs, args)
        if all(v not in args for v in free_vars(bag)):
            heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (TMaxHeap, EMakeMaxHeap)
            bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type)
            h = make_heap(bag.e, keyfunc).with_type(heap_type(elem_type, key_type))
            for prev_min in optimized_best(bag.e, keyfunc, op, args=args):
                prev_min = EStateVar(prev_min).with_type(elem_type)
                heap_peek = EHeapPeek2(EStateVar(h).with_type(h.type)).with_type(elem_type)
                conds = [optimized_in(x, bag), optimized_eq(x, prev_min)]
                if isinstance(x, EUnaryOp) and x.op == UOp.The:
                    conds = [optimized_exists(x.e)] + conds
                yield optimized_cond(
                    EAll(conds),
                    heap_peek,
                    prev_min)
    if isinstance(xs, EEmptyList):
        yield construct_value(elem_type)
    if isinstance(xs, ESingleton):
        yield xs.e
    if isinstance(xs, EBinOp) and xs.op == "+":
        a_ex = optimized_exists(xs.e1)
        b_ex = optimized_exists(xs.e2)
        bag_type = TBag(xs.type.elem_type)
        for a in optimized_best(xs.e1, keyfunc, op, args=args):
            for b in optimized_best(xs.e2, keyfunc, op, args=args):
                yield optimized_cond(a_ex,
                    optimized_cond(b_ex,
                        argbest(EBinOp(ESingleton(a).with_type(bag_type), "+", ESingleton(b).with_type(bag_type)).with_type(bag_type), keyfunc).with_type(elem_type),
                        a),
                    optimized_cond(b_ex, b, construct_value(elem_type)))
        # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar):
        #     sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1)
        #     sv_best = optimized_best(sv, keyfunc, op, args=args)
        #     yield optimized_cond(
        #         optimized_exists(sv),
        #         argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type),
        #         optimized_best(other, keyfunc, op, args=args))
        # else:
        #     parts = break_sum(xs)
        #     found = EFALSE
        #     best = construct_value(elem_type)
        #     for p in parts:
        #         ex = optimized_exists(p)
        #         best_here = optimized_best(p, keyfunc, op, args=args)
        #         best = optimized_cond(found,
        #             optimized_cond(ex,
        #                 optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL),
        #                     best_here,
        #                     best),
        #                 best),
        #             best_here)
        #         found = EAny([found, ex])
        #     yield best
    if isinstance(xs, EMap):
        for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function), op, args):
            yield optimized_cond(optimized_exists(xs.e),
                xs.transform_function.apply_to(b),
                construct_value(elem_type))
    if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(keyfunc)):
        yield EStateVar(argbest(xs.e, keyfunc).with_type(elem_type)).with_type(elem_type)
    if isinstance(xs, ECond):
        for a in optimized_best(xs.then_branch, keyfunc, op, args=args):
            for b in optimized_best(xs.else_branch, keyfunc, op, args=args):
                yield optimized_cond(xs.cond, a, b)
    if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct:
        yield from optimized_best(xs.e, keyfunc, op, args=args)
    # if isinstance(xs, EFilter):
    #     yield optimized_cond(
    #         xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)),
    #         optimized_best(xs.e, keyfunc, op, args=args),
    #         argbest(xs, keyfunc).with_type(elem_type))
    yield argbest(xs, keyfunc).with_type(elem_type)
Ejemplo n.º 23
0
def _try_optimize(e, context, pool):
    if not accelerate.value:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    if pool == RUNTIME_POOL:
        if all(v in state_vars for v in free_vars(e)):
            nsv = strip_EStateVar(e)
            sv = EStateVar(nsv).with_type(e.type)
            yield _check(sv, context, RUNTIME_POOL)

        for ee, p in map_accelerate(e, context):
            if p == RUNTIME_POOL:
                yield _check(ee, context, p)

        if isinstance(e, EListGet) and e.index == ZERO:
            yield _check(
                EUnaryOp(UOp.The, e.e).with_type(e.type), context,
                RUNTIME_POOL)

        if isinstance(e, EArgMin) or isinstance(e, EArgMax):
            ee = optimized_best(e.e,
                                e.f,
                                "<" if isinstance(e, EArgMin) else ">",
                                args=args)
            yield _check(ee, context, RUNTIME_POOL)

        if is_collection(e.type) and isinstance(e, EBinOp) and e.op == "-":
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EBinOp) and e.op == "===" and isinstance(
                e.e1.type, THandle):
            yield _check(
                EAll([
                    optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                    optimized_eq(optimized_val(e.e1),
                                 optimized_val(e.e2)).with_type(BOOL)
                ]), context, RUNTIME_POOL)

        if isinstance(e, EBinOp) and e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.The:
            for ee in optimize_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EFilter):
            ee = optimize_filter_as_if_distinct(e.e, e.p, args=args)
            yield _check(ee, context, RUNTIME_POOL)
            if isinstance(e.e, EFilter):
                # try swizzle
                ee = EFilter(_simple_filter(e.e.e, e.p, args=args),
                             e.e.p).with_type(e.type)
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EMap):
            for ee in optimize_map(e.e, e.f, args=args):
                yield _check(ee, context, RUNTIME_POOL)
Ejemplo n.º 24
0
def map_accelerate(e, context):
    with task("map_accelerate", size=e.size()):
        if is_constant_time(e):
            event("skipping map lookup inference for constant-time exp: {}".
                  format(pprint(e)))
            return

        @lru_cache()
        def make_binder(t):
            return fresh_var(t, hint="key")

        args = OrderedSet(v for (v, p) in context.vars() if p == RUNTIME_POOL)
        possible_keys = {}  # type -> [exp]
        i = 0

        stk = [e]
        while stk:
            event("exp {} / {}".format(i, e.size()))
            i += 1
            arg = stk.pop()
            if isinstance(arg, tuple):
                stk.extend(arg)
                continue
            if not isinstance(arg, Exp):
                continue
            if isinstance(arg, ELambda):
                stk.append(arg.body)
                continue

            if context.legal_for(free_vars(arg)):
                # all the work happens here
                binder = make_binder(arg.type)
                value = replace(
                    e,
                    arg,
                    binder,
                    match=lambda e1, e2: type(e1) == type(e2) and e1.type == e2
                    .type and alpha_equivalent(e1, e2))
                value = strip_EStateVar(value)
                # print(" ----> {}".format(pprint(value)))
                if any(v in args for v in free_vars(value)):
                    event("not all args were eliminated")
                else:
                    if arg.type not in possible_keys:
                        l = [
                            reachable_values_of_type(sv, arg.type)
                            for (sv, p) in context.vars() if p == STATE_POOL
                        ]
                        l = OrderedSet(x for x in l
                                       if not isinstance(x, EEmptyList))
                        possible_keys[arg.type] = l
                    for keys in possible_keys[arg.type]:
                        # print("reachable values of type {}: {}".format(pprint(arg.type), pprint(keys)))
                        # for v in state_vars:
                        #     print("  {} : {}".format(pprint(v), pprint(v.type)))
                        m = EMakeMap2(keys, ELambda(binder, value)).with_type(
                            TMap(arg.type, e.type))
                        assert not any(
                            v in args
                            for v in free_vars(m)), "oops! {}; args={}".format(
                                pprint(m), ", ".join(pprint(a) for a in args))
                        yield (m, STATE_POOL)
                        mg = EMapGet(EStateVar(m).with_type(m.type),
                                     arg).with_type(e.type)
                        # print(pprint(mg))
                        # mg._tag = True
                        yield (mg, RUNTIME_POOL)

            if isinstance(arg, EStateVar):
                # do not visit state expressions
                continue

            num_with_args = 0
            stk2 = list(arg.children())
            while stk2:
                child = stk2.pop()
                if isinstance(child, tuple):
                    stk.extend(child)
                    continue
                if not isinstance(child, Exp):
                    continue
                fvs = free_vars(child)
                if fvs & args:
                    num_with_args += 1
                    if num_with_args >= 2:
                        break
            if num_with_args < 2:
                stk.extend(arg.children())
            else:
                event("refusing to visit children of {}".format(pprint(arg)))
Ejemplo n.º 25
0
def optimized_best(xs, keyfunc, op, args):
    argbest = EArgMin if op == "<" else EArgMax
    elem_type = xs.type.elem_type
    key_type = keyfunc.body.type
    if excluded_element(xs, args) is not None:
        bag, x = excluded_element(xs, args)
        if all(v not in args for v in free_vars(bag)):
            heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (
                TMaxHeap, EMakeMaxHeap)
            bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type)
            h = make_heap(bag.e,
                          keyfunc).with_type(heap_type(elem_type, key_type))
            for prev_min in optimized_best(bag.e, keyfunc, op, args=args):
                prev_min = EStateVar(prev_min).with_type(elem_type)
                heap_peek = EHeapPeek2(EStateVar(h).with_type(
                    h.type)).with_type(elem_type)
                conds = [optimized_in(x, bag), optimized_eq(x, prev_min)]
                if isinstance(x, EUnaryOp) and x.op == UOp.The:
                    conds = [optimized_exists(x.e)] + conds
                yield optimized_cond(EAll(conds), heap_peek, prev_min)
    if isinstance(xs, EEmptyList):
        yield construct_value(elem_type)
    if isinstance(xs, ESingleton):
        yield xs.e
    if isinstance(xs, EBinOp) and xs.op == "+":
        a_ex = optimized_exists(xs.e1)
        b_ex = optimized_exists(xs.e2)
        bag_type = TBag(xs.type.elem_type)
        for a in optimized_best(xs.e1, keyfunc, op, args=args):
            for b in optimized_best(xs.e2, keyfunc, op, args=args):
                yield optimized_cond(
                    a_ex,
                    optimized_cond(
                        b_ex,
                        argbest(
                            EBinOp(
                                ESingleton(a).with_type(bag_type), "+",
                                ESingleton(b).with_type(bag_type)).with_type(
                                    bag_type),
                            keyfunc).with_type(elem_type), a),
                    optimized_cond(b_ex, b, construct_value(elem_type)))
        # if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar):
        #     sv, other = (xs.e1, xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2, xs.e1)
        #     sv_best = optimized_best(sv, keyfunc, op, args=args)
        #     yield optimized_cond(
        #         optimized_exists(sv),
        #         argbest(EBinOp(ESingleton(sv_best).with_type(xs.type), "+", other).with_type(xs.type), keyfunc).with_type(elem_type),
        #         optimized_best(other, keyfunc, op, args=args))
        # else:
        #     parts = break_sum(xs)
        #     found = EFALSE
        #     best = construct_value(elem_type)
        #     for p in parts:
        #         ex = optimized_exists(p)
        #         best_here = optimized_best(p, keyfunc, op, args=args)
        #         best = optimized_cond(found,
        #             optimized_cond(ex,
        #                 optimized_cond(EBinOp(keyfunc.apply_to(best_here), op, keyfunc.apply_to(best)).with_type(BOOL),
        #                     best_here,
        #                     best),
        #                 best),
        #             best_here)
        #         found = EAny([found, ex])
        #     yield best
    if isinstance(xs, EMap):
        for b in optimized_best(xs.e, compose(keyfunc, xs.transform_function),
                                op, args):
            yield optimized_cond(optimized_exists(xs.e),
                                 xs.transform_function.apply_to(b),
                                 construct_value(elem_type))
    if isinstance(xs, EStateVar) and not any(v in args
                                             for v in free_vars(keyfunc)):
        yield EStateVar(argbest(
            xs.e, keyfunc).with_type(elem_type)).with_type(elem_type)
    if isinstance(xs, ECond):
        for a in optimized_best(xs.then_branch, keyfunc, op, args=args):
            for b in optimized_best(xs.else_branch, keyfunc, op, args=args):
                yield optimized_cond(xs.cond, a, b)
    if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct:
        yield from optimized_best(xs.e, keyfunc, op, args=args)
    # if isinstance(xs, EFilter):
    #     yield optimized_cond(
    #         xs.predicate.apply_to(optimized_best(xs.e, keyfunc, op, args=args)),
    #         optimized_best(xs.e, keyfunc, op, args=args),
    #         argbest(xs, keyfunc).with_type(elem_type))
    yield argbest(xs, keyfunc).with_type(elem_type)
Ejemplo n.º 26
0
def _try_optimize(e: Exp, context: Context, pool: Pool):
    if not accelerate.value:
        return

    if pool != RUNTIME_POOL:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    # ---------------------------------------------------------------------
    # "Rewrite schemes": these trigger on many different AST shapes
    # They are listed first because they are more powerful than the
    # specific rewrite rules below.

    if not free_vars(e) and not free_funcs(e):
        try:
            yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL)
        except NotImplementedError:
            print("Unable to evaluate {!r}".format(e))

    if all(v in state_vars for v in free_vars(e)):
        nsv = strip_EStateVar(e)
        sv = EStateVar(nsv).with_type(e.type)
        yield _check(sv, context, RUNTIME_POOL)

    for ee in fold_into_map(e, context):
        yield _check(ee, context, pool)

    # ---------------------------------------------------------------------
    # "Rewrites": these trigger on specific AST nodes

    if isinstance(e, EBinOp):

        if e.op == "-" and is_collection(e.type):
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == "===" and isinstance(e.e1.type, THandle):
            yield _check(
                EAll([
                    optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                    optimized_eq(optimized_val(e.e1),
                                 optimized_val(e.e2)).with_type(BOOL)
                ]), context, RUNTIME_POOL)

        if e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, ECond):
        yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch),
                     context, RUNTIME_POOL)

    if isinstance(e, EGetField):
        for ee in optimized_get_field(e.e, e.field_name, args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and e.index == ZERO:
        for res in optimized_the(e.e, args):
            yield _check(res, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and isinstance(e.e, ECond):
        yield optimized_cond(
            e.e.cond,
            EListGet(e.e.then_branch, e.index).with_type(e.type),
            EListGet(e.e.else_branch, e.index).with_type(e.type))

    from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek
    if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems):
        yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type)

    if isinstance(e, EMapGet):
        ee = inline_mapget(e, context)
        yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EUnaryOp):

        if e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Distinct:
            for ee in optimized_distinct(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.The:
            for ee in optimized_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EArgMin) or isinstance(e, EArgMax):
        for ee in optimized_best(e.e,
                                 e.key_function,
                                 "<" if isinstance(e, EArgMin) else ">",
                                 args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EFilter):
        for ee in optimized_filter(e.e, e.predicate, args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EMap):
        for ee in optimized_map(e.e, e.transform_function, args=args):
            yield _check(ee, context, RUNTIME_POOL)
    from cozy.syntax import ESorted
    from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems
    target = e
    if isinstance(target, ESorted) and isinstance(target.e, EStateVar):
        e_max = EMakeMaxTreeMultiset(target.e.e).with_type(
            TMaxTreeMultiset(target.e.e.type.elem_type))
        e_min = EMakeMinTreeMultiset(target.e.e).with_type(
            TMinTreeMultiset(target.e.e.type.elem_type))
        ee = optimized_cond(
            target.asc,
            ETreeMultisetElems(EStateVar(e_min).with_type(
                e_min.type)).with_type(target.type),
            ETreeMultisetElems(EStateVar(e_max).with_type(
                e_max.type)).with_type(target.type))
        yield _check(ee, context, RUNTIME_POOL)
Ejemplo n.º 27
0
def _try_optimize(e, context, pool):
    if not accelerate.value:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    if pool == RUNTIME_POOL:

        if not free_vars(e) and not free_funcs(e):
            try:
                yield _check(uneval(e.type, eval(e, {})), context,
                             RUNTIME_POOL)
            except NotImplementedError:
                print("Unable to evaluate {!r}".format(e))

        if all(v in state_vars for v in free_vars(e)):
            nsv = strip_EStateVar(e)
            sv = EStateVar(nsv).with_type(e.type)
            yield _check(sv, context, RUNTIME_POOL)

        for ee in fold_into_map(e, context):
            yield _check(ee, context, pool)

        if isinstance(e, EListGet) and e.index == ZERO:
            for res in optimize_the(e.e, args):
                yield _check(res, context, RUNTIME_POOL)

        if isinstance(e, EArgMin) or isinstance(e, EArgMax):
            for ee in optimized_best(e.e,
                                     e.f,
                                     "<" if isinstance(e, EArgMin) else ">",
                                     args=args):
                yield _check(ee, context, RUNTIME_POOL)

        if is_collection(e.type) and isinstance(e, EBinOp) and e.op == "-":
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EBinOp) and e.op == "===" and isinstance(
                e.e1.type, THandle):
            yield _check(
                EAll([
                    optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                    optimized_eq(optimized_val(e.e1),
                                 optimized_val(e.e2)).with_type(BOOL)
                ]), context, RUNTIME_POOL)

        if isinstance(e, ECond):
            yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch),
                         context, RUNTIME_POOL)

        if isinstance(e, EGetField):
            for ee in optimized_get_field(e.e, e.f, args):
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EBinOp) and e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EUnaryOp) and e.op == UOp.The:
            for ee in optimize_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EFilter):
            ee = optimize_filter_as_if_distinct(e.e, e.p, args=args)
            yield _check(ee, context, RUNTIME_POOL)
            if isinstance(e.e, EFilter):
                # try swizzle
                ee = EFilter(_simple_filter(e.e.e, e.p, args=args),
                             e.e.p).with_type(e.type)
                yield _check(ee, context, RUNTIME_POOL)

        if isinstance(e, EMap):
            for ee in optimize_map(e.e, e.f, args=args):
                yield _check(ee, context, RUNTIME_POOL)
Ejemplo n.º 28
0
Archivo: core.py Proyecto: timwee/cozy
def can_elim_vars(spec: Exp, assumptions: Exp, vs: [EVar]):
    spec = strip_EStateVar(spec)
    sub = {v.id: fresh_var(v.type) for v in vs}
    return valid(
        EImplies(EAll([assumptions, subst(assumptions, sub)]),
                 EEq(spec, subst(spec, sub))))
Ejemplo n.º 29
0
def is_lenof(e, xs):
    return alpha_equivalent(strip_EStateVar(e), ELen(strip_EStateVar(xs)))
Ejemplo n.º 30
0
Archivo: impls.py Proyecto: uwplse/cozy
    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [e for v, e in self._concretization_functions]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret),
                    self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by
Ejemplo n.º 31
0
def _try_optimize(e : Exp, context : Context, pool : Pool):
    if not accelerate.value:
        return

    if pool != RUNTIME_POOL:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    # ---------------------------------------------------------------------
    # "Rewrite schemes": these trigger on many different AST shapes
    # They are listed first because they are more powerful than the
    # specific rewrite rules below.

    if not free_vars(e) and not free_funcs(e):
        try:
            yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL)
        except NotImplementedError:
            print("Unable to evaluate {!r}".format(e))

    if all(v in state_vars for v in free_vars(e)):
        nsv = strip_EStateVar(e)
        sv = EStateVar(nsv).with_type(e.type)
        yield _check(sv, context, RUNTIME_POOL)

    for ee in fold_into_map(e, context):
        yield _check(ee, context, pool)

    # ---------------------------------------------------------------------
    # "Rewrites": these trigger on specific AST nodes

    if isinstance(e, EBinOp):

        if e.op == "-" and is_collection(e.type):
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == "===" and isinstance(e.e1.type, THandle):
            yield _check(EAll([
                optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                optimized_eq(optimized_val(e.e1),  optimized_val(e.e2)).with_type(BOOL)]), context, RUNTIME_POOL)

        if e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, ECond):
        yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL)

    if isinstance(e, EGetField):
        for ee in optimized_get_field(e.e, e.field_name, args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and e.index == ZERO:
        for res in optimized_the(e.e, args):
            yield _check(res, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and isinstance(e.e, ECond):
        yield optimized_cond(e.e.cond,
                             EListGet(e.e.then_branch, e.index).with_type(e.type),
                             EListGet(e.e.else_branch, e.index).with_type(e.type))

    from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek
    if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems):
        yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type)

    if isinstance(e, EMapGet):
        ee = inline_mapget(e, context)
        yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EUnaryOp):

        if e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Distinct:
            for ee in optimized_distinct(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.The:
            for ee in optimized_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EArgMin) or isinstance(e, EArgMax):
        for ee in optimized_best(e.e, e.key_function, "<" if isinstance(e, EArgMin) else ">", args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EFilter):
        for ee in optimized_filter(e.e, e.predicate, args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EMap):
        for ee in optimized_map(e.e, e.transform_function, args=args):
            yield _check(ee, context, RUNTIME_POOL)
    from cozy.syntax import ESorted
    from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems
    target = e
    if isinstance(target, ESorted) and isinstance(target.e, EStateVar):
        e_max = EMakeMaxTreeMultiset(target.e.e).with_type(TMaxTreeMultiset(target.e.e.type.elem_type))
        e_min = EMakeMinTreeMultiset(target.e.e).with_type(TMinTreeMultiset(target.e.e.type.elem_type))
        ee = optimized_cond(target.asc,
                            ETreeMultisetElems(EStateVar(e_min).with_type(e_min.type)).with_type(target.type),
                            ETreeMultisetElems(EStateVar(e_max).with_type(e_max.type)).with_type(target.type))
        yield _check(ee, context, RUNTIME_POOL)
Ejemplo n.º 32
0
def is_lenof(e, xs):
    return alpha_equivalent(strip_EStateVar(e), ELen(strip_EStateVar(xs)))
Ejemplo n.º 33
0
def optimized_best(xs, keyfunc, op, args):
    argbest = EArgMin if op == "<" else EArgMax
    elem_type = xs.type.t
    key_type = keyfunc.body.type
    if excluded_element(xs, args) is not None:
        bag, x = excluded_element(xs, args)
        if all(v not in args for v in free_vars(bag)):
            heap_type, make_heap = (TMinHeap, EMakeMinHeap) if op == "<" else (
                TMaxHeap, EMakeMaxHeap)
            bag = EStateVar(strip_EStateVar(bag)).with_type(bag.type)
            h = make_heap(bag.e,
                          keyfunc).with_type(heap_type(elem_type, key_type))
            prev_min = EStateVar(
                optimized_best(
                    bag.e, keyfunc, op,
                    args=args).with_type(elem_type)).with_type(elem_type)
            heap_peek = EHeapPeek2(
                EStateVar(h).with_type(h.type),
                EStateVar(ELen(bag.e)).with_type(INT)).with_type(elem_type)
            return optimized_cond(
                EAll([optimized_in(x, bag),
                      optimized_eq(x, prev_min)]), heap_peek, prev_min)
    if isinstance(xs, EEmptyList):
        return construct_value(elem_type)
    if isinstance(xs, ESingleton):
        return xs.e
    if isinstance(xs, EBinOp) and xs.op == "+":
        if isinstance(xs.e1, EStateVar) or isinstance(xs.e2, EStateVar):
            sv, other = (xs.e1,
                         xs.e2) if isinstance(xs.e1, EStateVar) else (xs.e2,
                                                                      xs.e1)
            sv_best = optimized_best(sv, keyfunc, op, args=args)
            return optimized_cond(
                optimized_exists(sv),
                argbest(
                    EBinOp(ESingleton(sv_best).with_type(xs.type), "+",
                           other).with_type(xs.type),
                    keyfunc).with_type(elem_type),
                optimized_best(other, keyfunc, op, args=args))
        else:
            parts = break_sum(xs)
            found = F
            best = construct_value(elem_type)
            for p in parts:
                ex = optimized_exists(p)
                best_here = optimized_best(p, keyfunc, op, args=args)
                best = optimized_cond(
                    found,
                    optimized_cond(
                        ex,
                        optimized_cond(
                            EBinOp(keyfunc.apply_to(best_here), op,
                                   keyfunc.apply_to(best)).with_type(BOOL),
                            best_here, best), best), best_here)
                found = EAny([found, ex])
            return best
    if isinstance(xs, EMap):
        return xs.f.apply_to(
            optimized_best(xs.e, compose(keyfunc, xs.f), op, args))
    if isinstance(xs, EStateVar) and not any(v in args
                                             for v in free_vars(keyfunc)):
        return EStateVar(argbest(
            xs.e, keyfunc).with_type(elem_type)).with_type(elem_type)
    if isinstance(xs, ECond):
        return optimized_cond(
            xs.cond, optimized_best(xs.then_branch, keyfunc, op, args=args),
            optimized_best(xs.else_branch, keyfunc, op, args=args))
    if isinstance(xs, EUnaryOp) and xs.op == UOp.Distinct:
        return optimized_best(xs.e, keyfunc, op, args=args)
    # if isinstance(xs, EFilter):
    #     return optimized_cond(
    #         xs.p.apply_to(optimized_best(xs.e, keyfunc, op, args=args)),
    #         optimized_best(xs.e, keyfunc, op, args=args),
    #         argbest(xs, keyfunc).with_type(elem_type))
    return argbest(xs, keyfunc).with_type(elem_type)
Ejemplo n.º 34
0
def _simple_filter(xs: Exp, p: ELambda, args: {EVar}):
    """Assumes the body of p is already in negation normal form"""
    if p.body == ETRUE:
        yield xs
        return
    if p.body == EFALSE:
        yield EEmptyList().with_type(xs.type)
        return
    if isinstance(xs, EEmptyList):
        yield xs
        return
    yielded = False
    if isinstance(xs, ESingleton):
        yielded = True
        yield optimized_cond(p.apply_to(xs.e), xs,
                             EEmptyList().with_type(xs.type))
    if isinstance(p.body, EBinOp) and p.body.op == BOp.Or:
        for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]):
            for r1 in _simple_filter(xs, ELambda(p.arg, e1), args):
                for r2 in _simple_filter(xs,
                                         ELambda(p.arg, EAll([e2, ENot(e1)])),
                                         args):
                    yielded = True
                    yield EBinOp(r1, "+", r2).with_type(xs.type)
    if isinstance(p.body, EBinOp) and p.body.op == BOp.And:
        for e1, e2 in itertools.permutations([p.body.e1, p.body.e2]):
            for r1 in _simple_filter(xs, ELambda(p.arg, e1), args):
                yielded = True
                yield from _simple_filter(r1, ELambda(p.arg, e2), args)
    if isinstance(xs, EStateVar) and not any(v in args for v in free_vars(p)):
        yielded = True
        yield EStateVar(EFilter(xs.e, strip_EStateVar(p)).with_type(
            xs.type)).with_type(xs.type)
    if isinstance(xs, EMapGet) and isinstance(
            xs.map, EStateVar) and not any(v in args for v in free_vars(p)):
        for m in map_values_multi(xs.map.e,
                                  lambda ys: _simple_filter(ys, p, args)):
            yielded = True
            yield EMapGet(EStateVar(m).with_type(m.type),
                          xs.key).with_type(xs.type)
    if isinstance(xs, EBinOp) and xs.op in ("+", "-"):
        for e1 in _simple_filter(xs.e1, p, args):
            for e2 in _simple_filter(xs.e2, p, args):
                yielded = True
                yield EBinOp(e1, xs.op, e2).with_type(xs.type)
    if isinstance(p.body, EBinOp) and p.body.op == "==":
        e1 = p.body.e1
        e2 = p.body.e2
        fvs2 = free_vars(e2)
        fvs1 = free_vars(e1)
        for (e1, fvs1), (e2, fvs2) in itertools.permutations([(e1, fvs1),
                                                              (e2, fvs2)]):
            if p.arg in fvs1 and not any(
                    a in fvs1
                    for a in args) and p.arg not in fvs2 and isinstance(
                        xs, EStateVar):
                if e1 == p.arg:
                    yield optimized_cond(optimized_in(e2, xs),
                                         ESingleton(e2).with_type(xs.type),
                                         EEmptyList().with_type(xs.type))

                k = fresh_var(e1.type)
                e = EMapGet(
                    EStateVar(
                        EMakeMap2(
                            EMap(xs.e, ELambda(p.arg, e1)),
                            ELambda(k, EFilter(xs.e,
                                               ELambda(p.arg, EEq(e1, k)))))),
                    e2)
                res = retypecheck(e)
                assert res
                yielded = True
                yield e
    if not yielded:
        yield EFilter(xs, p).with_type(xs.type)