예제 #1
0
        for line in topic_f:
            info = json.loads(line)
            topic_ids.append(info['id'])
            topic_vectors.append(info['vector'])

    if not searcher:
        exit()

    # build output path
    output_path = args.output

    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = 'HNSW'

    # support trec and msmarco format only for now
    output_writer = get_output_writer(output_path, OutputFormat(args.output_format), max_hits=args.hits, tag=tag)

    search_time = 0
    with output_writer:
        batch_topic_vectors = list()
        batch_topic_ids = list()
        for index, (topic_id, vec) in enumerate(tqdm(zip(topic_ids, topic_vectors))):
            if args.batch_size <= 1 and args.threads <= 1:
                start = time.time()
                hits = searcher.search(vec, args.hits)
                search_time += time.time() - start
                results = [(topic_id, hits)]
            else:
                batch_topic_ids.append(str(topic_id))
                batch_topic_vectors.append(vec)
                if (index + 1) % args.batch_size == 0 or \
예제 #2
0
            n_str = f'prcl.n_{args.n}'
            a_str = f'prcl.alpha_{args.alpha}'
            clf_str = 'prcl_' + '+'.join(clf_rankers)
            tokens1 = ['run', args.topics, '+'.join(search_rankers)]
            tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str]
            output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt"
        else:
            tokens = ['run', args.topics, '+'.join(search_rankers), 'txt']
            output_path = '.'.join(tokens)

    print(f'Running {args.topics} topics, saving to {output_path}...')
    tag = output_path[:-4] if args.output is None else 'Anserini'

    output_writer = get_output_writer(
        output_path,
        OutputFormat(args.output_format),
        'w',
        max_hits=args.hits,
        tag=tag,
        topics=topics,
        use_max_passage=args.max_passage,
        max_passage_delimiter=args.max_passage_delimiter,
        max_passage_hits=args.max_passage_hits)

    with output_writer:
        batch_topics = list()
        batch_topic_ids = list()
        for index, (topic_id, text) in enumerate(
                tqdm(query_iterator, total=len(topics.keys()))):
            if (args.tokenizer != None):
                toks = tokenizer.tokenize(text)
예제 #3
0
    if not ssearcher:
        exit()

    set_bm25_parameters(ssearcher, args.sparse.index, args.sparse.k1, args.sparse.b)

    hsearcher = HybridSearcher(dsearcher, ssearcher)
    if not hsearcher:
        exit()

    # build output path
    output_path = args.run.output

    print(f'Running {args.run.topics} topics, saving to {output_path}...')
    tag = 'hybrid'

    output_writer = get_output_writer(output_path, OutputFormat(args.run.output_format), 'w',
                                      max_hits=args.run.hits, tag=tag, topics=topics,
                                      use_max_passage=args.run.max_passage,
                                      max_passage_delimiter=args.run.max_passage_delimiter,
                                      max_passage_hits=args.run.max_passage_hits)

    with output_writer:
        batch_topics = list()
        batch_topic_ids = list()
        for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
            if args.run.batch_size <= 1 and args.run.threads <= 1:
                hits = hsearcher.search(text, args.run.hits, args.fusion.alpha)
                results = [(topic_id, hits)]
            else:
                batch_topic_ids.append(str(topic_id))
                batch_topics.append(text)