コード例 #1
0
 def cmd_overlap_score(gt, pred):
     if hasattr(gt, 'parts'):
         gt_token_set = set([n.word for n in gt.parts if n.kind == "word"])
     else:
         gt_tokens = data_tools.basic_tokenizer(gt)
         if not gt_tokens:
             gt_tokens = data_tools.basic_tokenizer(gt)
         gt_token_set = set(gt_tokens)
     pred_tokens = data_tools.basic_tokenizer(pred)
     if not pred_tokens:
         pred_tokens = data_tools.basic_tokenizer(pred)
     pred_token_set = set(pred_tokens)
     return (len(gt_token_set & pred_token_set) + 0.0) / len(gt_token_set | pred_token_set)
コード例 #2
0
def extract_rewrites(data):
    nls, cms = data
    group_pairs_by_nl = {}
    for nl, cm in zip(nls, cms):
        nl = nl.strip()
        cm = cm.strip()
        if nl.lower() == "na":
            continue
        if not nl:
            continue
        if not cm:
            continue
        nl_temp = ' '.join(data_tools.basic_tokenizer(nl.decode('utf-8')))
        if not nl_temp in group_pairs_by_nl:
            group_pairs_by_nl[nl_temp] = {}
        cm_temp = data_tools.cmd2template(cm)
        if not cm_temp in group_pairs_by_nl[nl_temp]:
            group_pairs_by_nl[nl_temp][cm_temp] = collections.defaultdict(int)
        group_pairs_by_nl[nl_temp][cm_temp][cm] += 1

    merged = set()
    nls = group_pairs_by_nl.keys()
    for i in xrange(len(nls)):
        nl = nls[i]
        cm_set = set(group_pairs_by_nl[nl].keys())
        for j in xrange(i+1, len(nls)):
            nl2 = nls[j]
            cm_set2 = set(group_pairs_by_nl[nl2].keys())
            if len(cm_set & cm_set2) >= 2:
                for cm_temp in cm_set:
                    if not cm_temp in group_pairs_by_nl[nl2]:
                        group_pairs_by_nl[nl2][cm_temp] = \
                            group_pairs_by_nl[nl][cm_temp]
                    else:
                        for cm in group_pairs_by_nl[nl][cm_temp]:
                            group_pairs_by_nl[nl2][cm_temp][cm] += \
                                group_pairs_by_nl[nl][cm_temp][cm]
                merged.add(i)

    bash_paraphrases = {}
    for i in xrange(len(nls)):
        if i in merged:
            continue
        bash_paraphrases[nls[i]] = group_pairs_by_nl[nls[i]]

    with DBConnection() as db:
        db.create_schema()
        for nl, cm_temps in sorted(bash_paraphrases.items(),
                                   key=lambda x: len(x[1]), reverse=True):
            if len(cm_temps) >= 2:
                for cm_temp1 in cm_temps:
                    for cm_temp2 in cm_temps:
                        if cm_temp1 == cm_temp2:
                            continue
                        if not db.exist_rewrite((cm_temp1, cm_temp2)):
                            db.add_rewrite((cm_temp1, cm_temp2))
                            print("* {} --> {}".format(cm_temp1, cm_temp2))
                print()
コード例 #3
0
 def add_to_set(nl_data, cm_data, split):
     with_parent = True
     for nl, cm in zip(getattr(nl_data, split), getattr(cm_data, split)):
         ast = data_tools.bash_parser(cm)
         if ast:
             if is_simple(ast):
                 nl_chars = data_tools.char_tokenizer(nl, data_tools.basic_tokenizer,
                                                      normalize_digits=False,
                                                      normalize_long_pattern=False)
                 cm_chars = data_tools.char_tokenizer(cm, data_tools.bash_tokenizer,
                                                      normalize_digits=False,
                                                      normalize_long_pattern=False)
                 nl_tokens = data_tools.basic_tokenizer(nl)
                 cm_tokens = data_tools.ast2tokens(ast, with_parent=with_parent)
                 cm_seq = data_tools.ast2list(ast, list=[], with_parent=with_parent)
                 pruned_ast = normalizer.prune_ast(ast)
                 cm_pruned_tokens = data_tools.ast2tokens(
                     pruned_ast, loose_constraints=True, with_parent=with_parent)
                 cm_pruned_seq = data_tools.ast2list(
                     pruned_ast, list=[], with_parent=with_parent)
                 cm_normalized_tokens = data_tools.ast2tokens(
                     ast, loose_constraints=True, arg_type_only=True, with_parent=with_parent)
                 cm_normalized_seq = data_tools.ast2list(
                     ast, arg_type_only=True, list=[], with_parent=with_parent)
                 cm_canonical_tokens = data_tools.ast2tokens(
                     ast, loose_constraints=True, arg_type_only=True, ignore_flag_order=True,
                     with_parent=with_parent)
                 cm_canonical_seq = data_tools.ast2list(
                     ast, arg_type_only=True, ignore_flag_order=True, list=[],
                     with_parent=with_parent)
                 getattr(nl_list, split).append(nl)
                 getattr(cm_list, split).append(cm)
                 getattr(nl_char_list, split).append(nl_chars)
                 getattr(nl_token_list, split).append(nl_tokens)
                 getattr(cm_char_list, split).append(cm_chars)
                 getattr(cm_token_list, split).append(cm_tokens)
                 getattr(cm_seq_list, split).append(cm_seq)
                 getattr(cm_pruned_token_list, split).append(cm_pruned_tokens)
                 getattr(cm_pruned_seq_list, split).append(cm_pruned_seq)
                 getattr(cm_normalized_token_list, split).append(cm_normalized_tokens)
                 getattr(cm_normalized_seq_list, split).append(cm_normalized_seq)
                 getattr(cm_canonical_token_list, split).append(cm_canonical_tokens)
                 getattr(cm_canonical_seq_list, split).append(cm_canonical_seq)
             else:
                 print("Rare command: " + cm)
コード例 #4
0
def group_data_by_nl(dataset, use_bucket=False, use_nl_temp=True):
    if use_bucket:
        dataset = reduce(lambda x,y: x + y, dataset)
    grouped_dataset = {}
    for i in xrange(len(dataset)):
        nl_str, cm_str, nl, search_history = dataset[i]
        if use_nl_temp:
            nl_template = " ".join(data_tools.basic_tokenizer(nl_str.decode("utf-8")))
        else:
            nl_template = nl_str
        if nl_template in grouped_dataset:
            grouped_dataset[nl_template][0].append(nl_str)
            grouped_dataset[nl_template][1].append(cm_str)
            grouped_dataset[nl_template][2].append(nl)
            grouped_dataset[nl_template][3].append(search_history)
        else:
            grouped_dataset[nl_template] = [[nl_str], [cm_str], [nl], [search_history]]

    return grouped_dataset
コード例 #5
0
def token_overlap(s1, s2):
    tokens1 = set([w for w in basic_tokenizer(s1) if not is_stopword(w)])
    tokens2 = set([w for w in basic_tokenizer(s2) if not is_stopword(w)])
    return (len(tokens1 & tokens2) + 0.0) / len(tokens1 | tokens2)
コード例 #6
0
    def dump_data(self, data_dir, num_folds=10):
        # First-pass: group pairs by URLs
        pairs = self.unique_pairs("find")

        # Second-pass: group url clusters by nls
        templates = {}
        urls = pairs.keys()
        print("%d urls in the database" % len(urls))

        merged_urls_by_nl = []
        for i in xrange(len(urls)):
            url = urls[i]
            merge = False
            for j in xrange(i+1, len(urls)):
                url2 = urls[j]
                for nl in pairs[url]:
                    if nl in templates:
                        nl_template1 = templates[nl]
                    else:
                        nl_template1 = " ".join(basic_tokenizer(nl))
                        templates[nl] = nl_template1
                    for nl2 in pairs[url2]:
                        if nl2 in templates:
                            nl_template2 = templates[nl2]
                        else:
                            nl_template2 = " ".join(basic_tokenizer(nl2))
                            templates[nl2] = nl_template2
                        if nl_template1 == nl_template2:
                            merge = True
                            break
                    if merge:
                        break
                if merge:
                    break
            if merge:
                for nl in pairs[url]:
                    if nl in pairs[url2]:
                        pairs[url2][nl] = pairs[url][nl] | pairs[url2][nl]
                    else:
                        pairs[url2][nl] = pairs[url][nl]
                merged_urls_by_nl.append(i)
        print("%d urls merged by nl" % len(merged_urls_by_nl))

        # Third-pass: group url clusters by commands

        merged_urls_by_cmd = []
        """templates = {}
        for i in xrange(len(urls)):
            if i in merged_urls_by_nl:
                continue
            url = urls[i]
            merge = False
            for j in xrange(i+1, len(urls)):
                if j in merged_urls_by_nl:
                    continue
                url2 = urls[j]
                for _, cmds in pairs[url].items():
                    for cmd in cmds:
                        if cmd in templates:
                            template = templates[cmd]
                        else:
                            template = cmd2template(cmd, arg_type_only=split_by_template)
                            templates[cmd] = template
                        for _, cmd2s in pairs[url2].items():
                            for cmd2 in cmd2s:
                                if cmd2 in templates:
                                    template2 = templates[cmd2]
                                else:
                                    template2 = cmd2template(cmd2, arg_type_only=split_by_template)
                                    templates[cmd2] = template2
                                if template == template2:
                                    merge = True
                                    break
                            if merge:
                                break
                        if merge:
                            break
                    if merge:
                        break
                if merge:
                    break
            if merge:
                for nl in pairs[url]:
                    if nl in pairs[url2]:
                        pairs[url2][nl] = pairs[url][nl] | pairs[url2][nl]
                    else:
                        pairs[url2][nl] = pairs[url][nl]
                merged_urls_by_cmd.append(i)
        print("%d urls merged by cmd" % len(merged_urls_by_cmd))
        """

        remained_urls = []
        for i in xrange(len(urls)):
            if i in merged_urls_by_cmd:
                continue
            if i in merged_urls_by_nl:
                continue
            remained_urls.append(urls[i])
        sorted_urls = sorted(remained_urls, key=lambda x:
                             reduce(lambda a, b: a+b, [len(pairs[x][nl]) for nl in pairs[x]]), 
                             reverse=True)

        data = collections.defaultdict(list)

        num_pairs = 0
        num_train = 0
        num_train_pairs = 0
        num_dev = 0
        num_dev_pairs = 0
        num_test = 0
        num_test_pairs = 0
        num_urls = 0

        top_k = 0

        for i in xrange(len(sorted_urls)):
            url = sorted_urls[i]
            url_size = reduce(lambda x, y: x+y, [len(pairs[url][nl]) for nl in pairs[url]])
            # print("url %d (%d)" % (i, url_size))
            if i < top_k:
                for nl in pairs[url]:
                    print(nl)
                print("-------------")
                ind = random.randrange(num_folds - 2)
                num_train += 1
                num_train_pairs += url_size
            else:
                ind = random.randrange(num_folds)
                if ind < num_folds - 2:
                    num_train += 1
                    num_train_pairs += url_size
                elif ind == num_folds - 2:
                    num_dev += 1
                    num_dev_pairs += url_size
                elif ind == num_folds - 1:
                    num_test += 1
                    num_test_pairs += url_size
            num_urls += 1

            bin = data[ind]
            for nl in pairs[url]:
                for cmd in pairs[url][nl]:
                    num_pairs += 1
                    cmd = cmd.strip().replace('\n', ' ').replace('\r', ' ')
                    nl = nl.strip().replace('\n', ' ').replace('\r', ' ')
                    if not type(nl) is unicode:
                        nl = nl.decode()
                    if not type(cmd) is unicode:
                        cmd = cmd.decode()
                    bin.append((nl, cmd))

        print("Total number of pairs: %d" % num_pairs)
        print("Total number of url clusters: %d" % num_urls)
        print("Total number of train clusters: %d (%d pairs)" % (num_train, num_train_pairs))
        print("Total number of dev clusters: %d (%d pairs)" % (num_dev, num_dev_pairs))
        print("Total number of test clusters: %d (%d pairs)" % (num_test, num_test_pairs))
        print("%.2f pairs per url cluster" % ((num_pairs + 0.0) / num_urls))

        # if split_by_template:
        #     split_by = "template"
        # else:
        #     split_by = "command"
        # with open(data_dir + "/data.by.%s.dat" % split_by, 'w') as o_f:
        #     pickle.dump(data, o_f)

        train_nl_list = []
        train_cm_list = []
        dev_nl_list = []
        dev_cm_list = []
        test_nl_list = []
        test_cm_list = []

        numFolds = len(data)
        for i in xrange(numFolds):
            if i < numFolds - 2:
                for j in xrange(len(data[i])):
                    train_nl_list.append(data[i][j][0])
                    train_cm_list.append(data[i][j][1])
            elif i == numFolds - 2:
                for j in xrange(len(data[i])):
                    dev_nl_list.append(data[i][j][0])
                    dev_cm_list.append(data[i][j][1])
            elif i == numFolds - 1:
                for j in xrange(len(data[i])):
                    test_nl_list.append(data[i][j][0])
                    test_cm_list.append(data[i][j][1])

        def write_data(data_path, data):
            if not os.path.exists(data_path):
                with open(data_path, 'w') as o_f:
                    for line in data:
                        o_f.write(line.encode('utf-8') + '\n')

        train_path = os.path.join(data_dir, "train")
        dev_path = os.path.join(data_dir, "dev")
        test_path = os.path.join(data_dir, "test")
        write_data(train_path + ".nl", train_nl_list)
        write_data(train_path + ".cm", train_cm_list)
        write_data(dev_path + ".nl", dev_nl_list)
        write_data(dev_path + ".cm", dev_cm_list)
        write_data(test_path + ".nl", test_nl_list)
        write_data(test_path + ".cm", test_cm_list)