Beispiel #1
0
    def encode_bert(self,
                    query_tok,
                    query_mask,
                    doc_tok,
                    doc_mask,
                    customBert=None):
        BATCH, QLEN = query_tok.shape
        DIFF = 3  # = [CLS] and 2x[SEP]
        maxlen = self.bert.config.max_position_embeddings
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_mask[:, :1])
        NILS = torch.zeros_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([CLSS, doc_toks, SEPS, query_toks, SEPS], dim=1)
        mask = torch.cat([ONES, doc_mask, ONES, query_mask, ONES], dim=1)
        # segment_ids = torch.cat([NILS] * (2 + QLEN) + [ONES] * (1 + doc_toks.shape[1]), dim=1)
        segment_ids = torch.cat([NILS] * (2 + doc_toks.shape[1]) + [ONES] *
                                (1 + QLEN),
                                dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)

        # print(MAX_DOC_TOK_LEN, doc_tok.shape)

        # execute BERT model
        if not customBert:
            result = self.bert(toks, segment_ids.long(), mask)
        else:
            result = customBert(toks, segment_ids.long(), mask)

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:QLEN + 1] for r in result]
        doc_results = [r[:, QLEN + 2:-1] for r in result]
        doc_results = [
            modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN)
            for r in doc_results
        ]

        # build CLS representation
        cls_results = []
        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            cls_results.append(cls_result)

        return cls_results, query_results, doc_results
Beispiel #2
0
    def forward(self, query_tok, query_mask, doc_tok, doc_mask):
        BATCH, QLEN = query_tok.shape
        DIFF = 2  # 2x</s>
        maxlen = 512
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.eos_token_id)
        ONES = torch.ones_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1)
        mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)

        # execute BERT model
        result = self.mbart(input_ids=toks, attention_mask=mask)

        cls_output = result.logits
        cls_result = []
        for i in range(cls_output.shape[0] // BATCH):
            cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
        cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
        return cls_result
Beispiel #3
0
    def generate(self, query_tok, query_mask, doc_tok, doc_mask):
        BATCH, QLEN = query_tok.shape
        DIFF = 2  # 2x</s>
        maxlen = 512
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        SEPS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.get_vocab()['</s>'])
        ONES = torch.ones_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1)
        mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)
        result = self.mt5.generate(input_ids=toks,
                                   attention_mask=mask,
                                   output_scores=True,
                                   return_dict_in_generate=True,
                                   max_length=2)
        cls_output = result.scores[0]
        cls_result = []
        for i in range(cls_output.shape[0] // BATCH):
            cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
        cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
        return cls_result
Beispiel #4
0
    def encode_bert(self, query_tok, doc_tok):
        # query_tok includes CLS token
        # doc_tok includes SEP token
        BATCH = query_tok.shape[0]
        QLEN = 20
        DIFF = 3
        maxlen = self.bert.config.max_position_embeddings # 512
        MAX_DLEN = maxlen - QLEN - DIFF # 489
        # 1(CLS) + 20(Q) + 1(SEP) + 489(D) + 1(SEP) = 512

        query_mask = torch.where(query_tok > 0, torch.ones_like(query_tok), torch.zeros_like(query_tok))
        doc_mask = torch.where(doc_tok > 0, torch.ones_like(doc_tok), torch.zeros_like(doc_tok))

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DLEN)
        doc_masks, _ = modeling_util.subbatch(doc_mask, MAX_DLEN)
        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_masks = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_masks[:, :1])
        NILS = torch.zeros_like(query_masks[:, :1])

        # build BERT input sequences
        toks = torch.cat([CLSS, query_toks, SEPS, doc_toks, SEPS], dim=1)
        segs = torch.cat([NILS] * (2 + QLEN) + [ONES] *(1 + doc_toks.shape[1]), dim=1)
        # segs = torch.cat([NILS] * (2 + QLEN) + [doc_masks, ONES], dim=1)
        masks = torch.cat([ONES, query_masks, ONES, doc_masks, ONES], dim=1)

        # execute BERT 
        result = self.bert(toks, segs.long(), masks)

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:QLEN+1] for r in result]   # (N, QLEN)
        doc_results = [r[:, QLEN+2:-1] for r in result] # (N, MAX_DLEN)
        doc_results = [modeling_util.un_subbatch(r, doc_tok, MAX_DLEN) for r in doc_results]

        # build CLS representation
        cls_results = []
        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i*BATCH:(i+1)*BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            cls_results.append(cls_result)
            
        return query_tok, doc_tok, cls_results, query_results, doc_results
Beispiel #5
0
    def forward(self, query_tok, query_mask, doc_tok, doc_mask):
        BATCH, QLEN = query_tok.shape
        DIFF = 2  # 2x</s>
        maxlen = 512
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.eos_token_id)
        ONES = torch.ones_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1)
        mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)

        decoder_input_ids = shift_tokens_right(
            toks, self.config.pad_token_id, self.config.decoder_start_token_id)
        # execute BERT model
        outputs = self.mt5(input_ids=toks,
                           attention_mask=mask,
                           decoder_input_ids=decoder_input_ids)
        hidden_states = outputs[0]  # last hidden state

        eos_mask = toks.eq(self.config.eos_token_id)

        if len(torch.unique(eos_mask.sum(1))) > 1:
            raise ValueError(
                "All examples must have the same number of <eos> tokens.")
        sentence_representation = hidden_states[eos_mask, :].view(
            hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :]
        logits = self.classification_head(sentence_representation)
        cls_output = logits
        cls_result = []
        for i in range(cls_output.shape[0] // BATCH):
            cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
        cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
        return cls_result
Beispiel #6
0
    def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask):
        BATCH, QLEN = query_tok.shape
        DIFF = 3  # = [CLS] and 2x[SEP]
        maxlen = self.bert.config.max_position_embeddings
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_mask[:, :1])
        NILS = torch.zeros_like(query_mask[:, :1])

        # build BERT input sequences query & doc
        q_toks = torch.cat([CLSS, query_toks, SEPS], dim=1)
        q_mask = torch.cat([ONES, query_mask, ONES], dim=1)
        q_segid = torch.cat([NILS] * (2 + QLEN), dim=1)
        q_toks[q_toks == -1] = 0

        d_toks = torch.cat([CLSS, doc_toks, SEPS], dim=1)
        d_mask = torch.cat([ONES, doc_mask, ONES], dim=1)
        d_segid = torch.cat([NILS] * (2 + doc_toks.shape[1]), dim=1)
        d_toks[d_toks == -1] = 0

        # execute BERT model
        q_result_tuple = self.bert(q_toks, q_mask, q_segid.long())
        d_result_tuple = self.bert(d_toks, d_mask, d_segid.long())
        q_result = q_result_tuple[2]
        d_result = d_result_tuple[2]

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:-1] for r in q_result]
        doc_results = [r[:, 1:-1] for r in d_result]
        doc_results = [
            modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN)
            for r in doc_results
        ]

        # build CLS representation
        q_cls_results = []
        for layer in q_result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            q_cls_results.append(cls_result)

        d_cls_results = []
        for layer in d_result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            d_cls_results.append(cls_result)

        return q_cls_results, d_cls_results, query_results, doc_results