示例#1
0
def find_refinement(ast, state_map, lib, assumptions):
    assumptions = EAll(
        itertools.chain(assumptions, ast.assumptions,
                        (EEq(EVar(v).with_type(e.type), e)
                         for (v, e) in state_map.items())))
    for (v, t) in ast.statevars:
        refs = list(lib.impls(EVar(v).with_type(t), assumptions=assumptions))
        if not (len(refs) == 1 and refs[0] == t):
            return (v, refs)
    return None
示例#2
0
    def visit_EUnaryOp(self, e):
        op = e.op
        if op == UOp.Distinct:
            return self.visit_iterable(e)
        elif op == UOp.The:
            return self.find_one(e.e)
        elif op == UOp.Sum:
            sum_var = fresh_var(e.type, "sum")
            loop_var = fresh_var(e.e.type.elem_type, "x")
            self.stms.append(simplify_and_optimize(seq([
                SDecl(sum_var, ENum(0).with_type(e.type)),
                SForEach(loop_var, e.e,
                    SAssign(sum_var, EBinOp(sum_var, "+", loop_var).with_type(INT)))])))
            return sum_var
        elif op == UOp.Length:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Sum, EMap(e.e, ELambda(arg, ONE)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.All:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Empty, EFilter(e.e, ELambda(arg, ENot(arg))).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Any:
            arg = EVar("x").with_type(e.e.type.elem_type)
            return self.visit(EUnaryOp(UOp.Exists, EFilter(e.e, ELambda(arg, arg)).with_type(INT_BAG)).with_type(INT))
        elif op == UOp.Empty:
            iterable = e.e
            v = fresh_var(BOOL, "v")
            label = fresh_name("label")
            x = fresh_var(iterable.type.elem_type, "x")
            decl = SDecl(v, ETRUE)
            find = SEscapableBlock(label,
                SForEach(x, iterable, seq([
                    SAssign(v, EFALSE),
                    SEscapeBlock(label)])))
            self.stms.append(simplify_and_optimize(seq([decl, find])))
            return v
        elif op == UOp.Exists:
            return self.visit(ENot(EUnaryOp(UOp.Empty, e.e).with_type(BOOL)))
        # elif op == UOp.AreUnique:
        #     s = fresh_var(TSet(e.e.type.elem_type), "unique_elems")
        #     u = fresh_var(BOOL, "is_unique")
        #     x = fresh_var(e.e.type.elem_type)
        #     label = fresh_name("label")
        #     self.visit(seq([
        #         SDecl(s, EEmptyList().with_type(s.type)),
        #         SDecl(u, ETRUE),
        #         SEscapableBlock(label,
        #             SForEach(x, e.e,
        #                 SIf(EEscape("{s}.find({x}) != {s}.end()", ("s", "x"), (s, x)).with_type(BOOL),
        #                     seq([SAssign(u, EFALSE), SEscapeBlock(label)]),
        #                     SEscape("{indent}{s}.insert({x});\n", ("s", "x"), (s, x)))))]))
        #     return u.id

        return self.visit_Exp(e)
示例#3
0
def pull_temps(s: Stm, decls_out: [SDecl], exp_is_bad) -> Stm:
    def pull(e: Exp) -> Exp:
        if exp_is_bad(e):
            v = fresh_var(e.type)
            decls_out.append(SDecl(v.id, e))
            return v
        return e

    if isinstance(s, SNoOp):
        return s
    if isinstance(s, SSeq):
        s1 = pull_temps(s.s1, decls_out, exp_is_bad)
        s2 = pull_temps(s.s2, decls_out, exp_is_bad)
        return SSeq(s1, s2)
    if isinstance(s, SIf):
        cond = pull(s.cond)
        s1 = pull_temps(s.then_branch, decls_out, exp_is_bad)
        s2 = pull_temps(s.else_branch, decls_out, exp_is_bad)
        return SIf(cond, s1, s2)
    if isinstance(s, SForEach):
        bag = pull(s.iter)
        d_tmp = []
        body = pull_temps(s.body, d_tmp, exp_is_bad)
        to_fix, ok = partition(d_tmp, lambda d: s.id in free_vars(d.val))
        decls_out.extend(ok)
        for d in to_fix:
            v = EVar(d.id).with_type(d.val.type)
            mt = TMap(s.id.type, v.type)
            m = EMakeMap2(bag, ELambda(s.id, d.val)).with_type(mt)
            mv = fresh_var(m.type)
            md = SDecl(mv.id, m)
            decls_out.append(md)
            body = subst(body, {v.id: EMapGet(mv, s.id).with_type(v.type)})
        return SForEach(s.id, bag, body)
    if isinstance(s, SAssign):
        return SAssign(s.lhs, pull(s.rhs))
    if isinstance(s, SCall):
        return SCall(s.target, s.func, tuple(pull(arg) for arg in s.args))
    if isinstance(s, SMapDel):
        return SMapDel(s.map, pull(s.key))
    if isinstance(s, SMapPut):
        return SMapPut(s.map, pull(s.key), pull(s.value))
    if isinstance(s, SMapUpdate):
        key = pull(s.key)
        d_tmp = []
        change = pull_temps(s.change, d_tmp, exp_is_bad)
        for d in d_tmp:
            if s.val_var in free_vars(d.val):
                decls_out.append(
                    SDecl(
                        d.id,
                        subst(
                            d.val, {
                                s.val_var.id:
                                EMapGet(s.map, key).with_type(s.val_var.type)
                            })))
            else:
                decls_out.append(d)
        return SMapUpdate(s.map, key, s.val_var, change)
    raise NotImplementedError(s)
示例#4
0
文件: java.py 项目: wcphkust/cozy
 def visit_Query(self, q):
     if q.visibility != Visibility.Public:
         return ""
     ret_type = q.ret.type
     if is_collection(ret_type):
         x = EVar(self.fn("x")).with_type(ret_type.elem_type)
         def body(x):
             return SEscape("{indent}_callback.accept({x});\n", ["x"], [x])
         if q.docstring:
             self.write(indent_lines(q.docstring, self.get_indent()), "\n")
         self.begin_statement()
         self.write("public ", self.visit(TNative("void"), q.name), "(")
         self.visit_args(itertools.chain(q.args, [("_callback", TNative("java.util.function.Consumer<{t}>".format(t=self.visit(ret_type.elem_type, ""))))]))
         self.write(") ")
         with self.block():
             self.visit(simplify_and_optimize(SForEach(x, q.ret, SEscape("{indent}_callback.accept({x});\n", ["x"], [x]))))
     else:
         if q.docstring:
             self.write(indent_lines(q.docstring, self.get_indent()), "\n")
         self.begin_statement()
         self.write("public ", self.visit(ret_type, q.name), "(")
         self.visit_args(q.args)
         self.write(") ")
         with self.block():
             self.visit(simplify_and_optimize(SReturn(q.ret)))
     self.end_statement()
示例#5
0
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = T):
    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False
        args = FrozenDict(q1args)

        key = (args, assumptions)
        checker = _qe_cache.get(key)
        if checker is None:
            checker = ModelCachingSolver(vars=list(
                itertools.chain(state_vars, (EVar(v).with_type(t)
                                             for v, t in args.items()))),
                                         funcs=extern_funcs,
                                         assumptions=assumptions)
            _qe_cache[key] = checker

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
示例#6
0
文件: misc.py 项目: wcphkust/cozy
def queries_equivalent(q1: Query,
                       q2: Query,
                       state_vars: [EVar],
                       extern_funcs: {str: TFunc},
                       assumptions: Exp = ETRUE):
    """Determine whether two queries always return the same result.

    This function also checks that the two queries have semantically equivalent
    preconditions.  Checking the preconditions is necessary to ensure semantic
    equivalence of the queries: a query object should be interpreted to mean
    "if my preconditions hold then I compute and return my body expression".
    If two queries do not have semantically equivalent preconditions, then
    there might be cases where one is obligated to return a value and the other
    has no defined behavior.
    """

    with task("checking query equivalence", q1=q1.name, q2=q2.name):
        if q1.ret.type != q2.ret.type:
            return False
        q1args = dict(q1.args)
        q2args = dict(q2.args)
        if q1args != q2args:
            return False

        checker = solver_for_context(context=RootCtx(
            state_vars=state_vars,
            args=[EVar(a).with_type(t) for (a, t) in q1.args],
            funcs=extern_funcs),
                                     assumptions=assumptions)

        q1a = EAll(q1.assumptions)
        q2a = EAll(q2.assumptions)
        return checker.valid(EEq(q1a, q2a)) and checker.valid(
            EImplies(q1a, EEq(q1.ret, q2.ret)))
示例#7
0
    def subset_of(self, other) -> bool:
        """Test for subset inclusion.

        If this returns True, then it could be the case that an expression with
        this fingerprint always returns a strict subset of the elements that
        would be returned by an expression with the other fingerprint.
        """
        if not is_collection(self.type):
            raise ValueError(
                "this fingerprint is not for a collection-type expression")
        if not is_collection(other.type):
            raise ValueError(
                "other fingerprint is not for a collection-type expression")
        self._require_comparable_to(other)
        x = EVar("x").with_type(self.type)
        y = EVar("y").with_type(other.type)
        is_subset = EIsSubset(x, y)
        return all(
            eval_bulk(is_subset, [{
                x.id: a,
                y.id: b
            } for (a, b) in zip(self.outputs, other.outputs)]))
示例#8
0
        def reconcile_jobs():
            """Sync up the current set of jobs and the set of queries.

            This function spawns new jobs for new queries and cleans up old
            jobs whose queries have been dead-code-eliminated."""

            # figure out what new jobs we need
            job_query_names = set(j.q.name for j in improvement_jobs)
            new = []
            for q in impl.query_specs:
                if q.name not in job_query_names:
                    states_maintained_by_q = impl.states_maintained_by(q)
                    print("STARTING IMPROVEMENT JOB {}".format(q.name))
                    new.append(
                        ImproveQueryJob(
                            impl.abstract_state,
                            list(impl.spec.assumptions) + list(q.assumptions),
                            q,
                            context=impl.context_for_method(q),
                            solutions_q=solutions_q.handle_for_subjobs(),
                            hints=[
                                EStateVar(c).with_type(c.type) for c in
                                impl.concretization_functions.values()
                            ],
                            freebies=[
                                e for (v, e) in
                                impl.concretization_functions.items()
                                if EVar(v) in states_maintained_by_q
                            ],
                            ops=impl.op_specs,
                            improve_count=improve_count))

            # figure out what old jobs we can stop
            impl_query_names = set(q.name for q in impl.query_specs)
            old = [
                j for j in improvement_jobs if j.q.name not in impl_query_names
            ]

            # make it so
            stop_jobs(old)
            for j in new:
                j.start()
            improvement_jobs.extend(new)
示例#9
0
    def visit_Spec(self,
                   spec: Spec,
                   state_exps: {str: Exp},
                   sharing,
                   abstract_state=()):
        self.state_exps = state_exps
        self.funcs = {f.name: f for f in spec.extern_funcs}
        self.queries = {
            q.name: q
            for q in spec.methods if isinstance(q, Query)
        }
        self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar))

        self.write("#pragma once\n")
        self.write("#include <algorithm>\n")
        self.write("#include <set>\n")
        self.write("#include <functional>\n")
        self.write("#include <vector>\n")
        self.write("#include <unordered_set>\n")
        self.write("#include <string>\n")
        if self.use_qhash:
            self.write("#include <QHash>\n")
        else:
            self.write("#include <unordered_map>\n")

        if spec.header:
            self.write("\n" + spec.header.strip() + "\n")

        self.write("{}\nclass {} {{\n".format(
            ("\n" + spec.docstring) if spec.docstring else "", spec.name))

        self.write("public:\n")

        print("Setting up auxiliary types...")
        self.setup_types(spec, state_exps, sharing)
        with self.indented():
            for t, name in self.types.items():
                self.define_type(spec.name, t, name, sharing)
                self.begin_statement()
                if isinstance(t, THandle):
                    # No overridden hash code! We use pointers instead.
                    continue
                self.write("struct _Hash", name, " ")
                with self.block():
                    self.write_stmt("typedef ", spec.name, "::", name,
                                    " argument_type;")
                    self.write_stmt("typedef std::size_t result_type;")
                    self.begin_statement()
                    self.write(
                        "result_type operator()(const argument_type& x) const noexcept "
                    )
                    x = EVar("x").with_type(t)
                    if isinstance(t, TEnum):
                        fields = [EEnumToInt(x).with_type(INT)]
                    elif isinstance(t, TRecord):
                        fields = [
                            EGetField(x, f).with_type(ft)
                            for (f, ft) in t.fields
                        ]
                    elif isinstance(t, TTuple):
                        fields = [
                            ETupleGet(x, n).with_type(tt)
                            for (n, tt) in enumerate(t.ts)
                        ]
                    else:
                        raise NotImplementedError(t)
                    with self.block():
                        self.visit(self.compute_hash(fields))
                    self.end_statement()
                self.write(";")
                self.end_statement()

        print("Setting up member variables...")
        self.write("protected:\n")
        with self.indented():
            for name, t in spec.statevars:
                self.statevar_name = name
                self.declare_field(name, t)

        self.write("public:\n")
        with self.indented():
            print("Generating constructors...")

            # default constructor
            self.begin_statement()
            self.write("inline ", spec.name, "() ")
            with self.block():
                for name, t in spec.statevars:
                    initial_value = state_exps[name]
                    fvs = free_vars(initial_value)
                    initial_value = subst(initial_value, {
                        v.id: evaluation.construct_value(v.type)
                        for v in fvs
                    })
                    stm = simplify_and_optimize(
                        SAssign(EVar(name).with_type(t), initial_value))
                    self.visit(stm)
            self.end_statement()

            # explicit constructor
            if abstract_state:
                self.begin_statement()
                self.write("explicit inline ", spec.name, "(")
                self.visit_args(abstract_state)
                self.write(") ")
                with self.block():
                    for name, t in spec.statevars:
                        initial_value = state_exps[name]
                        self.visit(
                            simplify_and_optimize(
                                SAssign(
                                    EVar(name).with_type(t), initial_value)))
                self.end_statement()

            # disable copy constructor (TODO: support this in the future?)
            self.begin_statement()
            self.write(spec.name, "(const ", spec.name, "& other) = delete;")
            self.end_statement()

            # generate methods
            for op in spec.methods:
                print("Generating method {}...".format(op.name))
                self.visit(op)
        self.write("};\n")

        if spec.footer:
            self.write("\n", spec.footer)
            if not spec.footer.endswith("\n"):
                self.write("\n")
示例#10
0
    def define_type(self, toplevel_name, t, name, sharing):
        if isinstance(t, TEnum):
            self.begin_statement()
            self.write("enum ", name, " ")
            with self.block():
                for case in t.cases:
                    self.begin_statement()
                    self.write(case, ",")
                    self.end_statement()
            self.write(";")
            self.end_statement()
        elif isinstance(t, THandle):
            fields = [("val", t.value_type)]
            self.begin_statement()
            self.write("struct ", name, " ")
            with self.block():
                with self.deindented():
                    self.write_stmt("public:")
                for (f, ft) in fields:
                    self.declare_field(f, ft)
                with self.deindented():
                    self.write_stmt("private:")
            self.write(";")
            self.end_statement()

            # s = "struct {name} {{\n".format(indent=indent, name=name)
            # s += "public:\n".format(indent=indent)
            # for (f, ft) in fields:
            #     s += self.declare_field(f, ft)
            # s += "private:\n".format(indent=indent)
            # s += "friend class {toplevel_name};\n".format(indent=indent+INDENT, toplevel_name=toplevel_name)
            # for group in sharing.get(t, []):
            #     s += "union {{\n".format(indent=indent+INDENT)
            #     for gt in group:
            #         intrusive_data = gt.intrusive_data(t)
            #         s += "struct {{\n".format(indent=indent+INDENT*2)
            #         for (f, ft) in intrusive_data:
            #             s += "{field_decl};\n".format(indent=indent+INDENT*3, field_decl=self.visit(ft, f))
            #         s += "}};\n".format(indent=indent+INDENT*2)
            #     s += "}};\n".format(indent=indent+INDENT)
            # s += "}};\n".format(indent=indent)
            # return s
        elif isinstance(t, TRecord):
            self.begin_statement()
            self.write("struct ", name, " ")
            with self.block():

                # TODO: sort fields by size descending for better packing
                for f, ft in t.fields:
                    self.declare_field(f, ft)

                self.write_stmt("inline ", name, "() { }")
                self.begin_statement()
                self.write("inline ", name, "(")
                self.visit_args([("_" + f, t) for (f, t) in t.fields])
                self.write(") : ")
                for i, (f, ft) in enumerate(t.fields):
                    if i > 0:
                        self.write(", ")
                    self.write(f, "(::std::move(_", f, "))")
                self.write(" { }")
                self.end_statement()

                self.begin_statement()
                self.write("inline bool operator==(const ", name,
                           "& other) const ")
                with self.block():
                    this = EEscape("(*this)", (), ()).with_type(t)
                    other = EVar("other").with_type(t)
                    r = self.visit(
                        EAll([
                            EEq(
                                EGetField(this, f).with_type(ft),
                                EGetField(other, f).with_type(ft))
                            for f, ft in t.fields
                        ]))
                    self.begin_statement()
                    self.write("return ", r, ";")
                    self.end_statement()
                self.end_statement()
            self.write(";")
            self.end_statement()
        elif isinstance(t, TTuple):
            return self.define_type(
                toplevel_name,
                TRecord(
                    tuple(
                        ("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))),
                name, sharing)
        else:
            return ""
示例#11
0
文件: java.py 项目: wcphkust/cozy
    def visit_Spec(self, spec, state_exps, sharing, abstract_state=()):
        self.state_exps = state_exps
        self.funcs = { f.name: f for f in spec.extern_funcs }
        self.queries = { q.name: q for q in spec.methods if isinstance(q, Query) }
        self.vars = set(e.id for e in all_exps(spec) if isinstance(e, EVar))
        self.setup_types(spec, state_exps, sharing)

        if guava.value:
            self.write("import com.google.common.collect.TreeMultiset;\n")
            self.write("import com.google.common.collect.Iterators;\n")
        if spec.header:
            self.write(spec.header.strip() + "\n\n")

        if spec.docstring:
            self.write(spec.docstring + "\n")

        self.write("public class {} implements java.io.Serializable ".format(spec.name))
        with self.block():

            for name, t in spec.types:
                self.types[t] = name

            # member variables
            for name, t in spec.statevars:
                self.write("{}protected {};\n".format(INDENT, self.visit(t, name)))

            # constructor
            self.write(
                "{indent}public {name}() {{\n{indent2}clear();\n{indent}}}\n\n"
                .format(indent=INDENT, indent2=INDENT+INDENT, name=spec.name))

            # explicit constructor
            if abstract_state:
                self.begin_statement()
                self.write("public ", spec.name, "(")
                self.visit_args(abstract_state)
                self.write(") ")
                with self.block():
                    for name, t in spec.statevars:
                        initial_value = state_exps[name]
                        self.visit(simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value)))
                self.end_statement()

            # clear
            self.begin_statement()
            self.write("public void clear() ")
            with self.block():
                for name, t in spec.statevars:
                    initial_value = state_exps[name]
                    fvs = free_vars(initial_value)
                    initial_value = subst(initial_value,
                        {v.id : evaluation.construct_value(v.type) for v in fvs})
                    setup = simplify_and_optimize(SAssign(EVar(name).with_type(t), initial_value))
                    self.visit(setup)
            self.end_statement()

            # methods
            for op in spec.methods:
                self.visit(op)

            # generate auxiliary types
            for t, name in self.types.items():
                self.define_type(spec.name, t, name, sharing)

        self.write("\n")
        self.write(spec.footer)
        if not spec.footer.endswith("\n"):
            self.write("\n")
示例#12
0
文件: java.py 项目: wcphkust/cozy
    def define_type(self, toplevel_name, t, name, sharing):
        if isinstance(t, TEnum):
            self.begin_statement()
            self.write("public enum ", name, " ")
            with self.block():
                for case in t.cases:
                    self.begin_statement()
                    self.write(case)
                    self.end_statement()
            self.end_statement()
        elif isinstance(t, THandle) or isinstance(t, TRecord):
            public_fields = []
            private_fields = []
            value_equality = True
            handle_val_is_this = False
            if isinstance(t, THandle):
                if isinstance(t.value_type, TRecord):
                    handle_val_is_this = True
                else:
                    public_fields = [("val", t.value_type)]
                value_equality = False
                for group in sharing.get(t, []):
                    for gt in group:
                        intrusive_data = gt.intrusive_data(t)
                        for (f, ft) in intrusive_data:
                            private_fields.append((f, ft))
            else:
                public_fields = list(t.fields)
            all_fields = public_fields + private_fields
            self.begin_statement()
            self.write("public static class ", name)
            if handle_val_is_this:
                self.write(" extends ", self.visit(t.value_type, ""))
            self.write(" implements java.io.Serializable ")
            with self.block():
                for (f, ft) in public_fields + private_fields:
                    self.begin_statement()
                    self.write("private {field_decl};".format(field_decl=self.visit(ft, f)))
                    self.end_statement()
                for (f, ft) in public_fields:
                    self.begin_statement()
                    self.write("public {type} get{Field}() {{ return {field}; }}".format(
                        type=self.visit(ft, ""),
                        Field=common.capitalize(f),
                        field=f))
                    self.end_statement()
                if handle_val_is_this:
                    self.begin_statement()
                    self.write("public {type} getVal() {{ return this; }}".format(
                        type=self.visit(t.value_type, "")))
                    self.end_statement()

                def flatten(field_types):
                    args = []
                    exps = []
                    for ft in field_types:
                        if isinstance(ft, TRecord):
                            aa, ee = flatten([t for (f, t) in ft.fields])
                            args.extend(aa)
                            exps.append(EMakeRecord(tuple((f, e) for ((f, _), e) in zip(ft.fields, ee))).with_type(ft))
                        elif isinstance(ft, TTuple):
                            aa, ee = flatten(ft.ts)
                            args.extend(aa)
                            exps.append(ETuple(tuple(ee)).with_type(ft))
                        else:
                            v = self.fv(ft, "v")
                            args.append((v.id, ft))
                            exps.append(v)
                    return args, exps

                if isinstance(t, THandle):
                    args, exps = flatten([ft for (f, ft) in (t.value_type.fields if handle_val_is_this else public_fields)])
                else:
                    args = public_fields
                    exps = [EVar(f) for (f, ft) in args]
                self.begin_statement()
                self.write("public {ctor}({args}) ".format(ctor=name, args=", ".join(self.visit(ft, f) for (f, ft) in args)))
                with self.block():
                    if handle_val_is_this:
                        es = [self.visit(e) for e in exps]
                        self.begin_statement()
                        self.write("super({args});\n".format(
                            args=", ".join(es)))
                    for ((f, ft), e) in zip(public_fields, exps):
                        e = self.visit(e)
                        self.begin_statement()
                        self.write("this.{f} = {e};\n".format(f=f, e=e))
                self.end_statement()

                if value_equality:
                    self.begin_statement()
                    self.write("@Override")
                    self.end_statement()
                    self.begin_statement()
                    self.write("public int hashCode() ")
                    with self.block():
                        (compute, hc) = self.compute_hash(public_fields + private_fields)
                        self.write(compute)
                        self.begin_statement()
                        self.write("return ", hc, ";")
                        self.end_statement()
                    self.end_statement()

                    self.begin_statement()
                    self.write("@Override")
                    self.end_statement()
                    self.begin_statement()
                    self.write("public boolean equals(Object other) ")
                    with self.block():
                        self.write(self.get_indent(), "if (other == null) return false;\n")
                        self.write(self.get_indent(), "if (other == this) return true;\n")
                        self.write(self.get_indent(), "if (!(other instanceof {name})) return false;\n".format(name=name))
                        self.write(self.get_indent(), "{name} o = ({name})other;\n".format(name=name))
                        eq = self.visit(EAll([EEq(
                            EEscape("this.{}".format(f), (), ()).with_type(ft),
                            EEscape("o.{}".format(f),    (), ()).with_type(ft))
                            for (f, ft) in all_fields]))
                        self.write(self.get_indent(), "return ", eq, ";\n")
                    self.end_statement()
            self.end_statement()
        elif isinstance(t, TTuple):
            return self.define_type(toplevel_name, TRecord(tuple(("_{}".format(i), t.ts[i]) for i in range(len(t.ts)))), name, sharing)
        else:
            return ""
示例#13
0
 def context_for_method(self, m: Method) -> Context:
     """Construct a context describing expressions in the given method."""
     return RootCtx(state_vars=self.abstract_state,
                    args=[EVar(a).with_type(t) for (a, t) in m.args],
                    funcs=self.extern_funcs)
示例#14
0
 def abstract_state(self) -> [EVar]:
     """Returns the abstract state of this data structure."""
     return [EVar(name).with_type(t) for (name, t) in self.spec.statevars]
示例#15
0
def search_for_improvements(targets: [Exp], wf_solver: ModelCachingSolver,
                            context: Context, examples: [{
                                str: object
                            }], cost_model: CostModel,
                            stop_callback: Callable[[], bool], hints: [Exp],
                            ops: [Op], blacklist: {
                                (Exp, Context, Pool, Exp): str
                            }):
    """Search for potential improvements to any of the target expressions.

    This function yields expressions that look like improvements (or are
    ambiguous with respect to some target).  The expressions are only
    guaranteed to be correct on the given examples.

    This function may add new items to the given blacklist.
    """

    root_ctx = context

    def check_wf(e, ctx, pool):
        with task("pruning", size=e.size()):
            is_wf = exp_wf(e, pool=pool, context=ctx, solver=wf_solver)
            if not is_wf:
                return is_wf
            res = possibly_useful(wf_solver, e, ctx, pool, ops=ops)
            if not res:
                return res
            if cost_pruning.value and pool == RUNTIME_POOL and cost_model.compare(
                    e, targets[0], ctx, pool) == Order.GT:
                return No("too expensive")
            return True

    with task("setting up hints"):
        frags = list(
            unique(
                itertools.chain(
                    *[
                        all_subexpressions_with_context_information(
                            t, root_ctx) for t in targets
                    ], *[
                        all_subexpressions_with_context_information(
                            h, root_ctx) for h in hints
                    ])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=examples,
                          cost_model=cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=stop_callback,
                          do_eviction=enable_eviction.value)

    target_fp = Fingerprint.of(targets[0], examples)

    with task("setting up watches"):
        watches_by_context = OrderedDict()
        for target in targets:
            for e, ctx, pool in unique(
                    all_subexpressions_with_context_information(
                        target, context=root_ctx, pool=RUNTIME_POOL)):
                l = watches_by_context.get(ctx)
                if l is None:
                    l = []
                    watches_by_context[ctx] = l
                l.append((target, e, pool))

        watches = OrderedDict()
        for ctx, exprs in watches_by_context.items():
            exs = ctx.instantiate_examples(examples)
            for target, e, pool in exprs:
                fp = Fingerprint.of(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))

        watched_ctxs = list(
            unique(
                (ctx, pool)
                for _, _, ctx, pool in exploration_order(targets, root_ctx)))

    search_info = SearchInfo(context=root_ctx,
                             targets=targets,
                             target_fingerprint=target_fp,
                             examples=examples,
                             check_wf=check_wf,
                             cost_model=cost_model,
                             blacklist=blacklist)

    size = 0
    while True:

        print("starting minor iteration {} with |cache|={}".format(
            size, enum.cache_size()))
        if stop_callback():
            raise StopException()

        for ctx, pool in watched_ctxs:
            with task("searching for obvious substitutions",
                      ctx=ctx,
                      pool=pool_name(pool)):
                for info in enum.enumerate_with_info(size=size,
                                                     context=ctx,
                                                     pool=pool):
                    with task("searching for obvious substitution",
                              expression=pprint(info.e)):
                        fp = info.fingerprint
                        for ((fpx, cc, pp), reses) in watches.items():
                            if cc != ctx or pp != pool:
                                continue

                            if not fpx.equal_to(fp):
                                continue

                            for target, watched_e in reses:
                                replacement = info.e
                                event(
                                    "possible substitution: {} ---> {}".format(
                                        pprint(watched_e),
                                        pprint(replacement)))
                                event("replacement locations: {}".format(
                                    pprint(
                                        replace(target, root_ctx, RUNTIME_POOL,
                                                watched_e, ctx, pool,
                                                EVar("___")))))

                                if alpha_equivalent(watched_e, replacement):
                                    event("no change")
                                    continue

                                yield from _consider_replacement(
                                    target, watched_e, ctx, pool, replacement,
                                    search_info)

        if check_blind_substitutions.value:
            print("Guessing at substitutions...")
            for target, e, ctx, pool in exploration_order(targets, root_ctx):
                with task("checking substitutions",
                          target=pprint(
                              replace(target, root_ctx, RUNTIME_POOL, e, ctx,
                                      pool, EVar("___"))),
                          e=pprint(e)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("checking substitution",
                                  expression=pprint(info.e)):
                            if stop_callback():
                                raise StopException()
                            replacement = info.e
                            if replacement.type != e.type:
                                event("wrong type (is {}, need {})".format(
                                    pprint(replacement.type), pprint(e.type)))
                                continue
                            if alpha_equivalent(replacement, e):
                                event("no change")
                                continue
                            should_consider = should_consider_replacement(
                                target, root_ctx, e, ctx, pool,
                                Fingerprint.of(
                                    e, ctx.instantiate_examples(examples)),
                                info.e, info.fingerprint)
                            if not should_consider:
                                event(
                                    "skipped; `should_consider_replacement` returned {}"
                                    .format(should_consider))
                                continue

                            yield from _consider_replacement(
                                target, e, ctx, pool, replacement, search_info)

        if not enum.expressions_may_exist_above_size(context, RUNTIME_POOL,
                                                     size):
            raise StopException(
                "no more expressions can exist above size={}".format(size))

        size += 1
示例#16
0
    def search(self):

        root_ctx = self.context

        def check_wf(e, ctx, pool):
            with task("checking well-formedness", size=e.size()):
                is_wf = exp_wf(e,
                               pool=pool,
                               context=ctx,
                               solver=self.wf_solver)
                if not is_wf:
                    return is_wf
                res = good_idea_recursive(self.wf_solver,
                                          e,
                                          ctx,
                                          pool,
                                          ops=self.ops)
                if not res:
                    return res
                if pool == RUNTIME_POOL and self.cost_model.compare(
                        e, self.targets[0], ctx, pool) == Order.GT:
                    return No("too expensive")
                return True

        frags = list(
            unique(
                itertools.chain(*[shred(t, root_ctx) for t in self.targets],
                                *[shred(h, root_ctx) for h in self.hints])))
        frags.sort(key=hint_order)
        enum = Enumerator(examples=self.examples,
                          cost_model=self.cost_model,
                          check_wf=check_wf,
                          hints=frags,
                          heuristics=try_optimize,
                          stop_callback=self.stop_callback,
                          do_eviction=enable_eviction.value)

        size = 0
        target_fp = fingerprint(self.targets[0], self.examples)

        watches = OrderedDict()
        for target in self.targets:
            for e, ctx, pool in unique(
                    shred(target, context=root_ctx, pool=RUNTIME_POOL)):
                exs = ctx.instantiate_examples(self.examples)
                fp = fingerprint(e, exs)
                k = (fp, ctx, pool)
                l = watches.get(k)
                if l is None:
                    l = []
                    watches[k] = l
                l.append((target, e))
        watched_ctxs = list(
            unique((ctx, pool) for fp, ctx, pool in watches.keys()))

        def consider_new_target(old_target, e, ctx, pool, replacement):
            nonlocal n
            n += 1
            k = (e, ctx, pool, replacement)
            if enable_blacklist.value and k in self.blacklist:
                event("blacklisted")
                print("skipping blacklisted substitution: {} ---> {} ({})".
                      format(pprint(e), pprint(replacement),
                             self.blacklist[k]))
                return
            new_target = freshen_binders(
                replace(target, root_ctx, RUNTIME_POOL, e, ctx, pool,
                        replacement), root_ctx)
            if any(alpha_equivalent(t, new_target) for t in self.targets):
                event("already seen")
                return
            wf = check_wf(new_target, root_ctx, RUNTIME_POOL)
            if not wf:
                msg = "not well-formed [wf={}]".format(wf)
                event(msg)
                self.blacklist[k] = msg
                return
            if not fingerprints_match(fingerprint(new_target, self.examples),
                                      target_fp):
                msg = "not correct"
                event(msg)
                self.blacklist[k] = msg
                return
            if self.cost_model.compare(new_target, target, root_ctx,
                                       RUNTIME_POOL) not in (Order.LT,
                                                             Order.AMBIGUOUS):
                msg = "not an improvement"
                event(msg)
                self.blacklist[k] = msg
                return
            print("FOUND A GUESS AFTER {} CONSIDERED".format(n))
            print(" * in {}".format(pprint(old_target), pprint(e),
                                    pprint(replacement)))
            print(" * replacing {}".format(pprint(e)))
            print(" * with {}".format(pprint(replacement)))
            yield new_target

        while True:

            print("starting minor iteration {} with |cache|={}".format(
                size, enum.cache_size()))
            if self.stop_callback():
                raise StopException()

            n = 0

            for ctx, pool in watched_ctxs:
                with task("searching for obvious substitutions",
                          ctx=ctx,
                          pool=pool_name(pool)):
                    for info in enum.enumerate_with_info(size=size,
                                                         context=ctx,
                                                         pool=pool):
                        with task("searching for obvious substitution",
                                  expression=pprint(info.e)):
                            fp = info.fingerprint
                            for ((fpx, cc, pp), reses) in watches.items():
                                if cc != ctx or pp != pool:
                                    continue

                                if not fingerprints_match(fpx, fp):
                                    continue

                                for target, watched_e in reses:
                                    replacement = info.e
                                    event("possible substitution: {} ---> {}".
                                          format(pprint(watched_e),
                                                 pprint(replacement)))
                                    event("replacement locations: {}".format(
                                        pprint(
                                            replace(target, root_ctx,
                                                    RUNTIME_POOL, watched_e,
                                                    ctx, pool, EVar("___")))))

                                    if alpha_equivalent(
                                            watched_e, replacement):
                                        event("no change")
                                        continue

                                    yield from consider_new_target(
                                        target, watched_e, ctx, pool,
                                        replacement)

            if check_all_substitutions.value:
                print("Guessing at substitutions...")
                for target, e, ctx, pool in exploration_order(
                        self.targets, root_ctx):
                    with task("checking substitutions",
                              target=pprint(
                                  replace(target, root_ctx, RUNTIME_POOL, e,
                                          ctx, pool, EVar("___"))),
                              e=pprint(e)):
                        for info in enum.enumerate_with_info(size=size,
                                                             context=ctx,
                                                             pool=pool):
                            with task("checking substitution",
                                      expression=pprint(info.e)):
                                if self.stop_callback():
                                    raise StopException()
                                replacement = info.e
                                if replacement.type != e.type:
                                    event("wrong type (is {}, need {})".format(
                                        pprint(replacement.type),
                                        pprint(e.type)))
                                    continue
                                if alpha_equivalent(replacement, e):
                                    event("no change")
                                    continue
                                should_consider = should_consider_replacement(
                                    target, root_ctx, e, ctx, pool,
                                    fingerprint(
                                        e,
                                        ctx.instantiate_examples(
                                            self.examples)), info.e,
                                    info.fingerprint)
                                if not should_consider:
                                    event(
                                        "skipped; `should_consider_replacement` returned {}"
                                        .format(should_consider))
                                    continue

                                yield from consider_new_target(
                                    target, e, ctx, pool, replacement)

            print("CONSIDERED {}".format(n))
            size += 1