Example #1
0
 def init_query_encoder(self, opt):
     if hasattr(self, 'query_encoder'):
         # It is already instantiated
         return
     self.query_encoder = DprQueryEncoder(
         opt, dpr_model=opt['query_model'], pretrained_path=opt['dpr_model_file']
     )
Example #2
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared=None):
     """
     Initialize DPR Retriever.
     """
     super().__init__(opt, dictionary, shared=shared)
     self.load_index(opt, shared)
     self.n_docs = opt['n_docs']
     self.query_encoder = DprQueryEncoder(
         opt, dpr_model=opt['query_model'], pretrained_path=opt['dpr_model_file']
     )
Example #3
0
 def __init__(
     self,
     opt: Opt,
     dictionary: DictionaryAgent,
     query_encoder: Optional[torch.nn.Module] = None,
     shared=None,
 ):
     super().__init__(opt, dictionary, shared)
     self.n_docs = opt['n_docs']
     if query_encoder is None:
         self.query_encoder = DprQueryEncoder(
             opt,
             dpr_model=opt['memory_reader_model'],
             pretrained_path=opt['dpr_model_file'],
         )
     else:
         self.query_encoder = query_encoder
     self.memory_encoder = DprDocumentEncoder(
         opt,
         dpr_model=opt['memory_writer_model'],
         pretrained_path=opt['memory_writer_model_file'],
     ).eval()
     self._tokenizer = RagRetrieverTokenizer(
         datapath=opt['datapath'],
         query_model=opt['query_model'],
         dictionary=dictionary,
         delimiter='\n',
         max_length=opt['memory_retriever_truncate']
         if opt['memory_retriever_truncate'] > 0 else
         opt['rag_query_truncate'],
     )
     self.max_memories = opt.get('max_memories', 100)
     self.num_memory_slots = opt.get('batchsize', 1) * opt.get(
         'rag_turn_n_turns', 1)
     self.memory_vec_dict: Dict[int, torch.LongTensor] = {  # type: ignore
         k: torch.zeros(self.max_memories,
                        opt['max_doc_token_length']).to(torch.int64)
         for k in range(self.num_memory_slots)
     }
     self.memory_enc_dict: Dict[int, torch.Tensor] = {
         k: torch.zeros(self.max_memories, opt['retriever_embedding_size'])
         for k in range(self.num_memory_slots)
     }
     self.active_memory_slots: List[int] = []
     self.dict = dictionary
Example #4
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared=None):
     """
     Initialize DPR Retriever.
     """
     super().__init__(opt, dictionary, shared=shared)
     if not shared:
         self.indexer = indexer_factory(opt)
         index_path = modelzoo_path(opt['datapath'], opt['path_to_index'])
         passages_path = modelzoo_path(opt['datapath'],
                                       opt['path_to_dpr_passages'])
         embeddings_path = None
         if opt['path_to_dense_embeddings'] is not None:
             embeddings_path = modelzoo_path(
                 opt['datapath'], opt['path_to_dense_embeddings'])
         self.indexer.deserialize_from(index_path, embeddings_path)
         self.passages = load_passages_dict(passages_path)
     elif shared:
         self.indexer = shared['indexer']
         self.passages = shared['passages']
     self.n_docs = opt['n_docs']
     self.query_encoder = DprQueryEncoder(
         opt,
         dpr_model=opt['query_model'],
         pretrained_path=opt['dpr_model_file'])
Example #5
0
    def test_load_dpr(self):
        opt = ParlaiParser(True, True).parse_args([])
        # First, we'll load up a DPR model from the zoo dpr file.
        default_query_encoder = DprQueryEncoder(opt,
                                                dpr_model='bert',
                                                pretrained_path=DPR_ZOO_MODEL)
        rag_sequence_query_encoder = DprQueryEncoder(
            opt,
            dpr_model='bert_from_parlai_rag',
            pretrained_path=RAG_SEQUENCE_ZOO_MODEL,
        )
        assert not torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
        )
        # 1. Create a zoo RAG Agent, which involves a trained DPR model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index': 'compressed',
                    'fp16': False
                },
            }))
        # The default rag token model should have different query encoders
        # from both the RAG_SEQUENCE_ZOO_MODEL, and the default DPR_ZOO_MODEL
        assert not torch.allclose(
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )
        assert not torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )

        # 2. create a RAG Agent with the rag_sequence_zoo_model DPR model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index':
                    'compressed',
                    'dpr_model_file':
                    modelzoo_path(opt['datapath'], RAG_SEQUENCE_ZOO_MODEL),
                    'query_model':
                    'bert_from_parlai_rag',
                    'fp16':
                    False,
                },
            }))
        # If we override the DPR Model file, we should now have the same
        # weights as the query encoder from above.
        assert torch.allclose(
            rag_sequence_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )

        # 3. Create a RAG Agent with the default DPR zoo model
        rag = create_agent(
            Opt({
                'model_file':
                modelzoo_path(opt['datapath'], RAG_TOKEN_ZOO_MODEL),
                'override': {
                    'retriever_debug_index': 'compressed',
                    'dpr_model_file': modelzoo_path(opt['datapath'],
                                                    DPR_ZOO_MODEL),
                    'fp16': False,
                },
            }))

        # This model was trained with the DPR_ZOO_MODEL, and yet now should have the same weights
        # as we explicitly specified it.
        assert torch.allclose(
            default_query_encoder.embeddings.weight.float().cpu(),
            rag.model.retriever.query_encoder.embeddings.weight.float().cpu(),
        )
Example #6
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, shared=None):
     super().__init__(opt, dictionary, shared)
     self.query_encoder = DprQueryEncoder(
         opt,
         dpr_model=opt['query_model'],
         pretrained_path=opt['dpr_model_file'])