Exemple #1
0
class UnsupervisedTransformerReranker(Reranker):
    methods = dict(max=lambda x: x.max().item(),
                   mean=lambda x: x.mean().item(),
                   absmean=lambda x: x.abs().mean().item(),
                   absmax=lambda x: x.abs().max().item())

    def __init__(self,
                 model: PreTrainedModel,
                 tokenizer: BatchTokenizer,
                 sim_matrix_provider: SimilarityMatrixProvider,
                 method: str = 'max',
                 clean_special: bool = True,
                 argmax_only: bool = False):
        assert method in self.methods, 'inappropriate scoring method'
        self.model = model
        self.tokenizer = tokenizer
        self.encoder = LongBatchEncoder(model, tokenizer)
        self.sim_matrix_provider = sim_matrix_provider
        self.method = method
        self.clean_special = clean_special
        self.cleaner = SpecialTokensCleaner(tokenizer.tokenizer)
        self.device = next(self.model.parameters(), None).device
        self.argmax_only = argmax_only

    @torch.no_grad()
    def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
        encoded_query = self.encoder.encode_single(query)
        encoded_documents = self.encoder.encode(texts)
        texts = deepcopy(texts)
        max_score = None
        for enc_doc, text in zip(encoded_documents, texts):
            if self.clean_special:
                enc_doc = self.cleaner.clean(enc_doc)
            matrix = self.sim_matrix_provider.compute_matrix(
                encoded_query, enc_doc)
            score = self.methods[self.method](matrix) if matrix.size(1) > 0 \
                else -10000
            text.score = score
            max_score = score if max_score is None else max(max_score, score)
        if self.argmax_only:
            for text in texts:
                if text.score != max_score:
                    text.score = max_score - 10000

        texts.sort(key=lambda x: x.score, reverse=True)

        return texts
Exemple #2
0
class UnsupervisedTransformerReranker(Reranker):
    methods = dict(max=lambda x: x.max().item(),
                   mean=lambda x: x.mean().item(),
                   absmean=lambda x: x.abs().mean().item(),
                   absmax=lambda x: x.abs().max().item())

    def __init__(self,
                 model: PreTrainedModel,
                 tokenizer: PreTrainedTokenizer,
                 sim_matrix_provider: SimilarityMatrixProvider,
                 method: str = 'max',
                 clean_special: bool = True,
                 argmax_only: bool = False):
        assert method in self.methods, 'inappropriate scoring method'
        self.model = model
        max_seq_length = self.model.config.max_position_embeddings
        self.tokenizer = tokenizer
        self.batch_encoder = LongBatchEncoder(model, tokenizer, max_seq_length=max_seq_length)
        self.sim_matrix_provider = sim_matrix_provider
        self.method = method
        self.clean_special = clean_special
        self.cleaner = SpecialTokensCleaner(tokenizer.tokenizer)
        self.device = next(self.model.parameters(), None).device
        self.argmax_only = argmax_only

        print(self.tokenizer.__class__)

    def split(self, documents: List[Text]) -> List[Text]:
        batch_size = 16
        senticizer = SpacySenticizer()
        for idx, doc in enumerate(documents):
            document_features = []
            # split document into sentences
            sentences = senticizer(doc.text)
            # encode sentences
            sentences_to_features = self.tokenizer.tokenizer.batch_encode_plus(sentences, max_len=128)
            # max num sentences per doc, max length sentence
            # for b in batch(sentences_to_features, batch_size):
            #     document_features.append((idx, b))
            yield sentences_to_features

    @torch.no_grad()
    def rerank(self, query: Query, documents: List[Text]) -> List[Text]:
        MIN_SCORE = -10_000

        query_features = list(self.split([query]))[0]
        encoded_query = self.batch_encoder.align(query_features['input_ids'])
        encoded_query = encoded_query[:, 0, :]

        result = []
        document_features = []
        for features in self.split(documents):
            output = self.batch_encoder.align(features['input_ids'])
            # output = output.unsqueeze(0) # add batch
            # print(output.shape)
            document_features.append(torch.squeeze(output[:, 0, :]))

        for b in batch(document_features, 128):
            batch_torch = torch.stack(b)
            #print(encoded_query.shape, batch_torch.shape)
            matrix = self.sim_matrix_provider.compute_matrix_v2(encoded_query, batch_torch)
            # batch_scores = self.methods[self.method](matrix)
            if matrix.size(1) == 1:
                result.append(matrix.squeeze().tolist())
            else:
                print(matrix.shape)
                result.extend(matrix.squeeze().tolist())
        # print(result)
        return result

        # import sys;sys.exit()
        encoded_documents = self.batch_encoder.encode(documents)
        documents = deepcopy(documents)
        max_score = None
        for enc_doc, text in zip(encoded_documents, documents):
            if self.clean_special:
                enc_doc = self.cleaner.clean(enc_doc)
                if enc_doc is None:
                    print('invalid enc_doc')
                    continue
                print('after:', enc_doc.shape, text.shape)

            matrix = self.sim_matrix_provider.compute_matrix(encoded_query, enc_doc)
            if matrix.size(1) > 0:
                score = self.methods[self.method](matrix)
            else:
                score = MIN_SCORE
            text.score = score
            max_score = score if max_score is None else max(max_score, score)
        if self.argmax_only:
            for text in documents:
                if text.score != max_score:
                    text.score = max_score - 10_000
        return documents