예제 #1
0
파일: wf.py 프로젝트: uwplse/cozy
def exp_wf_nonrecursive(solver, e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = ETRUE):
    """Check the well-formedness of `e` but do not recurse into its children.

    Returns True or an instance of No explaining why `e` is not well-formed.

    See `exp_wf` for an explanation of well-formedness and the parameters that
    this function requires.
    """

    if hasattr(e, "_wf"):
        return True

    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args       = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)

    h = extension_handler(type(e))
    if h is not None:
        assumptions = EAll([assumptions, context.path_condition()])
        msg = h.check_wf(e, state_vars=state_vars, args=args, pool=pool, assumptions=assumptions, is_valid=solver.valid)
        if msg is not None:
            return No(msg)
        e._wf = True
        return True

    at_runtime = pool == RUNTIME_POOL
    if isinstance(e, EStateVar) and not at_runtime:
        return No("EStateVar in state pool position")
    if isinstance(e, EVar):
        if at_runtime and e in state_vars:
            return No("state var at runtime")
        elif not at_runtime and e in args:
            return No("arg in state exp")

    e._wf = True
    return True
예제 #2
0
파일: wf.py 프로젝트: paulswartz/cozy
def exp_wf_nonrecursive(solver,
                        e: Exp,
                        context: Context,
                        pool=RUNTIME_POOL,
                        assumptions: Exp = T):
    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])

    h = extension_handler(type(e))
    if h is not None:
        msg = h.check_wf(e,
                         state_vars=state_vars,
                         args=args,
                         pool=pool,
                         assumptions=assumptions,
                         is_valid=solver.valid)
        if msg is not None:
            raise ExpIsNotWf(e, e, msg)
        return
    at_runtime = pool == RUNTIME_POOL
    if isinstance(e, EStateVar) and not at_runtime:
        raise ExpIsNotWf(e, e, "EStateVar in state pool position")
    if isinstance(e, EVar):
        if at_runtime and e in state_vars:
            raise ExpIsNotWf(e, e, "state var at runtime")
        elif not at_runtime and e in args:
            raise ExpIsNotWf(e, e, "arg in state exp")
예제 #3
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        """Enumerate expressions (and fingerprints) of the given size.

        The output of this function is cached, so subsequent calls are very
        cheap.

        Arguments:
            context : a Context object describing the vars in scope
            size    : size of expressions to enumerate
            pool    : expression pool to visit
        """

        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        k = (pool, size, context)
        cache = self.cache

        if k in self.complete:
            yield from cache.find_expressions_of_size(context, pool, size)
        else:
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            yield from self._enumerate_with_info(context, size, pool)
            self.in_progress.remove(k)
            self.complete.add(k)
예제 #4
0
def exp_wf_nonrecursive(solver,
                        e: Exp,
                        context: Context,
                        pool=RUNTIME_POOL,
                        assumptions: Exp = ETRUE):
    """Check the well-formedness of `e` but do not recurse into its children.

    Returns True or an instance of No explaining why `e` is not well-formed.

    See `exp_wf` for an explanation of well-formedness and the parameters that
    this function requires.
    """

    if hasattr(e, "_wf"):
        return True

    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)

    h = extension_handler(type(e))
    if h is not None:
        assumptions = EAll([assumptions, context.path_condition()])
        msg = h.check_wf(e,
                         state_vars=state_vars,
                         args=args,
                         pool=pool,
                         assumptions=assumptions,
                         is_valid=solver.valid)
        if msg is not None:
            return No(msg)
        e._wf = True
        return True

    at_runtime = pool == RUNTIME_POOL
    if isinstance(e, EStateVar) and not at_runtime:
        return No("EStateVar in state pool position")
    if isinstance(e, EVar):
        if at_runtime and e in state_vars:
            return No("state var at runtime")
        elif not at_runtime and e in args:
            return No("arg in state exp")

    e._wf = True
    return True
예제 #5
0
def repair_well_formedness(e: Exp,
                           context: Context,
                           extra_available_state: [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(
            itertools.chain((v for v, p in context.vars() if p == STATE_POOL),
                            extra_available_state)),
                                 key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e,
                            s,
                            EStateVar(s).with_type(s.type),
                            match=alpha_equivalent,
                            filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
예제 #6
0
    def _compare(self, e1: Exp, e2: Exp, context: Context):
        e1_constant = not free_vars(e1) and not free_funcs(e1)
        e2_constant = not free_vars(e2) and not free_funcs(e2)
        if e1_constant and e2_constant:
            e1v = eval(e1, {})
            e2v = eval(e2, {})
            event("comparison obvious on constants: {} vs {}".format(e1v, e2v))
            return order_objects(e1v, e2v)
        if alpha_equivalent(e1, e2):
            event("shortcutting comparison of identical terms")
            return Order.EQUAL

        path_condition = EAll(context.path_conditions())
        always_le = self.solver.valid(EImplies(path_condition, ELe(e1, e2)))
        always_ge = self.solver.valid(EImplies(path_condition, EGe(e1, e2)))

        if always_le and always_ge:
            return Order.EQUAL
        if always_le:
            return Order.LT
        if always_ge:
            return Order.GT
        return Order.AMBIGUOUS
예제 #7
0
파일: wf.py 프로젝트: uwplse/cozy
def repair_well_formedness(e : Exp, context : Context, extra_available_state : [Exp] = []) -> Exp:
    """Repair the EStateVar nodes in an expression that is not well-formed.

    Parameters:
        e                     - the expression to repair
        context               - the intended context for e
        extra_available_state - extra state expressions that e can use

    Assuming that all expressions in extra_available_state are well-formed
    state expressions, the output will be a well-formed runtime expression that
    behaves like `e`.
    """

    with task("repairing"):

        e = strip_EStateVar(e)

        # state expressions in decreasing order of size
        available_state = sorted(unique(itertools.chain(
            (v for v, p in context.vars() if p == STATE_POOL),
            extra_available_state)), key=lambda e: -e.size())

        with task("making replacements", size=e.size()):
            for s in available_state:
                e = replace(e, s, EStateVar(s).with_type(s.type),
                    match=alpha_equivalent,
                    filter=lambda e: not isinstance(e, EStateVar))

        with task("freshening binders"):
            e = freshen_binders(e, context)

        with task("checking correctness"):
            res = exp_wf(e, context, RUNTIME_POOL)
            assert res, str(res)

        return e
예제 #8
0
def improve(target: Exp,
            context: Context,
            assumptions: Exp = ETRUE,
            stop_callback: Callable[[], bool] = never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None,
            ops: [Op] = (),
            improve_count: Value = None):
    """Improve the target expression using enumerative synthesis.

    This function is a generator that yields increasingly better and better
    versions of the input expression `target` in the given `context`.  The
    `cost_model` defines "better".

    It periodically calls `stop_callback` and exits gracefully when
    `stop_callback` returns True.

    Other parameters:
        - assumptions: a precondition.  The yielded improvements will only be
          correct when the assumptions are true.
        - hints: expressions that might be useful.  These will be explored
          first when looking for improvements.
        - examples: inputs that will be used internally to differentiate
          semantically distinct expressions.  This procedure discovers more
          examples as it runs, so there usually isn't a reason to provide any.
        - ops: update operations.  This function may make different choices
          about what expressions are state expressions based on what changes
          can happen to that state.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, allowing us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.  This is more desireable than running for
          an indeterminate amount of time doing nothing.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(target=target,
                               context=context,
                               assumptions=assumptions,
                               stop_callback=stop_callback,
                               hints=hints,
                               examples=examples,
                               cost_model=cost_model,
                               ops=ops))

    target = inline_lets(target)
    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    if heuristic_done(target):
        print("The target already looks great!")
        return

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    assert is_wf, "initial target is not well-formed: {}".format(is_wf)

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(ETRUE):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    is_good = possibly_useful(solver, target, context)
    assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(
        is_good, target)

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    blacklist = {}

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in search_for_improvements(targets=watched_targets,
                                                  wf_solver=solver,
                                                  context=context,
                                                  examples=examples,
                                                  cost_model=cost_model,
                                                  stop_callback=stop_callback,
                                                  hints=hints,
                                                  ops=ops,
                                                  blacklist=blacklist):
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, restart
                examples.append(counterexample)
                print("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(
                    len(examples)))
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = retention_policy(new_target, context, old_target,
                                               context, RUNTIME_POOL,
                                               cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print(
                            "Whoops! Looks like we already found something better."
                        )
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                break

        if improve_count is not None:
            with improve_count.get_lock():
                improve_count.value += 1
예제 #9
0
파일: wf.py 프로젝트: Nathan-Fenner/cozy
def exp_wf_nonrecursive(solver, e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = T):
    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args       = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])

    h = extension_handler(type(e))
    if h is not None:
        msg = h.check_wf(e, state_vars=state_vars, args=args, pool=pool, assumptions=assumptions, is_valid=solver.valid)
        if msg is not None:
            raise ExpIsNotWf(e, e, msg)
        return
    at_runtime = pool == RUNTIME_POOL
    if isinstance(e, EStateVar) and not at_runtime:
        raise ExpIsNotWf(e, e, "EStateVar in state pool position")
    if isinstance(e, EStateVar):
        fvs = free_vars(e.e)
        if not fvs:
            raise ExpIsNotWf(e, e, "constant value in state position")
        bad = [v for v in fvs if v not in state_vars]
        if bad:
            raise ExpIsNotWf(e, e, "free non-statevars in state position: {}".format(", ".join(v.id for v in bad)))
    if (isinstance(e, EDropFront) or isinstance(e, EDropBack)) and not at_runtime:
        raise ExpIsNotWf(e, e, "EDrop* in state position")
    if isinstance(e, EFlatMap) and not at_runtime:
        raise ExpIsNotWf(e, e, "EFlatMap in state position")
    if not allow_int_arithmetic_state.value and not at_runtime and isinstance(e, EBinOp) and e.type == INT:
        raise ExpIsNotWf(e, e, "integer arithmetic in state position")
    # if isinstance(e, EUnaryOp) and e.op == UOp.Distinct and not at_runtime:
    #     raise ExpIsNotWf(e, e, "'distinct' in state position")
    # if isinstance(e, EMapKeys) and not at_runtime:
    #     raise ExpIsNotWf(e, e, "'mapkeys' in state position")
    if isinstance(e, EVar):
        if at_runtime and e in state_vars:
            raise ExpIsNotWf(e, e, "state var at runtime")
        elif not at_runtime and e in args:
            raise ExpIsNotWf(e, e, "arg in state exp")
    # if is_collection(e.type) and is_collection(e.type.t):
    #     raise ExpIsNotWf(e, e, "collection of collection")
    if is_collection(e.type) and not is_scalar(e.type.t):
        raise ExpIsNotWf(e, e, "collection of nonscalar")
    if isinstance(e.type, TMap) and not is_scalar(e.type.k):
        raise ExpIsNotWf(e, e, "bad key type {}".format(pprint(e.type.k)))
    if isinstance(e.type, TMap) and isinstance(e.type.v, TMap):
        raise ExpIsNotWf(e, e, "map to map")
    # This check is probably a bad idea: whether `the` is legal may depend on
    # the contex that the expression is embedded within, so we can't skip it
    # during synthesis just because it looks invalid now.
    # if isinstance(e, EUnaryOp) and e.op == UOp.The:
    #     len = EUnaryOp(UOp.Length, e.e).with_type(INT)
    #     if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))):
    #         raise ExpIsNotWf(e, e, "illegal application of 'the': could have >1 elems")
    if not at_runtime and isinstance(e, EBinOp) and e.op == "-" and is_collection(e.type):
        raise ExpIsNotWf(e, e, "collection subtraction in state position")
    if not at_runtime and isinstance(e, ESingleton):
        raise ExpIsNotWf(e, e, "singleton in state position")
    # if not at_runtime and isinstance(e, ENum) and e.val != 0 and e.type == INT:
    #     raise ExpIsNotWf(e, e, "nonzero integer constant in state position")
    if not allow_conditional_state.value and not at_runtime and isinstance(e, ECond):
        raise ExpIsNotWf(e, e, "conditional in state position")
    if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList):
        raise ExpIsNotWf(e, e, "trivially empty map")
    if do_expensive_checks.value and not at_runtime and isinstance(e, EFilter):
        # catch "peels": removal of zero or one elements
        if solver.valid(EImplies(assumptions, ELe(ELen(EFilter(e.e, ELambda(e.p.arg, ENot(e.p.body))).with_type(e.type)), ONE))):
            raise ExpIsNotWf(e, e, "filter is a peel")
    if do_expensive_checks.value and not at_runtime and isinstance(e, EMakeMap2) and is_collection(e.type.v):
        all_collections = [sv for sv in state_vars if is_collection(sv.type)]
        total_size = ENum(0).with_type(INT)
        for c in all_collections:
            total_size = EBinOp(total_size, "+", EUnaryOp(UOp.Length, c).with_type(INT)).with_type(INT)
        my_size = EUnaryOp(UOp.Length, EFlatMap(EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type), e.value).with_type(e.type.v)).with_type(INT)
        s = EImplies(
            assumptions,
            EBinOp(total_size, ">=", my_size).with_type(BOOL))
        if not solver.valid(s):
            # from cozy.evaluation import eval
            # from cozy.solver import satisfy
            # model = satisfy(EAll([assumptions, EBinOp(total_size, "<", my_size).with_type(BOOL)]), collection_depth=3, validate_model=True)
            # assert model is not None
            # raise ExpIsNotWf(e, e, "non-polynomial-sized map ({}); total_size={}, this_size={}".format(model, eval(total_size, model), eval(my_size, model)))
            raise ExpIsNotWf(e, e, "non-polynomial-sized map")
예제 #10
0
파일: core.py 프로젝트: timwee/cozy
def improve(target: Exp,
            context: Context,
            assumptions: Exp = T,
            stop_callback=never_stop,
            hints: [Exp] = (),
            examples: [{
                str: object
            }] = (),
            cost_model: CostModel = None):
    """
    Improve the target expression using enumerative synthesis.
    This function is a generator that yields increasingly better and better
    versions of the input expression `target`.

    Notes on internals of this algorithm follow.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, and lets us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r})""".format(target=target,
                                             context=context,
                                             assumptions=assumptions,
                                             stop_callback=stop_callback,
                                             hints=hints,
                                             examples=examples,
                                             cost_model=cost_model))

    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    try:
        assert exp_wf(target, context=context, assumptions=assumptions)
    except ExpIsNotWf as ex:
        print(
            "WARNING: initial target is not well-formed [{}]; this might go poorly..."
            .format(str(ex)))
        print(pprint(ex.offending_subexpression))
        print(pprint(ex.offending_subexpression.type))
        # raise

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints] + [
        freshen_binders(wrap_naked_statevars(a, state_vars), context)
        for a in break_conj(assumptions)
    ] + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = None
    if incremental.value:
        solver = IncrementalSolver(vars=vars, funcs=funcs)
        solver.add_assumption(assumptions)
        _sat = solver.satisfy
    else:
        _sat = lambda e: satisfy(e, vars=vars, funcs=funcs)

    if _sat(assumptions) is None:
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    learner = Learner(watched_targets, assumptions, context, examples,
                      cost_model, stop_callback, hints)
    try:
        while True:
            # 1. find any potential improvement to any sub-exp of target
            for new_target in learner.next():
                print("Found candidate improvement: {}".format(
                    pprint(new_target)))

                # 2. check
                with task("verifying candidate"):
                    if incremental.value:
                        solver.push()
                        solver.add_assumption(
                            ENot(
                                EBinOp(target, "==",
                                       new_target).with_type(BOOL)))
                        counterexample = _sat(T)
                    else:
                        formula = EAll([
                            assumptions,
                            ENot(
                                EBinOp(target, "==",
                                       new_target).with_type(BOOL))
                        ])
                        counterexample = _sat(formula)
                if counterexample is not None:
                    if counterexample in examples:
                        print("assumptions = {!r}".format(assumptions))
                        print("duplicate example: {!r}".format(counterexample))
                        print("old target = {!r}".format(target))
                        print("new target = {!r}".format(new_target))
                        raise Exception("got a duplicate example")
                    # a. if incorrect: add example, reset the learner
                    examples.append(counterexample)
                    event("new example: {!r}".format(counterexample))
                    print("wrong; restarting with {} examples".format(
                        len(examples)))
                    learner.reset(examples)
                    break
                else:
                    # b. if correct: yield it, watch the new target, goto 1
                    print("The candidate is valid!")
                    print(repr(new_target))
                    print("Determining whether to yield it...")
                    with task("updating frontier"):
                        to_evict = []
                        keep = True
                        old_better = None
                        for old_target in watched_targets:
                            evc = eviction_policy(new_target, context,
                                                  old_target, context,
                                                  RUNTIME_POOL, cost_model)
                            if old_target not in evc:
                                to_evict.append(old_target)
                            if new_target not in evc:
                                old_better = old_target
                                keep = False
                                break
                        for t in to_evict:
                            watched_targets.remove(t)
                        if not keep:
                            print(
                                "Whoops! Looks like we already found something better."
                            )
                            print(" --> {}".format(pprint(old_better)))
                            continue
                        if target in to_evict:
                            print("Yep, it's an improvement!")
                            yield new_target
                            if heuristic_done(new_target):
                                print("target now matches doneness heuristic")
                                raise NoMoreImprovements()
                            target = new_target
                        else:
                            print("Nope, it isn't substantially better!")

                    watched_targets.append(new_target)
                    print("Now watching {} targets".format(
                        len(watched_targets)))
                    learner.watch(watched_targets)
                    break

                if incremental.value:
                    solver.pop()
    except NoMoreImprovements:
        return
    except KeyboardInterrupt:
        raise
예제 #11
0
    def _enumerate_core(self, context: Context, size: int,
                        pool: Pool) -> [Exp]:
        """Build new expressions of the given size.

        Arguments:
            context : a Context object describing the vars in scope
            size    : size of expressions to enumerate; each expression in
                      the output will have this size
            pool    : pool to enumerate

        This function is not cached.  Clients should call `enumerate` instead.

        This function tries to be a clean description of the Cozy grammar.  It
        does not concern itself with deduplication (which is handled
        efficiently by equivalence class deduplication).
        """

        if size < 0:
            return

        if size == 0:
            for e in LITERALS:
                yield e

            all_interesting_types = OrderedSet(self.hint_types)
            for v, _ in context.vars():
                all_interesting_types |= all_types(v.type)
            for t in all_interesting_types:
                l = construct_value(t)
                if l not in LITERALS:
                    yield l

            for (v, p) in context.vars():
                if p == pool:
                    yield v
            for (e, ctx, p) in self.hints:
                if p == pool:
                    fvs = free_vars(e)
                    if ctx.alpha_equivalent(context.generalize(fvs)):
                        yield context.adapt(e, ctx, e_fvs=fvs)
            return

        if not do_enumerate.value:
            return

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.elem_type,
                          omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Load all smaller expressions in this context and pool.
        # cache[S] contains expressions of size S in this context and pool.
        cache = [list(self.enumerate(context, sz, pool)) for sz in range(size)]

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context.root(), size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Arity-1 expressions
        for e in cache[size - 1]:
            if is_collection(e.type):
                elem_type = e.type.elem_type

                # This method of generating EEmptyList() ensures that we visit
                # empty collections of all possible types.
                yield EEmptyList().with_type(e.type)

                if is_numeric(elem_type):
                    yield EUnaryOp(UOp.Sum, e).with_type(elem_type)

                yield EUnaryOp(UOp.Length, e).with_type(INT)
                yield EUnaryOp(UOp.Empty, e).with_type(BOOL)
                yield EUnaryOp(UOp.Exists, e).with_type(BOOL)
                yield EUnaryOp(UOp.The, e).with_type(elem_type)
                yield EUnaryOp(UOp.Distinct, e).with_type(e.type)
                yield EUnaryOp(UOp.AreUnique, e).with_type(BOOL)

                if elem_type == BOOL:
                    yield EUnaryOp(UOp.Any, e).with_type(BOOL)
                    yield EUnaryOp(UOp.All, e).with_type(BOOL)

            yield ESingleton(e).with_type(TBag(e.type))

            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

            if e.type == BOOL:
                yield EUnaryOp(UOp.Not, e).with_type(BOOL)

            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

            if isinstance(e.type, TMap):
                yield EMapKeys(e).with_type(TBag(e.type.k))

        # Arity-2 expressions
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            # sz1 + sz2 = size - 1
            for e1 in cache[sz1]:
                t = e1.type

                if is_numeric(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", a2).with_type(t)
                        yield EBinOp(e1, "-", a2).with_type(t)

                if is_ordered(t):
                    for a2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, ">", a2).with_type(BOOL)
                        yield EBinOp(e1, "<", a2).with_type(BOOL)
                        yield EBinOp(e1, ">=", a2).with_type(BOOL)
                        yield EBinOp(e1, "<=", a2).with_type(BOOL)

                if t == BOOL:
                    for a2 in of_type(cache[sz2], BOOL):
                        yield EBinOp(e1, BOp.And, a2).with_type(BOOL)
                        yield EBinOp(e1, BOp.Or, a2).with_type(BOOL)
                        # Cozy supports the implication operator "=>", but this
                        # function does not enumerate it because
                        #  - (a => b) is equivalent to ((not a) or b)
                        #  - there isn't an implication operator in any of our
                        #    current target languages, so we would need to
                        #    desugar it to ((not a) or b) anyway.

                if not isinstance(t, TMap):
                    for a2 in of_type(cache[sz2], t):
                        yield EEq(e1, a2)
                        yield EBinOp(e1, "!=", a2).with_type(BOOL)

                if isinstance(t, TMap):
                    for k in of_type(cache[sz2], t.k):
                        yield EMapGet(e1, k).with_type(t.v)
                        yield EHasKey(e1, k).with_type(BOOL)

                if isinstance(t, TList):
                    for i in of_type(cache[sz2], INT):
                        yield EListGet(e1, i).with_type(e1.type.elem_type)

                if is_collection(t):
                    elem_type = t.elem_type
                    for e2 in of_type(cache[sz2], t):
                        yield EBinOp(e1, "+", e2).with_type(t)
                        yield EBinOp(e1, "-", e2).with_type(t)
                    for e2 in of_type(cache[sz2], elem_type):
                        yield EBinOp(e2, BOp.In, e1).with_type(BOOL)
                    for f in build_lambdas(e1, pool, sz2):
                        body_type = f.body.type
                        yield EMap(e1, f).with_type(TBag(body_type))
                        if body_type == BOOL:
                            yield EFilter(e1, f).with_type(t)
                        if is_numeric(body_type):
                            yield EArgMin(e1, f).with_type(elem_type)
                            yield EArgMax(e1, f).with_type(elem_type)
                        if is_collection(body_type):
                            yield EFlatMap(e1, f).with_type(
                                TBag(body_type.elem_type))

                        if pool == STATE_POOL and is_hashable(elem_type):
                            yield EMakeMap2(e1, f).with_type(
                                TMap(elem_type, body_type))

                e1_singleton = ESingleton(e1).with_type(TBag(e1.type))
                for f in build_lambdas(e1_singleton, pool, sz2):
                    yield ELet(e1, f).with_type(f.body.type)

        # Arity-3 expressions
        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            # sz1 + sz2 + sz3 = size - 1
            for e1 in cache[sz1]:
                if e1.type == BOOL:
                    cond = e1
                    for then_branch in cache[sz2]:
                        for else_branch in of_type(cache[sz3],
                                                   then_branch.type):
                            yield ECond(cond, then_branch,
                                        else_branch).with_type(
                                            then_branch.type)
                if isinstance(e1.type, TList):
                    for start in of_type(cache[sz2], INT):
                        for end in of_type(cache[sz3], INT):
                            yield EListSlice(e1, start, end).with_type(e1.type)
                            # It is not necessary to create slice expressions of
                            # the form a[:i] or a[i:].  Those are desugared
                            # after parsing to a[0:i] and a[i:len(a)]
                            # respectively, and Cozy is perfectly capable of
                            # discovering these expanded forms as well.

        for h in all_extension_handlers():
            yield from h.enumerate(context, size, pool, self.enumerate,
                                   build_lambdas)
예제 #12
0
    def _enumerate_with_info(self, context: Context, size: int,
                             pool: Pool) -> [EnumeratedExp]:
        """Helper for enumerate_with_info that bypasses the cache.

        Note that this method DOES affect the cache: it writes its output into
        the cache and may do evictions.  The enumerate_with_info method ensures
        that there is only ever one call to this method for a given (context,
        size, pool).
        """

        examples = context.instantiate_examples(self.examples)
        cache = self.cache
        queue = self._enumerate_core(context, size, pool)
        cost_model = self.cost_model

        while True:
            if self.stop_callback():
                raise StopException()

            try:
                e = next(queue)
            except StopIteration:
                # StopIteration is a "control flow exception" indicating that
                # there isn't a next element.  Since the queue is exhausted,
                # breaking out of the loop is the right thing to do.
                break

            self.stat_timer.check()

            e = freshen_binders(e, context)
            _consider(e, size, context, pool)

            wf = self.check_wf(e, context, pool)
            if not wf:
                _skip(e, size, context, pool, "wf={}".format(wf))
                continue

            fp = Fingerprint.of(e, examples)

            # Collect all expressions from parent contexts that are
            # fingerprint-equivalent to this one.  There might be more than one
            # because of how `retention_policy` works.
            known_equivalents = list(
                cache.find_equivalent_expressions(context, pool, fp))
            to_evict = []

            if any(e.type == prev_entry.e.type
                   and alpha_equivalent(e, prev_entry.e)
                   for prev_entry in known_equivalents):
                _skip(e, size, context, pool, "duplicate")
                should_keep = False
            else:
                # decide whether to keep this expression
                should_keep = True
                if known_equivalents:
                    with task("comparing to cached equivalents",
                              count=len(known_equivalents)):
                        for entry in known_equivalents:
                            prev_exp = entry.e
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = retention_policy(e, context, prev_exp,
                                                       context, pool,
                                                       cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break
                            if prev_exp not in to_keep:
                                to_evict.append(entry)

            assert not (to_evict and not should_keep)

            if should_keep:

                if self.do_eviction and to_evict:
                    with task("evicting", count=to_evict):
                        for entry in to_evict:
                            _evict(entry.e, entry.size, context, pool, e, size)
                            cache.remove(context, pool, entry)

                _accept(e, size, context, pool, fp)
                info = EnumeratedExp(e=e, fingerprint=fp, size=size)
                yield info
                cache.add(context, pool, info)

                if size == 0:
                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)
예제 #13
0
파일: core.py 프로젝트: uwplse/cozy
def improve(
        target        : Exp,
        context       : Context,
        assumptions   : Exp                = ETRUE,
        stop_callback : Callable[[], bool] = never_stop,
        hints         : [Exp]              = (),
        examples      : [{str:object}]     = (),
        cost_model    : CostModel          = None,
        ops           : [Op]               = (),
        improve_count   : Value              = None):
    """Improve the target expression using enumerative synthesis.

    This function is a generator that yields increasingly better and better
    versions of the input expression `target` in the given `context`.  The
    `cost_model` defines "better".

    It periodically calls `stop_callback` and exits gracefully when
    `stop_callback` returns True.

    Other parameters:
        - assumptions: a precondition.  The yielded improvements will only be
          correct when the assumptions are true.
        - hints: expressions that might be useful.  These will be explored
          first when looking for improvements.
        - examples: inputs that will be used internally to differentiate
          semantically distinct expressions.  This procedure discovers more
          examples as it runs, so there usually isn't a reason to provide any.
        - ops: update operations.  This function may make different choices
          about what expressions are state expressions based on what changes
          can happen to that state.

    Key differences from "regular" enumerative synthesis:
        - Expressions are either "state" expressions or "runtime" expressions,
          allowing this algorithm to choose what things to store on the data
          structure and what things to compute at query execution time. (The
          cost model is ultimately responsible for this choice.)
        - If a better version of *any subexpression* for the target is found,
          it is immediately substituted in and the overall expression is
          returned. This "smooths out" the search space a little, allowing us
          find kinda-good solutions very quickly, even if the best possible
          solution is out of reach.  This is more desireable than running for
          an indeterminate amount of time doing nothing.
    """

    print("call to improve:")
    print("""improve(
        target={target!r},
        context={context!r},
        assumptions={assumptions!r},
        stop_callback={stop_callback!r},
        hints={hints!r},
        examples={examples!r},
        cost_model={cost_model!r},
        ops={ops!r})""".format(
            target=target,
            context=context,
            assumptions=assumptions,
            stop_callback=stop_callback,
            hints=hints,
            examples=examples,
            cost_model=cost_model,
            ops=ops))

    target = inline_lets(target)
    target = freshen_binders(target, context)
    assumptions = freshen_binders(assumptions, context)

    if heuristic_done(target):
        print("The target already looks great!")
        return

    print()
    print("improving: {}".format(pprint(target)))
    print("subject to: {}".format(pprint(assumptions)))
    print()

    is_wf = exp_wf(target, context=context, assumptions=assumptions)
    assert is_wf, "initial target is not well-formed: {}".format(is_wf)

    state_vars = [v for (v, p) in context.vars() if p == STATE_POOL]
    if eliminate_vars.value and can_elim_vars(target, assumptions, state_vars):
        print("This job does not depend on state_vars.")
        # TODO: what can we do about it?

    hints = ([freshen_binders(h, context) for h in hints]
        + [freshen_binders(wrap_naked_statevars(a, state_vars), context) for a in break_conj(assumptions)]
        + [target])
    print("{} hints".format(len(hints)))
    for h in hints:
        print(" - {}".format(pprint(h)))
    vars = list(v for (v, p) in context.vars())
    funcs = context.funcs()

    solver = solver_for_context(context, assumptions=assumptions)

    if not solver.satisfiable(ETRUE):
        print("assumptions are unsat; this query will never be called")
        yield construct_value(target.type)
        return

    is_good = possibly_useful(solver, target, context)
    assert is_good, "WARNING: this target is already a bad idea\n is_good = {}, target = {}".format(is_good, target)

    examples = list(examples)

    if cost_model is None:
        cost_model = CostModel(funcs=funcs, assumptions=assumptions)

    watched_targets = [target]
    blacklist = {}

    while True:
        # 1. find any potential improvement to any sub-exp of target
        for new_target in search_for_improvements(
                targets=watched_targets,
                wf_solver=solver,
                context=context,
                examples=examples,
                cost_model=cost_model,
                stop_callback=stop_callback,
                hints=hints,
                ops=ops,
                blacklist=blacklist):
            print("Found candidate improvement: {}".format(pprint(new_target)))

            # 2. check
            with task("verifying candidate"):
                counterexample = solver.satisfy(ENot(EEq(target, new_target)))

            if counterexample is not None:
                if counterexample in examples:
                    print("assumptions = {!r}".format(assumptions))
                    print("duplicate example: {!r}".format(counterexample))
                    print("old target = {!r}".format(target))
                    print("new target = {!r}".format(new_target))
                    raise Exception("got a duplicate example")
                # a. if incorrect: add example, restart
                examples.append(counterexample)
                print("new example: {!r}".format(counterexample))
                print("wrong; restarting with {} examples".format(len(examples)))
                break
            else:
                # b. if correct: yield it, watch the new target, goto 1
                print("The candidate is valid!")
                print(repr(new_target))
                print("Determining whether to yield it...")
                with task("updating frontier"):
                    to_evict = []
                    keep = True
                    old_better = None
                    for old_target in watched_targets:
                        evc = retention_policy(new_target, context, old_target, context, RUNTIME_POOL, cost_model)
                        if old_target not in evc:
                            to_evict.append(old_target)
                        if new_target not in evc:
                            old_better = old_target
                            keep = False
                            break
                    for t in to_evict:
                        watched_targets.remove(t)
                    if not keep:
                        print("Whoops! Looks like we already found something better.")
                        print(" --> {}".format(pprint(old_better)))
                        continue
                    if target in to_evict:
                        print("Yep, it's an improvement!")
                        yield new_target
                        if heuristic_done(new_target):
                            print("target now matches doneness heuristic")
                            return
                        target = new_target
                    else:
                        print("Nope, it isn't substantially better!")

                watched_targets.append(new_target)
                print("Now watching {} targets".format(len(watched_targets)))
                break

        if improve_count is not None:
            with improve_count.get_lock():
                improve_count.value += 1
예제 #14
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        if context.parent() is not None:
            yield from self.enumerate_with_info(context.parent(), size, pool)

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            # print("[[{} cached @ size={}]]".format(len(res), size))
            for e in res:
                yield e
        else:
            # print("ENTER {}".format(k))
            examples = context.instantiate_examples(self.examples)
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression,
                    # decide which can be evicted
                    should_keep = True
                    # cost = self.cost_model.cost(e, pool)
                    # print("prev={}".format(prev))
                    # print("seen={}".format(self.seen))
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            # prev_cost = self.cost_model.cost(prev_exp, pool)
                            # ordering = cost.compare_to(prev_cost)
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                            # if ordering == Order.LT:
                            #     pass
                            # elif ordering == Order.GT:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "worse than {}".format(pprint(prev_exp)))
                            #     should_keep = False
                            #     break
                            # else:
                            #     self.blacklist.add(e_key)
                            #     _skip(e, context, pool, "{} to cached {}".format(
                            #         "equal" if ordering == Order.EQUAL else "similar",
                            #         pprint(prev_exp)))
                            #     assert ordering in (Order.EQUAL, Order.AMBIGUOUS)
                            #     should_keep = False
                            #     break

                if should_keep:

                    with task("evicting"):
                        to_evict = []
                        for (key, exps) in self.cache.items():
                            (p, s, c) = key
                            if p == pool and c in itertools.chain(
                                [context], parent_contexts(context)):
                                for ee in exps:
                                    if ee.fingerprint == fp:  # and cost_model.compare(e, ee.e, context, pool) == Order.LT:
                                        # to_evict.append((key, ee))
                                        to_keep = eviction_policy(
                                            e, context, ee.e, c, pool,
                                            cost_model)
                                        if ee.e not in to_keep:
                                            to_evict.append((key, ee))
                        for key, ee in to_evict:
                            (p, s, c) = key
                            # self.blacklist.add((ee.e, c, pool))
                            _evict(ee.e, c, pool, e)
                            self.cache[key].remove(ee)
                            self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp, cost=None)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            # print("trying {} accelerations".format(len(to_try)))
                            queue = itertools.chain(to_try, queue)

            # print("EXIT {}".format(k))
            self.in_progress.remove(k)
예제 #15
0
    def enumerate_with_info(self, context: Context, size: int,
                            pool: Pool) -> [EnumeratedExp]:
        canonical_context = self.canonical_context(context)
        if canonical_context is not context:
            print("adapting request: {} ---> {}".format(
                context, canonical_context))
            for info in self.enumerate_with_info(canonical_context, size,
                                                 pool):
                yield info._replace(e=context.adapt(info.e, canonical_context))
            return

        examples = context.instantiate_examples(self.examples)
        if context.parent() is not None:
            for info in self.enumerate_with_info(context.parent(), size, pool):
                e = info.e
                yield EnumeratedExp(e=e, fingerprint=fingerprint(e, examples))

        k = (pool, size, context)
        res = self.cache.get(k)
        if res is not None:
            for e in res:
                yield e
        else:
            assert k not in self.in_progress, "recursive enumeration?? {}".format(
                k)
            self.in_progress.add(k)
            res = []
            self.cache[k] = res
            queue = self.enumerate_core(context, size, pool)
            cost_model = self.cost_model
            while True:
                if self.stop_callback():
                    raise StopException()

                try:
                    e = next(queue)
                except StopIteration:
                    break

                fvs = free_vars(e)
                if not belongs_in_context(fvs, context):
                    continue

                e = freshen_binders(e, context)
                _consider(e, size, context, pool)

                wf = self.check_wf(e, context, pool)
                if not wf:
                    _skip(e, size, context, pool, "wf={}".format(wf))
                    continue

                fp = fingerprint(e, examples)

                # collect all expressions from parent contexts
                with task("collecting prev exps",
                          size=size,
                          context=context,
                          pool=pool_name(pool)):
                    prev = []
                    for sz in range(0, size + 1):
                        prev.extend(self.enumerate_with_info(
                            context, sz, pool))
                    prev = [p.e for p in prev if p.fingerprint == fp]

                if any(alpha_equivalent(e, p) for p in prev):
                    _skip(e, size, context, pool, "duplicate")
                    should_keep = False
                else:
                    # decide whether to keep this expression
                    should_keep = True
                    with task("comparing to cached equivalents"):
                        for prev_exp in prev:
                            event("previous: {}".format(pprint(prev_exp)))
                            to_keep = eviction_policy(e, context, prev_exp,
                                                      context, pool,
                                                      cost_model)
                            if e not in to_keep:
                                _skip(e, size, context, pool,
                                      "preferring {}".format(pprint(prev_exp)))
                                should_keep = False
                                break

                if should_keep:

                    if self.do_eviction:
                        with task("evicting"):
                            to_evict = []
                            for (key, exps) in self.cache.items():
                                (p, s, c) = key
                                if p == pool and c == context:
                                    for ee in exps:
                                        if ee.fingerprint == fp:
                                            event("considering eviction of {}".
                                                  format(pprint(ee.e)))
                                            to_keep = eviction_policy(
                                                e, context, ee.e, c, pool,
                                                cost_model)
                                            if ee.e not in to_keep:
                                                to_evict.append((key, ee))
                            for key, ee in to_evict:
                                (p, s, c) = key
                                _evict(ee.e, s, c, pool, e)
                                self.cache[key].remove(ee)
                                self.seen[(c, p, fp)].remove(ee.e)

                    _accept(e, size, context, pool)
                    seen_key = (context, pool, fp)
                    if seen_key not in self.seen:
                        self.seen[seen_key] = []
                    self.seen[seen_key].append(e)
                    info = EnumeratedExp(e=e, fingerprint=fp)
                    res.append(info)
                    yield info

                    with task("accelerating"):
                        to_try = make_random_access(
                            self.heuristics(e, context, pool))
                        if to_try:
                            event("trying {} accelerations of {}".format(
                                len(to_try), pprint(e)))
                            queue = itertools.chain(to_try, queue)

            self.in_progress.remove(k)
예제 #16
0
def good_idea(solver,
              e: Exp,
              context: Context,
              pool=RUNTIME_POOL,
              assumptions: Exp = T) -> bool:
    """Heuristic filter to ignore expressions that are almost certainly useless."""

    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])
    at_runtime = pool == RUNTIME_POOL

    if isinstance(e, EStateVar) and not free_vars(e.e):
        return No("constant value in state position")
    if (isinstance(e, EDropFront)
            or isinstance(e, EDropBack)) and not at_runtime:
        return No("EDrop* in state position")
    if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime:
        return No("EFlatMap in state position")
    if not allow_int_arithmetic_state.value and not at_runtime and isinstance(
            e, EBinOp) and e.type == INT:
        return No("integer arithmetic in state position")
    if is_collection(e.type) and not is_scalar(e.type.t):
        return No("collection of nonscalar")
    if isinstance(e.type, TMap) and not is_scalar(e.type.k):
        return No("bad key type {}".format(pprint(e.type.k)))
    if isinstance(e.type, TMap) and isinstance(e.type.v, TMap):
        return No("map to map")
    # This check is probably a bad idea: whether `the` is legal may depend on
    # the contex that the expression is embedded within, so we can't skip it
    # during synthesis just because it looks invalid now.
    # if isinstance(e, EUnaryOp) and e.op == UOp.The:
    #     len = EUnaryOp(UOp.Length, e.e).with_type(INT)
    #     if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))):
    #         return No("illegal application of 'the': could have >1 elems")
    if not at_runtime and isinstance(
            e, EBinOp) and e.op == "-" and is_collection(e.type):
        return No("collection subtraction in state position")
    # if not at_runtime and isinstance(e, ESingleton):
    #     return No("singleton in state position")
    # if not at_runtime and isinstance(e, ENum) and e.val != 0 and e.type == INT:
    #     return No("nonzero integer constant in state position")
    if at_runtime and isinstance(e, EStateVar) and isinstance(
            e.e, EBinOp) and is_scalar(e.e.e1.type) and is_scalar(e.e.e2.type):
        return No("constant-time binary operator in state position")
    if not allow_conditional_state.value and not at_runtime and isinstance(
            e, ECond):
        return No("conditional in state position")
    if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList):
        return No("trivially empty map")
    if not allow_peels.value and not at_runtime and isinstance(e, EFilter):
        # catch "peels": removal of zero or one elements
        if solver.valid(
                EImplies(
                    assumptions,
                    ELe(
                        ELen(
                            EFilter(e.e, ELambda(e.p.arg, ENot(
                                e.p.body))).with_type(e.type)), ONE))):
            return No("filter is a peel")
    if not allow_big_maps.value and not at_runtime and isinstance(
            e, EMakeMap2) and is_collection(e.type.v):
        all_collections = [sv for sv in state_vars if is_collection(sv.type)]
        total_size = ENum(0).with_type(INT)
        for c in all_collections:
            total_size = EBinOp(total_size, "+",
                                EUnaryOp(UOp.Length,
                                         c).with_type(INT)).with_type(INT)
        my_size = EUnaryOp(
            UOp.Length,
            EFlatMap(EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type),
                     e.value).with_type(e.type.v)).with_type(INT)
        s = EImplies(assumptions,
                     EBinOp(total_size, ">=", my_size).with_type(BOOL))
        if not solver.valid(s):
            # from cozy.evaluation import eval
            # from cozy.solver import satisfy
            # model = satisfy(EAll([assumptions, EBinOp(total_size, "<", my_size).with_type(BOOL)]), collection_depth=3, validate_model=True)
            # assert model is not None
            # return No("non-polynomial-sized map ({}); total_size={}, this_size={}".format(model, eval(total_size, model), eval(my_size, model)))
            return No("non-polynomial-sized map")

    return True
예제 #17
0
def _try_optimize(e: Exp, context: Context, pool: Pool):
    if not accelerate.value:
        return

    if pool != RUNTIME_POOL:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    # ---------------------------------------------------------------------
    # "Rewrite schemes": these trigger on many different AST shapes
    # They are listed first because they are more powerful than the
    # specific rewrite rules below.

    if not free_vars(e) and not free_funcs(e):
        try:
            yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL)
        except NotImplementedError:
            print("Unable to evaluate {!r}".format(e))

    if all(v in state_vars for v in free_vars(e)):
        nsv = strip_EStateVar(e)
        sv = EStateVar(nsv).with_type(e.type)
        yield _check(sv, context, RUNTIME_POOL)

    for ee in fold_into_map(e, context):
        yield _check(ee, context, pool)

    # ---------------------------------------------------------------------
    # "Rewrites": these trigger on specific AST nodes

    if isinstance(e, EBinOp):

        if e.op == "-" and is_collection(e.type):
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == "===" and isinstance(e.e1.type, THandle):
            yield _check(
                EAll([
                    optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                    optimized_eq(optimized_val(e.e1),
                                 optimized_val(e.e2)).with_type(BOOL)
                ]), context, RUNTIME_POOL)

        if e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, ECond):
        yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch),
                     context, RUNTIME_POOL)

    if isinstance(e, EGetField):
        for ee in optimized_get_field(e.e, e.field_name, args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and e.index == ZERO:
        for res in optimized_the(e.e, args):
            yield _check(res, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and isinstance(e.e, ECond):
        yield optimized_cond(
            e.e.cond,
            EListGet(e.e.then_branch, e.index).with_type(e.type),
            EListGet(e.e.else_branch, e.index).with_type(e.type))

    from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek
    if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems):
        yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type)

    if isinstance(e, EMapGet):
        ee = inline_mapget(e, context)
        yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EUnaryOp):

        if e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Distinct:
            for ee in optimized_distinct(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.The:
            for ee in optimized_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EArgMin) or isinstance(e, EArgMax):
        for ee in optimized_best(e.e,
                                 e.key_function,
                                 "<" if isinstance(e, EArgMin) else ">",
                                 args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EFilter):
        for ee in optimized_filter(e.e, e.predicate, args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EMap):
        for ee in optimized_map(e.e, e.transform_function, args=args):
            yield _check(ee, context, RUNTIME_POOL)
    from cozy.syntax import ESorted
    from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems
    target = e
    if isinstance(target, ESorted) and isinstance(target.e, EStateVar):
        e_max = EMakeMaxTreeMultiset(target.e.e).with_type(
            TMaxTreeMultiset(target.e.e.type.elem_type))
        e_min = EMakeMinTreeMultiset(target.e.e).with_type(
            TMinTreeMultiset(target.e.e.type.elem_type))
        ee = optimized_cond(
            target.asc,
            ETreeMultisetElems(EStateVar(e_min).with_type(
                e_min.type)).with_type(target.type),
            ETreeMultisetElems(EStateVar(e_max).with_type(
                e_max.type)).with_type(target.type))
        yield _check(ee, context, RUNTIME_POOL)
예제 #18
0
    def enumerate_core(self, context: Context, size: int, pool: Pool) -> [Exp]:
        """
        Arguments:
            conext : a Context object describing the vars in scope
            size   : size to enumerate
            pool   : pool to enumerate

        Yields all expressions of the given size legal in the given context and
        pool.
        """

        if size < 0:
            return

        if size == 0:
            for (e, p) in LITERALS:
                if p == pool:
                    yield e
            for (v, p) in context.vars():
                if p == pool:
                    yield v
                for t in all_types(v):
                    yield construct_value(t)
            for (e, ctx, p) in self.hints:
                if p == pool and ctx.alpha_equivalent(context):
                    yield context.adapt(e, ctx)
                for t in all_types(e):
                    yield construct_value(t)
            return

        yield from self.heuristic_enumeration(context, size, pool)

        for e in collections(self.enumerate(context, size - 1, pool)):
            yield EEmptyList().with_type(e.type)
            if is_numeric(e.type.t):
                yield EUnaryOp(UOp.Sum, e).with_type(e.type.t)

        for e in self.enumerate(context, size - 1, pool):
            yield ESingleton(e).with_type(TBag(e.type))

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, TRecord):
                for (f, t) in e.type.fields:
                    yield EGetField(e, f).with_type(t)

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, THandle):
                yield EGetField(e, "val").with_type(e.type.value_type)

        for e in self.enumerate(context, size - 1, pool):
            if isinstance(e.type, TTuple):
                for n in range(len(e.type.ts)):
                    yield ETupleGet(e, n).with_type(e.type.ts[n])

        for e in of_type(self.enumerate(context, size - 1, pool), BOOL):
            yield EUnaryOp(UOp.Not, e).with_type(BOOL)

        for e in self.enumerate(context, size - 1, pool):
            if is_numeric(e.type):
                yield EUnaryOp("-", e).with_type(e.type)

        for m in self.enumerate(context, size - 1, pool):
            if isinstance(m.type, TMap):
                yield EMapKeys(m).with_type(TBag(m.type.k))

        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for a1 in self.enumerate(context, sz1, pool):
                t = a1.type
                if not is_numeric(t):
                    continue
                for a2 in of_type(self.enumerate(context, sz2, pool), t):
                    yield EBinOp(a1, "+", a2).with_type(t)
                    yield EBinOp(a1, "-", a2).with_type(t)
                    yield EBinOp(a1, ">", a2).with_type(BOOL)
                    yield EBinOp(a1, "<", a2).with_type(BOOL)
                    yield EBinOp(a1, ">=", a2).with_type(BOOL)
                    yield EBinOp(a1, "<=", a2).with_type(BOOL)
            for a1 in collections(self.enumerate(context, sz1, pool)):
                for a2 in of_type(self.enumerate(context, sz2, pool), a1.type):
                    yield EBinOp(a1, "+", a2).with_type(a1.type)
                    yield EBinOp(a1, "-", a2).with_type(a1.type)
                for a2 in of_type(self.enumerate(context, sz2, pool),
                                  a1.type.t):
                    yield EBinOp(a2, BOp.In, a1).with_type(BOOL)
            for a1 in of_type(self.enumerate(context, sz1, pool), BOOL):
                for a2 in of_type(self.enumerate(context, sz2, pool), BOOL):
                    yield EBinOp(a1, BOp.And, a2).with_type(BOOL)
                    yield EBinOp(a1, BOp.Or, a2).with_type(BOOL)
            for a1 in self.enumerate(context, sz1, pool):
                if not isinstance(a1.type, TMap):
                    for a2 in of_type(self.enumerate(context, sz2, pool),
                                      a1.type):
                        yield EEq(a1, a2)
                        yield EBinOp(a1, "!=", a2).with_type(BOOL)
            for m in self.enumerate(context, sz1, pool):
                if isinstance(m.type, TMap):
                    for k in of_type(self.enumerate(context, sz2, pool),
                                     m.type.k):
                        yield EMapGet(m, k).with_type(m.type.v)
                        yield EHasKey(m, k).with_type(BOOL)
            for l in self.enumerate(context, sz1, pool):
                if not isinstance(l.type, TList):
                    continue
                for i in of_type(self.enumerate(context, sz2, pool), INT):
                    yield EListGet(l, i).with_type(l.type.t)

        for (sz1, sz2, sz3) in pick_to_sum(3, size - 1):
            for cond in of_type(self.enumerate(context, sz1, pool), BOOL):
                for then_branch in self.enumerate(context, sz2, pool):
                    for else_branch in of_type(
                            self.enumerate(context, sz2, pool),
                            then_branch.type):
                        yield ECond(cond, then_branch,
                                    else_branch).with_type(then_branch.type)

            for l in self.enumerate(context, sz1, pool):
                if not isinstance(l.type, TList):
                    continue
                for st in of_type(self.enumerate(context, sz2, pool), INT):
                    for ed in of_type(self.enumerate(context, sz3, pool), INT):
                        yield EListSlice(l, st, ed).with_type(l.type)

        for bag in collections(self.enumerate(context, size - 1, pool)):
            # len of bag
            count = EUnaryOp(UOp.Length, bag).with_type(INT)
            yield count
            # empty?
            yield EUnaryOp(UOp.Empty, bag).with_type(BOOL)
            # exists?
            yield EUnaryOp(UOp.Exists, bag).with_type(BOOL)
            # singleton?
            yield EEq(count, ONE)

            yield EUnaryOp(UOp.The, bag).with_type(bag.type.t)
            yield EUnaryOp(UOp.Distinct, bag).with_type(bag.type)
            yield EUnaryOp(UOp.AreUnique, bag).with_type(BOOL)

            if bag.type.t == BOOL:
                yield EUnaryOp(UOp.Any, bag).with_type(BOOL)
                yield EUnaryOp(UOp.All, bag).with_type(BOOL)

        def build_lambdas(bag, pool, body_size):
            v = fresh_var(bag.type.t, omit=set(v for v, p in context.vars()))
            inner_context = UnderBinder(context, v=v, bag=bag, bag_pool=pool)
            for lam_body in self.enumerate(inner_context, body_size, pool):
                yield ELambda(v, lam_body)

        # Iteration
        for (sz1, sz2) in pick_to_sum(2, size - 1):
            for bag in collections(self.enumerate(context, sz1, pool)):
                for lam in build_lambdas(bag, pool, sz2):
                    body_type = lam.body.type
                    yield EMap(bag, lam).with_type(TBag(body_type))
                    if body_type == BOOL:
                        yield EFilter(bag, lam).with_type(bag.type)
                    if is_numeric(body_type):
                        yield EArgMin(bag, lam).with_type(bag.type.t)
                        yield EArgMax(bag, lam).with_type(bag.type.t)
                    if is_collection(body_type):
                        yield EFlatMap(bag, lam).with_type(TBag(body_type.t))

        # Enable use of a state-pool expression at runtime
        if pool == RUNTIME_POOL:
            for e in self.enumerate(context, size - 1, STATE_POOL):
                yield EStateVar(e).with_type(e.type)

        # Create maps
        if pool == STATE_POOL:
            for (sz1, sz2) in pick_to_sum(2, size - 1):
                for bag in collections(self.enumerate(context, sz1,
                                                      STATE_POOL)):
                    if not is_scalar(bag.type.t):
                        continue
                    for lam in build_lambdas(bag, STATE_POOL, sz2):
                        t = TMap(bag.type.t, lam.body.type)
                        m = EMakeMap2(bag, lam).with_type(t)
                        yield m
예제 #19
0
def possibly_useful_nonrecursive(
    solver,
    e: Exp,
    context: Context,
    pool=RUNTIME_POOL,
    assumptions: Exp = ETRUE,
    ops: [Op] = ()) -> bool:
    """Heuristic filter to ignore expressions that are almost certainly useless."""

    state_vars = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])
    at_runtime = pool == RUNTIME_POOL

    h = extension_handler(type(e))
    if h is not None:
        res = h.possibly_useful(e, context, pool, assumptions, ops, solver)
        if not res:
            return res

    if isinstance(e, EStateVar) and not free_vars(e.e):
        return No("constant value in state position")
    if (isinstance(e, EDropFront)
            or isinstance(e, EDropBack)) and not at_runtime:
        return No("EDrop* in state position")
    if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime:
        return No("EFlatMap in state position")
    if not allow_int_arithmetic_state.value and not at_runtime and isinstance(
            e, EBinOp) and e.type == INT:
        return No("integer arithmetic in state position")
    if is_collection(e.type) and not is_scalar(e.type.elem_type):
        return No("collection of nonscalar: e {}\n elem_type: {}\n".format(
            e, e.type.elem_type))
    if isinstance(e.type, TMap) and not is_scalar(e.type.k):
        return No("bad key type {}".format(pprint(e.type.k)))
    if isinstance(e.type, TMap) and isinstance(e.type.v, TMap):
        return No("map to map")
    # This check is probably a bad idea: whether `the` is legal may depend on
    # the contex that the expression is embedded within, so we can't skip it
    # during synthesis just because it looks invalid now.
    # if isinstance(e, EUnaryOp) and e.op == UOp.The:
    #     len = EUnaryOp(UOp.Length, e.e).with_type(INT)
    #     if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))):
    #         return No("illegal application of 'the': could have >1 elems")
    if not at_runtime and isinstance(
            e, EBinOp) and e.op == "-" and is_collection(e.type):
        return No("collection subtraction in state position")
    # if not at_runtime and isinstance(e, ESingleton):
    #     return No("singleton in state position")
    if not allow_nonzero_state_constants.value and not at_runtime and isinstance(
            e, ENum) and e.val != 0:
        return No("nonzero integer constant in state position")
    if not allow_binop_state.value and at_runtime and isinstance(
            e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar(
                e.e.e1.type) and is_scalar(e.e.e2.type):
        return No(
            "constant-time binary operator {!r} in state position".format(
                e.e.op))
    if not allow_conditional_state.value and not at_runtime and isinstance(
            e, ECond):
        return No("conditional in state position")
    if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList):
        return No("trivially empty map")
    if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton):
        return No("really tiny map")
    if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)):
        # Cozy has no way to efficiently implement mins/maxes when more than
        # one element may leave the collection.
        from cozy.state_maintenance import mutate
        for op in ops:
            elems = e.e
            elems_prime = mutate(elems, op.body)
            formula = EAll([assumptions] + list(op.assumptions) + [
                EGt(
                    ELen(
                        EBinOp(elems, "-", elems_prime).with_type(elems.type)),
                    ONE)
            ])
            if solver.satisfiable(formula):
                return No(
                    "more than one element might be removed during {}".format(
                        op.name))
    if not allow_peels.value and not at_runtime and isinstance(e, EFilter):
        # catch "peels": removal of zero or one elements
        if solver.valid(
                EImplies(
                    assumptions,
                    ELe(
                        ELen(
                            EFilter(
                                e.e,
                                ELambda(e.predicate.arg, ENot(
                                    e.predicate.body))).with_type(e.type)),
                        ONE))):
            return No("filter is a peel")
    if not allow_big_maps.value and not at_runtime and isinstance(
            e, EMakeMap2) and is_collection(e.type.v):
        all_collections = [sv for sv in state_vars if is_collection(sv.type)]
        total_size = ENum(0).with_type(INT)
        for c in all_collections:
            total_size = EBinOp(total_size, "+",
                                EUnaryOp(UOp.Length,
                                         c).with_type(INT)).with_type(INT)
        my_size = EUnaryOp(
            UOp.Length,
            EFlatMap(
                EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type),
                e.value_function).with_type(e.type.v)).with_type(INT)
        s = EImplies(assumptions,
                     EBinOp(total_size, ">=", my_size).with_type(BOOL))
        if not solver.valid(s):
            return No("non-polynomial-sized map")

    return True
예제 #20
0
파일: acceleration.py 프로젝트: uwplse/cozy
def _try_optimize(e : Exp, context : Context, pool : Pool):
    if not accelerate.value:
        return

    if pool != RUNTIME_POOL:
        return

    state_vars = [v for v, p in context.vars() if p == STATE_POOL]
    args = [v for v, p in context.vars() if p == RUNTIME_POOL]

    # ---------------------------------------------------------------------
    # "Rewrite schemes": these trigger on many different AST shapes
    # They are listed first because they are more powerful than the
    # specific rewrite rules below.

    if not free_vars(e) and not free_funcs(e):
        try:
            yield _check(uneval(e.type, eval(e, {})), context, RUNTIME_POOL)
        except NotImplementedError:
            print("Unable to evaluate {!r}".format(e))

    if all(v in state_vars for v in free_vars(e)):
        nsv = strip_EStateVar(e)
        sv = EStateVar(nsv).with_type(e.type)
        yield _check(sv, context, RUNTIME_POOL)

    for ee in fold_into_map(e, context):
        yield _check(ee, context, pool)

    # ---------------------------------------------------------------------
    # "Rewrites": these trigger on specific AST nodes

    if isinstance(e, EBinOp):

        if e.op == "-" and is_collection(e.type):
            ee = optimized_bag_difference(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == "===" and isinstance(e.e1.type, THandle):
            yield _check(EAll([
                optimized_eq(optimized_addr(e.e1), optimized_addr(e.e2)),
                optimized_eq(optimized_val(e.e1),  optimized_val(e.e2)).with_type(BOOL)]), context, RUNTIME_POOL)

        if e.op == BOp.In:
            ee = optimized_in(e.e1, e.e2)
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, ECond):
        yield _check(optimized_cond(e.cond, e.then_branch, e.else_branch), context, RUNTIME_POOL)

    if isinstance(e, EGetField):
        for ee in optimized_get_field(e.e, e.field_name, args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and e.index == ZERO:
        for res in optimized_the(e.e, args):
            yield _check(res, context, RUNTIME_POOL)

    if isinstance(e, EListGet) and isinstance(e.e, ECond):
        yield optimized_cond(e.e.cond,
                             EListGet(e.e.then_branch, e.index).with_type(e.type),
                             EListGet(e.e.else_branch, e.index).with_type(e.type))

    from cozy.structures.treemultiset import ETreeMultisetElems, ETreeMultisetPeek
    if isinstance(e, EListGet) and isinstance(e.e, ETreeMultisetElems):
        yield ETreeMultisetPeek(e.e.e, e.index).with_type(e.type)

    if isinstance(e, EMapGet):
        ee = inline_mapget(e, context)
        yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EUnaryOp):

        if e.op == UOp.Sum:
            for ee in optimized_sum(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Length:
            ee = optimized_len(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Empty:
            ee = optimized_empty(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Exists:
            ee = optimized_exists(e.e)
            yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.Distinct:
            for ee in optimized_distinct(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

        if e.op == UOp.The:
            for ee in optimized_the(e.e, args):
                yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EArgMin) or isinstance(e, EArgMax):
        for ee in optimized_best(e.e, e.key_function, "<" if isinstance(e, EArgMin) else ">", args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EFilter):
        for ee in optimized_filter(e.e, e.predicate, args=args):
            yield _check(ee, context, RUNTIME_POOL)

    if isinstance(e, EMap):
        for ee in optimized_map(e.e, e.transform_function, args=args):
            yield _check(ee, context, RUNTIME_POOL)
    from cozy.syntax import ESorted
    from cozy.structures.treemultiset import EMakeMaxTreeMultiset, TMaxTreeMultiset, EMakeMinTreeMultiset, TMinTreeMultiset, ETreeMultisetElems
    target = e
    if isinstance(target, ESorted) and isinstance(target.e, EStateVar):
        e_max = EMakeMaxTreeMultiset(target.e.e).with_type(TMaxTreeMultiset(target.e.e.type.elem_type))
        e_min = EMakeMinTreeMultiset(target.e.e).with_type(TMinTreeMultiset(target.e.e.type.elem_type))
        ee = optimized_cond(target.asc,
                            ETreeMultisetElems(EStateVar(e_min).with_type(e_min.type)).with_type(target.type),
                            ETreeMultisetElems(EStateVar(e_max).with_type(e_max.type)).with_type(target.type))
        yield _check(ee, context, RUNTIME_POOL)
예제 #21
0
파일: core.py 프로젝트: uwplse/cozy
def possibly_useful_nonrecursive(solver, e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = ETRUE, ops : [Op] = ()) -> bool:
    """Heuristic filter to ignore expressions that are almost certainly useless."""

    state_vars  = OrderedSet(v for v, p in context.vars() if p == STATE_POOL)
    args        = OrderedSet(v for v, p in context.vars() if p == RUNTIME_POOL)
    assumptions = EAll([assumptions, context.path_condition()])
    at_runtime  = pool == RUNTIME_POOL

    h = extension_handler(type(e))
    if h is not None:
        res = h.possibly_useful(e, context, pool, assumptions, ops, solver)
        if not res:
            return res

    if isinstance(e, EStateVar) and not free_vars(e.e):
        return No("constant value in state position")
    if (isinstance(e, EDropFront) or isinstance(e, EDropBack)) and not at_runtime:
        return No("EDrop* in state position")
    if not allow_big_sets.value and isinstance(e, EFlatMap) and not at_runtime:
        return No("EFlatMap in state position")
    if not allow_int_arithmetic_state.value and not at_runtime and isinstance(e, EBinOp) and e.type == INT:
        return No("integer arithmetic in state position")
    if is_collection(e.type) and not is_scalar(e.type.elem_type):
        return No("collection of nonscalar: e {}\n elem_type: {}\n".format(e, e.type.elem_type))
    if isinstance(e.type, TMap) and not is_scalar(e.type.k):
        return No("bad key type {}".format(pprint(e.type.k)))
    if isinstance(e.type, TMap) and isinstance(e.type.v, TMap):
        return No("map to map")
    # This check is probably a bad idea: whether `the` is legal may depend on
    # the contex that the expression is embedded within, so we can't skip it
    # during synthesis just because it looks invalid now.
    # if isinstance(e, EUnaryOp) and e.op == UOp.The:
    #     len = EUnaryOp(UOp.Length, e.e).with_type(INT)
    #     if not valid(EImplies(assumptions, EBinOp(len, "<=", ENum(1).with_type(INT)).with_type(BOOL))):
    #         return No("illegal application of 'the': could have >1 elems")
    if not at_runtime and isinstance(e, EBinOp) and e.op == "-" and is_collection(e.type):
        return No("collection subtraction in state position")
    # if not at_runtime and isinstance(e, ESingleton):
    #     return No("singleton in state position")
    if not allow_nonzero_state_constants.value and not at_runtime and isinstance(e, ENum) and e.val != 0:
        return No("nonzero integer constant in state position")
    if not allow_binop_state.value and at_runtime and isinstance(e, EStateVar) and isinstance(e.e, EBinOp) and is_scalar(e.e.e1.type) and is_scalar(e.e.e2.type):
        return No("constant-time binary operator {!r} in state position".format(e.e.op))
    if not allow_conditional_state.value and not at_runtime and isinstance(e, ECond):
        return No("conditional in state position")
    if isinstance(e, EMakeMap2) and isinstance(e.e, EEmptyList):
        return No("trivially empty map")
    if isinstance(e, EMakeMap2) and isinstance(e.e, ESingleton):
        return No("really tiny map")
    if not at_runtime and (isinstance(e, EArgMin) or isinstance(e, EArgMax)):
        # Cozy has no way to efficiently implement mins/maxes when more than
        # one element may leave the collection.
        from cozy.state_maintenance import mutate
        for op in ops:
            elems = e.e
            elems_prime = mutate(elems, op.body)
            formula = EAll([assumptions] + list(op.assumptions) + [EGt(ELen(EBinOp(elems, "-", elems_prime).with_type(elems.type)), ONE)])
            if solver.satisfiable(formula):
                return No("more than one element might be removed during {}".format(op.name))
    if not allow_peels.value and not at_runtime and isinstance(e, EFilter):
        # catch "peels": removal of zero or one elements
        if solver.valid(EImplies(assumptions, ELe(ELen(EFilter(e.e, ELambda(e.predicate.arg, ENot(e.predicate.body))).with_type(e.type)), ONE))):
            return No("filter is a peel")
    if not allow_big_maps.value and not at_runtime and isinstance(e, EMakeMap2) and is_collection(e.type.v):
        all_collections = [sv for sv in state_vars if is_collection(sv.type)]
        total_size = ENum(0).with_type(INT)
        for c in all_collections:
            total_size = EBinOp(total_size, "+", EUnaryOp(UOp.Length, c).with_type(INT)).with_type(INT)
        my_size = EUnaryOp(UOp.Length, EFlatMap(EUnaryOp(UOp.Distinct, e.e).with_type(e.e.type), e.value_function).with_type(e.type.v)).with_type(INT)
        s = EImplies(
            assumptions,
            EBinOp(total_size, ">=", my_size).with_type(BOOL))
        if not solver.valid(s):
            return No("non-polynomial-sized map")

    return True