Ejemplo n.º 1
0
def _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, var_map, synth_fun, rule_data):
    ph_let_bound_vars, let_bound_vars = [], []
    if type(rule_data) == tuple:
        value = sexp_to_value(rule_data)
        ret = grammars.ExpressionRewrite(exprs.ConstantExpression(value))
    elif rule_data[0] == 'Constant':
        typ = sexp_to_type(rule_data[1])
        ret = grammars.NTRewrite('Constant' + str(typ), typ)
    elif rule_data[0] in [ 'Variable', 'InputVariable', 'LocalVariable' ]:
        raise NotImplementedError('Variable rules in grammars')
    elif type(rule_data) == str:
        if rule_data in [ a.variable_info.variable_name for a in arg_vars ]:
            (parameter_position, variable) = next((i, x) for (i, x) in enumerate(arg_vars)
                    if x.variable_info.variable_name == rule_data)
            expr = exprs.FormalParameterExpression(synth_fun,
                    variable.variable_info.variable_type,
                    parameter_position)
            ret = grammars.ExpressionRewrite(expr)
        elif rule_data in non_terminals:
            ret = grammars.NTRewrite(rule_data, nt_type[rule_data])
        elif rule_data in var_map:
            ret = grammars.ExpressionRewrite(var_map[rule_data])
        else:
            # Could be a 0 arity function
            func = syn_ctx.make_function(rule_data)
            if func != None:
                ret = grammars.ExpressionRewrite(syn_ctx.make_function_expr(rule_data))
            else:
                # Could be a let bound variable
                bound_var_ph = exprs.VariableExpression(exprs.VariableInfo(exprtypes.BoolType(), 'ph_' + rule_data))
                ph_let_bound_vars.append(bound_var_ph)
                ret = grammars.ExpressionRewrite(bound_var_ph)
    elif type(rule_data) == list:
        function_name = rule_data[0]
        if function_name != 'let':
            function_args = []
            for child in rule_data[1:]:
                ph_lbv, lbv, arg = _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, var_map, synth_fun, child)
                ph_let_bound_vars.extend(ph_lbv)
                let_bound_vars.extend(lbv)
                function_args.append(arg)
            function_arg_types = tuple([ x.type for x in function_args ])
            function = syn_ctx.make_function(function_name, *function_arg_types)
        else:
            def child_processing_func(rd, syn_ctx, new_var_map):
                ph_lbv, lbv, a = _process_rule(non_terminals, nt_type, syn_ctx, arg_vars, new_var_map, synth_fun, rd)
                ph_let_bound_vars.extend(ph_lbv)
                let_bound_vars.extend(lbv)
                return a
            def get_return_type(r):
                return r.type
            function, function_args = sexp_to_let(rule_data, syn_ctx, child_processing_func, get_return_type, var_map)
            let_bound_vars.extend(function.binding_vars)
        assert function is not None
        ret =  grammars.FunctionRewrite(function, *function_args)
    else:
        raise Exception('Unknown right hand side: %s' % rule_data)
    return ph_let_bound_vars, let_bound_vars, ret
Ejemplo n.º 2
0
    def get_score(self, expr):
        if self.stat_map is None:
            return 200.0
        grammar = self.grammar
        # start symbol
        start = grammars.NTRewrite(grammar.start, grammar.nt_type[grammar.start])
        score = 0.0
        current = start
        nts_addrs = [[]]

        changed = True
        while changed:
            if (len(nts_addrs) == 0): break
            changed = False
            _, _, current_expr = current.to_template_expr()
            _, ctxt = get_ctxt(self.fetchop_func, current_expr, nts_addrs[0], self.instrs, {tuple(addr) for addr in nts_addrs[1:]})
            cond = ','.join(ctxt)
            for (rule, next_rewrite, generated_nts_addrs) in grammar.one_step_expand(current, nts_addrs[0]):
                _, _, rule_expr = rule.to_template_expr()
                rule_topsymb = self.fetchop_func(rule_expr)
                topsymb = self.fetchop_func(fetch_prod(expr, nts_addrs[0]))
                if rule_topsymb == topsymb or isinstance(rule, grammars.NTRewrite):
                    if isinstance(rule, grammars.NTRewrite):
                        expand_cost = 0.0
                    else:
                        expand_cost = -1.0 * math.log2(self.stat_map.get(cond, {}).get(topsymb, 0.001))

                    score += expand_cost
                    nts_addrs = nts_addrs[1:] if len(generated_nts_addrs) == 0 else generated_nts_addrs + nts_addrs[1:]
                    # print(' -> ', str(next_rewrite), ' ', expand_cost)
                    current = next_rewrite
                    changed = True
                    break
        return score
Ejemplo n.º 3
0
def _merge_grammars(sf_grammar_list):
    start = "MergedStart"
    nts = [start]
    nt_type = {}
    rules = {}
    starts = []
    for sf_name, sf_obj, grammar in sf_grammar_list:
        renamed_grammar = grammar.add_prefix(sf_name)
        nts.extend(renamed_grammar.non_terminals)
        nt_type.update(renamed_grammar.nt_type)
        rules.update(renamed_grammar.rules)
        starts.append(renamed_grammar.start)
    comma_function = semantics_core.CommaFunction([ nt_type[s] for s in starts ])
    rules[start] = [ grammars.FunctionRewrite(comma_function,
            *tuple([ grammars.NTRewrite(s, nt_type[s]) for s in starts ])) ]
    nt_type[start] = None
    merged_grammar = grammars.Grammar(nts, nt_type, rules, start)

    return merged_grammar
Ejemplo n.º 4
0
def get_derivation_sequences(expr, grammar):
    result = []
    # print(exprs.expression_to_string(expr))
    # print(str(grammar))

    nts = grammar.non_terminals
    for nt in nts:
        seq = []
        start = grammars.NTRewrite(nt, grammar.nt_type[nt])
        seq.append(start)

        current = start
        nts_addrs = [[]]

        changed = True
        while changed:
            if (len(nts_addrs) == 0):
                break
            changed = False
            expanded_result = [(rule, next_rewrite, generated_nts_addrs)
                               for rule, next_rewrite, generated_nts_addrs in
                               grammar.one_step_expand(current, nts_addrs[0])]
            for (rule, next_rewrite, generated_nts_addrs) in expanded_result:
                rule_topsymb = fetchop_rewrite(rule)
                topsymb = fetchop(fetch_prod(expr, nts_addrs[0]))
                if rule_topsymb == topsymb or isinstance(
                        rule, grammars.NTRewrite):
                    # print('current_rewrite : ', str(current))
                    # print('next_rewrite : ', str(next_rewrite))
                    # pdb.set_trace()
                    if not isinstance(rule, grammars.NTRewrite):
                        seq.append(next_rewrite)
                    nts_addrs = nts_addrs[1:] if len(
                        generated_nts_addrs
                    ) == 0 else generated_nts_addrs + nts_addrs[1:]
                    current = next_rewrite
                    changed = True
                    break
        if (len(seq) > 1):
            result.append(seq)

    return result
Ejemplo n.º 5
0
    def get_corresponding_rule(nt, deriv_seq):
        if not nt in grammar.non_terminals:
            return None

        start = grammars.NTRewrite(nt, grammar.nt_type[nt])
        current = start
        nts_addrs = [[]]
        result = None
        for next_sentential in deriv_seq[1:]:
            changed = False
            for (rule, next_rewrite,
                 generated_nts_addrs) in grammar.one_step_expand(
                     current, nts_addrs[0]):
                if str(next_rewrite) == next_sentential:
                    nts_addrs = nts_addrs[1:] if len(
                        generated_nts_addrs
                    ) == 0 else generated_nts_addrs + nts_addrs[1:]
                    current = next_rewrite
                    result = next_rewrite
                    changed = True
                    break
            if not changed:
                return None
        return result
Ejemplo n.º 6
0
    def generate(self, compute_term_signature, points):
        grammar = self.grammar
        m = self.m
        m_sens = self.m_sens
        instrs = self.instrs
        stat_map = self.stat_map

    # ********************************************** Data structures ***************************************************
        # priority queue : list of str(rewrite)
        frontier = PriorityQueue()

        # str(rewrite)-> rewrite * addr of leftmost non-terminal symbol * (placeholder variables * non-terminals * expr)
        # roles:
        #   1. rewrite cannot be added to the priority queue as it is unhashable.
        #      We add str(rewrite) to the queue and this map stores its corresponding rewrite object.
        #   2. to avoid repeating calling to_template_expr
        strrewrite_to_rewrite = {}

        # directed graph of str(rewrite)
        # expansion_history = nx.DiGraph()

        # str(rewrite) -> normalized str(rewrite)
        strrewrite_to_normstrrewrite = {}
        # normalized str(rewrite) -> str(rewrite)   (representative of equivalence class)
        normstrrewrite_to_strrewrite = {}

        # sum of log probabilities of expansions so far
        # strrewrite -> R
        cost_so_far = {}

        # str(rewrite) -> priority
        strrewrite_to_priority = {}

        # set of ignored str(rewrite)
        ignored = []

    # ******************************************************************************************************************

        # start symbol
        start = grammars.NTRewrite(grammar.start, grammar.nt_type[grammar.start])

        # init for start symbol
        start_str = str(start)
        (start_ph_vars, start_nts, start_expr) = start.to_template_expr()
        strrewrite_to_rewrite[start_str] = (start, [[]], (start_ph_vars, start_nts, start_expr))
        strrewrite_to_normstrrewrite[start_str] = start_str
        normstrrewrite_to_strrewrite[start_str] = start_str
        frontier.put(start_str, 0)
        cost_so_far[start_str] = 0

        # init for non-terminals
        for nt in grammar.non_terminals:
            strrewrite_to_normstrrewrite[nt] = nt

        def incremental_update(ignored, frontier):
            strrewrite_to_normstrrewrite.clear()
            normstrrewrite_to_strrewrite.clear()

            for nt in grammar.non_terminals:
                strrewrite_to_normstrrewrite[nt] = nt

            for rewrite_str in ignored:
                frontier.put(rewrite_str, strrewrite_to_priority[rewrite_str])
            ignored.clear()

        num_points = len(points)
        while not frontier.empty():
            # incremental search with indistinguishability
            if len(points) > num_points and options.inc:
                incremental_update(ignored, frontier)

            num_points = len(points)
            _,current_str = frontier.get()
            (current, nts_addrs, (ph_vars, nts, current_expr)) = strrewrite_to_rewrite[current_str]
            if len(nts) == 0:
                #print('%50s :\t %.2f' % (exprs.expression_to_string(current_expr), cost_so_far[current_str]), flush=True)
                yield [current_expr]
            else:
                assert (len(nts_addrs) > 0)
                _, ctxt = get_ctxt(self.fetchop_func, current_expr, nts_addrs[0], instrs, {tuple(addr) for addr in nts_addrs[1:]})
                cond = ','.join(ctxt)

                # one step left-most expansion
                for rule, next_rewrite, generated_nts_addrs in grammar.one_step_expand(current, nts_addrs[0]):
                    next_nts_addrs = nts_addrs[1:] if len(generated_nts_addrs) == 0 else generated_nts_addrs + nts_addrs[1:]
                    next_ph_vars, next_nts, next_expr = next_rewrite.to_template_expr()
                    # update rewrite_forest and get a string of next_rewrite
                    next_str = str(next_rewrite)
                    strrewrite_to_rewrite[next_str] = (next_rewrite, next_nts_addrs, (next_ph_vars, next_nts, next_expr))
                    # if it is non-terminal rewriting, it causes no cost.
                    _, _, rule_expr = rule.to_template_expr()
                    topsymb = self.fetchop_func(rule_expr)
                    if isinstance(rule, grammars.NTRewrite):
                        expand_cost = 0.0
                    else:
                        expand_cost = -1.0 * math.log2(stat_map.get(cond, {}).get(topsymb, 0.001))
                    new_cost = cost_so_far[current_str] + expand_cost
                    # print(current_str, ' -> ', next_str, ' ', expand_cost)

                    if next_str not in cost_so_far or new_cost < cost_so_far[next_str]:
                        cost_so_far[next_str] = new_cost
                        future_cost = 0.0 if options.noheuristic else heuristic_sens(self.fetchop_func, m, m_sens, instrs, next_nts_addrs, next_nts, next_expr)
                        priority = new_cost + future_cost
                        strrewrite_to_priority[next_str] = priority

                        # get representative of eq class which current rewrite belongs to
                        if not options.noindis:
                            normalized_next_str = normalize_rewritestr(next_expr, strrewrite_to_normstrrewrite,
                                                                       compute_term_signature,
                                                                       grammar.non_terminals)
                        # no normalization
                        else:
                            normalized_next_str = next_str

                        rep = normstrrewrite_to_strrewrite[normalized_next_str] \
                            if normalized_next_str in normstrrewrite_to_strrewrite else next_str

                        # switch representative
                        if rep == next_str or priority < strrewrite_to_priority.get(rep, max_score):
                            normstrrewrite_to_strrewrite[normalized_next_str] = next_str
                            frontier.put(rep, priority, replace=next_str)
                            # rep is ignored and replaced by new one.
                            if rep != next_str and options.inc: ignored.append(rep)
                        else:
                            # next_str is ignored because it is worse than the current representative.
                            if options.inc:
                                ignored.append(next_str)

        return