Exemplo n.º 1
0
 def test_jaccard(self):
     sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
     index = SearchIndex(sets, similarity_func_name="jaccard",
             similarity_threshold=0.1)
     results = index.query([3,5,4])
     correct_results = set([(1, 1.0), (0, 0.2), (2, 0.5), (3, 0.2)])
     self.assertEqual(set(results), correct_results)
Exemplo n.º 2
0
def cross_collection_all_pairs(indexed_set_IDs, query_set_IDs, indexed_sets,
                               query_sets, similarity_func_name,
                               similarity_threshold):
    logging.info("Building search index on {} sets.".format(len(indexed_sets)))
    index = SearchIndex(indexed_sets,
                        similarity_func_name=similarity_func_name,
                        similarity_threshold=similarity_threshold)
    logging.info("Finished building search index.")
    logging.info(
        "Find pairs with similarity >= {}.".format(similarity_threshold))
    count = 0
    query_times = deque([])
    for set_ID, s in zip(query_set_IDs, query_sets):
        start = time.time()
        results = []
        for i, sim in index.query(s):
            results.append((set_ID, indexed_set_IDs[i], len(s),
                            len(indexed_sets[i]), sim))
            count += 1
        query_time = time.time() - start
        query_times.append(query_time)
        for result in results:
            yield result
    logging.info("Found {} pairs.".format(count))
    # Compute percentiles
    query_times = np.array(list(query_times))
    logging.info("Average query time: {}.".format(np.mean(query_times)))
    logging.info("Median query time: {}.".format(np.mean(query_times)))
    logging.info("90pct query time: {}.".format(np.percentile(query_times,
                                                              90)))
def process(opt):

    print('Reading in source lines...')
    with open(opt.src, 'r', encoding='utf-8') as f:
        src_lines = [l.strip() for l in f]

    print('Reading in tms lines...')
    with open(opt.tms, 'r', encoding='utf-8') as f:
        tms_lines = [l.strip() for l in f]

    print('Reading in idf_dict cache...')
    with open(opt.idf_dict, 'rb') as f:
        non_default_idf_dict = pickle.load(f)
    idf_dict = collections.defaultdict(
        lambda x: non_default_idf_dict.pop('default_value'))
    idf_dict.update(non_default_idf_dict)

    print('Building search index...')
    search_index = SearchIndex([set(l.strip().split()) for l in tms_lines],
                               similarity_threshold=opt.sss_lambda,
                               similarity_func_name='containment_min')

    print('Building tokenizer...')
    tokenizer = AutoTokenizer.from_pretrained('roberta-large')

    print('Fetching model...')
    model = get_model('roberta-large', 17, False)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    res = {}
    i = 0
    print('Starting matching ...')
    for src_line in tqdm(src_lines, mininterval=1.0, ncols=50):
        cand_is = match_fuzzy(src_line, search_index, opt)
        if len(cand_is) == 0:
            res[i] = []
        else:
            cand_info = calc_bert_score(src_line,
                                        tms_lines,
                                        cand_is,
                                        model,
                                        tokenizer,
                                        opt,
                                        idf_dict=idf_dict)
            res[i] = cand_info

        i += 1

    print(f'Writing match file to {opt.output}...')
    fw = open(opt.output, 'w')
    for src_i in sorted(res):
        match_line = ' ||| '.join([f'{i} {v}' for v, i in res[src_i]])
        fw.write(f'{match_line}\n')
    fw.close()
Exemplo n.º 4
0
 def test_containment(self):
     sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
     # Threshold 0.1
     index = SearchIndex(sets, similarity_func_name="containment",
             similarity_threshold=0.1)
     results = index.query([3,5,4])
     correct_results = set([(1, 1.0), (0, 1.0/3.0), (2, 2.0/3.0),
         (3, 1.0/3.0)])
     self.assertEqual(set(results), correct_results)
     # Threshold 0.5
     index = SearchIndex(sets, similarity_func_name="containment",
             similarity_threshold=0.5)
     results = index.query([3,5,4])
     correct_results = set([(1, 1.0), (2, 2.0/3.0)])
     self.assertEqual(set(results), correct_results)
Exemplo n.º 5
0
def search_jaccard_topk(index_data, query_data, k):
    (index_sets, index_keys) = index_data
    (query_sets, query_keys) = query_data
    print("Building jaccard search index.")
    start = time.perf_counter()
    # Build the search index with the 0 threshold to index all tokens.
    index = SearchIndex(index_sets, similarity_func_name="jaccard",
            similarity_threshold=0.0)
    duration = time.perf_counter() - start
    print("Finished building index in {:.3f}.".format(duration)) 
    times = []
    results = []
    for query_set, query_key in zip(query_sets, query_keys):
        start = time.perf_counter()
        result = [[index_keys[i], similarity] 
                for i, similarity in _query_jaccard_topk(index, query_set, k)]
        duration = time.perf_counter() - start
        times.append(duration)
        results.append((query_key, result))
        sys.stdout.write("\rQueried {} sets.".format(len(results)))
    sys.stdout.write("\n")
    return (results, times)
Exemplo n.º 6
0
    def test_containment_max(self):
        # Query small in large
        sets = [[1,2,3], [3,4,5], [2,3,4], [5,6,7]]
        index = SearchIndex(
            sets,
            similarity_func_name="containment_max",
            similarity_threshold=0.4
        )
        results = index.query([1,2])
        correct_results = {(0, 1.0), (2, 0.5)}
        self.assertEqual(set(results), correct_results)

        # Query large in small
        sets = [[1,2], [3,4], [2,3,4,5], [6,7], [1,6,7]]
        index = SearchIndex(
            sets,
            similarity_func_name="containment_max",
            similarity_threshold=0.4
        )
        results = index.query([1,2,3,4])
        correct_results = {(0, 1.0), (1, 1.0), (2, 0.75)}
        self.assertEqual(set(results), correct_results)
Exemplo n.º 7
0
    if os.path.exists(query_data_cache):
        print("Using cached query sets {}".format(query_data_cache))
        with open(query_data_cache, "rb") as d:
            query_data = pickle.load(d)
    else:
        print("Using query sets {}".format(args.query_sets))
        query_data = bootstrap_sets(args.query_sets, 1.0, num_perms, skip=0)
        with open(query_data_cache, "wb") as d:
            pickle.dump(query_data, d)

    if not args.skip_ground_truth:
        rows = []
        # Build search index separately, only works for containment.
        print("Building search index...")
        index = SearchIndex(index_data[1],
                            similarity_func_name="containment",
                            similarity_threshold=0.1)
        for threshold in thresholds:
            index.similarity_threshold = threshold
            print("Running ground truth benchmark threshold = {}".format(
                threshold))
            ground_truth_results, ground_truth_times = \
                    benchmark_ground_truth(threshold, index, query_data)
            for t, r, query_set, query_key in zip(ground_truth_times,
                                                  ground_truth_results,
                                                  query_data[1],
                                                  query_data[2]):
                rows.append((query_key, len(query_set), threshold, t,
                             ",".join(str(k) for k in r)))
        df_groundtruth = pd.DataFrame.from_records(rows,
                                                   columns=[