def __init__(self, term_signature, spec):
        super().__init__()
        self.term_signature = term_signature
        self.synth_funs = spec.synth_funs
        self.spec = spec
        self.syn_ctx = self.spec.syn_ctx
        self.point_var_exprs =  [ exprs.VariableExpression(v) for v in spec.point_vars ]

        self.smt_ctx = z3smt.Z3SMTContext()
        self.eval_ctx = evaluation.EvaluationContext()
        self.canon_apps = [ self.spec.canon_application[sf] for sf in self.synth_funs ]

        self.outvars = []
        for fn in self.synth_funs:
            self.outvars.append(
                    exprs.VariableExpression(exprs.VariableInfo(
                        exprtypes.IntType(), 'outvar_' + fn.function_name,
                        len(self.point_var_exprs) + len(self.outvars))))
        self.all_vars = self.point_var_exprs + self.outvars
        self.all_vars_z3 = [ _expr_to_smt(v, self.smt_ctx) for v in self.all_vars ]

        # self.clauses = spec.get_canon_clauses()
        self.lia_clauses = [ [ 
            LIAInequality.from_expr(exprs.substitute_all(disjunct, list(zip(self.canon_apps, self.outvars))))
            for disjunct in clause  ]
            for clause in spec.get_canon_clauses() ]
        self.rewritten_spec = exprs.substitute_all(
                self.spec.get_canonical_specification(),
                list(zip(self.canon_apps, self.outvars)))
Exemple #2
0
def rewrite_solution(synth_funs, solution, reverse_mapping):
    # Rewrite any predicates introduced in grammar decomposition
    if reverse_mapping is not None:
        for function_info, cond, orig_expr_template, expr_template in reverse_mapping:
            while True:
                app = exprs.find_application(solution, function_info.function_name)
                if app is None:
                    break
                assert exprs.is_application_of(expr_template, 'ite')

                ite = exprs.parent_of(solution, app)
                ite_without_dummy = exprs.FunctionExpression(ite.function_info, (app.children[0], ite.children[1], ite.children[2]))
                var_mapping = exprs.match(expr_template, ite_without_dummy)
                new_ite = exprs.substitute_all(orig_expr_template, var_mapping.items())
                solution = exprs.substitute(solution, ite, new_ite)

    # Rewrite back into formal parameters
    if len(synth_funs) == 1:
        sols = [solution]
    else:
        # The solution will be a comma operator combination of solution 
        # to each function
        sols = solution.children

    rewritten_solutions = []
    for sol, synth_fun in zip(sols, synth_funs):
        variables = exprs.get_all_formal_parameters(sol)
        substitute_pairs = []
        orig_vars = synth_fun.get_named_vars()
        for v in variables:
            substitute_pairs.append((v, orig_vars[v.parameter_position]))
        sol = exprs.substitute_all(sol, substitute_pairs)
        rewritten_solutions.append(sol)

    return rewritten_solutions
def solve_inequalities(model, outvars, inequalities, syn_ctx):
    # print("===================")
    # for var, val in model.items():
    #    print(exprs.expression_to_string(var) + " = " + str(val))
    if len(outvars) == 1:
        return solve_inequalities_one_outvar(model, outvars[0], inequalities,
                                             syn_ctx)

    # raise NotImplementedError
    # Check if we can get away with factoring out one outvar
    for ineq in inequalities:
        if not ineq.is_equality():
            continue
        for outvar in outvars:
            (coeff, (_, eq_lia_expr, _)) = ineq.get_bounds(outvar)
            if coeff == 1:
                eq_expr = eq_lia_expr.to_expr(syn_ctx)
                rest_ineqs = [
                    e.substitute(outvar, eq_lia_expr) for e in inequalities
                    if e != ineq
                ]
                rest_outvars = [o for o in outvars if o != outvar]
                rest_sol = solve_inequalities(model, rest_outvars, rest_ineqs,
                                              syn_ctx)
                sols = list(zip(rest_outvars, rest_sol))
                while True:
                    tp = exprs.substitute_all(eq_expr, sols)
                    if tp == eq_expr:
                        break
                    eq_expr = tp
                sols_dict = dict(sols)
                sols_dict[outvar] = eq_expr
                # print( [ (exprs.expression_to_string(o), exprs.expression_to_string(sols_dict[o])) for o in outvars ])
                # print("===================")
                return [sols_dict[o] for o in outvars]

    # Otherwise, just pick the first outvar
    outvar = outvars[0]
    [t] = solve_inequalities_one_outvar(model, outvar, inequalities, syn_ctx)
    lia_t = LIAExpression.from_expr(t)
    rest_ineqs = [
        e.substitute(outvar, lia_t) for e in inequalities if e != ineq
    ]
    rest_outvars = [o for o in outvars if o != outvar]
    rest_sol = solve_inequalities(model, rest_outvars, rest_ineqs, syn_ctx)
    sols = list(zip(rest_outvars, rest_sol))
    while True:
        tp = exprs.substitute_all(t, sols)
        if tp == t:
            break
        t = tp
    sols_dict = dict(sols)
    sols_dict[outvar] = t
    # print( [ (exprs.expression_to_string(o), exprs.expression_to_string(sols_dict[o])) for o in outvars ])
    # print("===================")
    return [sols_dict[o] for o in outvars]
Exemple #4
0
 def _instantiate(self, sub_exprs):
     # print('TEMPLATE:', exprs.expression_to_string(self.expr_template))
     # print('PHS:', [ exprs.expression_to_string(p) for p in self.place_holder_vars ])
     # print('SUBS:', [ exprs.expression_to_string(s) for s in sub_exprs ])
     ret = exprs.substitute_all(self.expr_template, list(zip(self.place_holder_vars, sub_exprs)))
     # print('RES:', exprs.expression_to_string(ret))
     return ret
    def solve(self):
        if len(self.points) == 0:
            print("Trivial solve!")
            return self._trivial_solve()

        print("-----------------")
        print("Nontrivial solve!")
        for point in self.points:
            print("POINT:", [ p.value_object for p in point])
        for sig, term in self.signature_to_term.items():
            print('SIGTOTERM:', str(sig), _expr_to_str(term))

        intro_var_signature = []
        for point in self.points:
            ivs = point[:len(self.spec.intro_vars)]
            intro_var_signature.append((ivs, point))

        # Nobody can understand what python groupby returns!!!
        # ivs_groups = itertools.groupby(
        #         sorted(intro_var_signature, key=lambda a: a[0]),
        #         key=lambda a: a[0])
        curr_ivs = None
        ivs_groups = []
        for ivs, point in intro_var_signature:
            if ivs == curr_ivs:
                ivs_groups[-1][1].append(point)
            else:
                ivs_groups.append((ivs, [point]))
                curr_ivs = ivs

        for ivs, points in ivs_groups:
            print("A:")
            terms = self._single_solve(ivs, points)
            print("C:", [ exprs.expression_to_string(t) for t in terms ])
            new_terms = []
            for term, sf in zip(terms, self.synth_funs):
                new_terms.append(exprs.substitute_all(term, 
                    list(zip(self.spec.intro_vars, self.spec.formal_params[sf]))))
            terms = new_terms
            print([ _expr_to_str(t) for t in terms ])

            sig = self.signature_factory()
            if len(self.synth_funs) > 1:
                domain_types = tuple([exprtypes.IntType()] * len(self.synth_funs))
                single_term = exprs.FunctionExpression(semantics_core.CommaFunction(domain_types),
                        tuple(terms))
            else:
                single_term = terms[0]

            for i, t in enumerate(self.term_signature(single_term, self.points)):
                if t:
                    sig.add(i)
            self.signature_to_term[sig] = single_term
        print("-----------------")

        return True
 def __init__(self, function_name, function_arity, domain_types, range_type, interpretation_expression, arg_vars):
     super().__init__(FunctionKinds.macro_function, function_name, function_arity, domain_types, range_type)
     self.formal_parameters = []
     for i, arg_var in enumerate(arg_vars):
         fp = exprs.FormalParameterExpression(self,
                 arg_var.variable_info.variable_type, i)
         self.formal_parameters.append(fp)
     self.interpretation_expression = \
             exprs.substitute_all(interpretation_expression,
                     list(zip(arg_vars, self.formal_parameters)))
Exemple #7
0
            def gen(gens):
                for product_tuple in product_fun(*gens):
                    retval = exprs.substitute_all(
                        self.expr_template,
                        list(zip(self.place_holder_vars, product_tuple)))
                    word = rule_to_word(exprs.expression_to_string(retval))
                    score = self.pref[word]

                    children_score = score_children_combiner(
                        x.score for x in product_tuple)
                    score = score_expr_combiner(score, children_score)
                    retval = retval._replace(score=score)
                    yield retval
Exemple #8
0
    def _do_transform(expr, syn_ctx):
        if not exprs.is_function_expression(expr):
            return expr

        new_children = [
            LetFlattener._do_transform(child, syn_ctx)
            for child in expr.children
        ]
        if exprs.is_application_of(expr, 'let'):
            in_expr = new_children[-1]
            sub_pairs = list(
                zip(expr.function_info.binding_vars, new_children[:-1]))
            return exprs.substitute_all(in_expr, sub_pairs)
        else:
            return exprs.FunctionExpression(expr.function_info,
                                            tuple(new_children))
 def instantiate_all(self, expr):
     instantiated_one = False
     while not instantiated_one:
         instantiated_one = True
         for fname, fint in self.function_interpretations.items():
             while True:
                 app = exprs.find_application(expr, fname)
                 if app is None:
                     break
                 instantiated_one = False
                 actual_params = app.children
                 formal_params = fint.formal_parameters
                 new_app = exprs.substitute_all(
                     fint.interpretation_expression,
                     list(zip(formal_params, actual_params)))
                 expr = exprs.substitute(expr, app, new_app)
     return expr
Exemple #10
0
    def __init__(self, syn_ctx, spec):
        self.syn_ctx = syn_ctx
        self.spec = spec
        self.synth_funs = syn_ctx.get_synth_funs()

        self.smt_ctx = z3smt.Z3SMTContext()
        self.smt_solver = self.smt_ctx.make_solver()

        # This var_info_list is the order of variables in cex points
        self.var_info_list = spec.get_point_variables()
        var_expr_list = [
            exprs.VariableExpression(x) for x in self.var_info_list
        ]
        self.var_smt_expr_list = [
            _expr_to_smt(x, self.smt_ctx) for x in var_expr_list
        ]

        self.intro_vars = spec.get_intro_vars()
        self.smt_intro_vars = [
            _expr_to_smt(x, self.smt_ctx) for x in self.intro_vars
        ]

        fun_apps = [
            syn_ctx.make_function_expr(f, *self.intro_vars)
            for f in self.synth_funs
        ]
        fun_app_subst_vars = [
            syn_ctx.make_variable_expr(f.range_type,
                                       '__output__' + f.function_name)
            for f in self.synth_funs
        ]
        self.outvar_cnstr = syn_ctx.make_function_expr(
            'and', *[
                syn_ctx.make_function_expr('eq', v, a)
                for (v, a) in zip(fun_app_subst_vars, fun_apps)
            ])
        self.canon_spec = spec.get_canonical_specification()
        canon_spec_with_outvar = exprs.substitute_all(
            self.canon_spec, list(zip(fun_apps, fun_app_subst_vars)))
        neg_canon_spec_with_outvar = syn_ctx.make_function_expr(
            'not', canon_spec_with_outvar)
        self.frozen_smt_cnstr = _expr_to_smt(neg_canon_spec_with_outvar,
                                             self.smt_ctx)
        self.smt_solver.push()
        self.smt_solver.add(self.frozen_smt_cnstr)
Exemple #11
0
def print_stat(benchmark_files, phog_file):
    from os.path import basename
    for benchmark_file in benchmark_files:
        # print('loading: ', benchmark_file)
        file_sexp = parser.sexpFromFile(benchmark_file)
        benchmark_tuple = parser.extract_benchmark(file_sexp)
        (theories, syn_ctx, synth_instantiator, macro_instantiator,
         uf_instantiator, constraints, grammar_map, forall_vars_map,
         default_grammar_sfs) = benchmark_tuple
        assert len(theories) == 1
        theory = theories[0]
        specification = get_specification(file_sexp)
        synth_funs = list(synth_instantiator.get_functions().values())
        grammar = grammar_map[synth_funs[0]]

        phog = SPhog(grammar, phog_file, synth_funs[0].range_type, specification) if options.use_sphog() else \
            Phog(grammar, phog_file, synth_funs[0].range_type)

        defs, _ = parser.filter_sexp_for('define-fun', file_sexp)
        if defs is None or len(defs) == 0:
            print('cannot find a solution!')
            exit(0)
        [name, args_data, ret_type_data, interpretation] = defs[-1]

        ((arg_vars, arg_types, arg_var_map),
         return_type) = parser._process_function_defintion(
             args_data, ret_type_data)
        expr = parser.sexp_to_expr(interpretation, syn_ctx, arg_var_map)
        i = 0
        subs_pairs = []
        for (var_expr, ty) in zip(arg_vars, arg_types):
            param_expr = exprs.FormalParameterExpression(None, ty, i)
            subs_pairs.append((var_expr, param_expr))
            i += 1
        expr = exprs.substitute_all(expr, subs_pairs)

        score = phog.get_score(expr)
        print(basename(benchmark_file), ' \t', score)
Exemple #12
0
def get_func_exprs_grammars(benchmark_files):
    global eu

    # Grammars
    results = []
    # for eusolver
    ite_related_macros = []
    for benchmark_file in benchmark_files:
        fun_exprs = []
        print('Loading : ', benchmark_file)

        file_sexp = parser.sexpFromFile(benchmark_file)
        if file_sexp is None:
            continue

        core_instantiator = semantics_core.CoreInstantiator()
        theory_instantiators = [
            parser.get_theory_instantiator(theory)
            for theory in parser._known_theories
        ]

        macro_instantiator = semantics_core.MacroInstantiator()
        uf_instantiator = semantics_core.UninterpretedFunctionInstantiator()
        synth_instantiator = semantics_core.SynthFunctionInstantiator()

        syn_ctx = synthesis_context.SynthesisContext(core_instantiator,
                                                     *theory_instantiators,
                                                     macro_instantiator,
                                                     uf_instantiator,
                                                     synth_instantiator)
        syn_ctx.set_macro_instantiator(macro_instantiator)

        defs, _ = parser.filter_sexp_for('define-fun', file_sexp)
        if defs is None: defs = []

        for [name, args_data, ret_type_data, interpretation] in defs:
            for eusolver in ([True, False] if eu else [False]):
                ((arg_vars, arg_types, arg_var_map),
                 return_type) = parser._process_function_defintion(
                     args_data, ret_type_data)
                expr = parser.sexp_to_expr(interpretation, syn_ctx,
                                           arg_var_map)
                macro_func = semantics_types.MacroFunction(
                    name, len(arg_vars), tuple(arg_types), return_type, expr,
                    arg_vars)
                # for eusolver  (recording macro functions of which definition include ite)
                if eusolver:
                    app = exprs.find_application(expr, 'ite')
                    if app is not None: ite_related_macros.append(name)

                macro_instantiator.add_function(name, macro_func)
                i = 0
                subs_pairs = []
                for (var_expr, ty) in zip(arg_vars, arg_types):
                    param_expr = exprs.FormalParameterExpression(None, ty, i)
                    subs_pairs.append((var_expr, param_expr))
                    i += 1
                expr = exprs.substitute_all(expr, subs_pairs)
                # resolve macro functions involving ite (for enumeration of pred exprs (eusolver))
                if eusolver:
                    for fname in ite_related_macros:
                        app = exprs.find_application(expr, fname)
                        if app is None: continue
                        expr = macro_instantiator.instantiate_macro(
                            expr, fname)
                fun_exprs.append(expr)

        @static_var("cnt", 0)
        def rename(synth_funs_data):
            for synth_fun_data in synth_funs_data:
                # to avoid duplicated names
                synth_fun_data[0] = "__aux_name__" + benchmark_file + str(
                    rename.cnt)
                rename.cnt += 1

        # collect grammars
        synth_funs_data, _ = parser.filter_sexp_for('synth-fun', file_sexp)
        if len(synth_funs_data) == 0:
            synth_funs_data, _ = parser.filter_sexp_for('synth-inv', file_sexp)
            rename(synth_funs_data)
            synth_funs_grammar_data = parser.process_synth_invs(
                synth_funs_data, synth_instantiator, syn_ctx)
        else:
            rename(synth_funs_data)
            synth_funs_grammar_data = parser.process_synth_funcs(
                synth_funs_data, synth_instantiator, syn_ctx)

        grammar = None
        for synth_fun, arg_vars, grammar_data in synth_funs_grammar_data:
            if grammar_data != 'Default grammar':
                # we only consider a single function synthesis for now
                grammar = parser.sexp_to_grammar(arg_vars, grammar_data,
                                                 synth_fun, syn_ctx)
                break

        results.append((fun_exprs, grammar))

    return results
Exemple #13
0
    def unify(self):
        term_solver = self.term_solver
        sig_to_term = term_solver.get_signature_to_term()

        # print([ f.function_name for f in self.synth_funs])
        # for point in self.points:
        # print([ c.value_object for c in point])
        # for (sig, term) in sig_to_term.items():
        # print(str(sig), exprs.expression_to_string(term))
        eval_ctx = self.eval_ctx
        self.last_dt_size = 0

        triv = self._try_trivial_unification()
        if triv is not None:
            yield ("TERM", triv)
            return

        # print([ [ pi.value_object for pi in p ] for p in self.points])

        pred_terms = []

        # Pick terms which cover maximum number of points
        sigs = [(s, s) for s in sig_to_term.keys()]
        while True:
            full_sig, curr_sig = max(sigs, key=lambda fc: len(fc[1]))

            # We have covered all points
            if len(curr_sig) == 0:
                break

            term = sig_to_term[full_sig]
            pred = self._compute_pre_condition(full_sig, curr_sig, term)
            pred_terms.append((pred, term))

            pred_sig = BitSet(len(self.points))
            for i in curr_sig:
                eval_ctx.set_valuation_map(self.points[i])
                if evaluation.evaluate_expression_raw(pred, eval_ctx):
                    pred_sig.add(i)
            assert not pred_sig.is_empty()

            # Remove newly covered points from all signatures
            sigs = [(f, c.difference(pred_sig)) for (f, c) in sigs]

        # for pred, term in pred_terms:
        #     print(_expr_to_str(pred), ' ====> ', _expr_to_str(term))
        e = self._pred_term_list_to_expr(pred_terms)
        if len(self.synth_funs) == 1:
            act_params = self.spec.canon_application[
                self.synth_funs[0]].children
            form_params = self.spec.formal_params[self.synth_funs[0]]
            e = exprs.substitute_all(e, list(zip(act_params, form_params)))
        else:
            es = []
            for ep, sf in zip(e, self.synth_funs):
                act_params = self.spec.canon_application[sf].children
                form_params = self.spec.formal_params[sf]
                es.append(
                    exprs.substitute_all(ep, list(zip(act_params,
                                                      form_params))))
            domain_types = tuple([exprtypes.IntType()] * len(self.synth_funs))
            e = exprs.FunctionExpression(
                semantics_core.CommaFunction(domain_types), tuple(es))
        yield ('TERM', e)
Exemple #14
0
    def _compute_pre_condition(self, coverable_sig, uncovered_sig, term):
        relevent_points = [
            p for (i, p) in enumerate(self.points)
            if (i in uncovered_sig) and (i in coverable_sig)
        ]
        eval_ctx = self.eval_ctx

        # Change term to use introvars
        for sf in self.synth_funs:
            act_params = self.spec.canon_application[sf].children
            form_params = self.spec.formal_params[sf]
            term_sub = exprs.substitute_all(term,
                                            list(zip(form_params, act_params)))

        def eval_on_relevent_points(pred):
            ret = []
            for p in relevent_points:
                eval_ctx.set_valuation_map(p)
                ret.append(evaluation.evaluate_expression_raw(pred, eval_ctx))
            return ret

        # Rewrite clauses with current term instead of synth_fun application
        curr_clauses = []
        for clause in self.clauses:
            curr_clause = []
            for disjunct in clause:
                curr_disjunct = disjunct
                if len(self.synth_funs) == 1:
                    curr_disjunct = exprs.substitute(
                        disjunct,
                        self.spec.canon_application[self.synth_funs[0]],
                        term_sub)
                else:
                    sub_pairs = list(
                        zip([
                            self.spec.canon_application[sf]
                            for sf in self.synth_funs
                        ], term_sub.children))
                    curr_disjunct = exprs.substitute_all(disjunct, sub_pairs)
                curr_clause.append(curr_disjunct)
            curr_clauses.append(curr_clause)

        only_intro_var_clauses = _filter_to_intro_vars(curr_clauses,
                                                       self.intro_vars)
        if not all([len(oivc) > 0 for oivc in only_intro_var_clauses]):
            raise NotImplementedError
        else:
            # Do the only_intro_var_clauses cover all relevent points?
            some_point_uncovered = False
            good_clauses = [
            ]  # If there are single disjuncts that cover everything
            for oivc in only_intro_var_clauses:
                sig = set()
                good_clause = []
                for d in oivc:
                    s = eval_on_relevent_points(d)
                    if all(s):
                        good_clause.append(d)
                    for i, t in enumerate(s):
                        if t:
                            sig.add(i)
                good_clauses.append(good_clause)
                if len(sig) != len(relevent_points):
                    some_point_uncovered = True
                    break
            if some_point_uncovered:
                raise NotImplementedError
            elif all([len(gc) > 0 for gc in good_clauses]):
                pre_cond = _clauses_to_expr(self.syn_ctx, good_clauses)
            else:
                pre_cond = _clauses_to_expr(self.syn_ctx,
                                            only_intro_var_clauses)

        if pre_cond is not None:
            return pre_cond
        else:
            raise NotImplementedError
Exemple #15
0
def get_func_exprs_grammars(benchmark_files):
    # expected format:
    #   sygus format problem
    #   (check-synth)
    #   a single solution
    global eu

    @static_var("cnt", 0)
    def rename(synth_funs_data):
        for synth_fun_data in synth_funs_data:
            # to avoid duplicated names
            synth_fun_data[0] = "__aux_name__" + benchmark_file + str(
                rename.cnt)
            rename.cnt += 1

    exprs_per_category = {}
    # decision tree : label -> exprs
    ## label : (ret_type, eu, STD spec / PBE spec, spec information ... )

    # for eusolver
    ite_related_macros = []
    # all vocabs
    all_vocabs = set([])

    for benchmark_file in benchmark_files:
        print('Loading : ', benchmark_file)
        file_sexp = parser.sexpFromFile(benchmark_file)
        if file_sexp is None:
            continue

        ## specification
        specification = get_specification(file_sexp)
        all_vocabs.update(basic_vocabs_for_spec(specification))

        core_instantiator = semantics_core.CoreInstantiator()
        theory_instantiators = [
            parser.get_theory_instantiator(theory)
            for theory in parser._known_theories
        ]
        macro_instantiator = semantics_core.MacroInstantiator()
        uf_instantiator = semantics_core.UninterpretedFunctionInstantiator()
        synth_instantiator = semantics_core.SynthFunctionInstantiator()

        syn_ctx = synthesis_context.SynthesisContext(core_instantiator,
                                                     *theory_instantiators,
                                                     macro_instantiator,
                                                     uf_instantiator,
                                                     synth_instantiator)
        syn_ctx.set_macro_instantiator(macro_instantiator)

        # collect grammars
        synth_funs_data, _ = parser.filter_sexp_for('synth-fun', file_sexp)
        if len(synth_funs_data) == 0:
            synth_funs_data, _ = parser.filter_sexp_for('synth-inv', file_sexp)
            # rename(synth_funs_data)
            synth_funs_grammar_data = parser.process_synth_invs(
                synth_funs_data, synth_instantiator, syn_ctx)
        else:
            # rename(synth_funs_data)
            synth_funs_grammar_data = parser.process_synth_funcs(
                synth_funs_data, synth_instantiator, syn_ctx)

        # handling only single function problems for now
        fetchop_func = fetchop
        spec_flag = ()
        synth_fun_name = ''
        for synth_fun, arg_vars, grammar_data in synth_funs_grammar_data:
            if grammar_data != 'Default grammar':
                synth_fun_name = synth_fun.function_name
                grammar = parser.sexp_to_grammar(arg_vars, grammar_data,
                                                 synth_fun, syn_ctx)
                # spec flag
                spec_flag = get_spec_flag(specification, grammar)
                # fetchop func
                fetchop_func = get_fetchop_func(specification, grammar)
                all_vocabs.update(
                    get_vocabs_from_grammar(grammar, fetchop_func))

        defs, _ = parser.filter_sexp_for('define-fun', file_sexp)
        if defs is None: defs = []
        if len(defs) > 0:
            for [name, args_data, ret_type_data, interpretation] in defs:
                print(name, ' ', synth_fun_name)
                if synth_fun_name in name:
                    for eusolver in ([True] if eu else [False]):
                        ((arg_vars, arg_types, arg_var_map),
                         return_type) = parser._process_function_defintion(
                             args_data, ret_type_data)
                        # category flag
                        flag = (return_type, eusolver, spec_flag)

                        expr = parser.sexp_to_expr(interpretation, syn_ctx,
                                                   arg_var_map)
                        macro_func = semantics_types.MacroFunction(
                            name, len(arg_vars), tuple(arg_types), return_type,
                            expr, arg_vars)
                        # for eusolver  (recording macro functions of which definition include ite)
                        if eusolver:
                            app = exprs.find_application(expr, 'ite')
                            if app is not None: ite_related_macros.append(name)

                        macro_instantiator.add_function(name, macro_func)
                        i = 0
                        subs_pairs = []
                        for (var_expr, ty) in zip(arg_vars, arg_types):
                            param_expr = exprs.FormalParameterExpression(
                                None, ty, i)
                            subs_pairs.append((var_expr, param_expr))
                            i += 1
                        expr = exprs.substitute_all(expr, subs_pairs)
                        # resolve macro functions involving ite (for enumeration of pred exprs (eusolver))
                        if eusolver:
                            for fname in ite_related_macros:
                                app = exprs.find_application(expr, fname)
                                if app is None: continue
                                expr = macro_instantiator.instantiate_macro(
                                    expr, fname)
                        if flag not in exprs_per_category:
                            exprs_per_category[flag] = set([])
                        exprs_per_category[flag].add((expr, fetchop_func))

    return exprs_per_category, all_vocabs
Exemple #16
0
def dt_rewrite_boolean_combs(syn_ctx, sol, synth_fun):
    orig_sol = sol
    smt_ctx = z3smt.Z3SMTContext()
    vs = exprs.get_all_variables(sol)
    dummy_vars = [
        exprs.VariableExpression(
            exprs.VariableInfo(v.variable_info.variable_type,
                               "D" + v.variable_info.variable_name, i))
        for (i, v) in enumerate(vs)
    ]
    argvars = [
        semantics.semantics_types.expression_to_smt(v, smt_ctx)
        for v in dummy_vars
    ]
    sol = exprs.substitute_all(sol, list(zip(vs, dummy_vars)))
    preds = get_atomic_preds(sol)
    terms = get_terms(sol)

    points = []

    from exprs import evaluation
    eval_ctx = evaluation.EvaluationContext()

    def add_point(point, pred_sig_list, term_sig_list):
        points.append(point)
        eval_ctx.set_valuation_map(point)
        solv = evaluation.evaluate_expression_raw(sol, eval_ctx)
        new_pred_sig_list = [
            utils.bitset_extend(
                sig, evaluation.evaluate_expression_raw(pred, eval_ctx))
            for (sig, pred) in zip(pred_sig_list, preds)
        ]
        new_term_sig_list = [
            utils.bitset_extend(
                sig,
                solv == evaluation.evaluate_expression_raw(term, eval_ctx))
            for (sig, term) in zip(term_sig_list, terms)
        ]
        return (new_pred_sig_list, new_term_sig_list)

    pred_sig_list = [BitSet(0) for p in preds]
    term_sig_list = [BitSet(0) for t in terms]

    expr = terms[0]
    fsol = None
    while True:
        z3point = exprs.sample(syn_ctx.make_function_expr('ne', expr, sol),
                               smt_ctx, argvars)
        if z3point is None:
            fsol = expr
            break
        else:
            point = list(
                map(lambda v, d: z3smt.z3value_to_value(v, d.variable_info),
                    z3point, dummy_vars))
            (pred_sig_list, term_sig_list) = add_point(point, pred_sig_list,
                                                       term_sig_list)
            dt = eusolver.eus_learn_decision_tree_for_ml_data(
                pred_sig_list, term_sig_list)
            expr = verifiers.naive_dt_to_expr(syn_ctx, dt, preds, terms)
    sol = exprs.substitute_all(fsol, list(zip(dummy_vars, vs)))
    return sol