def test_concat_enc_outs(self): enc_output = torch.rand(self.bsz * self.n_docs, self.seqlen, self.esz) enc_input, mask = self._create_input_and_mask() # Right padded mask = mask.repeat_interleave(self.n_docs, dim=0) _, new_mask = concat_enc_outs(enc_input, enc_output, mask, self.esz, self.pad_idx) ######################################################################## # Assertion: new mask has `True` elements in first (n_docs * seqlen_i) # # tokens in concatenated output # ######################################################################## assert all(new_mask[i, :self.batch_lens[i] * self.n_docs].sum() == self.n_docs * self.batch_lens[i] for i in range(self.bsz)) # Left padded enc_input, mask = self._create_input_and_mask(right_padded=False) mask = mask.repeat_interleave(self.n_docs, dim=0) _, new_mask = concat_enc_outs(enc_input, enc_output, mask, self.esz, self.pad_idx, right_padded=False) ####################################################################### # Assertion: new mask has `True` elements in last (n_docs * seqlen_i) # # tokens in concatenated output # ####################################################################### assert all(new_mask[i, -(self.batch_lens[i] * self.n_docs):].sum() == self.n_docs * self.batch_lens[i] for i in range(self.bsz))
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, memory_vec: torch.LongTensor, num_memories: torch.LongTensor, query_generator_vec: torch.LongTensor, gold_doc_vec: torch.LongTensor, gold_doc_title_vec: torch.LongTensor, num_gold_docs: torch.LongTensor, memory_decoder_vec: torch.LongTensor, num_memory_decoder_vecs: torch.LongTensor, positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder( # type: ignore input, input_lengths, query_vec, input_turns_cnt, memory_vec, num_memories, query_generator_vec, gold_doc_vec, gold_doc_title_vec, num_gold_docs, memory_decoder_vec, num_memory_decoder_vecs, positions, segments, ) # type: ignore if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore new_out, new_mask = concat_enc_outs( input, enc_out, mask, self.embedding_size, self.pad_idx ) return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, target_lengths: Optional[torch.LongTensor], positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: """ Override FidModel.encoder to pack all the documents into one input example. :param input: 2D [bsz, seqlen] input to the encoder :param input_lengths: 1D [bsz] lengths of each input item :param query_vec: 2D [bsz*n_turns, seqlen] input for the retriever :param input_turns_cnt: 1D [bsz] number of dialogue turns for each input example :param input_lengths: 1D [bsz] lengths of each target item (for each input item) :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores): encoder_out: *concatenated* encoded representations of context/document pairs encoder_mask: new mask for enc_out input_turns_cnt: pass along the input turns count for the decoder top_docs: List of top Documents for each batch example top_doc_scores: scores for each retrieved document. """ enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder( self, input, input_lengths, query_vec, input_turns_cnt, positions, segments ) # type: ignore seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0) if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore doc_starts = (enc_out == self.pad_idx).sum(dim=1) # left padded doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts # if no padding, docs are assumed to be min doc length long doc_lens[doc_lens.le(0)] = self.min_doc_len new_enc_out = enc_out.clone() # BEFORE: # [ # pad...doc_0 / in_0 # pad...doc_1 / in_0 # ... # pad...doc_n / in_m # ] total_length_i = 0 for i, doc_len in enumerate(doc_lens): if i % n_docs == 0: total_length_i = 0 # max doc length is determined by how much space we have after subtracting input length input_and_target = input_lengths[i // n_docs] + ( target_lengths[i // n_docs] if target_lengths is not None else 0 ) max_doc_len = torch.div( (seq_len - input_and_target), n_docs, rounding_mode='floor' ) max_doc_len = max(max_doc_len, self.min_doc_len) # type: ignore doc_len = min(doc_len, max_doc_len) total_length_i += doc_len if i % n_docs == n_docs - 1: # keep the actual input context when processing the last doc. clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i ) if target_lengths is not None: clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i - target_lengths ) pad_end = seq_len - clamped_input_length - doc_len else: pad_end = seq_len - doc_len # we simply move the portion of the doc we want to keep to the end of the tensor, # and mask out everything else. mask[i, :pad_end] = False new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[ i, doc_starts[i] : doc_starts[i] + doc_len ] new_out, new_mask = concat_enc_outs( input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False ) # After: # [ # doc_0 / doc_1 / doc_2 ... in_0 # doc_0 / doc_1 / doc_2 ... in_m # ] return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores