Ejemplo n.º 1
0
    def encode_bert(self,
                    query_tok,
                    query_mask,
                    doc_tok,
                    doc_mask,
                    customBert=None):
        BATCH, QLEN = query_tok.shape
        DIFF = 3  # = [CLS] and 2x[SEP]
        maxlen = self.bert.config.max_position_embeddings
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_mask[:, :1])
        NILS = torch.zeros_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([CLSS, doc_toks, SEPS, query_toks, SEPS], dim=1)
        mask = torch.cat([ONES, doc_mask, ONES, query_mask, ONES], dim=1)
        # segment_ids = torch.cat([NILS] * (2 + QLEN) + [ONES] * (1 + doc_toks.shape[1]), dim=1)
        segment_ids = torch.cat([NILS] * (2 + doc_toks.shape[1]) + [ONES] *
                                (1 + QLEN),
                                dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)

        # print(MAX_DOC_TOK_LEN, doc_tok.shape)

        # execute BERT model
        if not customBert:
            result = self.bert(toks, segment_ids.long(), mask)
        else:
            result = customBert(toks, segment_ids.long(), mask)

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:QLEN + 1] for r in result]
        doc_results = [r[:, QLEN + 2:-1] for r in result]
        doc_results = [
            modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN)
            for r in doc_results
        ]

        # build CLS representation
        cls_results = []
        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            cls_results.append(cls_result)

        return cls_results, query_results, doc_results
Ejemplo n.º 2
0
    def encode_bert(self, query_tok, doc_tok):
        # query_tok includes CLS token
        # doc_tok includes SEP token
        BATCH = query_tok.shape[0]
        QLEN = 20
        DIFF = 3
        maxlen = self.bert.config.max_position_embeddings # 512
        MAX_DLEN = maxlen - QLEN - DIFF # 489
        # 1(CLS) + 20(Q) + 1(SEP) + 489(D) + 1(SEP) = 512

        query_mask = torch.where(query_tok > 0, torch.ones_like(query_tok), torch.zeros_like(query_tok))
        doc_mask = torch.where(doc_tok > 0, torch.ones_like(doc_tok), torch.zeros_like(doc_tok))

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DLEN)
        doc_masks, _ = modeling_util.subbatch(doc_mask, MAX_DLEN)
        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_masks = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_masks[:, :1])
        NILS = torch.zeros_like(query_masks[:, :1])

        # build BERT input sequences
        toks = torch.cat([CLSS, query_toks, SEPS, doc_toks, SEPS], dim=1)
        segs = torch.cat([NILS] * (2 + QLEN) + [ONES] *(1 + doc_toks.shape[1]), dim=1)
        # segs = torch.cat([NILS] * (2 + QLEN) + [doc_masks, ONES], dim=1)
        masks = torch.cat([ONES, query_masks, ONES, doc_masks, ONES], dim=1)

        # execute BERT 
        result = self.bert(toks, segs.long(), masks)

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:QLEN+1] for r in result]   # (N, QLEN)
        doc_results = [r[:, QLEN+2:-1] for r in result] # (N, MAX_DLEN)
        doc_results = [modeling_util.un_subbatch(r, doc_tok, MAX_DLEN) for r in doc_results]

        # build CLS representation
        cls_results = []
        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i*BATCH:(i+1)*BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            cls_results.append(cls_result)
            
        return query_tok, doc_tok, cls_results, query_results, doc_results
Ejemplo n.º 3
0
    def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask):
        BATCH, QLEN = query_tok.shape
        DIFF = 3  # = [CLS] and 2x[SEP]
        maxlen = self.bert.config.max_position_embeddings
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1],
                               self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_mask[:, :1])
        NILS = torch.zeros_like(query_mask[:, :1])

        # build BERT input sequences query & doc
        q_toks = torch.cat([CLSS, query_toks, SEPS], dim=1)
        q_mask = torch.cat([ONES, query_mask, ONES], dim=1)
        q_segid = torch.cat([NILS] * (2 + QLEN), dim=1)
        q_toks[q_toks == -1] = 0

        d_toks = torch.cat([CLSS, doc_toks, SEPS], dim=1)
        d_mask = torch.cat([ONES, doc_mask, ONES], dim=1)
        d_segid = torch.cat([NILS] * (2 + doc_toks.shape[1]), dim=1)
        d_toks[d_toks == -1] = 0

        # execute BERT model
        q_result_tuple = self.bert(q_toks, q_mask, q_segid.long())
        d_result_tuple = self.bert(d_toks, d_mask, d_segid.long())
        q_result = q_result_tuple[2]
        d_result = d_result_tuple[2]

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:-1] for r in q_result]
        doc_results = [r[:, 1:-1] for r in d_result]
        doc_results = [
            modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN)
            for r in doc_results
        ]

        # build CLS representation
        q_cls_results = []
        for layer in q_result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            q_cls_results.append(cls_result)

        d_cls_results = []
        for layer in d_result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            d_cls_results.append(cls_result)

        return q_cls_results, d_cls_results, query_results, doc_results