Пример #1
0
 def __init__(self,
              assumptions: Exp = T,
              examples=(),
              funcs=(),
              freebies: [Exp] = [],
              ops: [Op] = []):
     """
     assumptions : assumed to be true when comparing expressions
     examples    : initial examples (the right set of examples can speed up
                   some cost comparisons; it is always safe to leave this
                   set empty)
     funcs       : in-scope functions
     freebies    : state variables that can be used for free
     ops         : mutators which are used to determine how expensive it is
                   to maintain a state variable
     """
     self.solver = ModelCachingSolver(vars=(),
                                      funcs=funcs,
                                      examples=examples,
                                      assumptions=assumptions)
     self.assumptions = assumptions
     # self.examples = list(examples)
     self.funcs = OrderedDict(funcs)
     self.ops = ops
     self.freebies = freebies
Пример #2
0
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = T):
    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False
        args = FrozenDict(q1args)

        key = (args, assumptions)
        checker = _qe_cache.get(key)
        if checker is None:
            checker = ModelCachingSolver(vars=list(
                itertools.chain(state_vars, (EVar(v).with_type(t)
                                             for v, t in args.items()))),
                                         funcs=extern_funcs,
                                         assumptions=assumptions)
            _qe_cache[key] = checker

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
Пример #3
0
    def __init__(self, spec: Spec, concretization_functions: [(EVar, Exp)],
                 query_specs: [Query], query_impls: OrderedDict,
                 updates: defaultdict, handle_updates: defaultdict):
        """Construct an implementation.

        This constructor should be considered private to the `impls` module.

        You should call `construct_initial_implementation` instead of calling
        this constructor directly; this constructor makes it easy to
        accidentally create a malformed implementation.

        Parameters:

         - spec: the original specification

         - concretization_functions: pairs of (v, e) indicating that this
           implementation stores a state variable `v` whose value tracks `e`,
           a function of the state in `spec`

         - query_specs: specifications for all queries in `spec` plus
           additional private helper queries that may have been added, in terms
           of the state in spec

         - query_impls: implementations for all queries in `spec` plus
           additional private helper queries, in terms of the state variables
           in `concretization_functions`.  The query implementations are stored
           in a map keyed by query name.

         - updates: a map from (concrete_var_name, op_name) to statements `stm`,
           where the concrete_var_name is one of the state variables described
           by `concretization_functions`, op_name is one of the update
           operations in `spec`, and `stm` is a statement that may use private
           helper queries and exists entirely to maintain the relationship
           between `concrete_var_name` and the state in the specification.
           The statements are all in terms of the original state, before the
           update started executing.

         - handle_updates: a map from (handle_type, op_name) to statements,
           where handle_type is a THandle instance and op_name is one of the
           update operations in `spec`.  These statements update the values of
           every reachable instance of the given handle (aka pointer).  Like
           updates, these statements are in terms of the original state before
           the update started executing.
        """
        self.spec = spec
        self._concretization_functions = concretization_functions
        self.query_specs = query_specs
        self.query_impls = query_impls
        self.updates = updates
        self.handle_updates = handle_updates
        self.state_solver = ModelCachingSolver(vars=self.abstract_state,
                                               funcs=self.extern_funcs,
                                               assumptions=EAll(
                                                   spec.assumptions))
Пример #4
0
 def __init__(self,
              assumptions: Exp = T,
              examples=(),
              funcs=(),
              freebies: [Exp] = []):
     self.solver = ModelCachingSolver(vars=(),
                                      funcs=funcs,
                                      examples=examples,
                                      assumptions=assumptions)
     self.assumptions = assumptions
     # self.examples = list(examples)
     self.funcs = OrderedDict(funcs)
     self.freebies = freebies
Пример #5
0
 def __init__(self, spec: Spec, concrete_state: [(EVar, Exp)],
              query_specs: [Query], query_impls: OrderedDict,
              updates: defaultdict, handle_updates: defaultdict):
     self.spec = spec
     self.concrete_state = concrete_state
     self.query_specs = query_specs
     self.query_impls = query_impls
     self.updates = updates  # maps (concrete_var_name, op_name) to stm
     self.handle_updates = handle_updates  # maps (handle_type, op_name) to stm
     self.state_solver = ModelCachingSolver(vars=self.abstract_state,
                                            funcs=self.extern_funcs,
                                            assumptions=EAll(
                                                spec.assumptions))
Пример #6
0
class CostModel(object):
    def __init__(self,
                 assumptions: Exp = T,
                 examples=(),
                 funcs=(),
                 freebies: [Exp] = []):
        self.solver = ModelCachingSolver(vars=(),
                                         funcs=funcs,
                                         examples=examples,
                                         assumptions=assumptions)
        self.assumptions = assumptions
        # self.examples = list(examples)
        self.funcs = OrderedDict(funcs)
        self.freebies = freebies

    @property
    def examples(self):
        return tuple(self.solver.examples)

    def _compare(self, e1: Exp, e2: Exp, context: Context):
        e1_constant = not free_vars(e1) and not free_funcs(e1)
        e2_constant = not free_vars(e2) and not free_funcs(e2)
        if e1_constant and e2_constant:
            e1v = eval(e1, {})
            e2v = eval(e2, {})
            event("comparison obvious on constants: {} vs {}".format(e1v, e2v))
            return order_objects(e1v, e2v)
        if alpha_equivalent(e1, e2):
            event("shortcutting comparison of identical terms")
            return Order.EQUAL

        path_condition = EAll(context.path_conditions())
        always_le = self.solver.valid(EImplies(path_condition, ELe(e1, e2)))
        always_ge = self.solver.valid(EImplies(path_condition, EGe(e1, e2)))

        if always_le and always_ge:
            return Order.EQUAL
        if always_le:
            return Order.LT
        if always_ge:
            return Order.GT
        return Order.AMBIGUOUS

    def compare(self, e1: Exp, e2: Exp, context: Context, pool: Pool) -> Order:
        with task("compare costs", context=context):
            if pool == RUNTIME_POOL:
                return composite_order(
                    lambda: order_objects(asymptotic_runtime(
                        e1), asymptotic_runtime(e2)), lambda: self._compare(
                            max_storage_size(e1, self.freebies),
                            max_storage_size(e2, self.freebies), context),
                    lambda: self._compare(rt(e1), rt(e2), context),
                    lambda: order_objects(e1.size(), e2.size()))
            else:
                return composite_order(
                    lambda: self._compare(storage_size(e1, self.freebies),
                                          storage_size(e2, self.freebies),
                                          context),
                    lambda: order_objects(e1.size(), e2.size()))
Пример #7
0
    def __init__(self,
            spec : Spec,
            concretization_functions : [(EVar, Exp)],
            query_specs : [Query],
            query_impls : OrderedDict,
            updates : defaultdict,
            handle_updates : defaultdict):
        """Construct an implementation.

        This constructor should be considered private to the `impls` module.

        You should call `construct_initial_implementation` instead of calling
        this constructor directly; this constructor makes it easy to
        accidentally create a malformed implementation.

        Parameters:

         - spec: the original specification

         - concretization_functions: pairs of (v, e) indicating that this
           implementation stores a state variable `v` whose value tracks `e`,
           a function of the state in `spec`

         - query_specs: specifications for all queries in `spec` plus
           additional private helper queries that may have been added, in terms
           of the state in spec

         - query_impls: implementations for all queries in `spec` plus
           additional private helper queries, in terms of the state variables
           in `concretization_functions`.  The query implementations are stored
           in a map keyed by query name.

         - updates: a map from (concrete_var_name, op_name) to statements `stm`,
           where the concrete_var_name is one of the state variables described
           by `concretization_functions`, op_name is one of the update
           operations in `spec`, and `stm` is a statement that may use private
           helper queries and exists entirely to maintain the relationship
           between `concrete_var_name` and the state in the specification.
           The statements are all in terms of the original state, before the
           update started executing.

         - handle_updates: a map from (handle_type, op_name) to statements,
           where handle_type is a THandle instance and op_name is one of the
           update operations in `spec`.  These statements update the values of
           every reachable instance of the given handle (aka pointer).  Like
           updates, these statements are in terms of the original state before
           the update started executing.
        """
        self.spec = spec
        self._concretization_functions = concretization_functions
        self.query_specs = query_specs
        self.query_impls = query_impls
        self.updates = updates
        self.handle_updates = handle_updates
        self.state_solver = ModelCachingSolver(
            vars=self.abstract_state,
            funcs=self.extern_funcs,
            assumptions=EAll(spec.assumptions))
Пример #8
0
 def __init__(self, targets, assumptions, context, examples, cost_model,
              stop_callback, hints):
     self.context = context
     self.stop_callback = stop_callback
     self.cost_model = cost_model
     self.assumptions = assumptions
     self.hints = list(hints)
     self.reset(examples)
     self.watch(targets)
     self.wf_solver = ModelCachingSolver(
         vars=[v for (v, p) in context.vars()], funcs=context.funcs())
Пример #9
0
def exp_wf(e : Exp, context : Context, pool = RUNTIME_POOL, assumptions : Exp = T, solver = None):
    """
    Returns True or throws exception indicating why `e` is not well-formed.
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in shred(e, context, pool):
        try:
            exp_wf_nonrecursive(solver, x, ctx, p, assumptions=ctx.adapt(assumptions, context))
        except ExpIsNotWf as exc:
            raise ExpIsNotWf(e, x, exc.reason)
    return True
Пример #10
0
def exp_wf(e: Exp,
           context: Context,
           pool=RUNTIME_POOL,
           assumptions: Exp = ETRUE,
           solver=None):
    """Check the well-formedess of `e`.

    Returns True or an instance of ExpIsNotWf that indicates why `e` is not
    well-formed.

    Parameters:
        e - an expression to check
        context - a context describing e's variables
        pool - what pool e lives in
        assumptions - facts that are true whenever e begins executing
            (NOTE: this does NOT need to include the path conditions from the
            context, but it is fine if it does.)
        solver - a ModelCachingSolver to use for solving formulas

    This function requires that:
     - all free variables in `e` are used in the correct pool
     - EStateVar only occurs in runtime expressions
    """
    if solver is None:
        solver = ModelCachingSolver(vars=[], funcs={})
    for x, ctx, p in all_subexpressions_with_context_information(
            e, context, pool):
        is_wf = exp_wf_nonrecursive(solver,
                                    x,
                                    ctx,
                                    p,
                                    assumptions=ctx.adapt(
                                        assumptions, context))
        if not is_wf:
            if isinstance(is_wf, No):
                return ExpIsNotWf(e, x, is_wf.msg)
            return is_wf
    return True
Пример #11
0
class Implementation(object):

    @typechecked
    def __init__(self,
            spec : Spec,
            concrete_state : [(EVar, Exp)],
            query_specs : [Query],
            query_impls : OrderedDict,
            updates : defaultdict,
            handle_updates : defaultdict):
        self.spec = spec
        self.concrete_state = concrete_state
        self.query_specs = query_specs
        self.query_impls = query_impls
        self.updates = updates # maps (concrete_var_name, op_name) to stm
        self.handle_updates = handle_updates # maps (handle_type, op_name) to stm
        self.state_solver = ModelCachingSolver(vars=self.abstract_state, funcs=self.extern_funcs)

    def __getstate__(self):
        d = dict(self.__dict__)
        if "state_solver" in d:
            del d["state_solver"]
        if hasattr(self, "__slots__"):
            for a in self.__slots__:
                d[a] = getattr(self, a)
        return d

    def add_query(self, q : Query):
        """
        Given a query in terms of abstract state, add an initial concrete
        implementation.
        """
        print("Adding query {}...".format(q.name))
        self.query_specs.append(q)
        fvs = free_vars(q)
        # initial rep
        qargs = set(EVar(a).with_type(t) for (a, t) in q.args)
        rep, ret = tease_apart(wrap_naked_statevars(q.ret, self.abstract_state))
        self.set_impl(q, rep, ret)

    @property
    def op_specs(self):
        return [ m for m in self.spec.methods if isinstance(m, Op) ]

    @property
    def abstract_state(self) -> [EVar]:
        return [EVar(name).with_type(t) for (name, t) in self.spec.statevars]

    @property
    def extern_funcs(self) -> { str : TFunc }:
        return OrderedDict((f.name, TFunc(tuple(t for a, t in f.args), f.out_type)) for f in self.spec.extern_funcs)

    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions_for_method(
                    reachable_handles_at_method(self.spec, sub_q),
                    sub_q)
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by

    def _setup_handle_updates(self):
        """
        This method creates update code for handle objects modified by each op.
        Must be called once after all user-specified queries have been added.
        """
        for op in self.op_specs:
            print("Setting up handle updates for {}...".format(op.name))
            handles = reachable_handles_at_method(self.spec, op)
            # print("-"*60)
            for t, bag in handles.items():
                # print("  {} : {}".format(pprint(t), pprint(bag)))
                h = fresh_var(t)
                lval = EGetField(h, "val").with_type(t.value_type)
                new_val = inc.mutate(lval, op.body)

                # get set of modified handles
                modified_handles = Query(
                    fresh_name("modified_handles"),
                    Visibility.Internal, [], op.assumptions,
                    EFilter(EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type),
                    "[{}] modified handles of type {}".format(op.name, pprint(t)))
                query_vars = [v for v in free_vars(modified_handles) if v not in self.abstract_state]
                modified_handles.args = [(arg.id, arg.type) for arg in query_vars]

                # modify each one
                subqueries = []
                state_update_stm = inc.mutate_in_place(
                    lval,
                    lval,
                    op.body,
                    abstract_state=self.abstract_state,
                    assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)],
                    subgoals_out=subqueries)
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                    state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                if state_update_stm != SNoOp():
                    state_update_stm = SForEach(h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm)
                    state_update_stm = self._add_subquery(sub_q=modified_handles, used_by=state_update_stm)
                self.handle_updates[(t, op.name)] = state_update_stm

    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and self.state_solver.valid(EImplies(EAll(self.spec.assumptions), EEq(e, ee))))
                    # aeq = find_one(vv for (vv, ee) in self.concrete_state if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self.concrete_state.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm

    @property
    def code(self) -> Spec:

        state_read_by_query = {
            query_name : free_vars(query)
            for query_name, query in self.query_impls.items() }
        
        # prevent read-after-write by lifting reads before writes.

        # list of SDecls
        temps = defaultdict(list)
        updates = dict(self.updates)

        for operator in self.op_specs:
            # Compute order constraints between statements:
            #   v1 -> v2 means that the update code for v1 should (if possible)
            #   appear before the update code for v2
            #   (i.e. the update code for v1 reads v2)
            g = igraph.Graph().as_directed()
            g.add_vertices(len(self.concrete_state))
            for (i, (v1, _)) in enumerate(self.concrete_state):
                v1_update_code = self.updates[(v1, operator.name)]
                v1_queries = list(self.queries_used_by(v1_update_code))
                for (j, (v2, _)) in enumerate(self.concrete_state):
                    # if v1_update_code reads v2...
                    if any(v2 in state_read_by_query[q] for q in v1_queries):
                        # then v1->v2
                        g.add_edges([(i, j)])

            # Find the minimum set of edges we need to break (see "feedback arc
            # set problem")
            edges_to_break = safe_feedback_arc_set(g, method="ip")
            g.delete_edges(edges_to_break)
            ordered_concrete_state = [self.concrete_state[i] for i in g.topological_sorting(mode="OUT")]

            # Lift auxiliary declarations as needed
            things_updated = []
            for v, _ in ordered_concrete_state:
                things_updated.append(v)
                stm = updates[(v, operator.name)]
                def problematic(e):
                    for x in all_exps(e):
                        if isinstance(x, ECall) and x.func in [q.name for q in self.query_specs]:
                            problems = set(things_updated) & state_read_by_query[x.func]
                            if problems:
                                return True
                    return False
                stm = pull_temps(stm,
                    decls_out=temps[operator.name],
                    exp_is_bad=problematic)
                updates[(v, operator.name)] = stm

        # construct new op implementations
        new_ops = []
        for op in self.op_specs:

            stms = [ updates[(v, op.name)] for (v, _) in ordered_concrete_state ]
            stms.extend(hup for ((t, op_name), hup) in self.handle_updates.items() if op.name == op_name)
            new_stms = seq(temps[op.name] + stms)
            new_ops.append(Op(
                op.name,
                op.args,
                [],
                new_stms,
                op.docstring))

        # assemble final result
        return Spec(
            self.spec.name,
            self.spec.types,
            self.spec.extern_funcs,
            [(v.id, e.type) for (v, e) in self.concrete_state],
            [],
            list(self.query_impls.values()) + new_ops,
            self.spec.header,
            self.spec.footer,
            self.spec.docstring)

    @property
    def concretization_functions(self) -> { str : Exp }:
        state_var_exps = OrderedDict()
        for (v, e) in self.concrete_state:
            state_var_exps[v.id] = e
        return state_var_exps

    def cleanup(self):
        """
        Remove unused state, queries, and updates.
        """

        # sort of like mark-and-sweep
        queries_to_keep = OrderedSet(q.name for q in self.query_specs if q.visibility == Visibility.Public)
        state_vars_to_keep = OrderedSet()
        changed = True
        while changed:
            changed = False
            for qname in list(queries_to_keep):
                if qname in self.query_impls:
                    for sv in free_vars(self.query_impls[qname]):
                        if sv not in state_vars_to_keep:
                            state_vars_to_keep.add(sv)
                            changed = True
                    for e in all_exps(self.query_impls[qname].ret):
                        if isinstance(e, ECall):
                            if e.func not in queries_to_keep:
                                queries_to_keep.add(e.func)
                                changed = True
            for op in self.op_specs:
                for ((ht, op_name), code) in self.handle_updates.items():
                    if op.name == op_name:
                        for qname in self.queries_used_by(code):
                            if qname not in queries_to_keep:
                                queries_to_keep.add(qname)
                                changed = True

                for sv in state_vars_to_keep:
                    for qname in self.queries_used_by(self.updates[(sv, op.name)]):
                        if qname not in queries_to_keep:
                            queries_to_keep.add(qname)
                            changed = True

        # remove old specs
        for q in list(self.query_specs):
            if q.name not in queries_to_keep:
                self.query_specs.remove(q)

        # remove old implementations
        for qname in list(self.query_impls.keys()):
            if qname not in queries_to_keep:
                del self.query_impls[qname]

        # remove old state vars
        self.concrete_state = [ v for v in self.concrete_state if any(v[0] in free_vars(q) for q in self.query_impls.values()) ]

        # remove old method implementations
        for k in list(self.updates.keys()):
            v, op_name = k
            if v not in [var for (var, exp) in self.concrete_state]:
                del self.updates[k]

    def queries_used_by(self, stm):
        for e in all_exps(stm):
            if isinstance(e, ECall) and e.func in [q.name for q in self.query_specs]:
                yield e.func

    def states_maintained_by(self, q : Query) -> [EVar]:
        concrete_vars = []
        for (var_name, op_name), stm in self.updates.items():
            if q.name in self.queries_used_by(stm):
                concrete_vars.append(var_name)
        return concrete_vars
Пример #12
0
class CostModel(object):
    def __init__(self,
                 assumptions: Exp = T,
                 examples=(),
                 funcs=(),
                 freebies: [Exp] = [],
                 ops: [Op] = []):
        """
        assumptions : assumed to be true when comparing expressions
        examples    : initial examples (the right set of examples can speed up
                      some cost comparisons; it is always safe to leave this
                      set empty)
        funcs       : in-scope functions
        freebies    : state variables that can be used for free
        ops         : mutators which are used to determine how expensive it is
                      to maintain a state variable
        """
        self.solver = ModelCachingSolver(vars=(),
                                         funcs=funcs,
                                         examples=examples,
                                         assumptions=assumptions)
        self.assumptions = assumptions
        # self.examples = list(examples)
        self.funcs = OrderedDict(funcs)
        self.ops = ops
        self.freebies = freebies

    def __repr__(self):
        return "CostModel(assumptions={!r}, examples={!r}, funcs={!r}, freebies={!r}, ops={!r})".format(
            self.assumptions, self.examples, self.funcs, self.freebies,
            self.ops)

    @property
    def examples(self):
        return tuple(self.solver.examples)

    def _compare(self, e1: Exp, e2: Exp, context: Context):
        e1_constant = not free_vars(e1) and not free_funcs(e1)
        e2_constant = not free_vars(e2) and not free_funcs(e2)
        if e1_constant and e2_constant:
            e1v = eval(e1, {})
            e2v = eval(e2, {})
            event("comparison obvious on constants: {} vs {}".format(e1v, e2v))
            return order_objects(e1v, e2v)
        if alpha_equivalent(e1, e2):
            event("shortcutting comparison of identical terms")
            return Order.EQUAL

        path_condition = EAll(context.path_conditions())
        always_le = self.solver.valid(EImplies(path_condition, ELe(e1, e2)))
        always_ge = self.solver.valid(EImplies(path_condition, EGe(e1, e2)))

        if always_le and always_ge:
            return Order.EQUAL
        if always_le:
            return Order.LT
        if always_ge:
            return Order.GT
        return Order.AMBIGUOUS

    def compare(self, e1: Exp, e2: Exp, context: Context, pool: Pool) -> Order:
        with task("compare costs", context=context):
            if consider_maintenance_cost.value:
                if pool == RUNTIME_POOL:
                    return prioritized_order(
                        lambda: order_objects(asymptotic_runtime(e1),
                                              asymptotic_runtime(e2)),
                        lambda: unprioritized_order(
                            lambda: prioritized_order(
                                lambda: self._compare(
                                    max_storage_size(e1, self.freebies),
                                    max_storage_size(e2, self.freebies),
                                    context), lambda: self._compare(
                                        rt(e1), rt(e2), context)), *
                            [
                                lambda op=op: self._compare(
                                    maintenance_cost(e1, op, self.freebies),
                                    maintenance_cost(e2, op, self.freebies),
                                    context) for op in self.ops
                            ]), lambda: order_objects(e1.size(), e2.size()))
                else:
                    return prioritized_order(
                        lambda: self._compare(storage_size(e1, self.freebies),
                                              storage_size(e2, self.freebies),
                                              context),
                        lambda: order_objects(e1.size(), e2.size()))
            else:
                if pool == RUNTIME_POOL:
                    return prioritized_order(
                        lambda: order_objects(asymptotic_runtime(e1),
                                              asymptotic_runtime(e2)),
                        lambda: self._compare(
                            max_storage_size(e1, self.freebies),
                            max_storage_size(e2, self.freebies), context),
                        lambda: self._compare(rt(e1), rt(e2), context),
                        lambda: order_objects(e1.size(), e2.size()))
                else:
                    return prioritized_order(
                        lambda: self._compare(storage_size(e1, self.freebies),
                                              storage_size(e2, self.freebies),
                                              context),
                        lambda: order_objects(e1.size(), e2.size()))
Пример #13
0
class Implementation(object):

    @typechecked
    def __init__(self,
            spec : Spec,
            concretization_functions : [(EVar, Exp)],
            query_specs : [Query],
            query_impls : OrderedDict,
            updates : defaultdict,
            handle_updates : defaultdict):
        """Construct an implementation.

        This constructor should be considered private to the `impls` module.

        You should call `construct_initial_implementation` instead of calling
        this constructor directly; this constructor makes it easy to
        accidentally create a malformed implementation.

        Parameters:

         - spec: the original specification

         - concretization_functions: pairs of (v, e) indicating that this
           implementation stores a state variable `v` whose value tracks `e`,
           a function of the state in `spec`

         - query_specs: specifications for all queries in `spec` plus
           additional private helper queries that may have been added, in terms
           of the state in spec

         - query_impls: implementations for all queries in `spec` plus
           additional private helper queries, in terms of the state variables
           in `concretization_functions`.  The query implementations are stored
           in a map keyed by query name.

         - updates: a map from (concrete_var_name, op_name) to statements `stm`,
           where the concrete_var_name is one of the state variables described
           by `concretization_functions`, op_name is one of the update
           operations in `spec`, and `stm` is a statement that may use private
           helper queries and exists entirely to maintain the relationship
           between `concrete_var_name` and the state in the specification.
           The statements are all in terms of the original state, before the
           update started executing.

         - handle_updates: a map from (handle_type, op_name) to statements,
           where handle_type is a THandle instance and op_name is one of the
           update operations in `spec`.  These statements update the values of
           every reachable instance of the given handle (aka pointer).  Like
           updates, these statements are in terms of the original state before
           the update started executing.
        """
        self.spec = spec
        self._concretization_functions = concretization_functions
        self.query_specs = query_specs
        self.query_impls = query_impls
        self.updates = updates
        self.handle_updates = handle_updates
        self.state_solver = ModelCachingSolver(
            vars=self.abstract_state,
            funcs=self.extern_funcs,
            assumptions=EAll(spec.assumptions))

    def safe_copy(self):
        """Create a copy of this implementation.

        The copy is "safe" in the sense that modifications made to the copy
        through methods on the Implementation class do not affect self (and
        vice versa).  However, note that the copy is not truly a deep copy; it
        may still be possible to get strange behavior by manually writing to
        properties of the underlying members (for instance, self and the copy
        will still share a Spec object).
        """
        return Implementation(
            self.spec,
            list(self._concretization_functions),
            list(self.query_specs),
            OrderedDict(self.query_impls),
            defaultdict(SNoOp, self.updates),
            defaultdict(SNoOp, self.handle_updates))

    def __getstate__(self):
        # During serialization, do not save the solver object.
        d = dict(self.__dict__)
        if "state_solver" in d:
            del d["state_solver"]
        if hasattr(self, "__slots__"):
            for a in self.__slots__:
                d[a] = getattr(self, a)
        return d

    def __setstate__(self, state):
        spec = state["spec"]
        try:
            concretization_functions = state["_concretization_functions"]
        except KeyError:
            # older format
            concretization_functions = state["concrete_state"]
        query_specs = state["query_specs"]
        query_impls = state["query_impls"]
        updates = state["updates"]
        handle_updates = state["handle_updates"]
        self.__init__(
            spec,
            concretization_functions,
            query_specs,
            query_impls,
            updates,
            handle_updates)

    def add_query(self, q : Query):
        """
        Given a query in terms of abstract state, add an initial concrete
        implementation.
        """
        print("Adding query {}...".format(q.name))
        self.query_specs.append(q)
        safe_ret = repair_well_formedness(
            q.ret,
            context=self.context_for_method(q),
            extra_available_state=[e for v, e in self._concretization_functions])
        rep, ret = unpack_representation(safe_ret)
        self.set_impl(q, rep, ret)

    @property
    def op_specs(self) -> [Op]:
        """Returns the specifications for all the update operations."""
        return [m for m in self.spec.methods if isinstance(m, Op)]

    @property
    def abstract_state(self) -> [EVar]:
        """Returns the abstract state of this data structure."""
        return [EVar(name).with_type(t) for (name, t) in self.spec.statevars]

    @property
    def abstract_invariants(self) -> [Exp]:
        """Returns any user-specified invariants about the abstract state."""
        return list(self.spec.assumptions)

    @property
    def concretization_functions(self) -> { str : Exp }:
        """Returns a mapping from concrete state variables to their meanings.

        The "meaning" of a concrete state variable is an expression (in terms
        of abstract state) that it tracks."""
        state_var_exps = OrderedDict()
        for (v, e) in self._concretization_functions:
            state_var_exps[v.id] = e
        return state_var_exps

    @property
    def extern_funcs(self) -> { str : TFunc }:
        """Return the types of all user-declared external functions."""
        return OrderedDict((f.name, TFunc(tuple(t for a, t in f.args), f.out_type)) for f in self.spec.extern_funcs)

    def context_for_method(self, m : Method) -> Context:
        """Construct a context describing expressions in the given method."""
        return RootCtx(
            state_vars=self.abstract_state,
            args=[EVar(a).with_type(t) for (a, t) in m.args],
            funcs=self.extern_funcs)

    def _add_subquery(self, sub_q : Query, used_by : Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [e for v, e in self._concretization_functions]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret),
                    self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(self.query_specs, lambda qq: dedup_queries.value and queries_equivalent(qq, sub_q, state_vars=state_vars, extern_funcs=funcs, assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a) for (a, t) in qq.args]
                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)
                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by

    def _setup_handle_updates(self):
        """
        This method creates update code for handle objects modified by each op.
        Must be called once after all user-specified queries have been added.
        """
        for op in self.op_specs:
            print("Setting up handle updates for {}...".format(op.name))
            handles = reachable_handles_at_method(self.spec, op)
            # print("-"*60)
            for t, bag in handles.items():
                # print("  {} : {}".format(pprint(t), pprint(bag)))
                h = fresh_var(t)
                lval = EGetField(h, "val").with_type(t.value_type)
                new_val = inc.mutate(lval, op.body)

                # get set of modified handles
                modified_handles = Query(
                    fresh_name("modified_handles"),
                    Visibility.Internal, [], op.assumptions,
                    EFilter(EUnaryOp(UOp.Distinct, bag).with_type(bag.type), ELambda(h, ENot(EEq(lval, new_val)))).with_type(bag.type),
                    "[{}] modified handles of type {}".format(op.name, pprint(t)))
                query_vars = [v for v in free_vars(modified_handles) if v not in self.abstract_state]
                modified_handles.args = [(arg.id, arg.type) for arg in query_vars]

                # modify each one
                subqueries = []
                state_update_stm = inc.mutate_in_place(
                    lval,
                    lval,
                    op.body,
                    abstract_state=self.abstract_state,
                    assumptions=list(op.assumptions) + [EDeepIn(h, bag), EIn(h, modified_handles.ret)],
                    invariants=self.abstract_invariants,
                    subgoals_out=subqueries)
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                    state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                if state_update_stm != SNoOp():
                    state_update_stm = SForEach(h, ECall(modified_handles.name, query_vars).with_type(bag.type), state_update_stm)
                    state_update_stm = self._add_subquery(sub_q=modified_handles, used_by=state_update_stm)
                self.handle_updates[(t, op.name)] = state_update_stm

    def set_impl(self, q : Query, rep : [(EVar, Exp)], ret : Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(v.id, aeq.id))
                        ret = subst(ret, { v.id : aeq })
                        to_remove.add(v)
                rep = [ x for x in rep if x[0] not in to_remove ]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q, lambda prev: ret, keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm

    @property
    def code(self) -> Spec:
        """Get the current code corresponding to this implementation.

        The code is returned as a Cozy specification object, but the returned
        object throws away any unused abstract state as well as all invariants
        and assumptions on methods. It implements the same data structure, but
        probably more efficiently.
        """

        state_read_by_query = {
            query_name : free_vars(query)
            for query_name, query in self.query_impls.items() }

        # prevent read-after-write by lifting reads before writes.

        # list of SDecls
        temps = defaultdict(list)
        updates = dict(self.updates)

        _concretization_functions = [v for v, e in self._concretization_functions]

        for operator in self.op_specs:

            # Compute order constraints between statements:
            #   v1 -> v2 means that the update code for v1 should (if possible)
            #   appear before the update code for v2
            #   (i.e. the update code for v1 reads v2)
            def state_used_during_update(v1 : EVar) -> [EVar]:
                v1_update_code = self.updates[(v1, operator.name)]
                v1_queries = list(self.queries_used_by(v1_update_code))
                res = OrderedSet()
                for q in v1_queries:
                    res |= state_read_by_query[q]
                return res
            g = DirectedGraph(
                nodes=_concretization_functions,
                successors=state_used_during_update)

            # Find the minimum set of edges we need to break cycles (see
            # "feedback arc set problem")
            edges_to_break = g.minimum_feedback_arc_set()
            g.delete_edges(edges_to_break)
            _concretization_functions = list(g.toposort())

            # Lift auxiliary declarations as needed
            things_updated = []
            for v in _concretization_functions:
                things_updated.append(v)
                stm = updates[(v, operator.name)]
                def problematic(e):
                    for x in all_exps(e):
                        if isinstance(x, ECall) and x.func in [q.name for q in self.query_specs]:
                            problems = set(things_updated) & state_read_by_query[x.func]
                            if problems:
                                return True
                    return False
                stm = pull_temps(stm,
                    decls_out=temps[operator.name],
                    exp_is_bad=problematic)
                updates[(v, operator.name)] = stm

        # construct new op implementations
        new_ops = []
        for op in self.op_specs:

            stms = [ updates[(v, op.name)] for v in _concretization_functions ]
            stms.extend(hup for ((t, op_name), hup) in self.handle_updates.items() if op.name == op_name)
            new_stms = seq(temps[op.name] + stms)
            new_ops.append(Op(
                op.name,
                op.args,
                [],
                new_stms,
                op.docstring))

        # assemble final result
        return Spec(
            self.spec.name,
            self.spec.types,
            self.spec.extern_funcs,
            [(v.id, e.type) for (v, e) in self._concretization_functions],
            [],
            list(self.query_impls.values()) + new_ops,
            self.spec.header,
            self.spec.footer,
            self.spec.docstring)

    def cleanup(self):
        """
        Remove unused state, queries, and updates.
        """

        def deps(thing):
            if isinstance(thing, str):
                yield from free_vars(self.query_impls[thing])
            elif isinstance(thing, EVar):
                for op in self.op_specs:
                    yield self.updates[(thing, op.name)]
            elif isinstance(thing, Stm):
                yield from self.queries_used_by(thing)
            else:
                raise ValueError(repr(thing))

        g = DirectedGraph(
            nodes=itertools.chain(self.query_impls.keys(), (v for v, _ in self._concretization_functions), self.updates.values()),
            successors=deps)
        roots = [q.name for q in self.query_specs if q.visibility == Visibility.Public]
        roots.extend(itertools.chain(*[self.queries_used_by(code) for ((ht, op_name), code) in self.handle_updates.items()]))
        queries_to_keep = set(q for q in g.reachable_nodes(roots) if isinstance(q, str))

        # remove old specs
        for q in list(self.query_specs):
            if q.name not in queries_to_keep:
                self.query_specs.remove(q)

        # remove old implementations
        for qname in list(self.query_impls.keys()):
            if qname not in queries_to_keep:
                del self.query_impls[qname]

        # remove old state vars
        self._concretization_functions = [ v for v in self._concretization_functions if any(v[0] in free_vars(q) for q in self.query_impls.values()) ]

        # remove old method implementations
        for k in list(self.updates.keys()):
            v, op_name = k
            if v not in [var for (var, exp) in self._concretization_functions]:
                del self.updates[k]

    def queries_used_by(self, stm):
        for e in all_exps(stm):
            if isinstance(e, ECall) and e.func in [q.name for q in self.query_specs]:
                yield e.func

    def states_maintained_by(self, q : Query) -> [EVar]:
        concrete_vars = []
        for (var_name, op_name), stm in self.updates.items():
            if q.name in self.queries_used_by(stm):
                concrete_vars.append(var_name)
        return concrete_vars
Пример #14
0
class Implementation(object):
    @typechecked
    def __init__(self, spec: Spec, concretization_functions: [(EVar, Exp)],
                 query_specs: [Query], query_impls: OrderedDict,
                 updates: defaultdict, handle_updates: defaultdict):
        """Construct an implementation.

        This constructor should be considered private to the `impls` module.

        You should call `construct_initial_implementation` instead of calling
        this constructor directly; this constructor makes it easy to
        accidentally create a malformed implementation.

        Parameters:

         - spec: the original specification

         - concretization_functions: pairs of (v, e) indicating that this
           implementation stores a state variable `v` whose value tracks `e`,
           a function of the state in `spec`

         - query_specs: specifications for all queries in `spec` plus
           additional private helper queries that may have been added, in terms
           of the state in spec

         - query_impls: implementations for all queries in `spec` plus
           additional private helper queries, in terms of the state variables
           in `concretization_functions`.  The query implementations are stored
           in a map keyed by query name.

         - updates: a map from (concrete_var_name, op_name) to statements `stm`,
           where the concrete_var_name is one of the state variables described
           by `concretization_functions`, op_name is one of the update
           operations in `spec`, and `stm` is a statement that may use private
           helper queries and exists entirely to maintain the relationship
           between `concrete_var_name` and the state in the specification.
           The statements are all in terms of the original state, before the
           update started executing.

         - handle_updates: a map from (handle_type, op_name) to statements,
           where handle_type is a THandle instance and op_name is one of the
           update operations in `spec`.  These statements update the values of
           every reachable instance of the given handle (aka pointer).  Like
           updates, these statements are in terms of the original state before
           the update started executing.
        """
        self.spec = spec
        self._concretization_functions = concretization_functions
        self.query_specs = query_specs
        self.query_impls = query_impls
        self.updates = updates
        self.handle_updates = handle_updates
        self.state_solver = ModelCachingSolver(vars=self.abstract_state,
                                               funcs=self.extern_funcs,
                                               assumptions=EAll(
                                                   spec.assumptions))

    def safe_copy(self):
        """Create a copy of this implementation.

        The copy is "safe" in the sense that modifications made to the copy
        through methods on the Implementation class do not affect self (and
        vice versa).  However, note that the copy is not truly a deep copy; it
        may still be possible to get strange behavior by manually writing to
        properties of the underlying members (for instance, self and the copy
        will still share a Spec object).
        """
        return Implementation(self.spec, list(self._concretization_functions),
                              list(self.query_specs),
                              OrderedDict(self.query_impls),
                              defaultdict(SNoOp, self.updates),
                              defaultdict(SNoOp, self.handle_updates))

    def __getstate__(self):
        # During serialization, do not save the solver object.
        d = dict(self.__dict__)
        if "state_solver" in d:
            del d["state_solver"]
        if hasattr(self, "__slots__"):
            for a in self.__slots__:
                d[a] = getattr(self, a)
        return d

    def __setstate__(self, state):
        spec = state["spec"]
        try:
            concretization_functions = state["_concretization_functions"]
        except KeyError:
            # older format
            concretization_functions = state["concrete_state"]
        query_specs = state["query_specs"]
        query_impls = state["query_impls"]
        updates = state["updates"]
        handle_updates = state["handle_updates"]
        self.__init__(spec, concretization_functions, query_specs, query_impls,
                      updates, handle_updates)

    def add_query(self, q: Query):
        """
        Given a query in terms of abstract state, add an initial concrete
        implementation.
        """
        print("Adding query {}...".format(q.name))
        self.query_specs.append(q)
        safe_ret = repair_well_formedness(
            q.ret,
            context=self.context_for_method(q),
            extra_available_state=[
                e for v, e in self._concretization_functions
            ])
        rep, ret = unpack_representation(safe_ret)
        self.set_impl(q, rep, ret)

    @property
    def op_specs(self) -> [Op]:
        """Returns the specifications for all the update operations."""
        return [m for m in self.spec.methods if isinstance(m, Op)]

    @property
    def abstract_state(self) -> [EVar]:
        """Returns the abstract state of this data structure."""
        return [EVar(name).with_type(t) for (name, t) in self.spec.statevars]

    @property
    def abstract_invariants(self) -> [Exp]:
        """Returns any user-specified invariants about the abstract state."""
        return list(self.spec.assumptions)

    @property
    def concretization_functions(self) -> {str: Exp}:
        """Returns a mapping from concrete state variables to their meanings.

        The "meaning" of a concrete state variable is an expression (in terms
        of abstract state) that it tracks."""
        state_var_exps = OrderedDict()
        for (v, e) in self._concretization_functions:
            state_var_exps[v.id] = e
        return state_var_exps

    @property
    def extern_funcs(self) -> {str: TFunc}:
        """Return the types of all user-declared external functions."""
        return OrderedDict(
            (f.name, TFunc(tuple(t for a, t in f.args), f.out_type))
            for f in self.spec.extern_funcs)

    def context_for_method(self, m: Method) -> Context:
        """Construct a context describing expressions in the given method."""
        return RootCtx(state_vars=self.abstract_state,
                       args=[EVar(a).with_type(t) for (a, t) in m.args],
                       funcs=self.extern_funcs)

    def _add_subquery(self, sub_q: Query, used_by: Stm) -> Stm:
        """Add a query that helps maintain some other state.

        Parameters:
            sub_q - the specification of the helper query
            used_by - the statement that calls `sub_q`

        If a query already exists that is equivalent to `sub_q`, this method
        returns `used_by` rewritten to use the existing query and does not add
        the query to the implementation.  Otherwise it returns `used_by`
        unchanged.
        """

        with task("adding query", query=sub_q.name):
            sub_q = shallow_copy(sub_q)
            with task("checking whether we need more handle assumptions"):
                new_a = implicit_handle_assumptions(
                    reachable_handles_at_method(self.spec, sub_q))
                if not valid(EImplies(EAll(sub_q.assumptions), EAll(new_a))):
                    event("we do!")
                    sub_q.assumptions = list(
                        itertools.chain(sub_q.assumptions, new_a))

            with task("repairing state var boundaries"):
                extra_available_state = [
                    e for v, e in self._concretization_functions
                ]
                sub_q.ret = repair_well_formedness(
                    strip_EStateVar(sub_q.ret), self.context_for_method(sub_q),
                    extra_available_state)

            with task("simplifying"):
                orig_a = sub_q.assumptions
                orig_a_size = sum(a.size() for a in sub_q.assumptions)
                orig_ret_size = sub_q.ret.size()
                sub_q.assumptions = tuple(
                    simplify_or_ignore(a) for a in sub_q.assumptions)
                sub_q.ret = simplify(sub_q.ret)
                a_size = sum(a.size() for a in sub_q.assumptions)
                ret_size = sub_q.ret.size()
                event("|assumptions|: {} -> {}".format(orig_a_size, a_size))
                event("|ret|: {} -> {}".format(orig_ret_size, ret_size))

                if a_size > orig_a_size:
                    print("NO, BAD SIMPLIFICATION")
                    print("original")
                    for a in orig_a:
                        print(" - {}".format(pprint(a)))
                    print("simplified")
                    for a in sub_q.assumptions:
                        print(" - {}".format(pprint(a)))
                    assert False

            state_vars = self.abstract_state
            funcs = self.extern_funcs
            qq = find_one(
                self.query_specs, lambda qq: dedup_queries.value and
                queries_equivalent(qq,
                                   sub_q,
                                   state_vars=state_vars,
                                   extern_funcs=funcs,
                                   assumptions=EAll(self.abstract_invariants)))
            if qq is not None:
                event("subgoal {} is equivalent to {}".format(
                    sub_q.name, qq.name))
                arg_reorder = [[x[0] for x in sub_q.args].index(a)
                               for (a, t) in qq.args]

                class Repl(BottomUpRewriter):
                    def visit_ECall(self, e):
                        args = tuple(self.visit(a) for a in e.args)
                        if e.func == sub_q.name:
                            args = tuple(args[idx] for idx in arg_reorder)
                            return ECall(qq.name, args).with_type(e.type)
                        else:
                            return ECall(e.func, args).with_type(e.type)

                used_by = Repl().visit(used_by)
            else:
                self.add_query(sub_q)
            return used_by

    def _setup_handle_updates(self):
        """
        This method creates update code for handle objects modified by each op.
        Must be called once after all user-specified queries have been added.
        """
        for op in self.op_specs:
            print("Setting up handle updates for {}...".format(op.name))
            handles = reachable_handles_at_method(self.spec, op)
            # print("-"*60)
            for t, bag in handles.items():
                # print("  {} : {}".format(pprint(t), pprint(bag)))
                h = fresh_var(t)
                lval = EGetField(h, "val").with_type(t.value_type)
                new_val = inc.mutate(lval, op.body)

                # get set of modified handles
                modified_handles = Query(
                    fresh_name("modified_handles"), Visibility.Internal, [],
                    op.assumptions,
                    EFilter(
                        EUnaryOp(UOp.Distinct, bag).with_type(bag.type),
                        ELambda(h, ENot(EEq(lval,
                                            new_val)))).with_type(bag.type),
                    "[{}] modified handles of type {}".format(
                        op.name, pprint(t)))
                query_vars = [
                    v for v in free_vars(modified_handles)
                    if v not in self.abstract_state
                ]
                modified_handles.args = [(arg.id, arg.type)
                                         for arg in query_vars]

                # modify each one
                subqueries = []
                state_update_stm = inc.mutate_in_place(
                    lval,
                    lval,
                    op.body,
                    abstract_state=self.abstract_state,
                    assumptions=list(op.assumptions) +
                    [EDeepIn(h, bag),
                     EIn(h, modified_handles.ret)],
                    invariants=self.abstract_invariants,
                    subgoals_out=subqueries)
                for sub_q in subqueries:
                    sub_q.docstring = "[{}] {}".format(op.name,
                                                       sub_q.docstring)
                    state_update_stm = self._add_subquery(
                        sub_q=sub_q, used_by=state_update_stm)
                if state_update_stm != SNoOp():
                    state_update_stm = SForEach(
                        h,
                        ECall(modified_handles.name,
                              query_vars).with_type(bag.type),
                        state_update_stm)
                    state_update_stm = self._add_subquery(
                        sub_q=modified_handles, used_by=state_update_stm)
                self.handle_updates[(t, op.name)] = state_update_stm

    def set_impl(self, q: Query, rep: [(EVar, Exp)], ret: Exp):
        """Update the implementation of a query.

        The query having the same name as `q` will have its implementation
        replaced by the given concrete representation and computation.

        This call may add additional "subqueries" to the implementation to
        maintain the new representation when each update operation is called.
        """
        with task("updating implementation", query=q.name):
            with task("finding duplicated state vars"):
                to_remove = set()
                for (v, e) in rep:
                    aeq = find_one(vv
                                   for (vv,
                                        ee) in self._concretization_functions
                                   if e.type == ee.type
                                   and self.state_solver.valid(EEq(e, ee)))
                    # aeq = find_one(vv for (vv, ee) in self._concretization_functions if e.type == ee.type and alpha_equivalent(e, ee))
                    if aeq is not None:
                        event("state var {} is equivalent to {}".format(
                            v.id, aeq.id))
                        ret = subst(ret, {v.id: aeq})
                        to_remove.add(v)
                rep = [x for x in rep if x[0] not in to_remove]

            self._concretization_functions.extend(rep)
            self.query_impls[q.name] = rewrite_ret(q,
                                                   lambda prev: ret,
                                                   keep_assumptions=False)

            for op in self.op_specs:
                with task("incrementalizing query", query=q.name, op=op.name):
                    for new_member, projection in rep:
                        subqueries = []
                        state_update_stm = inc.mutate_in_place(
                            new_member,
                            projection,
                            op.body,
                            abstract_state=self.abstract_state,
                            assumptions=op.assumptions,
                            invariants=self.abstract_invariants,
                            subgoals_out=subqueries)
                        for sub_q in subqueries:
                            sub_q.docstring = "[{}] {}".format(
                                op.name, sub_q.docstring)
                            state_update_stm = self._add_subquery(
                                sub_q=sub_q, used_by=state_update_stm)
                        self.updates[(new_member, op.name)] = state_update_stm

    @property
    def code(self) -> Spec:
        """Get the current code corresponding to this implementation.

        The code is returned as a Cozy specification object, but the returned
        object throws away any unused abstract state as well as all invariants
        and assumptions on methods. It implements the same data structure, but
        probably more efficiently.
        """

        state_read_by_query = {
            query_name: free_vars(query)
            for query_name, query in self.query_impls.items()
        }

        # prevent read-after-write by lifting reads before writes.

        # list of SDecls
        temps = defaultdict(list)
        updates = dict(self.updates)

        _concretization_functions = [
            v for v, e in self._concretization_functions
        ]

        for operator in self.op_specs:

            # Compute order constraints between statements:
            #   v1 -> v2 means that the update code for v1 should (if possible)
            #   appear before the update code for v2
            #   (i.e. the update code for v1 reads v2)
            def state_used_during_update(v1: EVar) -> [EVar]:
                v1_update_code = self.updates[(v1, operator.name)]
                v1_queries = list(self.queries_used_by(v1_update_code))
                res = OrderedSet()
                for q in v1_queries:
                    res |= state_read_by_query[q]
                return res

            g = DirectedGraph(nodes=_concretization_functions,
                              successors=state_used_during_update)

            # Find the minimum set of edges we need to break cycles (see
            # "feedback arc set problem")
            edges_to_break = g.minimum_feedback_arc_set()
            g.delete_edges(edges_to_break)
            _concretization_functions = list(g.toposort())

            # Lift auxiliary declarations as needed
            things_updated = []
            for v in _concretization_functions:
                things_updated.append(v)
                stm = updates[(v, operator.name)]

                def problematic(e):
                    for x in all_exps(e):
                        if isinstance(x, ECall) and x.func in [
                                q.name for q in self.query_specs
                        ]:
                            problems = set(
                                things_updated) & state_read_by_query[x.func]
                            if problems:
                                return True
                    return False

                stm = pull_temps(stm,
                                 decls_out=temps[operator.name],
                                 exp_is_bad=problematic)
                updates[(v, operator.name)] = stm

        # construct new op implementations
        new_ops = []
        for op in self.op_specs:

            stms = [updates[(v, op.name)] for v in _concretization_functions]
            stms.extend(hup
                        for ((t, op_name), hup) in self.handle_updates.items()
                        if op.name == op_name)
            new_stms = seq(temps[op.name] + stms)
            new_ops.append(Op(op.name, op.args, [], new_stms, op.docstring))

        # assemble final result
        return Spec(self.spec.name, self.spec.types, self.spec.extern_funcs,
                    [(v.id, e.type)
                     for (v, e) in self._concretization_functions], [],
                    list(self.query_impls.values()) + new_ops,
                    self.spec.header, self.spec.footer, self.spec.docstring)

    def cleanup(self):
        """
        Remove unused state, queries, and updates.
        """
        def deps(thing):
            if isinstance(thing, str):
                yield from free_vars(self.query_impls[thing])
            elif isinstance(thing, EVar):
                for op in self.op_specs:
                    yield self.updates[(thing, op.name)]
            elif isinstance(thing, Stm):
                yield from self.queries_used_by(thing)
            else:
                raise ValueError(repr(thing))

        g = DirectedGraph(nodes=itertools.chain(
            self.query_impls.keys(),
            (v for v, _ in self._concretization_functions),
            self.updates.values()),
                          successors=deps)
        roots = [
            q.name for q in self.query_specs
            if q.visibility == Visibility.Public
        ]
        roots.extend(
            itertools.chain(*[
                self.queries_used_by(code)
                for ((ht, op_name), code) in self.handle_updates.items()
            ]))
        queries_to_keep = set(q for q in g.reachable_nodes(roots)
                              if isinstance(q, str))

        # remove old specs
        for q in list(self.query_specs):
            if q.name not in queries_to_keep:
                self.query_specs.remove(q)

        # remove old implementations
        for qname in list(self.query_impls.keys()):
            if qname not in queries_to_keep:
                del self.query_impls[qname]

        # remove old state vars
        self._concretization_functions = [
            v for v in self._concretization_functions
            if any(v[0] in free_vars(q) for q in self.query_impls.values())
        ]

        # remove old method implementations
        for k in list(self.updates.keys()):
            v, op_name = k
            if v not in [var for (var, exp) in self._concretization_functions]:
                del self.updates[k]

    def queries_used_by(self, stm):
        for e in all_exps(stm):
            if isinstance(
                    e, ECall) and e.func in [q.name for q in self.query_specs]:
                yield e.func

    def states_maintained_by(self, q: Query) -> [EVar]:
        concrete_vars = []
        for (var_name, op_name), stm in self.updates.items():
            if q.name in self.queries_used_by(stm):
                concrete_vars.append(var_name)
        return concrete_vars