def _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, var_map, synth_fun, rule_data): ph_let_bound_vars, let_bound_vars = [], [] if type(rule_data) == tuple: value = sexp_to_value(rule_data) ret = grammars.ExpressionRewrite(exprs.ConstantExpression(value)) elif rule_data[0] == 'Constant': typ = sexp_to_type(rule_data[1]) ret = grammars.NTRewrite('Constant' + str(typ), typ) elif rule_data[0] in [ 'Variable', 'InputVariable', 'LocalVariable' ]: raise NotImplementedError('Variable rules in grammars') elif type(rule_data) == str: if rule_data in [ a.variable_info.variable_name for a in arg_vars ]: (parameter_position, variable) = next((i, x) for (i, x) in enumerate(arg_vars) if x.variable_info.variable_name == rule_data) expr = exprs.FormalParameterExpression(synth_fun, variable.variable_info.variable_type, parameter_position) ret = grammars.ExpressionRewrite(expr) elif rule_data in non_terminals: ret = grammars.NTRewrite(rule_data, nt_type[rule_data]) elif rule_data in var_map: ret = grammars.ExpressionRewrite(var_map[rule_data]) else: # Could be a 0 arity function func = syn_ctx.make_function(rule_data) if func != None: ret = grammars.ExpressionRewrite(syn_ctx.make_function_expr(rule_data)) else: # Could be a let bound variable bound_var_ph = exprs.VariableExpression(exprs.VariableInfo(exprtypes.BoolType(), 'ph_' + rule_data)) ph_let_bound_vars.append(bound_var_ph) ret = grammars.ExpressionRewrite(bound_var_ph) elif type(rule_data) == list: function_name = rule_data[0] if function_name != 'let': function_args = [] for child in rule_data[1:]: ph_lbv, lbv, arg = _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, var_map, synth_fun, child) ph_let_bound_vars.extend(ph_lbv) let_bound_vars.extend(lbv) function_args.append(arg) function_arg_types = tuple([ x.type for x in function_args ]) function = syn_ctx.make_function(function_name, *function_arg_types) else: def child_processing_func(rd, syn_ctx, new_var_map): ph_lbv, lbv, a = _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, new_var_map, synth_fun, rd) ph_let_bound_vars.extend(ph_lbv) let_bound_vars.extend(lbv) return a def get_return_type(r): return r.type function, function_args = sexp_to_let(rule_data, syn_ctx, child_processing_func, get_return_type, var_map) let_bound_vars.extend(function.binding_vars) assert function is not None ret = grammars.FunctionRewrite(function, *function_args) else: raise Exception('Unknown right hand side: %s' % rule_data) return ph_let_bound_vars, let_bound_vars, ret
def get_score(self, expr): if self.stat_map is None: return 200.0 grammar = self.grammar # start symbol start = grammars.NTRewrite(grammar.start, grammar.nt_type[grammar.start]) score = 0.0 current = start nts_addrs = [[]] changed = True while changed: if (len(nts_addrs) == 0): break changed = False _, _, current_expr = current.to_template_expr() _, ctxt = get_ctxt(self.fetchop_func, current_expr, nts_addrs[0], self.instrs, {tuple(addr) for addr in nts_addrs[1:]}) cond = ','.join(ctxt) for (rule, next_rewrite, generated_nts_addrs) in grammar.one_step_expand(current, nts_addrs[0]): _, _, rule_expr = rule.to_template_expr() rule_topsymb = self.fetchop_func(rule_expr) topsymb = self.fetchop_func(fetch_prod(expr, nts_addrs[0])) if rule_topsymb == topsymb or isinstance(rule, grammars.NTRewrite): if isinstance(rule, grammars.NTRewrite): expand_cost = 0.0 else: expand_cost = -1.0 * math.log2(self.stat_map.get(cond, {}).get(topsymb, 0.001)) score += expand_cost nts_addrs = nts_addrs[1:] if len(generated_nts_addrs) == 0 else generated_nts_addrs + nts_addrs[1:] # print(' -> ', str(next_rewrite), ' ', expand_cost) current = next_rewrite changed = True break return score
def _merge_grammars(sf_grammar_list): start = "MergedStart" nts = [start] nt_type = {} rules = {} starts = [] for sf_name, sf_obj, grammar in sf_grammar_list: renamed_grammar = grammar.add_prefix(sf_name) nts.extend(renamed_grammar.non_terminals) nt_type.update(renamed_grammar.nt_type) rules.update(renamed_grammar.rules) starts.append(renamed_grammar.start) comma_function = semantics_core.CommaFunction([ nt_type[s] for s in starts ]) rules[start] = [ grammars.FunctionRewrite(comma_function, *tuple([ grammars.NTRewrite(s, nt_type[s]) for s in starts ])) ] nt_type[start] = None merged_grammar = grammars.Grammar(nts, nt_type, rules, start) return merged_grammar
def get_derivation_sequences(expr, grammar): result = [] # print(exprs.expression_to_string(expr)) # print(str(grammar)) nts = grammar.non_terminals for nt in nts: seq = [] start = grammars.NTRewrite(nt, grammar.nt_type[nt]) seq.append(start) current = start nts_addrs = [[]] changed = True while changed: if (len(nts_addrs) == 0): break changed = False expanded_result = [(rule, next_rewrite, generated_nts_addrs) for rule, next_rewrite, generated_nts_addrs in grammar.one_step_expand(current, nts_addrs[0])] for (rule, next_rewrite, generated_nts_addrs) in expanded_result: rule_topsymb = fetchop_rewrite(rule) topsymb = fetchop(fetch_prod(expr, nts_addrs[0])) if rule_topsymb == topsymb or isinstance( rule, grammars.NTRewrite): # print('current_rewrite : ', str(current)) # print('next_rewrite : ', str(next_rewrite)) # pdb.set_trace() if not isinstance(rule, grammars.NTRewrite): seq.append(next_rewrite) nts_addrs = nts_addrs[1:] if len( generated_nts_addrs ) == 0 else generated_nts_addrs + nts_addrs[1:] current = next_rewrite changed = True break if (len(seq) > 1): result.append(seq) return result
def get_corresponding_rule(nt, deriv_seq): if not nt in grammar.non_terminals: return None start = grammars.NTRewrite(nt, grammar.nt_type[nt]) current = start nts_addrs = [[]] result = None for next_sentential in deriv_seq[1:]: changed = False for (rule, next_rewrite, generated_nts_addrs) in grammar.one_step_expand( current, nts_addrs[0]): if str(next_rewrite) == next_sentential: nts_addrs = nts_addrs[1:] if len( generated_nts_addrs ) == 0 else generated_nts_addrs + nts_addrs[1:] current = next_rewrite result = next_rewrite changed = True break if not changed: return None return result
def generate(self, compute_term_signature, points): grammar = self.grammar m = self.m m_sens = self.m_sens instrs = self.instrs stat_map = self.stat_map # ********************************************** Data structures *************************************************** # priority queue : list of str(rewrite) frontier = PriorityQueue() # str(rewrite)-> rewrite * addr of leftmost non-terminal symbol * (placeholder variables * non-terminals * expr) # roles: # 1. rewrite cannot be added to the priority queue as it is unhashable. # We add str(rewrite) to the queue and this map stores its corresponding rewrite object. # 2. to avoid repeating calling to_template_expr strrewrite_to_rewrite = {} # directed graph of str(rewrite) # expansion_history = nx.DiGraph() # str(rewrite) -> normalized str(rewrite) strrewrite_to_normstrrewrite = {} # normalized str(rewrite) -> str(rewrite) (representative of equivalence class) normstrrewrite_to_strrewrite = {} # sum of log probabilities of expansions so far # strrewrite -> R cost_so_far = {} # str(rewrite) -> priority strrewrite_to_priority = {} # set of ignored str(rewrite) ignored = [] # ****************************************************************************************************************** # start symbol start = grammars.NTRewrite(grammar.start, grammar.nt_type[grammar.start]) # init for start symbol start_str = str(start) (start_ph_vars, start_nts, start_expr) = start.to_template_expr() strrewrite_to_rewrite[start_str] = (start, [[]], (start_ph_vars, start_nts, start_expr)) strrewrite_to_normstrrewrite[start_str] = start_str normstrrewrite_to_strrewrite[start_str] = start_str frontier.put(start_str, 0) cost_so_far[start_str] = 0 # init for non-terminals for nt in grammar.non_terminals: strrewrite_to_normstrrewrite[nt] = nt def incremental_update(ignored, frontier): strrewrite_to_normstrrewrite.clear() normstrrewrite_to_strrewrite.clear() for nt in grammar.non_terminals: strrewrite_to_normstrrewrite[nt] = nt for rewrite_str in ignored: frontier.put(rewrite_str, strrewrite_to_priority[rewrite_str]) ignored.clear() num_points = len(points) while not frontier.empty(): # incremental search with indistinguishability if len(points) > num_points and options.inc: incremental_update(ignored, frontier) num_points = len(points) _,current_str = frontier.get() (current, nts_addrs, (ph_vars, nts, current_expr)) = strrewrite_to_rewrite[current_str] if len(nts) == 0: #print('%50s :\t %.2f' % (exprs.expression_to_string(current_expr), cost_so_far[current_str]), flush=True) yield [current_expr] else: assert (len(nts_addrs) > 0) _, ctxt = get_ctxt(self.fetchop_func, current_expr, nts_addrs[0], instrs, {tuple(addr) for addr in nts_addrs[1:]}) cond = ','.join(ctxt) # one step left-most expansion for rule, next_rewrite, generated_nts_addrs in grammar.one_step_expand(current, nts_addrs[0]): next_nts_addrs = nts_addrs[1:] if len(generated_nts_addrs) == 0 else generated_nts_addrs + nts_addrs[1:] next_ph_vars, next_nts, next_expr = next_rewrite.to_template_expr() # update rewrite_forest and get a string of next_rewrite next_str = str(next_rewrite) strrewrite_to_rewrite[next_str] = (next_rewrite, next_nts_addrs, (next_ph_vars, next_nts, next_expr)) # if it is non-terminal rewriting, it causes no cost. _, _, rule_expr = rule.to_template_expr() topsymb = self.fetchop_func(rule_expr) if isinstance(rule, grammars.NTRewrite): expand_cost = 0.0 else: expand_cost = -1.0 * math.log2(stat_map.get(cond, {}).get(topsymb, 0.001)) new_cost = cost_so_far[current_str] + expand_cost # print(current_str, ' -> ', next_str, ' ', expand_cost) if next_str not in cost_so_far or new_cost < cost_so_far[next_str]: cost_so_far[next_str] = new_cost future_cost = 0.0 if options.noheuristic else heuristic_sens(self.fetchop_func, m, m_sens, instrs, next_nts_addrs, next_nts, next_expr) priority = new_cost + future_cost strrewrite_to_priority[next_str] = priority # get representative of eq class which current rewrite belongs to if not options.noindis: normalized_next_str = normalize_rewritestr(next_expr, strrewrite_to_normstrrewrite, compute_term_signature, grammar.non_terminals) # no normalization else: normalized_next_str = next_str rep = normstrrewrite_to_strrewrite[normalized_next_str] \ if normalized_next_str in normstrrewrite_to_strrewrite else next_str # switch representative if rep == next_str or priority < strrewrite_to_priority.get(rep, max_score): normstrrewrite_to_strrewrite[normalized_next_str] = next_str frontier.put(rep, priority, replace=next_str) # rep is ignored and replaced by new one. if rep != next_str and options.inc: ignored.append(rep) else: # next_str is ignored because it is worse than the current representative. if options.inc: ignored.append(next_str) return