Esempio n. 1
0
def optimized_bag_difference(xs, ys):
    # EStateVar(distinct xs) - (EStateVar(xs) - [i])
    # ===> is-last(i, xs) ? [] : [i]
    if (isinstance(ys, EBinOp) and ys.op == "-"
            and isinstance(ys.e1, EStateVar) and isinstance(ys.e2, ESingleton)
            and isinstance(xs, EStateVar) and isinstance(xs.e, EUnaryOp)
            and xs.e.op == UOp.Distinct and alpha_equivalent(xs.e.e, ys.e1.e)):
        distinct_elems = xs.e
        elems = distinct_elems.e
        elem_type = elems.type.t
        m = histogram(elems)
        m_rt = EStateVar(m).with_type(m.type)
        count = EMapGet(m_rt, ys.e2.e).with_type(INT)
        return optimized_cond(optimized_eq(count, ONE), ys.e2,
                              EEmptyList().with_type(xs.type))

    # xs - (xs - [i])
    # ===> (i in xs) ? [i] : []
    if isinstance(ys, EBinOp) and ys.op == "-" and isinstance(
            ys.e2, ESingleton) and alpha_equivalent(xs, ys.e1):
        return optimized_cond(optimized_in(ys.e2.e, xs), ys.e2,
                              EEmptyList().with_type(xs.type))

    # [x] - xs
    if isinstance(xs, ESingleton):
        return optimized_cond(optimized_in(xs.e, ys),
                              EEmptyList().with_type(xs.type), xs)

    # only legal if xs are distinct, but we'll give it a go...
    return EFilter(xs, mk_lambda(
        xs.type.t, lambda x: ENot(optimized_in(x, ys)))).with_type(xs.type)
Esempio n. 2
0
 def visit_EBinOp(self, e):
     if e.op == BOp.In:
         if isinstance(e.e2, EBinOp) and e.e2.op == "+":
             return self.visit(EAny([EIn(e.e1, e.e2.e1), EIn(e.e1, e.e2.e2)]))
         elif isinstance(e.e2, EUnaryOp) and e.e2.op == UOp.Distinct:
             return self.visit(EIn(e.e1, e.e2.e))
         elif isinstance(e.e2, EMapKeys):
             return self.visit(EHasKey(e.e2.e, e.e1).with_type(BOOL))
     elif e.op in ("==", "==="):
         e1 = self.visit(e.e1)
         e2 = self.visit(e.e2)
         if alpha_equivalent(e1, e2):
             return ETRUE
         if isinstance(e2, ECond) and alpha_equivalent(e1, e2.else_branch):
             return self.visit(EBinOp(ENot(e2.cond), BOp.Or, EBinOp(e1, e.op, e2.then_branch).with_type(BOOL)).with_type(BOOL))
         e = EBinOp(e1, e.op, e2).with_type(e.type)
     if isinstance(e.e1, ECond):
         return self.visit(ECond(e.e1.cond,
             EBinOp(e.e1.then_branch, e.op, e.e2).with_type(e.type),
             EBinOp(e.e1.else_branch, e.op, e.e2).with_type(e.type)).with_type(e.type))
     if isinstance(e.e2, ECond):
         return self.visit(ECond(e.e2.cond,
             EBinOp(e.e1, e.op, e.e2.then_branch).with_type(e.type),
             EBinOp(e.e1, e.op, e.e2.else_branch).with_type(e.type)).with_type(e.type))
     return EBinOp(self.visit(e.e1), e.op, self.visit(e.e2)).with_type(e.type)
Esempio n. 3
0
 def test_make_record_other(self):
     assert not alpha_equivalent(
         EMakeRecord((("x", ENum(0)), ("y", ETRUE))),
         ETRUE)
     assert not alpha_equivalent(
         ETRUE,
         EMakeRecord((("x", ENum(0)), ("y", ETRUE))))
Esempio n. 4
0
 def visit_EBinOp(self, e):
     if e.op == BOp.In:
         if isinstance(e.e2, EBinOp) and e.e2.op == "+":
             return self.visit(EAny([EIn(e.e1, e.e2.e1), EIn(e.e1, e.e2.e2)]))
         elif isinstance(e.e2, EUnaryOp) and e.e2.op == UOp.Distinct:
             return self.visit(EIn(e.e1, e.e2.e))
         elif isinstance(e.e2, EMapKeys):
             return self.visit(EHasKey(e.e2.e, e.e1).with_type(BOOL))
     elif e.op in ("==", "==="):
         e1 = self.visit(e.e1)
         e2 = self.visit(e.e2)
         if alpha_equivalent(e1, e2):
             return T
         if isinstance(e2, ECond) and alpha_equivalent(e1, e2.else_branch):
             return self.visit(EBinOp(ENot(e2.cond), BOp.Or, EBinOp(e1, e.op, e2.then_branch).with_type(BOOL)).with_type(BOOL))
         e = EBinOp(e1, e.op, e2).with_type(e.type)
     if isinstance(e.e1, ECond):
         return self.visit(ECond(e.e1.cond,
             EBinOp(e.e1.then_branch, e.op, e.e2).with_type(e.type),
             EBinOp(e.e1.else_branch, e.op, e.e2).with_type(e.type)).with_type(e.type))
     if isinstance(e.e2, ECond):
         return self.visit(ECond(e.e2.cond,
             EBinOp(e.e1, e.op, e.e2.then_branch).with_type(e.type),
             EBinOp(e.e1, e.op, e.e2.else_branch).with_type(e.type)).with_type(e.type))
     return EBinOp(self.visit(e.e1), e.op, self.visit(e.e2)).with_type(e.type)
Esempio n. 5
0
def optimized_bag_difference(xs, ys):
    # EStateVar(distinct xs) - (EStateVar(xs) - [i])
    # ===> is-last(i, xs) ? [] : [i]
    if (isinstance(ys, EBinOp) and ys.op == "-" and
            isinstance(ys.e1, EStateVar) and
            isinstance(ys.e2, ESingleton) and
            isinstance(xs, EStateVar) and isinstance(xs.e, EUnaryOp) and xs.e.op == UOp.Distinct and
            alpha_equivalent(xs.e.e, ys.e1.e)):
        distinct_elems = xs.e
        elems = distinct_elems.e
        elem_type = elems.type.elem_type
        m = histogram(elems)
        m_rt = EStateVar(m).with_type(m.type)
        count = EMapGet(m_rt, ys.e2.e).with_type(INT)
        return optimized_cond(
            optimized_eq(count, ONE),
            ys.e2,
            EEmptyList().with_type(xs.type))

    # xs - (xs - [i])
    # ===> (i in xs) ? [i] : []
    if isinstance(ys, EBinOp) and ys.op == "-" and isinstance(ys.e2, ESingleton) and alpha_equivalent(xs, ys.e1):
        return optimized_cond(optimized_in(ys.e2.e, xs),
            ys.e2,
            EEmptyList().with_type(xs.type))

    # [x] - xs
    if isinstance(xs, ESingleton):
        return optimized_cond(
            optimized_in(xs.e, ys),
            EEmptyList().with_type(xs.type),
            xs)

    # only legal if xs are distinct, but we'll give it a go...
    return EFilter(xs, mk_lambda(xs.type.elem_type, lambda x: ENot(optimized_in(x, ys)))).with_type(xs.type)
Esempio n. 6
0
 def visit_EBinOp(self, e):
     if e.op == BOp.In:
         if isinstance(e.e2, EBinOp) and e.e2.op == "+":
             return self.visit(
                 EAny([EIn(e.e1, e.e2.e1),
                       EIn(e.e1, e.e2.e2)]))
         elif isinstance(e.e2, EUnaryOp) and e.e2.op == UOp.Distinct:
             return self.visit(EIn(e.e1, e.e2.e))
     elif e.op in ("==", "==="):
         e1 = self.visit(e.e1)
         e2 = self.visit(e.e2)
         if alpha_equivalent(e1, e2):
             return T
         if e.op == "==":
             while isinstance(e1, EWithAlteredValue):
                 e1 = e1.handle
             while isinstance(e2, EWithAlteredValue):
                 e2 = e2.handle
         e = EBinOp(e1, e.op, e2).with_type(e.type)
     if isinstance(e.e1, ECond):
         return self.visit(
             ECond(e.e1.cond,
                   EBinOp(e.e1.then_branch, e.op, e.e2).with_type(e.type),
                   EBinOp(e.e1.else_branch, e.op,
                          e.e2).with_type(e.type)).with_type(e.type))
     if isinstance(e.e2, ECond):
         return self.visit(
             ECond(e.e2.cond,
                   EBinOp(e.e1, e.op, e.e2.then_branch).with_type(e.type),
                   EBinOp(e.e1, e.op, e.e2.else_branch).with_type(
                       e.type)).with_type(e.type))
     return EBinOp(self.visit(e.e1), e.op,
                   self.visit(e.e2)).with_type(e.type)
Esempio n. 7
0
    def test_hint_instantation(self):

        x = EVar("x").with_type(INT)
        y = EVar("y").with_type(INT)
        z = EVar("z").with_type(INT)
        hint = ECall("f", (x,)).with_type(INT)
        context = UnderBinder(
            RootCtx(args=[x]),
            v=y,
            bag=ESingleton(x).with_type(TBag(x.type)),
            bag_pool=RUNTIME_POOL)
        cost_model = CostModel()

        f = lambda a: a + 1
        enumerator = Enumerator(
            examples=[{"x": 1, "f": f}, {"x": 100, "f": f}],
            hints=[(hint, context, RUNTIME_POOL)],
            cost_model=cost_model)

        results = []
        for ctx in (
                context,
                context.parent(),
                UnderBinder(context, v=z, bag=ESingleton(y).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL),
                UnderBinder(context.parent(), v=z, bag=ESingleton(x).with_type(TBag(y.type)), bag_pool=RUNTIME_POOL),
                UnderBinder(context.parent(), v=y, bag=ESingleton(ONE).with_type(INT_BAG), bag_pool=RUNTIME_POOL)):
            print("-" * 30)
            found = False
            for e in enumerator.enumerate(ctx, 0, RUNTIME_POOL):
                print(" -> {}".format(pprint(e)))
                found = found or alpha_equivalent(e, hint)
            print("found? {}".format(found))
            results.append(found)

        assert all(results)
Esempio n. 8
0
 def visit(self, e, *args):
     if isinstance(e, Exp) and _sametype(
             e, self.needle
     ) and self.pool == self.needle_pool and alpha_equivalent(
             self.needle, e) and self.needle_context.alpha_equivalent(
                 self.ctx):
         return self.ctx.adapt(self.replacement, self.needle_context)
     return super().visit(e, *args)
Esempio n. 9
0
 def alpha_equivalent(self, other):
     if not isinstance(other, UnderBinder):
         return False
     if not self.var.type == other.var.type:
         return False
     if not self._parent.alpha_equivalent(other._parent):
         return False
     return alpha_equivalent(self.bag, self._parent.adapt(other.bag, other._parent))
Esempio n. 10
0
 def test_optimized_in1(self):
     xs = EVar("xs").with_type(INT_BAG)
     i = EVar("i").with_type(INT)
     j = EVar("j").with_type(INT)
     e1 = EIn(i, EBinOp(EStateVar(xs), "-", ESingleton(j)))
     assert retypecheck(e1)
     e2 = optimized_in(i, e1.e2)
     assert not alpha_equivalent(e1, e2)
     self.assert_same(e1, e2)
Esempio n. 11
0
 def test_optimized_in2(self):
     xs = EVar("xs").with_type(INT_BAG)
     ys = EVar("ys").with_type(INT_BAG)
     i = EVar("i").with_type(INT)
     e1 = EIn(i, EBinOp(xs, "-", ys))
     assert retypecheck(e1)
     e2 = optimized_in(i, e1.e2)
     assert not alpha_equivalent(e1, e2)
     self.assert_same(e1, e2)
Esempio n. 12
0
 def test_optimized_in1(self):
     xs = EVar("xs").with_type(INT_BAG)
     i = EVar("i").with_type(INT)
     j = EVar("j").with_type(INT)
     e1 = EIn(i, EBinOp(EStateVar(xs), "-", ESingleton(j)))
     assert retypecheck(e1)
     e2 = optimized_in(i, e1.e2)
     assert not alpha_equivalent(e1, e2)
     self.assert_same(e1, e2)
Esempio n. 13
0
 def test_optimized_in2(self):
     xs = EVar("xs").with_type(INT_BAG)
     ys = EVar("ys").with_type(INT_BAG)
     i = EVar("i").with_type(INT)
     e1 = EIn(i, EBinOp(xs, "-", ys))
     assert retypecheck(e1)
     e2 = optimized_in(i, e1.e2)
     assert not alpha_equivalent(e1, e2)
     self.assert_same(e1, e2)
Esempio n. 14
0
 def test_lambdas(self):
     employers = EVar("employers").with_type(TBag(TInt()))
     e1 = mk_lambda(
         employers.type.t, lambda employer: EGetField(
             EGetField(employer, "val"), "employer_name"))
     e2 = mk_lambda(
         employers.type.t, lambda employer: EGetField(
             EGetField(employer, "val"), "employer_name"))
     assert alpha_equivalent(e1, e2)
Esempio n. 15
0
 def alpha_equivalent(self, other):
     if not isinstance(other, UnderBinder):
         return False
     if self.pool != other.pool:
         return False
     if not self.var.type == other.var.type:
         return False
     if not self._parent.alpha_equivalent(other._parent):
         return False
     return alpha_equivalent(self.bag, self._parent.adapt(other.bag, other._parent))
Esempio n. 16
0
 def visit_ECond(self, e):
     cond = self.visit(e.cond)
     if cond == ETRUE:
         return self.visit(e.then_branch)
     elif cond == EFALSE:
         return self.visit(e.else_branch)
     elif alpha_equivalent(self.visit(e.then_branch), self.visit(e.else_branch)):
         return self.visit(e.then_branch)
     tb = replace(e.then_branch, cond, ETRUE)
     eb = replace(e.else_branch, cond, EFALSE)
     return ECond(cond, self.visit(tb), self.visit(eb)).with_type(e.type)
Esempio n. 17
0
def try_optimize(e: Exp, context: Context, pool: Pool):
    """Yields expressions for the given context and pool.

    The expressions are likely to be semantically equivalent to `e` and likely
    to be better than `e`, but this function makes no promises.

    None of the expressions will be syntactically equivalent to `e`.
    """
    for ee in _try_optimize(e, context, pool):
        if not alpha_equivalent(e, ee):
            yield ee
Esempio n. 18
0
 def visit_ECond(self, e):
     cond = self.visit(e.cond)
     if cond == T:
         return self.visit(e.then_branch)
     elif cond == F:
         return self.visit(e.else_branch)
     elif alpha_equivalent(self.visit(e.then_branch), self.visit(e.else_branch)):
         return self.visit(e.then_branch)
     tb = replace(e.then_branch, cond, T)
     eb = replace(e.else_branch, cond, F)
     return ECond(cond, self.visit(tb), self.visit(eb)).with_type(e.type)
Esempio n. 19
0
def try_optimize(e : Exp, context : Context, pool : Pool):
    """Yields expressions for the given context and pool.

    The expressions are likely to be semantically equivalent to `e` and likely
    to be better than `e`, but this function makes no promises.

    None of the expressions will be syntactically equivalent to `e`.
    """
    for ee in _try_optimize(e, context, pool):
        if not alpha_equivalent(e, ee):
            yield ee
Esempio n. 20
0
 def adapt(self, e : Exp, ctx, e_fvs=None) -> Exp:
     if self == ctx:
         return e
     if e_fvs is None:
         e_fvs = free_vars(e)
     if isinstance(ctx, UnderBinder):
         if ctx.var not in e_fvs:
             return self.adapt(e, ctx.parent(), e_fvs=e_fvs)
         if alpha_equivalent(self.bag, self._parent.adapt(ctx.bag, ctx._parent)):
             e = self._parent.adapt(e, ctx._parent, e_fvs=e_fvs)
             return subst(e, { ctx.var.id : self.var })
     return self._parent.adapt(e, ctx, e_fvs=e_fvs)
Esempio n. 21
0
    def order_cardinalities(self, other, assumptions : Exp = T, solver : IncrementalSolver = None) -> Exp:
        if solver is None:
            solver = IncrementalSolver()
        if incremental:
            solver.push()
            solver.add_assumption(assumptions)

        cardinalities = OrderedDict()
        for m in (self.cardinalities, other.cardinalities):
            for k, v in m.items():
                cardinalities[v] = k

        conds = []
        res = []
        for (v1, c1) in cardinalities.items():
            res.append(EBinOp(v1, ">=", ZERO).with_type(BOOL))
            for (v2, c2) in cardinalities.items():
                if v1 == v2:
                    continue
                if alpha_equivalent(c1, c2):
                    res.append(EEq(v1, v2))
                    continue

                if incremental and use_indicators:
                    conds.append((v1, v2, fresh_var(BOOL), cardinality_le(c1, c2, as_f=True)))
                else:
                    if incremental:
                        le = cardinality_le(c1, c2, solver=solver)
                    else:
                        # print("CMP {}: {} / {}".format("<-" if v1 < v2 else "->", pprint(c1), pprint(c2)))
                        le = cardinality_le(c1, c2, assumptions=assumptions, solver=solver)
                    if le:
                        res.append(EBinOp(v1, "<=", v2).with_type(BOOL))

        if incremental and use_indicators:
            solver.add_assumption(EAll(
                [EEq(indicator, f) for (v1, v2, indicator, f) in conds]))
            for (v1, v2, indicator, f) in conds:
                if solver.valid(indicator):
                    res.append(EBinOp(v1, "<=", v2).with_type(BOOL))

        if incremental:
            solver.pop()

        if assume_large_cardinalities.value:
            min_cardinality = ENum(assume_large_cardinalities.value).with_type(INT)
            for cvar, exp in cardinalities.items():
                if isinstance(exp, EVar):
                    res.append(EBinOp(cvar, ">", min_cardinality).with_type(BOOL))

        # print("cards: {}".format(pprint(EAll(res))))
        return EAll(res)
Esempio n. 22
0
def mutate(e: syntax.Exp, op: syntax.Stm) -> syntax.Exp:
    """Return the new value of `e` after executing `op`."""
    if isinstance(op, syntax.SNoOp):
        return e
    elif isinstance(op, syntax.SAssign):
        return _do_assignment(op.lhs, op.rhs, e)
    elif isinstance(op, syntax.SCall):
        if op.func == "add":
            return mutate(
                e,
                syntax.SCall(op.target, "add_all", (syntax.ESingleton(
                    op.args[0]).with_type(op.target.type), )))
        elif op.func == "add_all":
            return mutate(
                e,
                syntax.SAssign(
                    op.target,
                    syntax.EBinOp(op.target, "+",
                                  op.args[0]).with_type(op.target.type)))
        elif op.func == "remove":
            return mutate(
                e,
                syntax.SCall(op.target, "remove_all", (syntax.ESingleton(
                    op.args[0]).with_type(op.target.type), )))
        elif op.func == "remove_all":
            return mutate(
                e,
                syntax.SAssign(
                    op.target,
                    syntax.EBinOp(op.target, "-",
                                  op.args[0]).with_type(op.target.type)))
        else:
            raise Exception("Unknown func: {}".format(op.func))
    elif isinstance(op, syntax.SIf):
        then_branch = mutate(e, op.then_branch)
        else_branch = mutate(e, op.else_branch)
        if alpha_equivalent(then_branch, else_branch):
            return then_branch
        return syntax.ECond(op.cond, then_branch,
                            else_branch).with_type(e.type)
    elif isinstance(op, syntax.SSeq):
        if isinstance(op.s1, syntax.SSeq):
            return mutate(e, syntax.SSeq(op.s1.s1,
                                         syntax.SSeq(op.s1.s2, op.s2)))
        e2 = mutate(mutate(e, op.s2), op.s1)
        if isinstance(op.s1, syntax.SDecl):
            e2 = subst(e2, {op.s1.id: op.s1.val})
        return e2
    elif isinstance(op, syntax.SDecl):
        return e
    else:
        raise NotImplementedError(type(op))
Esempio n. 23
0
def check_discovery(spec, expected, state_vars=[], args=[], examples=[], assumptions=ETRUE):
    ctx = RootCtx(state_vars=state_vars, args=args)
    for r in improve(spec,
            assumptions=assumptions,
            context=ctx,
            examples=examples):
        print("GOT RESULT ==> {}".format(pprint(r)))
        if isinstance(expected, Exp):
            if alpha_equivalent(r, expected):
                return True
        elif expected(r):
            return True
    return False
Esempio n. 24
0
 def adapt(self, e: Exp, ctx, e_fvs=None) -> Exp:
     if self == ctx:
         return e
     if e_fvs is None:
         e_fvs = free_vars(e)
     if isinstance(ctx, UnderBinder):
         if ctx.var not in e_fvs:
             return self.adapt(e, ctx.parent(), e_fvs=e_fvs)
         if alpha_equivalent(self.bag,
                             self._parent.adapt(ctx.bag, ctx._parent)):
             e = self._parent.adapt(e, ctx._parent, e_fvs=e_fvs)
             return subst(e, {ctx.var.id: self.var})
     return self._parent.adapt(e, ctx, e_fvs=e_fvs)
Esempio n. 25
0
def check_ops_preserve_invariants(spec : Spec):
    if not invariant_preservation_check.value:
        return []
    res = []
    for m in spec.methods:
        if not isinstance(m, Op):
            continue
        for a in spec.assumptions:
            print("Checking that {} preserves {}...".format(m.name, pprint(a)))
            a_post_delta = mutate(a, m.body)
            if not alpha_equivalent(a, a_post_delta):
                assumptions = list(m.assumptions) + list(spec.assumptions)
                if not valid(EImplies(EAll(assumptions), a_post_delta)):
                    res.append("{.name!r} may not preserve invariant {}".format(m, pprint(a)))
    return res
Esempio n. 26
0
def heap_func(e : Exp, concretization_functions : { str : Exp } = None) -> ELambda:
    if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap):
        return e.f
    if isinstance(e, EVar) and concretization_functions:
        ee = concretization_functions.get(e.id)
        if ee is not None:
            return heap_func(ee)
    if isinstance(e, ECond):
        h1 = heap_func(e.then_branch)
        h2 = heap_func(e.else_branch)
        if alpha_equivalent(h1, h2):
            return h1
        v = fresh_var(h1.arg.type)
        return ELambda(v, ECond(e.cond, h1.apply_to(v), h2.apply_to(v)).with_type(h1.body.type))
    raise NotImplementedError(repr(e))
Esempio n. 27
0
 def test_make_map(self):
     employers = EVar("employers").with_type(TBag(TInt()))
     e1 = EMakeMap(
         employers,
         mk_lambda(
             employers.type.t, lambda employer: EGetField(
                 EGetField(employer, "val"), "employer_name")),
         mk_lambda(employers.type, lambda es: es))
     e2 = EMakeMap(
         employers,
         mk_lambda(
             employers.type.t, lambda employer: EGetField(
                 EGetField(employer, "val"), "employer_name")),
         mk_lambda(employers.type, lambda es: es))
     assert alpha_equivalent(e1, e2)
Esempio n. 28
0
    def test_hint_instantation(self):

        x = EVar("x").with_type(INT)
        y = EVar("y").with_type(INT)
        z = EVar("z").with_type(INT)
        hint = ECall("f", (x, )).with_type(INT)
        context = UnderBinder(RootCtx(args=[x]),
                              v=y,
                              bag=ESingleton(x).with_type(TBag(x.type)),
                              bag_pool=RUNTIME_POOL)
        cost_model = CostModel()

        f = lambda a: a + 1
        enumerator = Enumerator(examples=[{
            "x": 1,
            "f": f
        }, {
            "x": 100,
            "f": f
        }],
                                hints=[(hint, context, RUNTIME_POOL)],
                                cost_model=cost_model)

        results = []
        for ctx in (context, context.parent(),
                    UnderBinder(context,
                                v=z,
                                bag=ESingleton(y).with_type(TBag(y.type)),
                                bag_pool=RUNTIME_POOL),
                    UnderBinder(context.parent(),
                                v=z,
                                bag=ESingleton(x).with_type(TBag(y.type)),
                                bag_pool=RUNTIME_POOL),
                    UnderBinder(context.parent(),
                                v=y,
                                bag=ESingleton(ONE).with_type(INT_BAG),
                                bag_pool=RUNTIME_POOL)):
            print("-" * 30)
            found = False
            for e in enumerator.enumerate(ctx, 0, RUNTIME_POOL):
                print(" -> {}".format(pprint(e)))
                found = found or alpha_equivalent(e, hint)
            print("found? {}".format(found))
            results.append(found)

        assert all(results)
Esempio n. 29
0
def check_discovery(spec,
                    expected,
                    state_vars=[],
                    args=[],
                    examples=[],
                    assumptions=ETRUE):
    ctx = RootCtx(state_vars=state_vars, args=args)
    for r in improve(spec,
                     assumptions=assumptions,
                     context=ctx,
                     examples=examples):
        print("GOT RESULT ==> {}".format(pprint(r)))
        if isinstance(expected, Exp):
            if alpha_equivalent(r, expected):
                return True
        elif expected(r):
            return True
    return False
Esempio n. 30
0
File: heaps.py Progetto: uwplse/cozy
def heap_func(e : Exp, concretization_functions : { str : Exp } = None) -> ELambda:
    """
    Assuming 'e' produces a heap, this returns the function used to sort its elements.
    """
    if isinstance(e, EMakeMinHeap) or isinstance(e, EMakeMaxHeap):
        return e.key_function
    if isinstance(e, EVar) and concretization_functions:
        ee = concretization_functions.get(e.id)
        if ee is not None:
            return heap_func(ee)
    if isinstance(e, ECond):
        h1 = heap_func(e.then_branch)
        h2 = heap_func(e.else_branch)
        if alpha_equivalent(h1, h2):
            return h1
        v = fresh_var(h1.arg.type)
        return ELambda(v, ECond(e.cond, h1.apply_to(v), h2.apply_to(v)).with_type(h1.body.type))
    raise NotImplementedError(repr(e))
Esempio n. 31
0
 def consider_new_target(old_target, e, ctx, pool, replacement):
     nonlocal n
     n += 1
     k = (e, ctx, pool, replacement)
     if enable_blacklist.value and k in self.blacklist:
         event("blacklisted")
         print("skipping blacklisted substitution: {} ---> {} ({})".
               format(pprint(e), pprint(replacement),
                      self.blacklist[k]))
         return
     new_target = freshen_binders(
         replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool,
                 replacement), root_ctx)
     if any(alpha_equivalent(t, new_target) for t in self.targets):
         event("already seen")
         return
     wf = check_wf(new_target, root_ctx, RUNTIME_POOL)
     if not wf:
         msg = "not well-formed [wf={}]".format(wf)
         event(msg)
         self.blacklist[k] = msg
         return
     if not fingerprints_match(fingerprint(new_target, self.examples),
                               target_fp):
         msg = "not correct"
         event(msg)
         self.blacklist[k] = msg
         return
     if self.cost_model.compare(new_target, target, root_ctx,
                                RUNTIME_POOL) not in (Order.LT,
                                                      Order.AMBIGUOUS):
         msg = "not an improvement"
         event(msg)
         self.blacklist[k] = msg
         return
     print("FOUND A GUESS AFTER {} CONSIDERED".format(n))
     print(" * in {}".format(pprint(old_target), pprint(e),
                             pprint(replacement)))
     print(" * replacing {}".format(pprint(e)))
     print(" * with {}".format(pprint(replacement)))
     yield new_target
Esempio n. 32
0
def simplify_sum(e):
    parts = list(break_sum(e))
    t, f = partition(parts, lambda p: p[0])
    t = [x[1] for x in t]
    f = [x[1] for x in f]
    parts = []
    for x in t:
        opp = find_one(
            f,
            lambda y: alpha_equivalent(strip_EStateVar(x), strip_EStateVar(y)))
        if opp:
            f.remove(opp)
        else:
            parts.append(x)
    parts.extend(EUnaryOp("-", x).with_type(INT) for x in f)

    if not parts:
        return ZERO
    res = parts[0]
    for i in range(1, len(parts)):
        res = EBinOp(res, "+", parts[i]).with_type(INT)
    return res
Esempio n. 33
0
def mutate(e : syntax.Exp, op : syntax.Stm) -> syntax.Exp:
    """Return the new value of `e` after executing `op`.

    Evaluating the returned expression will give the same output as running
    `op` followed by evaluating the input expression `e`.
    """
    if isinstance(op, syntax.SNoOp):
        return e
    elif isinstance(op, syntax.SAssign):
        return _do_assignment(op.lhs, op.rhs, e)
    elif isinstance(op, syntax.SCall):
        if op.func == "add":
            return mutate(e, syntax.SCall(op.target, "add_all", (syntax.ESingleton(op.args[0]).with_type(op.target.type),)))
        elif op.func == "add_all":
            return mutate(e, syntax.SAssign(op.target, syntax.EBinOp(op.target, "+", op.args[0]).with_type(op.target.type)))
        elif op.func == "remove":
            return mutate(e, syntax.SCall(op.target, "remove_all", (syntax.ESingleton(op.args[0]).with_type(op.target.type),)))
        elif op.func == "remove_all":
            return mutate(e, syntax.SAssign(op.target, syntax.EBinOp(op.target, "-", op.args[0]).with_type(op.target.type)))
        else:
            raise Exception("Unknown func: {}".format(op.func))
    elif isinstance(op, syntax.SIf):
        then_branch = mutate(e, op.then_branch)
        else_branch = mutate(e, op.else_branch)
        if alpha_equivalent(then_branch, else_branch):
            return then_branch
        return syntax.ECond(op.cond, then_branch, else_branch).with_type(e.type)
    elif isinstance(op, syntax.SSeq):
        if isinstance(op.s1, syntax.SSeq):
            return mutate(e, syntax.SSeq(op.s1.s1, syntax.SSeq(op.s1.s2, op.s2)))
        e2 = mutate(mutate(e, op.s2), op.s1)
        if isinstance(op.s1, syntax.SDecl):
            e2 = lightweight_subst(e2, op.s1.var, op.s1.val)
        return e2
    elif isinstance(op, syntax.SDecl):
        return e
    else:
        raise NotImplementedError(type(op))
Esempio n. 34
0
    def _compare(self, e1: Exp, e2: Exp, context: Context):
        e1_constant = not free_vars(e1) and not free_funcs(e1)
        e2_constant = not free_vars(e2) and not free_funcs(e2)
        if e1_constant and e2_constant:
            e1v = eval(e1, {})
            e2v = eval(e2, {})
            event("comparison obvious on constants: {} vs {}".format(e1v, e2v))
            return order_objects(e1v, e2v)
        if alpha_equivalent(e1, e2):
            event("shortcutting comparison of identical terms")
            return Order.EQUAL

        path_condition = EAll(context.path_conditions())
        always_le = self.solver.valid(EImplies(path_condition, ELe(e1, e2)))
        always_ge = self.solver.valid(EImplies(path_condition, EGe(e1, e2)))

        if always_le and always_ge:
            return Order.EQUAL
        if always_le:
            return Order.LT
        if always_ge:
            return Order.GT
        return Order.AMBIGUOUS
Esempio n. 35
0
 def test_make_record_order_dependent(self):
     assert not alpha_equivalent(
         EMakeRecord((("x", ENum(0)), ("y", ETRUE))),
         EMakeRecord((("y", ETRUE), ("x", ENum(0)))))
Esempio n. 36
0
 def visit(self, e, *args):
     if isinstance(e, Exp) and _sametype(e, self.needle) and self.pool == self.needle_pool and alpha_equivalent(self.needle, e) and self.needle_context.alpha_equivalent(self.ctx):
         return self.ctx.adapt(self.replacement, self.needle_context)
     return super().visit(e, *args)
Esempio n. 37
0
def optimized_eq(a, b):
    if alpha_equivalent(a, b):
        return ETRUE
    else:
        return EEq(a, b)
Esempio n. 38
0
 def test_mixed_binders(self):
     x = EVar("x")
     y = EVar("y")
     e1 = ELambda(x, ELambda(y, x))
     e2 = ELambda(x, ELambda(x, x))
     assert not alpha_equivalent(e1, e2)
Esempio n. 39
0
File: core.py Progetto: uwplse/cozy
def _consider_replacement(
        target      : Exp,
        e           : Exp,
        ctx         : Context,
        pool        : Pool,
        replacement : Exp,
        info        : SearchInfo):
    """Helper for search_for_improvements.

    This procedure decides whether replacing `e` with `replacement` in the
    given `target` is an improvement.  If yes, it yields the result of the
    replacement.  Otherwise it yields nothing.

    Parameters:
     - target: the top-level expression to improve
     - e: a subexpression of the target
     - ctx: e's context in the target
     - pool: e's pool in the target
     - replacement: a possible replacement for e
     - info: a SearchInfo object with auxiliary data

    This procedure may add items to info.blacklist.
    """
    context = info.context
    blacklist = info.blacklist
    k = (e, ctx, pool, replacement)
    if enable_blacklist.value and k in blacklist:
        event("blacklisted")
        print("skipping blacklisted substitution: {} ---> {} ({})".format(pprint(e), pprint(replacement), blacklist[k]))
        return
    new_target = freshen_binders(replace(
        target, context, RUNTIME_POOL,
        e, ctx, pool,
        replacement), context)
    if any(alpha_equivalent(t, new_target) for t in info.targets):
        event("already seen")
        return
    wf = info.check_wf(new_target, context, RUNTIME_POOL)
    if not wf:
        msg = "not well-formed [wf={}]".format(wf)
        event(msg)
        blacklist[k] = msg
        return
    if not Fingerprint.of(new_target, info.examples).equal_to(info.target_fingerprint):
        msg = "not correct"
        event(msg)
        blacklist[k] = msg
        return
    if not info.cost_model.compare(new_target, target, context, RUNTIME_POOL).could_be(Order.LT):
        msg = "not an improvement"
        event(msg)
        blacklist[k] = msg
        return
    print("FOUND A GUESS")
    print(" * in {}".format(pprint(target), pprint(e), pprint(replacement)))
    print(" * replacing {}".format(pprint(e)))
    print(" * with {}".format(pprint(replacement)))
    from cozy.structures.treemultiset import ETreeMultisetElems
    if isinstance(e, ETreeMultisetElems) and isinstance(e.e, EStateVar) and \
            isinstance(replacement, EStateVar) and isinstance(replacement.e, ETreeMultisetElems):
        # FIXME(zhen): current enumerator will always try to make ETreeMultisetElems a state var
        # FIXME(zhen): we don't want this because we need to put TreeSet into state var, rather than its iterator
        # FIXME(zhen): I still don't know how to fix this in a sensible way, but giving up an "improvement"
        # FIXME(zhen): should be okay temporarily
        print("give up {} -> {}".format(pprint(e), pprint(replacement)))
        return
    yield new_target
Esempio n. 40
0
 def test_tuples(self):
     one = ENum(1)
     e = ETuple((one, one))
     assert alpha_equivalent(e, e)
Esempio n. 41
0
File: core.py Progetto: uwplse/cozy
def search_for_improvements(
        targets       : [Exp],
        wf_solver     : ModelCachingSolver,
        context       : Context,
        examples      : [{str:object}],
        cost_model    : CostModel,
        stop_callback : Callable[[], bool],
        hints         : [Exp],
        ops           : [Op],
        blacklist     : {(Exp, Context, Pool, Exp) : str}):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context
    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(unique(itertools.chain(
            *[all_subexpressions_with_context_information(t, root_ctx) for t in targets],
            *[all_subexpressions_with_context_information(h, root_ctx) for h in hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(
            examples=examples,
            cost_model=cost_model,
            check_wf=check_wf,
            hints=frags,
            heuristics=try_optimize,
            stop_callback=stop_callback,
            do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(all_subexpressions_with_context_information(target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(unique((ctx, pool) for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(
        context=root_ctx,
        targets=targets,
        target_fingerprint=target_fp,
        examples=examples,
        check_wf=check_wf,
        cost_model=cost_model,
        blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions", ctx=ctx, pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                    with task("searching for obvious substitution", expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event("possible substitution: {} ---> {}".format(pprint(watched_e), pprint(replacement)))
                                event("replacement locations: {}".format(pprint(replace(target, root_ctx, RUNTIME_POOL, watched_e, ctx, pool, EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(target, watched_e, ctx, pool, replacement, search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                        target=pprint(replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool, EVar("___"))),
                        e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size, context=ctx, pool=pool):
                        with task("checking substitution", expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx,
                                e, ctx, pool, Fingerprint.of(e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event("skipped; `should_consider_replacement` returned {}".format(should_consider))
                                continue

                            yield from _consider_replacement(target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL, size):
            raise StopException("no more expressions can exist above size={}".format(size))

        size += 1
Esempio n. 42
0
def _maintenance_cost(e: Exp, op: Op, freebies: [Exp] = []):
    """Determines the cost of maintaining the expression when there are
    freebies and ops being considered.

    The cost is the result of mutating the expression and getting the storage
    size of the difference between the mutated expression and the original.
    """
    e_prime = mutate(e, op.body)
    if alpha_equivalent(e, e_prime):
        return ZERO

    h = extension_handler(type(e.type))
    if h is not None:
        return h.maintenance_cost(old_value=e,
                                  new_value=e_prime,
                                  op=op,
                                  freebies=freebies,
                                  storage_size=storage_size,
                                  maintenance_cost=_maintenance_cost)

    if is_scalar(e.type):
        return storage_size(e, freebies)
    elif isinstance(e.type, TBag) or isinstance(e.type, TSet):
        things_added = storage_size(
            EBinOp(e_prime, "-", e).with_type(e.type), freebies).with_type(INT)
        things_remov = storage_size(
            EBinOp(e, "-", e_prime).with_type(e.type), freebies).with_type(INT)

        return ESum([things_added, things_remov])
    elif isinstance(e.type, TList):
        return storage_size(e_prime, freebies)
    elif isinstance(e.type, TMap):
        keys = EMapKeys(e).with_type(TBag(e.type.k))
        vals = EMap(
            keys,
            mk_lambda(e.type.k,
                      lambda k: EMapGet(e, k).with_type(e.type.v))).with_type(
                          TBag(e.type.v))

        keys_prime = EMapKeys(e_prime).with_type(TBag(e_prime.type.k))
        vals_prime = EMap(
            keys_prime,
            mk_lambda(e_prime.type.k, lambda k: EMapGet(e_prime, k).with_type(
                e_prime.type.v))).with_type(TBag(e_prime.type.v))

        keys_added = storage_size(
            EBinOp(keys_prime, "-", keys).with_type(keys.type),
            freebies).with_type(INT)
        keys_rmved = storage_size(
            EBinOp(keys, "-", keys_prime).with_type(keys.type),
            freebies).with_type(INT)

        vals_added = storage_size(
            EBinOp(vals_prime, "-", vals).with_type(vals.type),
            freebies).with_type(INT)
        vals_rmved = storage_size(
            EBinOp(vals, "-", vals_prime).with_type(vals.type),
            freebies).with_type(INT)

        keys_difference = ESum([keys_added, keys_rmved])
        vals_difference = ESum([vals_added, vals_rmved])
        return EBinOp(keys_difference, "*", vals_difference).with_type(INT)

    else:
        raise NotImplementedError(repr(e.type))
Esempio n. 43
0
File: core.py Progetto: timwee/cozy
    def next(self):
        class No(object):
            def __init__(self, msg):
                self.msg = msg

            def __bool__(self):
                return False

            def __str__(self):
                return "no: {}".format(self.msg)

        # with task("pre-computing cardinalities"):
        #     cards = [self.cost_model.cardinality(ctx.e) for ctx in enumerate_fragments(self.target) if is_collection(ctx.e.type)]

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                try:
                    exp_wf(e,
                           pool=pool,
                           context=ctx,
                           assumptions=self.assumptions,
                           solver=self.wf_solver)
                except ExpIsNotWf as exc:
                    return No("at {}: {}".format(
                        pprint(exc.offending_subexpression), exc.reason))
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    # from cozy.cost_model import debug_comparison
                    # debug_comparison(self.cost_model, e, self.target, ctx)
                    return No("too expensive")
                # if isinstance(e.type, TBag):
                #     c = self.cost_model.cardinality(e)
                #     if all(cc < c for cc in cards):
                #         # print("too big: {}".format(pprint(e)))
                #         return No("too big")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback)

        size = 0
        # target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        target_fp = fingerprint(self.targets[0], self.examples)

        if not hasattr(self, "blacklist"):
            self.blacklist = set()

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0
            for target, e, ctx, pool in exploration_order(
                    self.targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if self.stop_callback():
                                raise StopException()
                            if info.e.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(info.e.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(info.e, e):
                                event("no change")
                                continue

                            k = (e, ctx, pool, info.e)
                            if k in self.blacklist:
                                event("blacklisted")
                                continue

                            n += 1
                            ee = freshen_binders(
                                replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                        pool, info.e), root_ctx)
                            if any(
                                    alpha_equivalent(t, ee)
                                    for t in self.targets):
                                event("already seen")
                                continue
                            if not self.matches(fingerprint(ee, self.examples),
                                                target_fp):
                                event("incorrect")
                                self.blacklist.add(k)
                                continue
                            wf = check_wf(ee, root_ctx, RUNTIME_POOL)
                            if not wf:
                                event("not well-formed [wf={}]".format(wf))
                                # if "expensive" in str(wf):
                                #     print(repr(self.cost_model.examples))
                                #     print(repr(ee))
                                self.blacklist.add(k)
                                continue
                            if self.cost_model.compare(
                                    ee, target, root_ctx,
                                    RUNTIME_POOL) not in (Order.LT,
                                                          Order.AMBIGUOUS):
                                event("not an improvement")
                                self.blacklist.add(k)
                                continue
                            print(
                                "FOUND A GUESS AFTER {} CONSIDERED".format(n))
                            yield ee

            print("CONSIDERED {}".format(n))
            size += 1

        raise NoMoreImprovements()
Esempio n. 44
0
def try_optimize(e, context, pool):
    for ee in _try_optimize(e, context, pool):
        if not alpha_equivalent(e, ee):
            yield ee
Esempio n. 45
0
def map_accelerate(e, context):
    with task("map_accelerate", size=e.size()):
        if is_constant_time(e):
            event("skipping map lookup inference for constant-time exp: {}".
                  format(pprint(e)))
            return

        @lru_cache()
        def make_binder(t):
            return fresh_var(t, hint="key")

        args = OrderedSet(v for (v, p) in context.vars() if p == RUNTIME_POOL)
        possible_keys = {}  # type -> [exp]
        i = 0

        stk = [e]
        while stk:
            event("exp {} / {}".format(i, e.size()))
            i += 1
            arg = stk.pop()
            if isinstance(arg, tuple):
                stk.extend(arg)
                continue
            if not isinstance(arg, Exp):
                continue
            if isinstance(arg, ELambda):
                stk.append(arg.body)
                continue

            if context.legal_for(free_vars(arg)):
                # all the work happens here
                binder = make_binder(arg.type)
                value = replace(
                    e,
                    arg,
                    binder,
                    match=lambda e1, e2: type(e1) == type(e2) and e1.type == e2
                    .type and alpha_equivalent(e1, e2))
                value = strip_EStateVar(value)
                # print(" ----> {}".format(pprint(value)))
                if any(v in args for v in free_vars(value)):
                    event("not all args were eliminated")
                else:
                    if arg.type not in possible_keys:
                        l = [
                            reachable_values_of_type(sv, arg.type)
                            for (sv, p) in context.vars() if p == STATE_POOL
                        ]
                        l = OrderedSet(x for x in l
                                       if not isinstance(x, EEmptyList))
                        possible_keys[arg.type] = l
                    for keys in possible_keys[arg.type]:
                        # print("reachable values of type {}: {}".format(pprint(arg.type), pprint(keys)))
                        # for v in state_vars:
                        #     print("  {} : {}".format(pprint(v), pprint(v.type)))
                        m = EMakeMap2(keys, ELambda(binder, value)).with_type(
                            TMap(arg.type, e.type))
                        assert not any(
                            v in args
                            for v in free_vars(m)), "oops! {}; args={}".format(
                                pprint(m), ", ".join(pprint(a) for a in args))
                        yield (m, STATE_POOL)
                        mg = EMapGet(EStateVar(m).with_type(m.type),
                                     arg).with_type(e.type)
                        # print(pprint(mg))
                        # mg._tag = True
                        yield (mg, RUNTIME_POOL)

            if isinstance(arg, EStateVar):
                # do not visit state expressions
                continue

            num_with_args = 0
            stk2 = list(arg.children())
            while stk2:
                child = stk2.pop()
                if isinstance(child, tuple):
                    stk.extend(child)
                    continue
                if not isinstance(child, Exp):
                    continue
                fvs = free_vars(child)
                if fvs & args:
                    num_with_args += 1
                    if num_with_args >= 2:
                        break
            if num_with_args < 2:
                stk.extend(arg.children())
            else:
                event("refusing to visit children of {}".format(pprint(arg)))
Esempio n. 46
0
 def test_make_record_yes(self):
     assert alpha_equivalent(
         EMakeRecord((("x", ENum(0)), ("y", ETRUE))),
         EMakeRecord((("x", ENum(0)), ("y", ETRUE))))
Esempio n. 47
0
 def test_tuple_nontuple(self):
     one = ENum(1)
     e = ETuple((one, one))
     assert not alpha_equivalent(e, one)
Esempio n. 48
0
    def next(self):
        target_cost = self.cost_model.cost(self.target, RUNTIME_POOL)
        self.ncount += 1
        while True:
            if self.backlog is not None:
                if self.stop_callback():
                    raise StopException()
                (e, pool, cost) = self.backlog
                improvements = list(self._possible_replacements(e, pool, cost))
                if self.backlog_counter < len(improvements):
                    i = improvements[self.backlog_counter]
                    self.backlog_counter += 1
                    return i
                else:
                    self.backlog = None
                    self.backlog_counter = 0
            for (e, pool) in self.builder_iter:
                self._on_exp(e, pool)
                if self.stop_callback():
                    raise StopException()

                # # Stopgap measure... long story --Calvin
                # bad = False
                # for x in all_exps(e):
                #     if isinstance(x, EStateVar):
                #         if any(v not in self.state_vars for v in free_vars(x.e)):
                #             bad = True
                #             _on_exp(e, "skipping due to illegal free vars under EStateVar")
                # if bad:
                #     continue

                new_e = self.pre_optimize(e, pool) if preopt.value else e
                if new_e is not e:
                    _on_exp(e, "preoptimized", new_e)
                    e = new_e

                cost = self.cost_model.cost(e, pool)

                if pool == RUNTIME_POOL and (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and self.compare_costs(cost, target_cost) == Cost.WORSE:
                    _on_exp(e, "too expensive", cost, target_cost)
                    continue

                fp = self._fingerprint(e)
                prev = list(self.seen.find_all(pool, fp))
                should_add = True
                if not prev:
                    _on_exp(e, "new", pool_name(pool))
                elif any(alpha_equivalent(e, ee) for (ee, _, _) in prev):
                    _on_exp(e, "duplicate")
                    should_add = False
                else:
                    better_than = None
                    worse_than = None
                    for prev_exp, prev_size, prev_cost in prev:
                        self._on_cost_cmp()
                        ordering = self.compare_costs(cost, prev_cost)
                        assert ordering in (Cost.WORSE, Cost.BETTER, Cost.UNORDERED)
                        if enforce_strong_progress.value and ordering != Cost.WORSE:
                            bad = find_one(all_exps(e), lambda ee: alpha_equivalent(ee, prev_exp))
                            if bad:
                                _on_exp(e, "failed strong progress requirement", bad)
                                should_add = False
                                break
                        _on_exp(e, ordering, pool_name(pool), prev_exp)
                        if ordering == Cost.UNORDERED:
                            continue
                        elif ordering == Cost.BETTER:
                            better_than = (prev_exp, prev_size, prev_cost)
                            _on_exp(prev_exp, "found better alternative", e)
                            self.cache.evict(prev_exp, size=prev_size, pool=pool)
                            self.seen.remove(prev_exp, pool, fp)
                            if (self.cost_model.is_monotonic() or hyperaggressive_culling.value) and hyperaggressive_eviction.value:
                                for (cached_e, size, p) in list(self.cache):
                                    if p != pool:
                                        continue
                                    if prev_exp in all_exps(cached_e):
                                        _on_exp(cached_e, "evicted since it contains", prev_exp)
                                        self.cache.evict(cached_e, size=size, pool=pool)
                        else:
                            should_add = False
                            worse_than = (prev_exp, prev_size, prev_cost)
                            # break
                    if worse_than and better_than:
                        print("Uh-oh! Strange cost relationship between")
                        print("  (1) this exp: {}".format(pprint(e)))
                        print("  (2) prev. A:  {}".format(pprint(worse_than[0])))
                        print("  (2) prev. B:  {}".format(pprint(better_than[0])))
                        print("e1 = {}".format(repr(e)))
                        print("e2 = {}".format(repr(worse_than[0])))
                        print("e3 = {}".format(repr(better_than[0])))
                        print("(1) vs (2): {}".format(cost.compare_to(worse_than[2], self.assumptions)))
                        print("(2) vs (3): {}".format(worse_than[2].compare_to(better_than[2], self.assumptions)))
                        print("(3) vs (1): {}".format(better_than[2].compare_to(cost, self.assumptions)))
                        # raise Exception("insane cost model behavior")

                if should_add:
                    self.cache.add(e, pool=pool, size=self.current_size)
                    self.seen.add(e, pool, fp, self.current_size, cost)
                    self.last_progress = self.current_size
                else:
                    continue

                for pr in self._possible_replacements(e, pool, cost):
                    self.backlog = (e, pool, cost)
                    self.backlog_counter = 1
                    return pr

            if self.last_progress < (self.current_size+1) // 2:
                raise NoMoreImprovements("hit termination condition")

            self.current_size += 1
            self.builder_iter = self.builder.build(self.cache, self.current_size)
            if self.current_size == 0:
                self.builder_iter = itertools.chain(self.builder_iter, list(self.roots))
            for f, ct in sorted(_fates.items(), key=lambda x: x[1], reverse=True):
                print("  {:6} | {}".format(ct, f))
            _fates.clear()
            self._start_minor_it()
Esempio n. 49
0
 def visit_Exp(self, e):
     if any(alpha_equivalent(e, x) for x in available_state):
         return target_syntax.EStateVar(e).with_type(e.type)
     return super().visit_ADT(e)
Esempio n. 50
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        examples = context.instantiate_examples(self.examples)
        if context.parent() is not None:
            for info in self.enumerate_with_info(context.parent(), size, pool):
                e = info.e
                yield EnumeratedExp(e=e, fingerprint=fingerprint(e, examples))

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            for e in res:
                yield e
        else:
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, size, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, size, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, size, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression
                    should_keep = True
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                if should_keep:

                    if self.do_eviction:
                        with task("evicting"):
                            to_evict = []
                            for (key, exps) in self.cache.items():
                                (p, s, c) = key
                                if p == pool and c == context:
                                    for ee in exps:
                                        if ee.fingerprint == fp:
                                            event("considering eviction of {}".
                                                  format(pprint(ee.e)))
                                            to_keep = eviction_policy(
                                                e, context, ee.e, c, pool,
                                                cost_model)
                                            if ee.e not in to_keep:
                                                to_evict.append((key, ee))
                            for key, ee in to_evict:
                                (p, s, c) = key
                                _evict(ee.e, s, c, pool, e)
                                self.cache[key].remove(ee)
                                self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, size, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)

            self.in_progress.remove(k)
Esempio n. 51
0
 def test_lambdas(self):
     employers = EVar("employers").with_type(TBag(TInt()))
     e1 = mk_lambda(employers.type.elem_type, lambda employer: EGetField(EGetField(employer, "val"), "employer_name"))
     e2 = mk_lambda(employers.type.elem_type, lambda employer: EGetField(EGetField(employer, "val"), "employer_name"))
     assert alpha_equivalent(e1, e2)
Esempio n. 52
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        if context.parent() is not None:
            yield from self.enumerate_with_info(context.parent(), size, pool)

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            # print("[[{} cached @ size={}]]".format(len(res), size))
            for e in res:
                yield e
        else:
            # print("ENTER {}".format(k))
            examples = context.instantiate_examples(self.examples)
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression,
                    # decide which can be evicted
                    should_keep = True
                    # cost = self.cost_model.cost(e, pool)
                    # print("prev={}".format(prev))
                    # print("seen={}".format(self.seen))
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            # prev_cost = self.cost_model.cost(prev_exp, pool)
                            # ordering = cost.compare_to(prev_cost)
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                            # if ordering == Order.LT:
                            #     pass
                            # elif ordering == Order.GT:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "worse than {}".format(pprint(prev_exp)))
                            #     should_keep = False
                            #     break
                            # else:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "{} to cached {}".format(
                            #         "equal" if ordering == Order.EQUAL else "similar",
                            #         pprint(prev_exp)))
                            #     assert ordering in (Order.EQUAL, Order.AMBIGUOUS)
                            #     should_keep = False
                            #     break

                if should_keep:

                    with task("evicting"):
                        to_evict = []
                        for (key, exps) in self.cache.items():
                            (p, s, c) = key
                            if p == pool and c in itertools.chain(
                                [context], parent_contexts(context)):
                                for ee in exps:
                                    if ee.fingerprint == fp:  # and cost_model.compare(e, ee.e, context, pool) == Order.LT:
                                        # to_evict.append((key, ee))
                                        to_keep = eviction_policy(
                                            e, context, ee.e, c, pool,
                                            cost_model)
                                        if ee.e not in to_keep:
                                            to_evict.append((key, ee))
                        for key, ee in to_evict:
                            (p, s, c) = key
                            # self.blacklist.add((ee.e, c, pool))
                            _evict(ee.e, c, pool, e)
                            self.cache[key].remove(ee)
                            self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp, cost=None)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            # print("trying {} accelerations".format(len(to_try)))
                            queue = itertools.chain(to_try, queue)

            # print("EXIT {}".format(k))
            self.in_progress.remove(k)
Esempio n. 53
0
def is_lenof(e, xs):
    return alpha_equivalent(strip_EStateVar(e), ELen(strip_EStateVar(xs)))
Esempio n. 54
0
 def test_binders(self):
     v1 = EVar("foo")
     e1 = EMap(v1, mk_lambda(TInt(), lambda arg: v1))
     e2 = EMap(v1, mk_lambda(TInt(), lambda arg: v1))
     assert e1.transform_function.arg.id != e2.transform_function.arg.id
     assert alpha_equivalent(e1, e2)
Esempio n. 55
0
def optimized_eq(a, b):
    if alpha_equivalent(a, b):
        return T
    else:
        return EEq(a, b)
Esempio n. 56
0
 def test_free_vars_not_equivalent(self):
     x = EVar("_var3423")
     y = EVar("_var3422")
     assert not alpha_equivalent(x, y)