Example #1
0
    def test_tfidf(self):
        env = pyndri.TFIDFQueryEnvironment(self.index)

        self.assertEqual(env.query('ipsum'), ((1, 0.7098885466183784), ))

        self.assertEqual(env.query('his'),
                         ((2, 0.16955104430709383), (3, 0.07757942488345955)))
Example #2
0
    def __init__(self, env: str = 'default', verbose: bool = False, avg_len=False):
        if verbose:
            helpers.log(f'Loading index {INDRI_INDEX_DIR} with {env} query environment.')
        start = datetime.now()

        self.index = pyndri.Index(f'{INDRI_INDEX_DIR}')
        self.token2id, self.id2token, self.id2df = self.index.get_dictionary()
        self.id2tf = self.index.get_term_frequencies()

        if avg_len:
            # Monte Carlo Estimation for document length:
            doc_lengths = np.empty(self.index.document_count(), dtype=np.float)
            for (idx, doc_iid) in enumerate(range(self.index.document_base(), self.index.maximum_document())):
                doc_lengths[idx] = self.index.document_length(doc_iid)
            self.avg_doc_len = float(doc_lengths.mean())

        self.tokenizer = Tokenizer()

        if os.path.isfile(TITLE2WID):
            with open(TITLE2WID, 'rb') as file:
                self.title2wid = pickle.load(file)

        if os.path.isfile(WID2TITLE):
            with open(WID2TITLE, 'rb') as file:
                self.wid2title = pickle.load(file)
        try:
            if os.path.isfile(WID2INT):
                with open(WID2INT, 'rb') as file:
                    self.wid2int = pickle.load(file)

            if os.path.isfile(INT2WID):
                with open(INT2WID, 'rb') as file:
                    self.int2wid = pickle.load(file)
        except FileNotFoundError:
            helpers.log('ID mappings do not exist yet. Not loaded.')

        if env == 'default':
            self.env = pyndri.QueryEnvironment(self.index)
        elif env == 'tfidf':
            self.env = pyndri.TFIDFQueryEnvironment(self.index, k1=1.2, b=0.75)
        elif env == 'prf':
            env = pyndri.QueryEnvironment(self.index)
            self.env = pyndri.PRFQueryEnvironment(env, fb_docs=10, fb_terms=10)
        else:
            raise ValueError(f'Unknown environment configuration {env}')

        stop = datetime.now()
        if verbose:
            helpers.log(f'Loaded index in {stop - start}.')
    print('export {}LENGTH_MEAN={}'.format(prefix, mean))
    print('export {}LENGTH_MIN={}'.format(prefix, min_))
    print('export {}LENGTH_MAX={}'.format(prefix, max_))
    print('export {}LENGTH_STD={}'.format(prefix, std))
    print('export {}TOTAL_TERMS={}'.format(prefix, index.total_terms()))
    print('export {}UNIQUE_TERMS={}'.format(prefix, index.unique_terms()))

    with pyndri.open(sys.argv[1]) as index:
        # Constructs a QueryEnvironment that uses a
        # language model with Dirichlet smoothing.
        lm_query_env = pyndri.QueryEnvironment(
            index, rules=('method:dirichlet,mu:5000', ))
        print(
            lm_query_env.query('hello world',
                               results_requested=-5,
                               include_snippets=True))

        # Constructs a QueryEnvironment that uses the TF-IDF retrieval model.
        #
        # See "Baseline (non-LM) retrieval"
        # (https://lemurproject.org/doxygen/lemur/html/IndriRunQuery.html)
        tfidf_query_env = pyndri.TFIDFQueryEnvironment(index)
        print(tfidf_query_env.query('hello world'))

        # Constructs a QueryEnvironment that uses the Okapi BM25 retrieval model.
        #
        # See "Baseline (non-LM) retrieval"
        # (https://lemurproject.org/doxygen/lemur/html/IndriRunQuery.html)
        bm25_query_env = pyndri.OkapiQueryEnvironment(index)
        print(bm25_query_env.query('hello world'))
Example #4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--num_workers',
                        type=argparse_utils.positive_int, default=16)

    parser.add_argument('--topics', nargs='+',
                        type=argparse_utils.existing_file_path)

    parser.add_argument('model', type=argparse_utils.existing_file_path)

    parser.add_argument('--index', required=True)

    parser.add_argument('--linear', action='store_true', default=False)
    parser.add_argument('--self_information',
                        action='store_true',
                        default=False)
    parser.add_argument('--l2norm_phrase', action='store_true', default=False)

    parser.add_argument('--bias_coefficient',
                        type=argparse_utils.ratio,
                        default=0.0)

    parser.add_argument('--rerank_exact_matching_documents',
                        action='store_true',
                        default=False)

    parser.add_argument('--strict', action='store_true', default=False)

    parser.add_argument('--top_k', default=None)

    parser.add_argument('--num_queries',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('run_out')

    args = parser.parse_args()

    args.index = pyndri.Index(args.index)

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    if not args.top_k:
        args.top_k = 1000
    elif args.top_k == 'all':
        args.top_k = args.top_k = \
            args.index.maximum_document() - args.index.document_base()
    elif args.top_k.isdigit():
        args.top_k = int(args.top_k)
    elif all(map(os.path.exists, args.top_k.split())):
        topics_and_documents = {}

        for qrel_path in args.top_k.split():
            with open(qrel_path, 'r') as f_qrel:
                for topic_id, judgments in trec_utils.parse_qrel(f_qrel):
                    if topic_id not in topics_and_documents:
                        topics_and_documents[topic_id] = set()

                    for doc_id, _ in judgments:
                        topics_and_documents[topic_id].add(doc_id)

        args.top_k = topics_and_documents
    else:
        raise RuntimeError()

    logging.info('Loading dictionary.')
    dictionary = pyndri.extract_dictionary(args.index)

    logging.info('Loading model.')
    model_base, epoch_and_ext = args.model.rsplit('_', 1)
    epoch = int(epoch_and_ext.split('.')[0])

    if not os.path.exists('{}_meta'.format(model_base)):
        model_meta_base, batch_idx = model_base.rsplit('_', 1)
    else:
        model_meta_base = model_base

    kwargs = {
        'strict': args.strict,
    }

    if args.self_information:
        kwargs['self_information'] = True

    if args.linear:
        kwargs['bias_coefficient'] = args.bias_coefficient
        kwargs['nonlinearity'] = None

    if args.l2norm_phrase:
        kwargs['l2norm_phrase'] = True

    model = nvsm.load_model(
        nvsm.load_meta(model_meta_base),
        model_base, epoch, **kwargs)

    for topic_path in args.topics:
        run_out_path = '{}-{}'.format(
            args.run_out, os.path.basename(topic_path))

        if os.path.exists(run_out_path):
            logging.warning('Run for topics %s already exists (%s); skipping.',
                            topic_path, run_out_path)

            continue

        queries = list(pyndri.utils.parse_queries(
            args.index, dictionary, topic_path,
            strict=args.strict,
            num_queries=args.num_queries))

        if args.rerank_exact_matching_documents:
            assert not isinstance(args.top_k, dict)

            topics_and_documents = {}

            query_env = pyndri.TFIDFQueryEnvironment(args.index)

            for topic_id, topic_token_ids in queries:
                topics_and_documents[topic_id] = set()

                query_str = ' '.join(
                    dictionary[term_id] for term_id in topic_token_ids
                    if term_id is not None)

                for int_doc_id, score in query_env.query(
                        query_str, results_requested=1000):
                    topics_and_documents[topic_id].add(
                        args.index.ext_document_id(int_doc_id))

            args.top_k = topics_and_documents

        run = trec_utils.OnlineTRECRun(
            'cuNVSM', rank_cutoff=(
                args.top_k if isinstance(args.top_k, int)
                else sys.maxsize))

        rank_fn = RankFn(
            args.num_workers,
            args=args, model=model)

        for idx, (topic_id, topic_data) in enumerate(rank_fn(queries)):
            if topic_data is None:
                continue

            logging.info('Query %s (%d/%d)', topic_id, idx + 1, len(queries))

            (topic_repr,
             topic_scores_and_documents) = topic_data

            run.add_ranking(topic_id, topic_scores_and_documents)

            del topic_scores_and_documents

        run.close_and_write(run_out_path, overwrite=False)

        logging.info('Run outputted to %s.', run_out_path)

    del rank_fn