def __init__(self,
                 contexts=None,
                 fill_context_embeddings=True,
                 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        super(LongQAModel, self).__init__()
        self.device = device
        self.c_model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
        self.c_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
        self.q_model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)
        self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
        self.r_model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base').to(device)
        self.r_tokenizer = DPRReaderTokenizerFast.from_pretrained('facebook/dpr-reader-single-nq-base')
        self.contexts = contexts
        # Not enough time to load context embeddings in AWS SageMaker,
        # but can fill weights from saved state dict after loading model.
        if not self.contexts:
            with open('code/contexts.json') as f:
                self.contexts = json.load(f)
#             output_features = self.c_model.ctx_encoder.bert_model.pooler.dense.out_features
#             self.context_embeddings = nn.Parameter(torch.zeros(len(self.contexts), output_features)).to(device)
#         else:
        context_embeddings = []
        with torch.no_grad():
           for context in self.contexts:
               input_ids = self.c_tokenizer(context, return_tensors='pt').to(device)["input_ids"]
               output = self.c_model(input_ids)
               context_embeddings.append(output.pooler_output)
        self.context_embeddings = nn.Parameter(torch.cat(context_embeddings, dim=0)).to(device)
        print('cwd!:', os.getcwd())
        print(os.listdir('code'))
        self.noise_remover = joblib.load('code/filter_model.sav')
예제 #2
0
 def __init__(self):
     self.tokenizer_q = DPRQuestionEncoderTokenizer.from_pretrained(
         'facebook/dpr-question_encoder-single-nq-base')
     self.model_q = DPRQuestionEncoder.from_pretrained(
         'facebook/dpr-question_encoder-single-nq-base')
     self.model_q.to(DEVICE)
     self.tokenizer_d = DPRContextEncoderTokenizer.from_pretrained(
         'facebook/dpr-ctx_encoder-single-nq-base')
     self.model_d = DPRContextEncoder.from_pretrained(
         'facebook/dpr-ctx_encoder-single-nq-base')
     self.model_d.to(DEVICE)
예제 #3
0
    def __init__(self):
        self.context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
            'facebook/dpr-ctx_encoder-single-nq-base')
        self.context_model = DPRContextEncoder.from_pretrained(
            'facebook/dpr-ctx_encoder-single-nq-base', return_dict=True)

        self.query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            'facebook/dpr-question_encoder-single-nq-base')
        self.query_encoder = DPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")

        self.reader_tokenizer = DPRReaderTokenizer.from_pretrained(
            'facebook/dpr-reader-single-nq-base')
        self.reader_model = DPRReader.from_pretrained(
            'facebook/dpr-reader-single-nq-base', return_dict=True)
        self.vector_length = 768
def download_model(outputdir_question_tokenizer: str,
                   outputdir_question_encoder: str,
                   outputdir_ctx_tokenizer: str, outputdir_ctx_encoder: str):
    q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base")
    print("Save question tokenizer to ", outputdir_question_tokenizer)
    q_tokenizer.save_pretrained(outputdir_question_tokenizer)

    q_encoder = DPRQuestionEncoder.from_pretrained(
        "facebook/dpr-question_encoder-single-nq-base")
    print("Save question encoder to ", outputdir_question_encoder)
    q_encoder.save_pretrained(outputdir_question_encoder)

    ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base")
    print("Save context tokenizer to ", outputdir_ctx_tokenizer)
    ctx_tokenizer.save_pretrained(outputdir_ctx_tokenizer)

    ctx_encoder = DPRContextEncoder.from_pretrained(
        "facebook/dpr-ctx_encoder-single-nq-base")
    print("Save context encoder to", outputdir_ctx_encoder)
    ctx_encoder.save_pretrained(outputdir_ctx_encoder)
예제 #5
0
model.load_state_dict(
    torch.load('Reader/weight_electra/weights_3.pth',
               map_location=torch.device('cpu')))
model.eval()
tokenizer = BertWordPieceTokenizer("Reader/electra_base_uncased/vocab.txt",
                                   lowercase=True)

torch.set_grad_enabled(False)
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
    "facebook/dpr-question_encoder-single-nq-base")
q_encoder = DPRQuestionEncoder.from_pretrained(
    "Retrieval/question_encoder").to(device=torch.device('cpu'))
q_encoder.eval()

# ctx_tokenizer = BertWordPieceTokenizer("ctx_tokenizer/vocab.txt", lowercase=True)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
    "facebook/dpr-ctx_encoder-single-nq-base")
ctx_encoder = DPRContextEncoder.from_pretrained("Retrieval/ctx_encoder").to(
    device=torch.device('cpu'))
ctx_encoder.eval()

app = Flask(__name__)


@app.route('/')
def home():
    return render_template('home.html')


@app.route('/', methods=['POST'])
def Answering():
    question = request.form['question']
예제 #6
0
        type=str,
        help=
        'directory that contains corpus files to be encoded, in jsonl format.',
        required=True)
    parser.add_argument('--index',
                        type=str,
                        help='directory to store brute force index of corpus',
                        required=True)
    parser.add_argument('--batch', type=int, help='batch size', default=8)
    parser.add_argument('--device',
                        type=str,
                        help='device cpu or cuda [cuda:0, cuda:1...]',
                        default='cuda:0')
    args = parser.parse_args()

    tokenizer = DPRContextEncoderTokenizer.from_pretrained(args.encoder)
    model = DPRContextEncoder.from_pretrained(args.encoder)
    model.to(args.device)

    index = faiss.IndexFlatIP(args.dimension)

    if not os.path.exists(args.index):
        os.mkdir(args.index)

    titles = []
    texts = []
    with open(os.path.join(args.index, 'docid'), 'w') as id_file:
        for file in sorted(os.listdir(args.corpus)):
            file = os.path.join(args.corpus, file)
            if file.endswith('json') or file.endswith('jsonl'):
                print(f'Encoding {file}')
예제 #7
0
    def load(cls,
             pretrained_model_name_or_path,
             tokenizer_class=None,
             use_fast=False,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        `pretrained_model_name_or_path` or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
            use the Python one (False).
            Only DistilBERT, BERT and Electra fast tokenizers are supported.
        :type use_fast: bool
        :param kwargs:
        :return: Tokenizer
        """

        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        # guess tokenizer type from name
        if tokenizer_class is None:
            if "albert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "AlbertTokenizer"
            elif "xlm-roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLMRobertaTokenizer"
            elif "roberta" in pretrained_model_name_or_path.lower():
                tokenizer_class = "RobertaTokenizer"
            elif 'codebert' in pretrained_model_name_or_path.lower():
                if "mlm" in pretrained_model_name_or_path.lower():
                    raise NotImplementedError(
                        "MLM part of codebert is currently not supported in FARM"
                    )
                else:
                    tokenizer_class = "RobertaTokenizer"
            elif "camembert" in pretrained_model_name_or_path.lower(
            ) or "umberto" in pretrained_model_name_or_path:
                tokenizer_class = "CamembertTokenizer"
            elif "distilbert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DistilBertTokenizer"
            elif "bert" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "xlnet" in pretrained_model_name_or_path.lower():
                tokenizer_class = "XLNetTokenizer"
            elif "electra" in pretrained_model_name_or_path.lower():
                tokenizer_class = "ElectraTokenizer"
            elif "word2vec" in pretrained_model_name_or_path.lower() or \
                    "glove" in pretrained_model_name_or_path.lower() or \
                    "fasttext" in pretrained_model_name_or_path.lower():
                tokenizer_class = "EmbeddingTokenizer"
            elif "minilm" in pretrained_model_name_or_path.lower():
                tokenizer_class = "BertTokenizer"
            elif "dpr-question_encoder" in pretrained_model_name_or_path.lower(
            ):
                tokenizer_class = "DPRQuestionEncoderTokenizer"
            elif "dpr-ctx_encoder" in pretrained_model_name_or_path.lower():
                tokenizer_class = "DPRContextEncoderTokenizer"
            else:
                raise ValueError(
                    f"Could not infer tokenizer_class from name '{pretrained_model_name_or_path}'. Set "
                    f"arg `tokenizer_class` in Tokenizer.load() to one of: AlbertTokenizer, "
                    f"XLMRobertaTokenizer, RobertaTokenizer, DistilBertTokenizer, BertTokenizer, or "
                    f"XLNetTokenizer.")
            logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        ret = None
        if tokenizer_class == "AlbertTokenizer":
            if use_fast:
                logger.error(
                    'AlbertTokenizerFast is not supported! Using AlbertTokenizer instead.'
                )
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif tokenizer_class == "XLMRobertaTokenizer":
            if use_fast:
                logger.error(
                    'XLMRobertaTokenizerFast is not supported! Using XLMRobertaTokenizer instead.'
                )
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "RobertaTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                logger.error(
                    'RobertaTokenizerFast is not supported! Using RobertaTokenizer instead.'
                )
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DistilBertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = DistilBertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DistilBertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "BertTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = BertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = BertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "XLNetTokenizer":
            if use_fast:
                logger.error(
                    'XLNetTokenizerFast is not supported! Using XLNetTokenizer instead.'
                )
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "ElectraTokenizer" in tokenizer_class:  # because it also might be fast tokekenizer we use "in"
            if use_fast:
                ret = ElectraTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = ElectraTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "EmbeddingTokenizer":
            if use_fast:
                logger.error(
                    'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.'
                )
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "CamembertTokenizer":
            if use_fast:
                logger.error(
                    'CamembertTokenizerFast is not supported! Using CamembertTokenizer instead.'
                )
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = CamembertTokenizer._from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRQuestionEncoderTokenizer" or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRQuestionEncoderTokenizerFast":
                ret = DPRQuestionEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRQuestionEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "DPRContextEncoderTokenizer" or tokenizer_class == "DPRContextEncoderTokenizerFast":
            if use_fast or tokenizer_class == "DPRContextEncoderTokenizerFast":
                ret = DPRContextEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRContextEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret
예제 #8
0
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from CustomDPRDataset import CustomDPRDataset
from tqdm import tqdm
import sys

from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, AdamW, get_linear_schedule_with_warmup

# initialize tokenizers and models for context encoder and question encoder
context_name = 'facebook/dpr-ctx_encoder-multiset-base'  # set to what context encoder we want to use
question_name = 'facebook/dpr-question_encoder-multiset-base'  # set to what question encoder we want to use
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_name)
context_model = DPRContextEncoder.from_pretrained(context_name).cuda()
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_name)
question_model = DPRQuestionEncoder.from_pretrained(question_name).cuda()

nll = nn.NLLLoss()
# question_model.half()
# context_model.half()

# params
batch_size = 256
grad_accum = 8
lr = 1e-5
text_descrip = "batchsize256_gradaccum8_v2"

print("intialized models/tokenizers")

# initialize dataset
예제 #9
0
    def load(cls,
             pretrained_model_name_or_path,
             revision=None,
             tokenizer_class=None,
             use_fast=True,
             **kwargs):
        """
        Enables loading of different Tokenizer classes with a uniform interface. Either infer the class from
        model config or define it manually via `tokenizer_class`.

        :param pretrained_model_name_or_path:  The path of the saved pretrained model or its name (e.g. `bert-base-uncased`)
        :type pretrained_model_name_or_path: str
        :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :type revision: str
        :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
        :type tokenizer_class: str
        :param use_fast: (Optional, False by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
            use the Python one (False).
            Only DistilBERT, BERT and Electra fast tokenizers are supported.
        :type use_fast: bool
        :param kwargs:
        :return: Tokenizer
        """
        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
        kwargs["revision"] = revision

        if tokenizer_class is None:
            tokenizer_class = cls._infer_tokenizer_class(
                pretrained_model_name_or_path)

        logger.info(f"Loading tokenizer of type '{tokenizer_class}'")
        # return appropriate tokenizer object
        ret = None
        if "AlbertTokenizer" in tokenizer_class:
            if use_fast:
                ret = AlbertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = AlbertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "XLMRobertaTokenizer" in tokenizer_class:
            if use_fast:
                ret = XLMRobertaTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = XLMRobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "RobertaTokenizer" in tokenizer_class:
            if use_fast:
                ret = RobertaTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = RobertaTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DistilBertTokenizer" in tokenizer_class:
            if use_fast:
                ret = DistilBertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DistilBertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "BertTokenizer" in tokenizer_class:
            if use_fast:
                ret = BertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = BertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "XLNetTokenizer" in tokenizer_class:
            if use_fast:
                ret = XLNetTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
            else:
                ret = XLNetTokenizer.from_pretrained(
                    pretrained_model_name_or_path, keep_accents=True, **kwargs)
        elif "ElectraTokenizer" in tokenizer_class:
            if use_fast:
                ret = ElectraTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = ElectraTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif tokenizer_class == "EmbeddingTokenizer":
            if use_fast:
                logger.error(
                    'EmbeddingTokenizerFast is not supported! Using EmbeddingTokenizer instead.'
                )
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = EmbeddingTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "CamembertTokenizer" in tokenizer_class:
            if use_fast:
                ret = CamembertTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = CamembertTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DPRQuestionEncoderTokenizer" in tokenizer_class:
            if use_fast:
                ret = DPRQuestionEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRQuestionEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        elif "DPRContextEncoderTokenizer" in tokenizer_class:
            if use_fast:
                ret = DPRContextEncoderTokenizerFast.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
            else:
                ret = DPRContextEncoderTokenizer.from_pretrained(
                    pretrained_model_name_or_path, **kwargs)
        if ret is None:
            raise Exception("Unable to load tokenizer")
        else:
            return ret
예제 #10
0
 def __init__(self, model_name, tokenizer_name=None, device='cuda:0'):
     self.device = device
     self.model = DPRContextEncoder.from_pretrained(model_name)
     self.model.to(self.device)
     self.tokenizer = DPRContextEncoderTokenizer.from_pretrained(
         tokenizer_name or model_name)
예제 #11
0
class DPRIndex(DocumentChunker):
    '''
    Class for indexing and searching documents, using a combination of
    vectors producted by DPR and keyword matching from Elastic TF-IDF. As a
    subclass of DocumentChunker, this class automatically handles document
    chunking as well.
    '''

    INDEX_NAME = 'dense-passage-retrieval'
    D = 768
    context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
        'facebook/dpr-ctx_encoder-single-nq-base')
    context_model = DPRContextEncoder.from_pretrained(
        'facebook/dpr-ctx_encoder-single-nq-base', return_dict=True)
    question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
        'facebook/dpr-question_encoder-single-nq-base')
    question_model = DPRQuestionEncoder.from_pretrained(
        'facebook/dpr-question_encoder-single-nq-base', return_dict=True)

    def __init__(self, documents: List[DPRDocument]):
        super(DocumentChunker).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device == 'cuda':
            self.reader_model = self.reader_model.cuda()
        self.faiss_index = faiss.IndexFlatIP(self.D)
        self._setup_elastic_index()
        self._build_index(documents)

    def _setup_elastic_index(self):
        '''Sets up the Elastic Index. Deletes old ones if needed.'''
        self.es = Elasticsearch()
        if self.es.indices.exists(self.INDEX_NAME):
            logging.warning(f'Deleting old index for {self.INDEX_NAME}.')
            self.es.indices.delete(self.INDEX_NAME)
        self.es.indices.create(index=self.INDEX_NAME)

    def _build_index(self, documents):
        '''
        Initializes the data structure to keep track of which chunks
        correspond to which documents.
        '''
        self.documents = documents
        self.doc_bodies = [doc.body for doc in self.documents]
        self.chunks = []
        self.chunk_index = {}  # {chunk: document}
        self.inverse_chunk_index = {}  # {document: [chunks]}
        chunk_counter = 0
        for doc_counter, doc_body in tqdm(enumerate(self.doc_bodies),
                                          total=len(self.doc_bodies)):
            self.inverse_chunk_index[doc_counter] = []
            chunked_docs = self.chunk_document(doc_body)
            self.chunks.extend(chunked_docs)
            for chunked_doc in chunked_docs:
                chunk_embedding = self.embed_context(chunked_doc)
                self.faiss_index.add(chunk_embedding)
                self.es.create(self.INDEX_NAME,
                               id=chunk_counter,
                               body={'chunk': chunked_doc})
                self.chunk_index[chunk_counter] = doc_counter
                self.inverse_chunk_index[doc_counter].append(chunk_counter)
                chunk_counter += 1
        self.total_docs = len(self.documents)
        self.total_chunks = len(self.chunks)

    def embed_question(self, question: str):
        '''Embed the question in vector space with the question encoder.'''
        input_ids = self.question_tokenizer(question,
                                            return_tensors='pt')['input_ids']
        embeddings = self.question_model(
            input_ids).pooler_output.detach().numpy()
        return embeddings

    def embed_context(self, context: str):
        '''Embed the context (doc) in vector space with the question encoder.'''
        input_ids = self.context_tokenizer(context,
                                           return_tensors='pt')['input_ids']
        embeddings = self.context_model(
            input_ids).pooler_output.detach().numpy()
        return embeddings

    def search_dense_index(self, question: str, k: int = 5):
        '''
        Search the vector index by encoding the question and then performing
        nearest neighbor on the FAISS index of context vectors.

        Args:
            question (str):
                The natural language question, e.g. `who is bill gates?`
            k (int):
                The number of documents to return from the index.
        '''
        if k > self.total_chunks:
            k = self.total_chunks
        question_embedding = self.embed_question(question)
        dists, chunk_ids = self.faiss_index.search(question_embedding, k=k)
        dists, chunk_ids = list(dists[0]), list(chunk_ids[0])
        dists = list(map(float, dists))  # For Flask
        structured_response = []
        for dist, chunk_id in zip(dists, chunk_ids):
            chunk = self.chunks[chunk_id]
            document_id = self.chunk_index[chunk_id]
            document = self.documents[document_id]
            blob = {
                'document': document,
                'document_id': document_id,
                'chunk': chunk,
                'chunk_id': int(chunk_id),  # For Flask
                'faiss_dist': dist
            }
            structured_response.append(blob)
        return structured_response

    def search_sparse_index(self, query):
        body = {'size': 10, 'query': {'match': {'chunk': query}}}
        results = self.es.search(index=self.INDEX_NAME, body=body)
        hits = results['hits']['hits']
        return hits

    def _merge_results(self, sparse_results, dense_results):
        '''Merges the results of sparse and dense retrieval.'''
        results_index = {}
        for sparse_result in sparse_results:
            id, score = sparse_result['_id'], sparse_result['_score']
            id = int(id)
            results_index[id] = {'elastic_score': score}
        for dense_result in dense_results:
            id, score = dense_result['chunk_id'], dense_result['faiss_dist']
            if id in results_index:
                results_index[id]['faiss_dist'] = score
            else:
                results_index[id] = {'faiss_dist': score}
        results = []
        for chunk_id, scores in results_index.items():
            document_id = self.chunk_index[chunk_id]
            document = self.documents[document_id]
            chunk = self.chunks[chunk_id]
            doc_profile = document.to_dict()
            result = {
                'chunk_id': chunk_id,
                'chunk': chunk,
                'document_id': document_id,
                'document': doc_profile,
                'scores': scores
            }
            results.append(result)
        return results

    def search_dual_index(self, query: str):
        '''Search both the sparse and dense indices and merge the results.'''
        sparse_result = self.search_sparse_index(query)
        dense_result = self.search_dense_index(query)
        merged_results = self._merge_results(sparse_result, dense_result)
        return merged_results