def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask, customBert=None): BATCH, QLEN = query_tok.shape DIFF = 3 # = [CLS] and 2x[SEP] maxlen = self.bert.config.max_position_embeddings MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_mask = torch.cat([query_mask] * sbcount, dim=0) CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]']) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]']) ONES = torch.ones_like(query_mask[:, :1]) NILS = torch.zeros_like(query_mask[:, :1]) # build BERT input sequences toks = torch.cat([CLSS, doc_toks, SEPS, query_toks, SEPS], dim=1) mask = torch.cat([ONES, doc_mask, ONES, query_mask, ONES], dim=1) # segment_ids = torch.cat([NILS] * (2 + QLEN) + [ONES] * (1 + doc_toks.shape[1]), dim=1) segment_ids = torch.cat([NILS] * (2 + doc_toks.shape[1]) + [ONES] * (1 + QLEN), dim=1) toks[toks == -1] = 0 # remove padding (will be masked anyway) # print(MAX_DOC_TOK_LEN, doc_tok.shape) # execute BERT model if not customBert: result = self.bert(toks, segment_ids.long(), mask) else: result = customBert(toks, segment_ids.long(), mask) # extract relevant subsequences for query and doc query_results = [r[:BATCH, 1:QLEN + 1] for r in result] doc_results = [r[:, QLEN + 2:-1] for r in result] doc_results = [ modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN) for r in doc_results ] # build CLS representation cls_results = [] for layer in result: cls_output = layer[:, 0] cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) cls_results.append(cls_result) return cls_results, query_results, doc_results
def forward(self, query_tok, query_mask, doc_tok, doc_mask): BATCH, QLEN = query_tok.shape DIFF = 2 # 2x</s> maxlen = 512 MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_mask = torch.cat([query_mask] * sbcount, dim=0) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.eos_token_id) ONES = torch.ones_like(query_mask[:, :1]) # build BERT input sequences toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1) mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1) toks[toks == -1] = 0 # remove padding (will be masked anyway) # execute BERT model result = self.mbart(input_ids=toks, attention_mask=mask) cls_output = result.logits cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) return cls_result
def generate(self, query_tok, query_mask, doc_tok, doc_mask): BATCH, QLEN = query_tok.shape DIFF = 2 # 2x</s> maxlen = 512 MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_mask = torch.cat([query_mask] * sbcount, dim=0) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.get_vocab()['</s>']) ONES = torch.ones_like(query_mask[:, :1]) # build BERT input sequences toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1) mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1) toks[toks == -1] = 0 # remove padding (will be masked anyway) result = self.mt5.generate(input_ids=toks, attention_mask=mask, output_scores=True, return_dict_in_generate=True, max_length=2) cls_output = result.scores[0] cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) return cls_result
def encode_bert(self, query_tok, doc_tok): # query_tok includes CLS token # doc_tok includes SEP token BATCH = query_tok.shape[0] QLEN = 20 DIFF = 3 maxlen = self.bert.config.max_position_embeddings # 512 MAX_DLEN = maxlen - QLEN - DIFF # 489 # 1(CLS) + 20(Q) + 1(SEP) + 489(D) + 1(SEP) = 512 query_mask = torch.where(query_tok > 0, torch.ones_like(query_tok), torch.zeros_like(query_tok)) doc_mask = torch.where(doc_tok > 0, torch.ones_like(doc_tok), torch.zeros_like(doc_tok)) doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DLEN) doc_masks, _ = modeling_util.subbatch(doc_mask, MAX_DLEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_masks = torch.cat([query_mask] * sbcount, dim=0) CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]']) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]']) ONES = torch.ones_like(query_masks[:, :1]) NILS = torch.zeros_like(query_masks[:, :1]) # build BERT input sequences toks = torch.cat([CLSS, query_toks, SEPS, doc_toks, SEPS], dim=1) segs = torch.cat([NILS] * (2 + QLEN) + [ONES] *(1 + doc_toks.shape[1]), dim=1) # segs = torch.cat([NILS] * (2 + QLEN) + [doc_masks, ONES], dim=1) masks = torch.cat([ONES, query_masks, ONES, doc_masks, ONES], dim=1) # execute BERT result = self.bert(toks, segs.long(), masks) # extract relevant subsequences for query and doc query_results = [r[:BATCH, 1:QLEN+1] for r in result] # (N, QLEN) doc_results = [r[:, QLEN+2:-1] for r in result] # (N, MAX_DLEN) doc_results = [modeling_util.un_subbatch(r, doc_tok, MAX_DLEN) for r in doc_results] # build CLS representation cls_results = [] for layer in result: cls_output = layer[:, 0] cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i*BATCH:(i+1)*BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) cls_results.append(cls_result) return query_tok, doc_tok, cls_results, query_results, doc_results
def forward(self, query_tok, query_mask, doc_tok, doc_mask): BATCH, QLEN = query_tok.shape DIFF = 2 # 2x</s> maxlen = 512 MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_mask = torch.cat([query_mask] * sbcount, dim=0) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.eos_token_id) ONES = torch.ones_like(query_mask[:, :1]) # build BERT input sequences toks = torch.cat([query_toks, SEPS, doc_toks, SEPS], dim=1) mask = torch.cat([query_mask, ONES, doc_mask, ONES], dim=1) toks[toks == -1] = 0 # remove padding (will be masked anyway) decoder_input_ids = shift_tokens_right( toks, self.config.pad_token_id, self.config.decoder_start_token_id) # execute BERT model outputs = self.mt5(input_ids=toks, attention_mask=mask, decoder_input_ids=decoder_input_ids) hidden_states = outputs[0] # last hidden state eos_mask = toks.eq(self.config.eos_token_id) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError( "All examples must have the same number of <eos> tokens.") sentence_representation = hidden_states[eos_mask, :].view( hidden_states.size(0), -1, hidden_states.size(-1))[:, -1, :] logits = self.classification_head(sentence_representation) cls_output = logits cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) return cls_result
def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask): BATCH, QLEN = query_tok.shape DIFF = 3 # = [CLS] and 2x[SEP] maxlen = self.bert.config.max_position_embeddings MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) query_toks = torch.cat([query_tok] * sbcount, dim=0) query_mask = torch.cat([query_mask] * sbcount, dim=0) CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]']) SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]']) ONES = torch.ones_like(query_mask[:, :1]) NILS = torch.zeros_like(query_mask[:, :1]) # build BERT input sequences query & doc q_toks = torch.cat([CLSS, query_toks, SEPS], dim=1) q_mask = torch.cat([ONES, query_mask, ONES], dim=1) q_segid = torch.cat([NILS] * (2 + QLEN), dim=1) q_toks[q_toks == -1] = 0 d_toks = torch.cat([CLSS, doc_toks, SEPS], dim=1) d_mask = torch.cat([ONES, doc_mask, ONES], dim=1) d_segid = torch.cat([NILS] * (2 + doc_toks.shape[1]), dim=1) d_toks[d_toks == -1] = 0 # execute BERT model q_result_tuple = self.bert(q_toks, q_mask, q_segid.long()) d_result_tuple = self.bert(d_toks, d_mask, d_segid.long()) q_result = q_result_tuple[2] d_result = d_result_tuple[2] # extract relevant subsequences for query and doc query_results = [r[:BATCH, 1:-1] for r in q_result] doc_results = [r[:, 1:-1] for r in d_result] doc_results = [ modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN) for r in doc_results ] # build CLS representation q_cls_results = [] for layer in q_result: cls_output = layer[:, 0] cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) q_cls_results.append(cls_result) d_cls_results = [] for layer in d_result: cls_output = layer[:, 0] cls_result = [] for i in range(cls_output.shape[0] // BATCH): cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH]) cls_result = torch.stack(cls_result, dim=2).mean(dim=2) d_cls_results.append(cls_result) return q_cls_results, d_cls_results, query_results, doc_results