def rewrite_arbitrary_arity_and_or(syn_ctx, sol): import functools apps = exprs.find_all_applications(sol, 'and') for app in apps: if len(app.children) == 2: continue elif len(app.children) == 1: new_app = syn_ctx.make_function_expr('and', app.children[0], app.children[0]) else: new_app = functools.reduce( lambda a, b: syn_ctx.make_function_expr('and', a, b), app.children) sol = exprs.substitute(sol, app, new_app) new_apps = exprs.find_all_applications(sol, 'and') assert all([len(new_app.children) == 2 for new_app in new_apps]) apps = exprs.find_all_applications(sol, 'or') for app in apps: if len(app.children) == 2: continue elif len(app.children) == 1: new_app = syn_ctx.make_function_expr('or', app.children[0], app.children[0]) else: new_app = functools.reduce( lambda a, b: syn_ctx.make_function_expr('or', a, b), app.children) sol = exprs.substitute(sol, app, new_app) new_apps = exprs.find_all_applications(sol, 'or') assert all([len(new_app.children) == 2 for new_app in new_apps]) return sol
def preprocess_operators(term_exprs, pred_exprs): eval_context = evaluation.EvaluationContext() bitsize = 64 bvlshr = semantics_bv.BVLShR(bitsize) new_term_exprs = set([]) new_pred_exprs = set([]) for term_expr, f in term_exprs: subst_pairs = set([]) all_exprs = exprs.get_all_exprs(term_expr) for e in all_exprs: if exprs.is_function_expression(e): if e.function_info.function_name == 'bvudiv': if exprs.is_constant_expression(e.children[1]): value = evaluation.evaluate_expression_raw( e.children[1], eval_context) new_right_child = exprs.ConstantExpression( exprs.Value( BitVector(int(math.log2(value.value)), bitsize), exprtypes.BitVectorType(bitsize))) subst_pairs.add( (e, exprs.FunctionExpression( bvlshr, (e.children[0], new_right_child)))) new_term_expr = term_expr for (old_term, new_term) in subst_pairs: new_term_expr = exprs.substitute(new_term_expr, old_term, new_term) new_term_exprs.add((new_term_expr, f)) for pred_expr, f in pred_exprs: subst_pairs = set([]) all_exprs = exprs.get_all_exprs(pred_expr) for e in all_exprs: if exprs.is_function_expression(e): if e.function_info.function_name == 'bvudiv': if exprs.is_constant_expression(e.children[1]): value = evaluation.evaluate_expression_raw( e.children[1], eval_context) new_right_child = exprs.ConstantExpression( exprs.Value( BitVector(int(math.log2(value.value)), bitsize), exprtypes.BitVectorType(bitsize))) subst_pairs.add( (e, exprs.FunctionExpression( bvlshr, (e.children[0], new_right_child)))) new_pred_expr = pred_expr for (old_term, new_term) in subst_pairs: new_pred_expr = exprs.substitute(new_pred_expr, old_term, new_term) new_pred_exprs.add((new_pred_expr, f)) return (new_term_exprs, new_pred_exprs)
def massage_full_lia_solution(syn_ctx, synth_funs, final_solution, massaging): # for sf in final_solution: # print(exprs.expression_to_string(sf)) try: new_final_solution = [] for sf, sol in zip(synth_funs, final_solution): if sf not in massaging: new_final_solution.append(sol) continue (boolean_combs, comparators, consts, negatives, constant_multiplication, div, mod) = massaging[sf] # Don't try to rewrite div's and mod's # It is futile if not div and exprs.find_application(sol, 'div') != None: return None if not mod and exprs.find_application(sol, 'mod') != None: return None terms = get_terms(sol) for term in terms: termp = rewrite_term(syn_ctx, term, negatives, consts, constant_multiplication) if termp is None: return None sol = exprs.substitute(sol, term, termp) aps = get_atomic_preds(sol) for ap in aps: new_ap = rewrite_pred(syn_ctx, ap, boolean_combs, comparators, negatives, consts, constant_multiplication) if new_ap is None: # print(exprs.expression_to_string(ap)) return None sol = exprs.substitute(sol, ap, new_ap) sol = simplify(syn_ctx, sol) if not boolean_combs: # print(exprs.expression_to_string(sol)) # sol = rewrite_boolean_combs(syn_ctx, sol) sol = dt_rewrite_boolean_combs(syn_ctx, sol, sf) else: sol = rewrite_arbitrary_arity_and_or(syn_ctx, sol) if not \ verify(sol, boolean_combs, comparators, consts, negatives, constant_multiplication, div, mod): return None new_final_solution.append(sol) return new_final_solution except: raise
def _intro_new_universal_vars(clauses, syn_ctx, uf_info): intro_vars = [ syn_ctx.make_variable_expr(uf_info.domain_types[i], '_intro_var_%d' % i) for i in range(len(uf_info.domain_types)) ] retval = [] for clause in clauses: arg_tuples = _get_synth_function_invocation_args(clause) if len(arg_tuples) == 0: continue arg_tuple = arg_tuples.pop() eq_constraints = [] for i in range(len(arg_tuple)): arg = arg_tuple[i] var = intro_vars[i] eq_constraints.append(syn_ctx.make_function_expr('ne', arg, var)) if (clause.expr_kind == exprs.ExpressionKinds.function_expression and clause.function_info.function_name == 'or'): clause_disjuncts = clause.children else: clause_disjuncts = [clause] for i in range(len(arg_tuple)): clause_disjuncts = [ exprs.substitute(c, arg_tuple[i], intro_vars[i]) for c in clause_disjuncts ] eq_constraints.extend(clause_disjuncts) retval.append(syn_ctx.make_ac_function_expr('or', *eq_constraints)) return (retval, intro_vars)
def rewrite_solution(synth_funs, solution, reverse_mapping): # Rewrite any predicates introduced in grammar decomposition if reverse_mapping is not None: for function_info, cond, orig_expr_template, expr_template in reverse_mapping: while True: app = exprs.find_application(solution, function_info.function_name) if app is None: break assert exprs.is_application_of(expr_template, 'ite') ite = exprs.parent_of(solution, app) ite_without_dummy = exprs.FunctionExpression(ite.function_info, (app.children[0], ite.children[1], ite.children[2])) var_mapping = exprs.match(expr_template, ite_without_dummy) new_ite = exprs.substitute_all(orig_expr_template, var_mapping.items()) solution = exprs.substitute(solution, ite, new_ite) # Rewrite back into formal parameters if len(synth_funs) == 1: sols = [solution] else: # The solution will be a comma operator combination of solution # to each function sols = solution.children rewritten_solutions = [] for sol, synth_fun in zip(sols, synth_funs): variables = exprs.get_all_formal_parameters(sol) substitute_pairs = [] orig_vars = synth_fun.get_named_vars() for v in variables: substitute_pairs.append((v, orig_vars[v.parameter_position])) sol = exprs.substitute_all(sol, substitute_pairs) rewritten_solutions.append(sol) return rewritten_solutions
def apply(constraints, syn_ctx): new_constraints = [] found_one = False for constraint in constraints: ite = exprs.find_application(constraint, 'ite') if ite is None: new_constraints.append(constraint) continue else: found_one = True cond, tt, ff = ite.children tc = syn_ctx.make_function_expr( 'or', exprs.substitute(constraint, ite, tt), syn_ctx.make_function_expr('not', cond)) fc = syn_ctx.make_function_expr( 'or', exprs.substitute(constraint, ite, ff), cond) new_constraints.append(tc) new_constraints.append(fc) if found_one: return RewriteITE.apply(new_constraints, syn_ctx) else: return new_constraints
def instantiate_all(self, expr): instantiated_one = False while not instantiated_one: instantiated_one = True for fname, fint in self.function_interpretations.items(): while True: app = exprs.find_application(expr, fname) if app is None: break instantiated_one = False actual_params = app.children formal_params = fint.formal_parameters new_app = exprs.substitute_all( fint.interpretation_expression, list(zip(formal_params, actual_params))) expr = exprs.substitute(expr, app, new_app) return expr
def apply(constraints, uf_instantiator, syn_ctx): import random conds = [] all_apps = set() for uf_name, uf_info in uf_instantiator.get_functions().items(): uf_apps = set() for c in constraints: uf_apps |= set(exprs.find_all_applications(c, uf_name)) all_apps |= uf_apps while len(uf_apps) > 0: uf_app1 = uf_apps.pop() for uf_app2 in uf_apps: app1_args, app2_args = uf_app1.children, uf_app2.children args_eq_expr = [ syn_ctx.make_function_expr('=', a1, a2) for (a1, a2) in zip(app1_args, app2_args) ] output_neq_expr = syn_ctx.make_function_expr( 'ne', uf_app1, uf_app2) cond = syn_ctx.make_function_expr('and', output_neq_expr, *args_eq_expr) conds.append(cond) if len(conds) > 0: constraints = [ syn_ctx.make_function_expr('or', *conds, constraint) for constraint in constraints ] for app in sorted(all_apps, key=exprs.get_expression_size): var = syn_ctx.make_variable_expr( app.function_info.range_type, 'ufcall_' + app.function_info.function_name + '_' + str(random.randint(1, 1000000))) constraints = [ exprs.substitute(constraint, app, var) for constraint in constraints ] return constraints
def substitute_expr(self, old, new): self.expr = exprs.substitute(self.expr, old, new)
def _compute_pre_condition(self, coverable_sig, uncovered_sig, term): relevent_points = [ p for (i, p) in enumerate(self.points) if (i in uncovered_sig) and (i in coverable_sig) ] eval_ctx = self.eval_ctx # Change term to use introvars for sf in self.synth_funs: act_params = self.spec.canon_application[sf].children form_params = self.spec.formal_params[sf] term_sub = exprs.substitute_all(term, list(zip(form_params, act_params))) def eval_on_relevent_points(pred): ret = [] for p in relevent_points: eval_ctx.set_valuation_map(p) ret.append(evaluation.evaluate_expression_raw(pred, eval_ctx)) return ret # Rewrite clauses with current term instead of synth_fun application curr_clauses = [] for clause in self.clauses: curr_clause = [] for disjunct in clause: curr_disjunct = disjunct if len(self.synth_funs) == 1: curr_disjunct = exprs.substitute( disjunct, self.spec.canon_application[self.synth_funs[0]], term_sub) else: sub_pairs = list( zip([ self.spec.canon_application[sf] for sf in self.synth_funs ], term_sub.children)) curr_disjunct = exprs.substitute_all(disjunct, sub_pairs) curr_clause.append(curr_disjunct) curr_clauses.append(curr_clause) only_intro_var_clauses = _filter_to_intro_vars(curr_clauses, self.intro_vars) if not all([len(oivc) > 0 for oivc in only_intro_var_clauses]): raise NotImplementedError else: # Do the only_intro_var_clauses cover all relevent points? some_point_uncovered = False good_clauses = [ ] # If there are single disjuncts that cover everything for oivc in only_intro_var_clauses: sig = set() good_clause = [] for d in oivc: s = eval_on_relevent_points(d) if all(s): good_clause.append(d) for i, t in enumerate(s): if t: sig.add(i) good_clauses.append(good_clause) if len(sig) != len(relevent_points): some_point_uncovered = True break if some_point_uncovered: raise NotImplementedError elif all([len(gc) > 0 for gc in good_clauses]): pre_cond = _clauses_to_expr(self.syn_ctx, good_clauses) else: pre_cond = _clauses_to_expr(self.syn_ctx, only_intro_var_clauses) if pre_cond is not None: return pre_cond else: raise NotImplementedError