Пример #1
0
    def main(self, args):
        dp = DatasetPaths()
        if self.dataset_type == 'oc':
            corpus = Corpus.load_pkl(dp.get_pkl_path(self.dataset_type))
        else:
            corpus = Corpus.load(dp.get_db_path(self.dataset_type))

        authors = Counter()
        key_phrases = Counter()
        years = Counter()
        venues = Counter()
        num_docs_with_kp = 0

        in_citations_counts = []
        out_citations_counts = []
        for doc in corpus:
            authors.update(doc.authors)
            key_phrases.update(doc.key_phrases)
            if len(doc.key_phrases) > 0:
                num_docs_with_kp += 1
            in_citations_counts.append(doc.in_citation_count)
            out_citations_counts.append(doc.out_citation_count)
            years.update([doc.year])
            venues.update([doc.venue])

        training_years = [corpus[doc_id].year for doc_id in corpus.train_ids]
        validation_years = [corpus[doc_id].year for doc_id in corpus.valid_ids]
        testing_years = [corpus[doc_id].year for doc_id in corpus.test_ids]

        print("No. of documents = {}".format(len(corpus)))
        print("Unique number of authors = {}".format(len(authors)))
        print("Unique number of key phrases = {}".format(len(key_phrases)))
        print("Unique number of venues = {}".format(len(venues)))
        print("No. of docs with key phrases = {}".format(num_docs_with_kp))
        print("Average in citations = {} (+/- {})".format(np.mean(in_citations_counts),
                                                          np.std(in_citations_counts)))
        print("Average out citations = {} (+/- {})".format(np.mean(out_citations_counts),
                                                           np.std(out_citations_counts)))
        print("No. of training examples = {} ({} to {})".format(len(corpus.train_ids),
                                                                np.min(training_years),
                                                                np.max(training_years)))
        print("No. of validation examples = {} ({} to {})".format(len(corpus.valid_ids),
                                                                  np.min(validation_years),
                                                                  np.max(validation_years)))
        print("No. of testing examples = {} ({} to {})".format(len(corpus.test_ids),
                                                               np.min(testing_years),
                                                               np.max(testing_years)))
        print(authors.most_common(10))
Пример #2
0
    def main(self, args):
        dp = DatasetPaths()

        corpus_json = dp.get_json_path(self.dataset_name)
        index_location = dp.get_bm25_index_path(self.dataset_name)

        if os.path.exists(index_location):
            assert False
        else:
            os.mkdir(index_location)

        bm25_index = create_in(index_location, schema)
        writer = bm25_index.writer()

        for doc in tqdm.tqdm(file_util.read_json_lines(corpus_json)):
            writer.add_document(id=doc['id'],
                                title=doc['title'],
                                abstract=doc['abstract'])

        writer.commit()
Пример #3
0
def model_from_directory(dirname: str, on_cpu=False) -> Tuple[Featurizer, Any]:
    dp = DatasetPaths()

    options_json = file_util.read_json(
        os.path.join(dirname, dp.OPTIONS_FILENAME), )
    options = ModelOptions(**json.loads(options_json))

    featurizer_file_prefix = 'pretrained_' if options.use_pretrained else 'corpus_fit_'

    featurizer = file_util.read_pickle(
        os.path.join(dirname, featurizer_file_prefix +
                     dp.FEATURIZER_FILENAME))  # type: Featurizer

    options.n_authors = featurizer.n_authors
    options.n_features = featurizer.n_features
    options.n_venues = featurizer.n_venues
    options.n_keyphrases = featurizer.n_keyphrases
    create_model = import_from('citeomatic.models.%s' % options.model_name,
                               'create_model')
    if on_cpu:
        with tf.device('/cpu:0'):
            models = create_model(options)
    else:
        models = create_model(options)

    print("Loading model from %s " % dirname)
    print(models['citeomatic'].summary())
    if dirname.startswith('s3://'):
        models['citeomatic'].load_weights(
            file_util.cache_file(
                os.path.join(dirname, dp.CITEOMATIC_WEIGHTS_FILENAME)))
        models['embedding'].load_weights(
            file_util.cache_file(
                os.path.join(dirname, dp.EMBEDDING_WEIGHTS_FILENAME)))
    else:
        models['citeomatic'].load_weights(
            os.path.join(dirname, dp.CITEOMATIC_WEIGHTS_FILENAME))
        if models['embedding'] is not None:
            models['embedding'].load_weights(
                os.path.join(dirname, dp.EMBEDDING_WEIGHTS_FILENAME))
    return featurizer, models
    def main(self, args):
        logging.info("Reading Open Corpus file from: {}".format(
            self.input_path))
        logging.info("Writing json file to: {}".format(self.output_path))

        dp = DatasetPaths()

        assert os.path.exists(self.input_path)
        assert not os.path.exists(self.output_path)
        assert not os.path.exists(dp.get_pkl_path('oc'))

        with open(self.output_path, 'w') as f:
            for obj in tqdm.tqdm(file_util.read_json_lines(self.input_path)):
                if 'year' not in obj:
                    continue
                translated_obj = {
                    FieldNames.PAPER_ID:
                    obj['id'],
                    FieldNames.TITLE_RAW:
                    obj['title'],
                    FieldNames.ABSTRACT_RAW:
                    obj['paperAbstract'],
                    FieldNames.AUTHORS: [a['name'] for a in obj['authors']],
                    FieldNames.IN_CITATION_COUNT:
                    len(obj['inCitations']),
                    FieldNames.KEY_PHRASES:
                    obj['keyPhrases'],
                    FieldNames.OUT_CITATIONS:
                    obj['outCitations'],
                    FieldNames.URLS:
                    obj['pdfUrls'],
                    FieldNames.S2_URL:
                    obj['s2Url'],
                    FieldNames.VENUE:
                    obj['venue'],
                    FieldNames.YEAR:
                    obj['year'],
                    FieldNames.TITLE:
                    ' '.join(global_tokenizer(obj['title'])),
                    FieldNames.ABSTRACT:
                    ' '.join(global_tokenizer(obj['paperAbstract']))
                }
                f.write(json.dumps(translated_obj))
                f.write("\n")
        f.close()
        oc_corpus = Corpus.build(dp.get_db_path('oc'), dp.get_json_path('oc'))
        pickle.dump(oc_corpus, open(dp.get_pkl_path('oc')))
Пример #5
0
from citeomatic.common import DatasetPaths, FieldNames, global_tokenizer
from citeomatic.corpus import Corpus
import os
import json
from citeomatic import file_util
import pickle

input_path = '/data/split_opencorpus/1.json'
output_path = '/data/temp/1.json'
output_pkl_path = '/data/temp/1.pkl'

logging.info("Reading Open Corpus file from: {}".format(input_path))
logging.info("Writing json file to: {}".format(output_path))

dp = DatasetPaths()

assert os.path.exists(input_path)
assert not os.path.exists(output_path)
assert not os.path.exists(output_pkl_path)

s = 0
with open(output_path, 'w') as f:
    for obj in tqdm.tqdm(file_util.read_json_lines(input_path)):
        if 'year' not in obj:
            continue
        translated_obj = {
            FieldNames.PAPER_ID: obj['id'],
            FieldNames.TITLE_RAW: obj['title'],
            FieldNames.ABSTRACT_RAW: obj['paperAbstract'],
            FieldNames.AUTHORS: [a['name'] for a in obj['authors']],
Пример #6
0
from citeomatic.common import DatasetPaths, FieldNames, global_tokenizer
from citeomatic.corpus import Corpus
import os
import json
from citeomatic import file_util
import pickle

input_path = '/data/split_opencorpus/1.json'
output_path = '/data/split_stripped_opencorpus/1.json'
output_pkl_path = '/data/split_stripped_opencorpus/1.pkl'

logging.info("Reading Open Corpus file from: {}".format(input_path))
logging.info("Writing json file to: {}".format(output_path))

dp = DatasetPaths()

assert os.path.exists(input_path)
assert not os.path.exists(output_path)
assert not os.path.exists(output_pkl_path)

with open(output_path, 'w') as f:
    for obj in tqdm.tqdm(file_util.read_json_lines(input_path)):
        if 'year' not in obj:
            continue
        translated_obj = {
            FieldNames.PAPER_ID:
            obj['id'],
            FieldNames.TITLE_RAW:
            obj['title'],
            FieldNames.ABSTRACT_RAW:
    def main(self, args):

        if self.dataset_name == 'dblp':
            input_path = DatasetPaths.DBLP_GOLD_DIR
            output_path = DatasetPaths.DBLP_CORPUS_JSON
        elif self.dataset_name == 'pubmed':
            input_path = DatasetPaths.PUBMED_GOLD_DIR
            output_path = DatasetPaths.PUBMED_CORPUS_JSON
        else:
            assert False

        logging.info("Reading Gold data from {}".format(input_path))
        logging.info("Writing corpus to {}".format(output_path))
        assert os.path.exists(input_path)
        assert not os.path.exists(output_path)

        papers_file = os.path.join(input_path, "papers.txt")
        abstracts_file = os.path.join(input_path, "abstracts.txt")
        keyphrases_file = os.path.join(input_path, "paper_keyphrases.txt")
        citations_file = os.path.join(input_path, "paper_paper.txt")
        authors_file = os.path.join(input_path, "paper_author.txt")

        venues_file = os.path.join(input_path, "paper_venue.txt")

        paper_titles = {}
        paper_years = {}
        paper_abstracts = {}
        paper_keyphrases = {}
        paper_citations = {}
        paper_in_citations = {}
        paper_authors = {}
        paper_venues = {}

        bad_ids = set()
        for line in file_util.read_lines(abstracts_file):
            parts = line.split("\t")
            paper_id = int(parts[0])
            if len(parts) == 2:
                paper_abstracts[paper_id] = parts[1]
            else:
                paper_abstracts[paper_id] = ""

            if paper_abstracts[paper_id] == "":
                bad_ids.add(paper_id)

        for line in file_util.read_lines(papers_file):
            parts = line.split('\t')
            paper_id = int(parts[0])
            paper_years[paper_id] = int(parts[2])
            paper_titles[paper_id] = parts[3]

        for line in file_util.read_lines(keyphrases_file):
            parts = line.split("\t")
            paper_id = int(parts[0])
            if paper_id not in paper_keyphrases:
                paper_keyphrases[paper_id] = []

            for kp in parts[1:]:
                kp = kp.strip()
                if len(kp) > 0:
                    paper_keyphrases[paper_id].append(kp[:-4])

        for line in file_util.read_lines(citations_file):
            parts = line.split("\t")
            paper_id = int(parts[0])
            if paper_id not in paper_citations:
                paper_citations[paper_id] = []
            c = int(parts[1])
            if c in bad_ids:
                continue
            paper_citations[paper_id].append(str(c))

            if c not in paper_in_citations:
                paper_in_citations[c] = []
            if paper_id not in paper_in_citations:
                paper_in_citations[paper_id] = []

            paper_in_citations[c].append(paper_id)

        for line in file_util.read_lines(authors_file):
            parts = line.split("\t")
            paper_id = int(parts[0])
            if paper_id not in paper_authors:
                paper_authors[paper_id] = []

            paper_authors[paper_id].append(parts[1])

        for line in file_util.read_lines(venues_file):
            parts = line.split("\t")
            paper_id = int(parts[0])
            paper_venues[paper_id] = parts[1]

        test_paper_id = 13
        print("==== Test Paper Details ====")
        print(paper_titles[test_paper_id])
        print(paper_years[test_paper_id])
        print(paper_abstracts[test_paper_id])
        print(paper_keyphrases[test_paper_id])
        print(paper_citations[test_paper_id])
        print(paper_in_citations[test_paper_id])
        print(paper_authors[test_paper_id])
        print(paper_venues[test_paper_id])
        print("==== Test Paper Details ====")

        def _print_len(x, name=''):
            print("No. of {} = {}".format(name, len(x)))

        _print_len(paper_titles, 'Titles')
        _print_len(paper_years, 'Years')
        _print_len(paper_abstracts, 'Abstracts')
        _print_len(paper_keyphrases, 'KeyPhrases')
        _print_len(paper_citations, 'Paper Citations')
        _print_len(paper_in_citations, 'Paper In citations')
        _print_len(paper_authors, ' Authors')
        _print_len(paper_venues, ' Venues')

        logging.info("Skipped {} papers due to insufficient data".format(
            len(bad_ids)))

        corpus = {}
        for id, title in tqdm.tqdm(paper_titles.items()):
            if id in bad_ids:
                continue
            doc = document_from_dict({
                FieldNames.PAPER_ID:
                str(id),
                FieldNames.TITLE:
                ' '.join(global_tokenizer(title)),
                FieldNames.ABSTRACT:
                ' '.join(global_tokenizer(paper_abstracts[id])),
                FieldNames.OUT_CITATIONS:
                paper_citations.get(id, []),
                FieldNames.YEAR:
                paper_years[id],
                FieldNames.AUTHORS:
                paper_authors.get(id, []),
                FieldNames.KEY_PHRASES:
                paper_keyphrases[id],
                FieldNames.OUT_CITATION_COUNT:
                len(paper_citations.get(id, [])),
                FieldNames.IN_CITATION_COUNT:
                len(paper_in_citations.get(id, [])),
                FieldNames.VENUE:
                paper_venues.get(id, ''),
                FieldNames.TITLE_RAW:
                title,
                FieldNames.ABSTRACT_RAW:
                paper_abstracts[id]
            })
            corpus[id] = doc

        with open(output_path, 'w') as f:
            for _, doc in corpus.items():
                doc_json = dict_from_document(doc)
                f.write(json.dumps(doc_json))
                f.write("\n")

        dp = DatasetPaths()
        Corpus.build(dp.get_db_path(self.dataset_name),
                     dp.get_json_path(self.dataset_name))
Пример #8
0
    def main(self, args):
        dp = DatasetPaths()
        if self.dataset_type == 'oc':
            corpus = Corpus.load_pkl(dp.get_pkl_path(self.dataset_type))
        else:
            corpus = Corpus.load(dp.get_db_path(self.dataset_type))

        if self.ranker_type == 'none':
            citation_ranker = NoneRanker()
        elif self.ranker_type == 'neural':
            assert self.citation_ranker_dir is not None
            ranker_featurizer, ranker_models = model_from_directory(
                self.citation_ranker_dir, on_cpu=True)
            citation_ranker = Ranker(
                corpus=corpus,
                featurizer=ranker_featurizer,
                citation_ranker=ranker_models['citeomatic'],
                num_candidates_to_rank=100)
        else:
            assert False

        candidate_results_map = {}
        if self.num_candidates is None:
            if self.dataset_type == 'oc':
                num_candidates_list = [100]
            else:
                num_candidates_list = [1, 5, 10, 15, 25, 50, 75, 100]
        else:
            num_candidates_list = [self.num_candidates]

        for num_candidates in num_candidates_list:

            if self.candidate_selector_type == 'bm25':
                index_path = dp.get_bm25_index_path(self.dataset_type)
                candidate_selector = BM25CandidateSelector(
                    corpus, index_path, num_candidates, False)
            elif self.candidate_selector_type == 'ann':
                assert self.paper_embedder_dir is not None
                featurizer, models = model_from_directory(
                    self.paper_embedder_dir, on_cpu=True)
                candidate_selector = self._make_ann_candidate_selector(
                    corpus=corpus,
                    featurizer=featurizer,
                    embedding_model=models['embedding'],
                    num_candidates=num_candidates)
            elif self.candidate_selector_type == 'oracle':
                candidate_selector = OracleCandidateSelector(corpus)
            else:
                assert False

            results = eval_text_model(corpus,
                                      candidate_selector,
                                      citation_ranker,
                                      papers_source=self.split,
                                      n_eval=self.n_eval)
            candidate_results_map[num_candidates] = results

        best_k = -1
        best_metric = 0.0
        metric_key = self.metric + "_1"
        for k, v in candidate_results_map.items():
            if best_metric < v[metric_key][EVAL_DATASET_KEYS[
                    self.dataset_type]]:
                best_k = k
                best_metric = v[metric_key][EVAL_DATASET_KEYS[
                    self.dataset_type]]

        print(json.dumps(candidate_results_map, indent=4, sort_keys=True))
        print(best_k)
        print(best_metric)
Пример #9
0
import re
import resource

import numpy as np
import pandas as pd
import six
import tqdm
from keras.preprocessing.sequence import pad_sequences
from sklearn.feature_extraction.text import CountVectorizer

from citeomatic.candidate_selectors import CandidateSelector
from citeomatic.utils import flatten
from citeomatic.common import DatasetPaths
from citeomatic.models.options import ModelOptions

dp = DatasetPaths()

CLEAN_TEXT_RE = re.compile('[^ a-z]')

# filters for authors and docs
MAX_AUTHORS_PER_DOCUMENT = 8
MAX_KEYPHRASES_PER_DOCUMENT = 20
MIN_TRUE_CITATIONS = {
    'pubmed': 2,
    'dblp': 1,
    'oc': 2
}
MAX_TRUE_CITATIONS = 100

# Adjustments to how we boost heavily cited documents.
CITATION_SLOPE = 0.01
Пример #10
0
def end_to_end_training(model_options: ModelOptions,
                        dataset_type,
                        models_dir,
                        models_ann_dir=None):
    # step 1: make the directory
    if not os.path.exists(models_dir):
        os.makedirs(models_dir)

    # step 2: load the corpus DB
    print("Loading corpus db...")
    dp = DatasetPaths()
    db_file = dp.get_db_path(dataset_type)
    json_file = dp.get_json_path(dataset_type)
    if not os.path.isfile(db_file):
        print(
            "Have to build the database! This may take a while, but should only happen once."
        )
        Corpus.build(db_file, json_file)

    if dataset_type == 'oc':
        corpus = Corpus.load_pkl(dp.get_pkl_path(dataset_type))
    else:
        corpus = Corpus.load(db_file, model_options.train_frac)

    # step 3: load/make the featurizer (once per hyperopt run)
    print("Making feautrizer")
    featurizer_file_prefix = 'pretrained_' if model_options.use_pretrained else 'corpus_fit_'

    featurizer_file = os.path.join(
        models_dir, featurizer_file_prefix + dp.FEATURIZER_FILENAME)

    if os.path.isfile(featurizer_file):
        featurizer = file_util.read_pickle(featurizer_file)
    else:
        featurizer = Featurizer(
            max_features=model_options.max_features,
            max_title_len=model_options.max_title_len,
            max_abstract_len=model_options.max_abstract_len,
            use_pretrained=model_options.use_pretrained,
            min_author_papers=model_options.min_author_papers,
            min_venue_papers=model_options.min_venue_papers,
            min_keyphrase_papers=model_options.min_keyphrase_papers)
        featurizer.fit(corpus,
                       is_featurizer_for_test=model_options.train_for_test_set)
        file_util.write_pickle(featurizer_file, featurizer)

    # update model options after featurization
    model_options.n_authors = featurizer.n_authors
    model_options.n_venues = featurizer.n_venues
    model_options.n_keyphrases = featurizer.n_keyphrases
    model_options.n_features = featurizer.n_features
    if model_options.use_pretrained:
        model_options.dense_dim = model_options.dense_dim_pretrained

    # step 4: train the model
    citeomatic_model, embedding_model = train_text_model(
        corpus,
        featurizer,
        model_options,
        models_ann_dir=models_ann_dir,
        debug=True,
        tensorboard_dir=None)

    # step 5: save the model
    citeomatic_model.save_weights(os.path.join(models_dir,
                                               dp.CITEOMATIC_WEIGHTS_FILENAME),
                                  overwrite=True)

    if embedding_model is not None:
        embedding_model.save_weights(os.path.join(
            models_dir, dp.EMBEDDING_WEIGHTS_FILENAME),
                                     overwrite=True)

    file_util.write_json(
        os.path.join(models_dir, dp.OPTIONS_FILENAME),
        model_options.to_json(),
    )

    return corpus, featurizer, model_options, citeomatic_model, embedding_model
Пример #11
0
def train_text_model(
    corpus: Corpus,
    featurizer: Featurizer,
    model_options: ModelOptions,
    models_ann_dir=None,
    debug=False,
    tensorboard_dir=None,
):
    """
    Utility function for training citeomatic models.
    """

    # load pretrained embeddings
    if model_options.use_pretrained:
        dp = DatasetPaths()
        pretrained_embeddings_file = dp.embeddings_weights_for_corpus('shared')
        with h5py.File(pretrained_embeddings_file, 'r') as f:
            pretrained_embeddings = f['embedding'][...]
    else:
        pretrained_embeddings = None

    create_model = import_from(
        'citeomatic.models.%s' % model_options.model_name, 'create_model')
    models = create_model(model_options, pretrained_embeddings)
    model, embedding_model = models['citeomatic'], models['embedding']

    logging.info(model.summary())

    if model_options.train_for_test_set:
        paper_ids_for_training = corpus.train_ids + corpus.valid_ids
        candidates_for_training = corpus.train_ids + corpus.valid_ids + corpus.test_ids
    else:
        paper_ids_for_training = corpus.train_ids
        candidates_for_training = corpus.train_ids + corpus.valid_ids

    training_dg = DataGenerator(
        corpus=corpus,
        featurizer=featurizer,
        margin_multiplier=model_options.margin_multiplier,
        use_variable_margin=model_options.use_variable_margin)
    training_generator = training_dg.triplet_generator(
        paper_ids=paper_ids_for_training,
        candidate_ids=candidates_for_training,
        batch_size=model_options.batch_size,
        neg_to_pos_ratio=model_options.neg_to_pos_ratio)

    validation_dg = DataGenerator(
        corpus=corpus,
        featurizer=featurizer,
        margin_multiplier=model_options.margin_multiplier,
        use_variable_margin=model_options.use_variable_margin)
    validation_generator = validation_dg.triplet_generator(
        paper_ids=corpus.valid_ids,
        candidate_ids=corpus.train_ids + corpus.valid_ids,
        batch_size=1024,
        neg_to_pos_ratio=model_options.neg_to_pos_ratio)

    if model_options.optimizer == 'tfopt':
        optimizer = TFOptimizer(
            tf.contrib.opt.LazyAdamOptimizer(learning_rate=model_options.lr))
    else:
        optimizer = import_from('keras.optimizers',
                                model_options.optimizer)(lr=model_options.lr)

    model.compile(optimizer=optimizer, loss=layers.triplet_loss)

    # training calculation
    model_options.samples_per_epoch = int(
        np.minimum(model_options.samples_per_epoch,
                   model_options.total_samples))
    epochs = int(
        np.ceil(model_options.total_samples / model_options.samples_per_epoch))
    steps_per_epoch = int(model_options.samples_per_epoch /
                          model_options.batch_size)

    # callbacks
    callbacks_list = []
    if debug:
        callbacks_list.append(MemoryUsageCallback())
    if model_options.tb_dir is not None:
        callbacks_list.append(
            TensorBoard(log_dir=model_options.tb_dir,
                        histogram_freq=1,
                        write_graph=True))
    if model_options.reduce_lr_flag:
        if model_options.optimizer != 'tfopt':
            callbacks_list.append(
                ReduceLROnPlateau(verbose=1,
                                  patience=2,
                                  epsilon=0.01,
                                  min_lr=1e-6,
                                  factor=0.5))

    if models_ann_dir is None:
        ann_featurizer = featurizer
        paper_embedding_model = embedding_model
        embed_at_epoch_end = True
        embed_at_train_begin = False
    else:
        ann_featurizer, ann_models = model_from_directory(models_ann_dir,
                                                          on_cpu=True)
        paper_embedding_model = ann_models['embedding']
        paper_embedding_model._make_predict_function()
        embed_at_epoch_end = False
        embed_at_train_begin = True
    callbacks_list.append(
        UpdateANN(corpus, ann_featurizer, paper_embedding_model, training_dg,
                  validation_dg, embed_at_epoch_end, embed_at_train_begin))

    if model_options.tb_dir is None:
        validation_data = validation_generator
    else:
        validation_data = next(validation_generator)

    # logic
    model.fit_generator(generator=training_generator,
                        steps_per_epoch=steps_per_epoch,
                        epochs=epochs,
                        callbacks=callbacks_list,
                        validation_data=validation_generator,
                        validation_steps=10)

    return model, embedding_model
Пример #12
0
    def train_and_evaluate(self, eval_params):
        # Needed especially for hyperopt runs
        K.clear_session()

        model_kw = {
            name: getattr(self, name)
            for name in ModelOptions.class_traits().keys()
        }
        model_kw.update(eval_params)
        model_options = ModelOptions(**model_kw)

        if model_options.use_metadata:
            model_options.use_keyphrases = True
            model_options.use_authors = True
            model_options.use_venue = True

        print("====== OPTIONS =====")
        print(model_options)
        print("======")

        if model_options.train_for_test_set:
            logging.info(
                "\n\n============== TRAINING FOR TEST SET =============\n\n")

        training_outputs = end_to_end_training(model_options,
                                               self.dataset_type,
                                               self.models_dir,
                                               self.models_ann_dir)
        corpus, featurizer, model_options, citeomatic_model, embedding_model = training_outputs

        if self.candidate_selector_type == 'ann':
            # if no ann_dir is provided, then we use the model that was just trained
            # and have to rebuild the ANN
            if self.models_ann_dir is None:
                print(
                    'Using embedding model that was just trained for eval. Building...'
                )
                paper_embedding_model = EmbeddingModel(featurizer,
                                                       embedding_model)
                self.ann = ANN.build(paper_embedding_model, corpus)
            # if a dir is provided, then go ahead and load it
            else:
                featurizer_for_ann, ann_models = model_from_directory(
                    self.models_ann_dir, on_cpu=True)
                paper_embedding_model = EmbeddingModel(featurizer_for_ann,
                                                       ann_models['embedding'])
                # the ANN itself needs to be only built once
                if self.ann is None:
                    if corpus.corpus_type == 'oc' and os.path.exists(
                            DatasetPaths.OC_ANN_FILE + ".pickle"):
                        self.ann = ANN.load(DatasetPaths.OC_ANN_FILE)
                    else:
                        self.ann = ANN.build(paper_embedding_model, corpus)

            candidate_selector = ANNCandidateSelector(
                corpus=corpus,
                ann=self.ann,
                paper_embedding_model=paper_embedding_model,
                top_k=model_options.num_ann_nbrs_to_fetch,
                extend_candidate_citations=model_options.
                extend_candidate_citations)
        elif self.candidate_selector_type == 'bm25':
            dp = DatasetPaths()
            candidate_selector = BM25CandidateSelector(
                corpus=corpus,
                index_path=dp.get_bm25_index_path(self.dataset_type),
                top_k=model_options.num_ann_nbrs_to_fetch,
                extend_candidate_citations=model_options.
                extend_candidate_citations)
        else:
            # Should not come here. Adding this to make pycharm happy.
            assert False

        if self.citation_ranker_type == 'neural':
            ranker = Ranker(
                corpus=corpus,
                featurizer=featurizer,
                citation_ranker=citeomatic_model,
                num_candidates_to_rank=model_options.num_candidates_to_rank)
        elif self.citation_ranker_type == 'none':
            ranker = NoneRanker()
        else:
            # Should not come here. Adding this to make pycharm happy.
            assert False

        if self.mode != 'hyperopt' or model_options.total_samples == self.total_samples_secondary:
            results_training = eval_text_model(corpus,
                                               candidate_selector,
                                               ranker,
                                               papers_source='train',
                                               n_eval=self.n_eval)
        else:
            results_training = {}

        results_validation = eval_text_model(corpus,
                                             candidate_selector,
                                             ranker,
                                             papers_source='valid',
                                             n_eval=self.n_eval)

        logging.info("===== Validation Results ===== ")
        logging.info("Validation Precision\n\n{}".format(
            results_validation['precision_1']))
        logging.info("Validation Recall\n\n{}".format(
            results_validation['recall_1']))

        p = results_validation['precision_1'][EVAL_DATASET_KEYS[
            self.dataset_type]]
        r = results_validation['recall_1'][EVAL_DATASET_KEYS[
            self.dataset_type]]
        f1 = results_validation['f1_1'][EVAL_DATASET_KEYS[self.dataset_type]]

        if self.model_name == PAPER_EMBEDDING_MODEL:
            # optimizing for recall
            l = -r
        else:
            # optimizing for F1
            l = -f1

        out = {
            'loss': l,  # have to negate since we're minimizing
            'losses_training': results_training,
            'losses_validation': results_validation,
            'status': STATUS_FAIL if np.isnan(f1) else STATUS_OK,
            'params': eval_params
        }

        return out