def preprocess_operators(term_exprs, pred_exprs): eval_context = evaluation.EvaluationContext() bitsize = 64 bvlshr = semantics_bv.BVLShR(bitsize) new_term_exprs = set([]) new_pred_exprs = set([]) for term_expr, f in term_exprs: subst_pairs = set([]) all_exprs = exprs.get_all_exprs(term_expr) for e in all_exprs: if exprs.is_function_expression(e): if e.function_info.function_name == 'bvudiv': if exprs.is_constant_expression(e.children[1]): value = evaluation.evaluate_expression_raw( e.children[1], eval_context) new_right_child = exprs.ConstantExpression( exprs.Value( BitVector(int(math.log2(value.value)), bitsize), exprtypes.BitVectorType(bitsize))) subst_pairs.add( (e, exprs.FunctionExpression( bvlshr, (e.children[0], new_right_child)))) new_term_expr = term_expr for (old_term, new_term) in subst_pairs: new_term_expr = exprs.substitute(new_term_expr, old_term, new_term) new_term_exprs.add((new_term_expr, f)) for pred_expr, f in pred_exprs: subst_pairs = set([]) all_exprs = exprs.get_all_exprs(pred_expr) for e in all_exprs: if exprs.is_function_expression(e): if e.function_info.function_name == 'bvudiv': if exprs.is_constant_expression(e.children[1]): value = evaluation.evaluate_expression_raw( e.children[1], eval_context) new_right_child = exprs.ConstantExpression( exprs.Value( BitVector(int(math.log2(value.value)), bitsize), exprtypes.BitVectorType(bitsize))) subst_pairs.add( (e, exprs.FunctionExpression( bvlshr, (e.children[0], new_right_child)))) new_pred_expr = pred_expr for (old_term, new_term) in subst_pairs: new_pred_expr = exprs.substitute(new_pred_expr, old_term, new_term) new_pred_exprs.add((new_pred_expr, f)) return (new_term_exprs, new_pred_exprs)
def get_partial_ast(expr, addr): # print('get_partial_ast: ', exprs.expression_to_string(expr), ' ', addr) if len(addr) == 0: if not (exprs.is_function_expression(expr)): return expr else: return exprs.FunctionExpression(expr.function_info, ()) else: hd, *tl = addr rest = get_partial_ast(expr.children[hd], tl) children = expr.children[0:hd] + (rest, ) return exprs.FunctionExpression(expr.function_info, children)
def _dummy_spec(self, synth_fun): func = semantics_types.SynthFunction( 'pred_indicator_' + str(random.randint(1, 10000000)), synth_fun.function_arity, synth_fun.domain_types, exprtypes.BoolType()) args = [ exprs.FormalParameterExpression(func, argtype, i) for i, argtype in enumerate(synth_fun.domain_types) ] indicator_expr = exprs.FunctionExpression(func, tuple(args)) eval_ctx = evaluation.EvaluationContext() def compute_indicator(term, points): eval_ctx.set_interpretation(func, term) retval = [] for point in points: eval_ctx.set_valuation_map(point) try: retval.append( evaluation.evaluate_expression_raw( indicator_expr, eval_ctx)) except (basetypes.UnboundLetVariableError, basetypes.PartialFunctionError): # Can't mess up on predicates return [False] * len(points) return retval return func, compute_indicator
def rewrite_solution(synth_funs, solution, reverse_mapping): # Rewrite any predicates introduced in grammar decomposition if reverse_mapping is not None: for function_info, cond, orig_expr_template, expr_template in reverse_mapping: while True: app = exprs.find_application(solution, function_info.function_name) if app is None: break assert exprs.is_application_of(expr_template, 'ite') ite = exprs.parent_of(solution, app) ite_without_dummy = exprs.FunctionExpression(ite.function_info, (app.children[0], ite.children[1], ite.children[2])) var_mapping = exprs.match(expr_template, ite_without_dummy) new_ite = exprs.substitute_all(orig_expr_template, var_mapping.items()) solution = exprs.substitute(solution, ite, new_ite) # Rewrite back into formal parameters if len(synth_funs) == 1: sols = [solution] else: # The solution will be a comma operator combination of solution # to each function sols = solution.children rewritten_solutions = [] for sol, synth_fun in zip(sols, synth_funs): variables = exprs.get_all_formal_parameters(sol) substitute_pairs = [] orig_vars = synth_fun.get_named_vars() for v in variables: substitute_pairs.append((v, orig_vars[v.parameter_position])) sol = exprs.substitute_all(sol, substitute_pairs) rewritten_solutions.append(sol) return rewritten_solutions
def get_history(expr, pick=None): return get_history_new(expr, pick) def add_child(expr, parent_addr, child_expr): if not (exprs.is_function_expression(expr)): return expr if len(parent_addr) == 0: children = expr.children + (child_expr, ) else: hd, *tl = parent_addr children = list(expr.children) children[hd] = add_child(expr.children[hd], tl, child_expr) children = tuple(children) return exprs.FunctionExpression(expr.function_info, children) def remove_children(expr): if exprs.is_function_expression(expr): expr_wo_children = exprs.FunctionExpression(expr.function_info, ()) else: expr_wo_children = expr return expr_wo_children if not exprs.is_function_expression(expr): return [(expr, ())] stack = [] for i, child in enumerate(expr.children): stack.append((child, [i])) history = [(remove_children(expr), ())] # print(exprs.expression_to_string(expr)) while len(stack) > 0: e, addr = stack.pop(0) # print(addr) # print(exprs.expression_to_string(history[-1])) # add next step if exprs.is_function_expression(e): next_step = add_child( history[-1][0], addr[:-1], exprs.FunctionExpression(e.function_info, ())) history.append((next_step, tuple(addr))) else: next_step = add_child(history[-1][0], addr[:-1], e) history.append((next_step, tuple(addr))) # for online search if pick is not None and pick == tuple(addr): return [history[-1]] # DFS visit if exprs.is_function_expression(e): added = [] for i, child in enumerate(e.children): new_addr = list(addr) new_addr.append(i) added.append((child, new_addr)) stack[0:0] = added # print(exprs.expression_to_string(history[-1][0]), ' ', history[-1][1]) return history
def _trivial_solve(self): ret = exprs.ConstantExpression(exprs.Value(0, exprtypes.IntType())) if len(self.synth_funs) > 1: domain_types = tuple([exprtypes.IntType()] * len(self.synth_funs)) ret = exprs.FunctionExpression(semantics_core.CommaFunction(domain_types), tuple([ret] * len(self.synth_funs))) self.signature_to_term = {None:ret} return True
def get_partial_ast(expr, addr): if not (exprs.is_function_expression(expr)) or len(addr) == 0: return expr else: hd, *tl = addr rest = get_partial_ast(expr.children[hd], tl) children = expr.children[0:hd] + (rest, ) return exprs.FunctionExpression(expr.function_info, children)
def add_dummy_pred(expr): arg_var = exprs.VariableExpression( exprs.VariableInfo(exprtypes.BoolType(), 'd', 0)) dummy_macro_func = semantics_types.MacroFunction( dummy_pred_name, 1, (exprtypes.BoolType(), ), exprtypes.BoolType(), arg_var, [arg_var]) expr = exprs.FunctionExpression(dummy_macro_func, (expr, )) return expr
def solve(self): if len(self.points) == 0: print("Trivial solve!") return self._trivial_solve() print("-----------------") print("Nontrivial solve!") for point in self.points: print("POINT:", [ p.value_object for p in point]) for sig, term in self.signature_to_term.items(): print('SIGTOTERM:', str(sig), _expr_to_str(term)) intro_var_signature = [] for point in self.points: ivs = point[:len(self.spec.intro_vars)] intro_var_signature.append((ivs, point)) # Nobody can understand what python groupby returns!!! # ivs_groups = itertools.groupby( # sorted(intro_var_signature, key=lambda a: a[0]), # key=lambda a: a[0]) curr_ivs = None ivs_groups = [] for ivs, point in intro_var_signature: if ivs == curr_ivs: ivs_groups[-1][1].append(point) else: ivs_groups.append((ivs, [point])) curr_ivs = ivs for ivs, points in ivs_groups: print("A:") terms = self._single_solve(ivs, points) print("C:", [ exprs.expression_to_string(t) for t in terms ]) new_terms = [] for term, sf in zip(terms, self.synth_funs): new_terms.append(exprs.substitute_all(term, list(zip(self.spec.intro_vars, self.spec.formal_params[sf])))) terms = new_terms print([ _expr_to_str(t) for t in terms ]) sig = self.signature_factory() if len(self.synth_funs) > 1: domain_types = tuple([exprtypes.IntType()] * len(self.synth_funs)) single_term = exprs.FunctionExpression(semantics_core.CommaFunction(domain_types), tuple(terms)) else: single_term = terms[0] for i, t in enumerate(self.term_signature(single_term, self.points)): if t: sig.add(i) self.signature_to_term[sig] = single_term print("-----------------") return True
def _do_transform(self, expr_object, syn_ctx): neg = { '<=': '>', '<': '>=', '>=': '<', '>': '<=', '=': 'ne', 'ne': '=', } if not exprs.is_function_expression(expr_object): return expr_object function_info = expr_object.function_info function_name = function_info.function_name if function_name in ['and', 'or']: children = [ self._do_transform(child, syn_ctx) for child in expr_object.children ] return exprs.FunctionExpression(function_info, tuple(children)) elif function_name in ['<=', '>=', '<', '>', '=', 'eq', 'ne']: children = [ self._do_transform(child, syn_ctx) for child in expr_object.children ] return syn_ctx.make_function_expr(function_name, *children) elif (function_name in ['not'] and exprs.is_function_expression(expr_object.children[0]) and expr_object.children[0].function_info.function_name in ['<=', '>=', '<', '>', '=', 'eq', 'ne']): child = expr_object.children[0] child_func_name = child.function_info.function_name ret_func_name = neg[child_func_name] return syn_ctx.make_function_expr(ret_func_name, *child.children) elif function_name in ['add']: children = [ self._do_transform(child, syn_ctx) for child in expr_object.children ] new_children = [] for child in children: if exprs.is_function_expression( child) and child.function_info.function_name == 'add': new_children.extend(child.children) else: new_children.append(child) return syn_ctx.make_function_expr('add', *new_children) elif function_name in ['sub', 'mul', '-']: children = [ self._do_transform(child, syn_ctx) for child in expr_object.children ] return syn_ctx.make_function_expr(function_name, *children) else: return expr_object
def to_template_expr(self): ph_vars = [] nts = [] child_exprs = [] for child in self.children: curr_ph_vars, curr_nts, child_expr = child.to_template_expr() ph_vars.extend(curr_ph_vars) nts.extend(curr_nts) child_exprs.append(child_expr) expr_template = exprs.FunctionExpression(self.function_info, tuple(child_exprs)) return ph_vars, nts, expr_template
def __init__(self, expr_valuations, synth_fun, theory): self.synth_fun = synth_fun self.eval_ctx = evaluation.EvaluationContext() self.theory = theory self._initialize_valuations(expr_valuations) args = [ exprs.FormalParameterExpression(synth_fun, argtype, i) for i, argtype in enumerate(synth_fun.domain_types)] self.synth_fun_expr = exprs.FunctionExpression(synth_fun, tuple(args)) self.is_multipoint = False
def add_child(expr, parent_addr, child_expr): if not (exprs.is_function_expression(expr)): return expr if len(parent_addr) == 0: children = expr.children + (child_expr, ) else: hd, *tl = parent_addr children = list(expr.children) children[hd] = add_child(expr.children[hd], tl, child_expr) children = tuple(children) return exprs.FunctionExpression(expr.function_info, children)
def _do_transform(expr, syn_ctx): if not exprs.is_function_expression(expr): return expr new_children = [ LetFlattener._do_transform(child, syn_ctx) for child in expr.children ] if exprs.is_application_of(expr, 'let'): in_expr = new_children[-1] sub_pairs = list( zip(expr.function_info.binding_vars, new_children[:-1])) return exprs.substitute_all(in_expr, sub_pairs) else: return exprs.FunctionExpression(expr.function_info, tuple(new_children))
def make_function_expr(self, function_name_or_info, *child_exps): """Makes a typed function expression applied to the given child expressions.""" if (isinstance(function_name_or_info, str)): function_info = self.make_function(function_name_or_info, *child_exps) function_name = function_name_or_info else: assert (isinstance(function_name_or_info, semantics_types.FunctionBase)) function_info = function_name_or_info function_name = function_info.function_name if (function_info == None): raise basetypes.ArgumentError( 'Could not instantiate function named "' + function_name + '" with argument types: (' + ', '.join( [str(exprs.get_expression_type(x)) for x in child_exps]) + ')') return exprs.FunctionExpression(function_info, tuple(child_exps))
def remove_children(expr): if exprs.is_function_expression(expr): expr_wo_children = exprs.FunctionExpression(expr.function_info, ()) else: expr_wo_children = expr return expr_wo_children
def _instantiate(self, sub_exprs): return exprs.FunctionExpression(self.function_descriptor, sub_exprs)
def unify(self): term_solver = self.term_solver sig_to_term = term_solver.get_signature_to_term() # print([ f.function_name for f in self.synth_funs]) # for point in self.points: # print([ c.value_object for c in point]) # for (sig, term) in sig_to_term.items(): # print(str(sig), exprs.expression_to_string(term)) eval_ctx = self.eval_ctx self.last_dt_size = 0 triv = self._try_trivial_unification() if triv is not None: yield ("TERM", triv) return # print([ [ pi.value_object for pi in p ] for p in self.points]) pred_terms = [] # Pick terms which cover maximum number of points sigs = [(s, s) for s in sig_to_term.keys()] while True: full_sig, curr_sig = max(sigs, key=lambda fc: len(fc[1])) # We have covered all points if len(curr_sig) == 0: break term = sig_to_term[full_sig] pred = self._compute_pre_condition(full_sig, curr_sig, term) pred_terms.append((pred, term)) pred_sig = BitSet(len(self.points)) for i in curr_sig: eval_ctx.set_valuation_map(self.points[i]) if evaluation.evaluate_expression_raw(pred, eval_ctx): pred_sig.add(i) assert not pred_sig.is_empty() # Remove newly covered points from all signatures sigs = [(f, c.difference(pred_sig)) for (f, c) in sigs] # for pred, term in pred_terms: # print(_expr_to_str(pred), ' ====> ', _expr_to_str(term)) e = self._pred_term_list_to_expr(pred_terms) if len(self.synth_funs) == 1: act_params = self.spec.canon_application[ self.synth_funs[0]].children form_params = self.spec.formal_params[self.synth_funs[0]] e = exprs.substitute_all(e, list(zip(act_params, form_params))) else: es = [] for ep, sf in zip(e, self.synth_funs): act_params = self.spec.canon_application[sf].children form_params = self.spec.formal_params[sf] es.append( exprs.substitute_all(ep, list(zip(act_params, form_params)))) domain_types = tuple([exprtypes.IntType()] * len(self.synth_funs)) e = exprs.FunctionExpression( semantics_core.CommaFunction(domain_types), tuple(es)) yield ('TERM', e)