Exemplo n.º 1
0
    def __init__(self):
        args = get_args()
        self.embedding_size = args.hidden_size
        self.faiss_use_gpu = args.faiss_use_gpu
        self.evidence_embedder_obj = None
        self.evidence_dataset = None
        self.mips_index = None
        self.eval_dataset = None

        # Get Evidence (Wikipedia) dataset
        self.get_evidence_dataset()

        # Load query encoder checkpoint
        only_query_model = True
        if args.biencoder_shared_query_context_model:
            only_query_model = False

        model = get_model(lambda: biencoder_model_provider(only_query_model=\
            only_query_model, biencoder_shared_query_context_model=\
            args.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(
            model, only_query_model=only_query_model)

        assert len(self.model) == 1
        self.model[0].eval()

        # Load faiss indexer
        self.faiss_wrapper()
Exemplo n.º 2
0
def pretrain_ict_model_provider():
    args = get_args()
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model)
    return model
Exemplo n.º 3
0
def pretrain_ict_model_provider(pre_process=True, post_process=True):
    args = get_args()

    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
                args.biencoder_shared_query_context_model,
                pre_process=pre_process, post_process=post_process)

    return model
Exemplo n.º 4
0
    def model_provider(pre_process=True, post_process=True):
        """Build the model."""
        args = get_args()
        print_rank_0('building retriever model for {} ...'.format(args.task))

        model = biencoder_model_provider(only_context_model=False,
                    only_query_model=False,
                    biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model,
                    pre_process=pre_process, post_process=post_process)

        return model
Exemplo n.º 5
0
    def load_attributes(self):
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

        model = get_model(lambda: biencoder_model_provider(only_context_model \
            = only_context_model, biencoder_shared_query_context_model = \
            self.biencoder_shared_query_context_model))

        self.model = load_biencoder_checkpoint(
            model, only_context_model=only_context_model)

        assert len(self.model) == 1
        self.model[0].eval()

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)