def extract_rewrites(data):
    nls, cms = data
    group_pairs_by_nl = {}
    for nl, cm in zip(nls, cms):
        nl = nl.strip()
        cm = cm.strip()
        if nl.lower() == "na":
            continue
        if not nl:
            continue
        if not cm:
            continue
        nl_temp = ' '.join(data_tools.basic_tokenizer(nl.decode('utf-8')))
        if not nl_temp in group_pairs_by_nl:
            group_pairs_by_nl[nl_temp] = {}
        cm_temp = data_tools.cmd2template(cm)
        if not cm_temp in group_pairs_by_nl[nl_temp]:
            group_pairs_by_nl[nl_temp][cm_temp] = collections.defaultdict(int)
        group_pairs_by_nl[nl_temp][cm_temp][cm] += 1

    merged = set()
    nls = group_pairs_by_nl.keys()
    for i in xrange(len(nls)):
        nl = nls[i]
        cm_set = set(group_pairs_by_nl[nl].keys())
        for j in xrange(i+1, len(nls)):
            nl2 = nls[j]
            cm_set2 = set(group_pairs_by_nl[nl2].keys())
            if len(cm_set & cm_set2) >= 2:
                for cm_temp in cm_set:
                    if not cm_temp in group_pairs_by_nl[nl2]:
                        group_pairs_by_nl[nl2][cm_temp] = \
                            group_pairs_by_nl[nl][cm_temp]
                    else:
                        for cm in group_pairs_by_nl[nl][cm_temp]:
                            group_pairs_by_nl[nl2][cm_temp][cm] += \
                                group_pairs_by_nl[nl][cm_temp][cm]
                merged.add(i)

    bash_paraphrases = {}
    for i in xrange(len(nls)):
        if i in merged:
            continue
        bash_paraphrases[nls[i]] = group_pairs_by_nl[nls[i]]

    with DBConnection() as db:
        db.create_schema()
        for nl, cm_temps in sorted(bash_paraphrases.items(),
                                   key=lambda x: len(x[1]), reverse=True):
            if len(cm_temps) >= 2:
                for cm_temp1 in cm_temps:
                    for cm_temp2 in cm_temps:
                        if cm_temp1 == cm_temp2:
                            continue
                        if not db.exist_rewrite((cm_temp1, cm_temp2)):
                            db.add_rewrite((cm_temp1, cm_temp2))
                            print("* {} --> {}".format(cm_temp1, cm_temp2))
                print()
예제 #2
0
def group_data_by_cm(dataset, use_bucket=False, use_cm_temp=True):
    if use_bucket:
        dataset = reduce(lambda x,y: x + y, dataset)
    grouped_dataset = {}
    for i in xrange(len(dataset)):
        nl_str, cm_str, nl, search_history = dataset[i]
        if use_cm_temp:
            cm_template = data_tools.cmd2template(cm_str)
        else:
            cm_template = cm_str
        if cm_template in grouped_dataset:
            grouped_dataset[cm_template][0].append(nl_str)
            grouped_dataset[cm_template][1].append(cm_str)
            grouped_dataset[cm_template][2].append(nl)
            grouped_dataset[cm_template][3].append(search_history)
        else:
            grouped_dataset[cm_template] = [[nl_str], [cm_str], [nl], [search_history]]

    return grouped_dataset