예제 #1
0
def decompose_boolean_combination(e):
    if exprs.is_application_of(e, 'and') or exprs.is_application_of(e, 'or'):
        ret = []
        for child in e.children:
            ret.extend(decompose_boolean_combination(child))
        return ret
    elif exprs.is_application_of(e, 'not'):
        return decompose_boolean_combination(e.children[0])
    else:
        return [e]
예제 #2
0
    def _verify_expr(self, term):
        smt_ctx = self.smt_ctx
        smt_solver = self.smt_solver

        if len(self.synth_funs) == 1:
            smt_ctx.set_interpretation(self.synth_funs[0], term)
        else:
            assert exprs.is_application_of(term, ',')
            for f, t in zip(self.synth_funs, term.children):
                smt_ctx.set_interpretation(f, t)
        eq_cnstr = _expr_to_smt(self.outvar_cnstr, smt_ctx)
        smt_solver.push()
        smt_solver.add(eq_cnstr)
        # print("1:", exprs.expression_to_string(self.canon_spec))
        # print("2:", smt_solver)
        r = smt_solver.check()
        # print("3:", smt_solver.model())
        smt_solver.pop()

        if (r == z3.sat):
            cex_point = model_to_point(smt_solver.model(),
                                       self.var_smt_expr_list,
                                       self.var_info_list)
            return [cex_point]
        else:
            return term
예제 #3
0
파일: benchmarks.py 프로젝트: jiry17/IntSy
def rewrite_solution(synth_funs, solution, reverse_mapping):
    # Rewrite any predicates introduced in grammar decomposition
    if reverse_mapping is not None:
        for function_info, cond, orig_expr_template, expr_template in reverse_mapping:
            while True:
                app = exprs.find_application(solution, function_info.function_name)
                if app is None:
                    break
                assert exprs.is_application_of(expr_template, 'ite')

                ite = exprs.parent_of(solution, app)
                ite_without_dummy = exprs.FunctionExpression(ite.function_info, (app.children[0], ite.children[1], ite.children[2]))
                var_mapping = exprs.match(expr_template, ite_without_dummy)
                new_ite = exprs.substitute_all(orig_expr_template, var_mapping.items())
                solution = exprs.substitute(solution, ite, new_ite)

    # Rewrite back into formal parameters
    if len(synth_funs) == 1:
        sols = [solution]
    else:
        # The solution will be a comma operator combination of solution 
        # to each function
        sols = solution.children

    rewritten_solutions = []
    for sol, synth_fun in zip(sols, synth_funs):
        variables = exprs.get_all_formal_parameters(sol)
        substitute_pairs = []
        orig_vars = synth_fun.get_named_vars()
        for v in variables:
            substitute_pairs.append((v, orig_vars[v.parameter_position]))
        sol = exprs.substitute_all(sol, substitute_pairs)
        rewritten_solutions.append(sol)

    return rewritten_solutions
    def term_signature(self, term, points):
        eval_ctx = self.eval_ctx
        if len(self.synth_funs) > 1:
            assert exprs.is_application_of(term, ',')
            interpretations = term.children
            for func, interpretation in zip(self.synth_funs, interpretations):
                eval_ctx.set_interpretation(func, interpretation)
        else:
            eval_ctx.set_interpretation(self.synth_funs[0], term)

        retval = []
        for point in points:
            eval_ctx.set_valuation_map(point)
            try:
                r = evaluation.evaluate_expression_raw(self.canon_spec,
                                                       eval_ctx)
                # print(exprs.expression_to_string(term), "is", r, "on", [ p.value_object for p in point ])
                # print(eval_ctx.eval_stack_top)
                retval.append(r)
            except (basetypes.PartialFunctionError,
                    basetypes.UnboundLetVariableError):
                # Exceptions may be raised when applying partial functions like div, mod, etc
                retval.append(False)

        return retval
예제 #5
0
    def _verify_expr(self, term):
        smt_ctx = self.smt_ctx
        smt_solver = self.smt_solver

        if len(self.synth_funs) == 1:
            smt_ctx.set_interpretation(self.synth_funs[0], term)
        else:
            assert exprs.is_application_of(term, ',')
            for f, t in zip(self.synth_funs, term.children):
                smt_ctx.set_interpretation(f, t)
        # print(_expr_to_str(self.neg_canon_spec))
        full_constraint = _expr_to_smt(self.neg_canon_spec, smt_ctx)

        smt_solver.push()
        smt_solver.add(full_constraint)
        r = smt_solver.check()
        smt_solver.pop()

        if (r == z3.sat):
            cex_point = model_to_point(smt_solver.model(),
                                       self.var_smt_expr_list,
                                       self.var_info_list)
            return [cex_point]
        else:
            return term
예제 #6
0
def get_terms(e):
    if not exprs.is_application_of(e, 'ite'):
        return [e]
    else:
        ret = []
        ret.extend(get_terms(e.children[1]))
        ret.extend(get_terms(e.children[2]))
        return ret
예제 #7
0
def get_preds(e):
    if not exprs.is_application_of(e, 'ite'):
        return []
    else:
        ret = [e.children[0]]
        ret.extend(get_preds(e.children[1]))
        ret.extend(get_preds(e.children[2]))
        return ret
예제 #8
0
파일: benchmarks.py 프로젝트: jiry17/IntSy
def get_pbe_valuations(constraints, synth_fun):
    valuations = []
    for constraint in constraints:
        if not exprs.is_application_of(constraint, 'eq') and \
                not exprs.is_application_of(constraint, '='):
            return None
        if len(exprs.get_all_variables(constraint)) > 0:
            return None
        arg_func, arg_other = None, None
        for a in constraint.children:
            if exprs.is_application_of(a, synth_fun):
                arg_func = a
            else:
                arg_other = a
        if arg_func is None or arg_other is None:
            return None
        valuations.append((arg_func.children, arg_other))
    return valuations
예제 #9
0
    def _verify_guard_term_list(self, guard_term_list, dt_tuple):
        smt_ctx = self.smt_ctx
        smt_solver = self.smt_solver
        intro_vars = self.smt_intro_vars
        cex_points = []
        selected_leaf_terms = []

        at_least_one_branch_failed = False
        for (pred, term_list) in guard_term_list:
            smt_solver.push()
            smt_pred = _expr_to_smt(pred, smt_ctx, intro_vars)
            # print('SMT guard')
            # print(smt_pred)
            smt_solver.add(smt_pred)
            all_terms_failed = True
            for term in term_list:
                # print('Verifying term')
                # print(_expr_to_str(term))
                # print('with guard')
                # print(_expr_to_str(pred))
                if len(self.synth_funs) == 1:
                    smt_ctx.set_interpretation(self.synth_funs[0], term)
                else:
                    assert exprs.is_application_of(term, ',')
                    for f, t in zip(self.synth_funs, term.children):
                        smt_ctx.set_interpretation(f, t)
                eq_cnstr = _expr_to_smt(self.outvar_cnstr, smt_ctx)
                # print('SMT constraint')
                # print(eq_cnstr)
                smt_solver.push()
                smt_solver.add(eq_cnstr)
                r = smt_solver.check()
                smt_solver.pop()
                if (r == z3.sat):
                    cex_points.append(
                        model_to_point(smt_solver.model(),
                                       self.var_smt_expr_list,
                                       self.var_info_list))
                else:
                    all_terms_failed = False
                    selected_leaf_terms.append(term)
                    break

            if (all_terms_failed):
                at_least_one_branch_failed = True
            smt_solver.pop()

        if (at_least_one_branch_failed):
            retval = list(set(cex_points))
            retval.sort()
            return retval
        else:
            (term_list, term_sig_list, pred_list, pred_sig_list, dt) = dt_tuple
            e = decision_tree_to_expr(dt, pred_list, self.syn_ctx,
                                      selected_leaf_terms)
            return e
예제 #10
0
def simplify_modus_ponens(syn_ctx, expr):
    if exprs.is_application_of(expr, 'ite'):
        [cond, e_then, e_else] = expr.children
        rcond = apply_modus_ponens(syn_ctx, cond)
        rthen = simplify_modus_ponens(syn_ctx, e_then)
        relse = simplify_modus_ponens(syn_ctx, e_else)
        ret = syn_ctx.make_function_expr('ite', rcond, rthen, relse)
        return ret
    else:
        return expr
예제 #11
0
def apply_modus_ponens(syn_ctx, pred):
    if exprs.is_application_of(pred, 'and'):
        children = pred.children
        new_children = [c for c in children]
        for child in children:
            if is_and_or(child):
                continue
            # Is atomic
            to_remove = []
            for nc in new_children:
                if exprs.is_application_of(nc, 'or') and any(
                        map(lambda e: exprs.equals(e, child), nc.children)):
                    to_remove.append(nc)
                else:
                    pass
            new_children = [nc for nc in new_children if nc not in to_remove]
        if len(new_children) == 1:
            return new_children[0]
        else:
            return syn_ctx.make_function_expr('and', *new_children)
    else:
        return pred
예제 #12
0
def rewrite_boolean_combs(syn_ctx, sol):
    import functools

    if not exprs.is_application_of(sol, 'ite'):
        return sol

    cond = sol.children[0]
    child1 = rewrite_boolean_combs(syn_ctx, sol.children[1])
    child2 = rewrite_boolean_combs(syn_ctx, sol.children[2])

    if not exprs.is_function_expression(cond):
        return syn_ctx.make_function_expr('ite', cond, child1, child2)
    fun = cond.function_info.function_name
    if fun not in ['and', 'or', 'not']:
        return syn_ctx.make_function_expr('ite', cond, child1, child2)

    if fun == 'not':
        return syn_ctx.make_function_expr('ite', cond.children[0], child2,
                                          child1)
    elif len(cond.children) == 1:
        return syn_ctx.make_function_expr('ite', cond.children[0], child1,
                                          child2)

    if fun == 'or':
        init = child2
        combine = lambda a, b: syn_ctx.make_function_expr('ite', b, child1, a)
        cond_children = cond.children
        if any([
                exprs.find_application(c, 'and') is not None
                or exprs.find_application(c, 'or') is not None
                for c in cond_children
        ]):
            ret = rewrite_boolean_combs(
                syn_ctx, functools.reduce(combine, cond.children, init))
        else:
            ret = functools.reduce(combine, cond.children, init)
        return ret
    else:
        init = child1
        combine = lambda a, b: syn_ctx.make_function_expr('ite', b, a, child2)
        cond_children = cond.children
        if any([
                exprs.find_application(c, 'and') is not None
                or exprs.find_application(c, 'or') is not None
                for c in cond_children
        ]):
            ret = rewrite_boolean_combs(
                syn_ctx, functools.reduce(combine, cond.children, init))
        else:
            ret = functools.reduce(combine, cond.children, init)
        return ret
예제 #13
0
    def _do_transform(expr, syn_ctx):
        if not exprs.is_function_expression(expr):
            return expr

        new_children = [
            LetFlattener._do_transform(child, syn_ctx)
            for child in expr.children
        ]
        if exprs.is_application_of(expr, 'let'):
            in_expr = new_children[-1]
            sub_pairs = list(
                zip(expr.function_info.binding_vars, new_children[:-1]))
            return exprs.substitute_all(in_expr, sub_pairs)
        else:
            return exprs.FunctionExpression(expr.function_info,
                                            tuple(new_children))
예제 #14
0
 def _flatten_and_or(self, expr_object, syn_ctx):
     kind = expr_object.expr_kind
     if (kind != exprs.ExpressionKinds.function_expression):
         return expr_object
     elif (not self._matches_expression_any(expr_object, 'and', 'or')):
         return expr_object
     else:
         func = expr_object.function_info
         new_children = []
         for child in expr_object.children:
             if not exprs.is_application_of(child, func):
                 new_children.append(self._flatten_and_or(child, syn_ctx))
             else:
                 childp = self._flatten_and_or(child, syn_ctx)
                 new_children.extend(childp.children)
         return syn_ctx.make_function_expr(func, *new_children)
예제 #15
0
    def verify_term_solve(self, terms):
        smt_ctx = self.smt_ctx
        smt_solver = self.smt_solver
        smt_solver.pop()

        eq_cnstrs = []
        for term in terms:
            if len(self.synth_funs) == 1:
                smt_ctx.set_interpretation(self.synth_funs[0], term)
            else:
                assert exprs.is_application_of(term, ',')
                for f, t in zip(self.synth_funs, term.children):
                    smt_ctx.set_interpretation(f, t)
            eq_cnstrs.append(_expr_to_smt(self.canon_spec, smt_ctx))
        eq_cnstr = z3.And(*[z3.Not(ec) for ec in eq_cnstrs], eq_cnstrs[0].ctx)

        # print("----------")
        # print(eq_cnstr)

        smt_solver.push()
        smt_solver.add(eq_cnstr)
        r = smt_solver.check()

        # print(smt_solver)
        # print(smt_solver.model())
        # print("----------")
        smt_solver.pop()

        smt_solver.push()
        smt_solver.add(self.frozen_smt_cnstr)

        if (r == z3.sat):
            cex_point = model_to_point(smt_solver.model(),
                                       self.var_smt_expr_list,
                                       self.var_info_list)
            return [cex_point]
        else:
            return None
예제 #16
0
def canonicalize_specification(expr, syn_ctx, theory):
    """Performs a bunch of operations:
    1. Checks that the expr is "well-bound" to the syn_ctx object.
    2. Checks that the specification has the single-invocation property.
    3. Gathers the set of synth functions (should be only one).
    4. Gathers the variables used in the specification.
    5. Converts the specification to CNF (as part of the single-invocation test)
    6. Given that the spec is single invocation, rewrites the CNF spec (preserving and sat)
       by introducing new variables that correspond to a uniform way of invoking the
       (single) synth function

    Returns a tuple containing:
    1. A list of 'variable_info' objects corresponding to the variables used in the spec
    2. A list of synth functions (should be a singleton list)
    3. A list of clauses corresponding to the CNF specification
    4. A list of NEGATED clauses
    5. A list containing the set of formal parameters that all appearances of the synth
       functions are invoked with.
    """
    check_expr_binding_to_context(expr, syn_ctx)
    clauses, cnf_expr = to_cnf(expr, theory, syn_ctx)

    synth_function_set = gather_synth_functions(expr)
    synth_function_list = list(synth_function_set)
    num_funs = len(synth_function_list)

    orig_variable_set = gather_variables(expr)
    orig_variable_list = [x.variable_info for x in orig_variable_set]
    orig_variable_list.sort(key=lambda x: x.variable_name)

    # check single invocation/separability properties
    if (not check_single_invocation_property(clauses, syn_ctx)):
        raise basetypes.ArgumentError('Spec:\n%s\nis not single-invocation!' %
                                      exprs.expression_to_string(expr))

    (intro_clauses,
     intro_vars) = _intro_new_universal_vars(clauses, syn_ctx,
                                             synth_function_list[0])

    # ensure that the intro_vars at the head of the list
    # Arjun: Why? Most likely not necessary
    variable_list = [x.variable_info for x in intro_vars] + orig_variable_list
    num_vars = len(variable_list)
    for i in range(num_vars):
        variable_list[i].variable_eval_offset = i
    num_funs = len(synth_function_list)
    for i in range(num_funs):
        synth_function_list[i].synth_function_id = i

    if len(intro_clauses) == 1:
        canon_spec = intro_clauses[0]
    else:
        canon_spec = syn_ctx.make_function_expr('and', *intro_clauses)

    canon_clauses = []
    for ic in intro_clauses:
        if exprs.is_application_of(ic, 'or'):
            disjuncts = ic.children
        else:
            disjuncts = [ic]
        canon_clauses.append(disjuncts)

    return (variable_list, synth_function_list, canon_spec, canon_clauses,
            intro_vars)
예제 #17
0
def is_and_or(e):
    return (exprs.is_application_of(e, 'and')
            or exprs.is_application_of(e, 'or'))