def get_next_decoder_input( self, prev_input: torch.LongTensor, selection: torch.LongTensor, incr_state_inds: torch.LongTensor, ) -> torch.LongTensor: """ Repeat decoder input accordingly for RAG Token. RAG Turn marginalizes over all the documents within turns; this means that during generation, although the input will be of length [n_docs*n_turns*bsz], the output will be of length [n_turns*bsz]. Thus, we need to make sure that we repeat the beam selection n_docs times and continue to marginalize over documents accordingly. :param prev_input: previous [bsz*n_docs*n_turns, seqlen] input to the decoder. :param selection: [bsz*n_turns]-length beam selection ids for the decoder :param incr_state_inds: beam indices that are continuing in next generation. :return dec_input: return the decoder input with appropriate repeated selections. """ prev_input = _unstack_ctxt(prev_input, self.n_docs).index_select( 0, incr_state_inds ) # type: ignore dec_input = torch.cat( [prev_input, selection.repeat_interleave(self.n_docs, 1).unsqueeze(-1)], dim=-1, ) dec_input = _stack_ctxt(dec_input) return dec_input # type: ignore
def corrupt_batch( self, positive_batch: torch.LongTensor ) -> torch.LongTensor: # noqa: D102 if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat_interleave( repeats=self.num_negs_per_pos, dim=0) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() device = positive_batch.device # Decide whether to corrupt head or tail head_corruption_probability = self.corrupt_head_probability[ positive_batch[:, 1]] head_mask = torch.rand( num_negs, device=device) < head_corruption_probability.to(device=device) # Tails are corrupted if heads are not corrupted tail_mask = ~head_mask # We at least make sure to not replace the triples by the original value # See below for explanation of why this is on a range of [0, num_entities - 1] index_max = self.num_entities - 1 # Randomly sample corruption. negative_entities = torch.randint( index_max, size=(num_negs, ), device=positive_batch.device, ) # Replace heads negative_batch[head_mask, 0] = negative_entities[head_mask] # Replace tails negative_batch[tail_mask, 2] = negative_entities[tail_mask] # To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] negative_batch[head_mask, 0] += (negative_batch[head_mask, 0] >= positive_batch[head_mask, 0]).long() negative_batch[tail_mask, 2] += (negative_batch[tail_mask, 2] >= positive_batch[tail_mask, 2]).long() return negative_batch.view(-1, self.num_negs_per_pos, 3)
def retrieve_and_concat( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, ) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]: """ Retrieve documents, concat with input. :param input: 2D [bsz, seqlen] input to the encoder :param input_lengths: 1D [bsz] lengths of each input item :param query_vec: 2D [bsz*n_turns, seqlen] input for the retriever :param input_turns_cnt: 1D [bsz] number of dialogue turns for each input example :return (expanded_input, top_docs, top_doc_scores): expanded_input: [bsz * n_docs, seqlen+doc_len] tensor of context/document inputs top_docs: List of top documents for each input top_doc_scores: document scores for each document """ # 1. Retrieve top_docs, top_doc_scores = self.retriever.retrieve(query_vec) # 2. Expand the input if input_turns_cnt is not None: input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore input_lengths = input_lengths.repeat_interleave( input_turns_cnt, dim=0) # type: ignore expanded_input = self.concat_docs_and_input(input, input_lengths, top_docs, top_doc_scores.size(1)) return expanded_input, top_docs, top_doc_scores
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, memory_vec: torch.LongTensor, num_memories: torch.LongTensor, query_generator_vec: torch.LongTensor, gold_doc_vec: torch.LongTensor, gold_doc_title_vec: torch.LongTensor, num_gold_docs: torch.LongTensor, memory_decoder_vec: torch.LongTensor, num_memory_decoder_vecs: torch.LongTensor, positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder( # type: ignore input, input_lengths, query_vec, input_turns_cnt, memory_vec, num_memories, query_generator_vec, gold_doc_vec, gold_doc_title_vec, num_gold_docs, memory_decoder_vec, num_memory_decoder_vecs, positions, segments, ) # type: ignore if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore new_out, new_mask = concat_enc_outs( input, enc_out, mask, self.embedding_size, self.pad_idx ) return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: """ Concatenate all encoder outputs in model forward. :param input: 2D [bsz, seqlen] input to the encoder :param input_lengths: 1D [bsz] lengths of each input item :param query_vec: 2D [bsz*n_turns, seqlen] input for the retriever :param input_turns_cnt: 1D [bsz] number of dialogue turns for each input example :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores): encoder_out: *concatenated* encoded representations of context/document pairs encoder_mask: new mask for enc_out input_turns_cnt: pass along the input turns count for the decoder top_docs: List of top Documents for each batch example top_doc_scores: scores for each retrieved document. """ enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = super().encoder( input, input_lengths, query_vec, input_turns_cnt, positions, segments ) # type: ignore if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore new_out, new_mask = concat_enc_outs( input, enc_out, mask, self.embedding_size, self.pad_idx ) return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
def corrupt_batch( self, positive_batch: torch.LongTensor ) -> torch.LongTensor: # noqa: D102 if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat_interleave( repeats=self.num_negs_per_pos, dim=0) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Equally corrupt all sides split_idx = int(math.ceil(num_negs / len(self._corruption_indices))) # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() for index, start in zip(self._corruption_indices, range(0, num_negs, split_idx)): stop = min(start + split_idx, num_negs) # Relations have a different index maximum than entities # At least make sure to not replace the triples by the original value index_max = (self.num_relations if index == 1 else self.num_entities) - 1 negative_batch[start:stop, index] = torch.randint( high=index_max, size=(stop - start, ), device=positive_batch.device, ) # To make sure we don't replace the {head, relation, tail} by the # original value we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_{heads, relations, tails} -1] negative_batch[start:stop, index] += (negative_batch[start:stop, index] >= positive_batch[start:stop, index]).long() return negative_batch.view(-1, self.num_negs_per_pos, 3)
def get_initial_forced_decoder_input( self, bsz: int, inputs: torch.LongTensor, n_docs: int, start_idx: int, end_idx: int, input_turns_cnt: Optional[torch.LongTensor] = None, ) -> torch.LongTensor: """ Return the initial input to the decoder during training. Repeat inputs n_docs * n_turns times. :param bsz: batchsize :param inputs: inputs to decode :param n_docs: number of docs per input :param start_idx: start token idx :param end_idx: end token idx :param input_turns_cnt: an optional tensor containing the number of turns of each corresponding context. :return initial_input: initial input for the decoder. """ if input_turns_cnt is not None: inputs = inputs.repeat_interleave(input_turns_cnt, dim=0) # type: ignore bsz = input_turns_cnt.sum() # type: ignore inputs = get_forced_decoder_inputs(inputs, bsz, start_idx, end_idx, self.generation_model) inputs = inputs.repeat(1, n_docs).reshape(-1, inputs.size(1)) # type: ignore return inputs
def retrieve_and_concat( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_generator_vec: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, memory_vec: torch.LongTensor, num_memories: torch.LongTensor, gold_doc_vec: torch.LongTensor, gold_doc_title_vec: torch.LongTensor, num_gold_docs: torch.LongTensor, memory_decoder_vec: torch.LongTensor, num_memory_decoder_vecs: torch.LongTensor, skip_search: torch.BoolTensor, ) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]: """ Override RagModel.retrieve_and_concat to perform different retrieval, depending on the RetrieverType. """ self.flush_previous_retriever_search_results() start = time.time() logging.debug(f'Begin encoder: {time.time() - start:.2f}') if input_turns_cnt is not None: if query_generator_vec is not None: query_generator_vec = query_generator_vec.repeat_interleave( input_turns_cnt, dim=0) # type: ignore if memory_vec is not None: memory_vec = memory_vec.repeat_interleave( input_turns_cnt, dim=0) # type: ignore if num_memories is not None: num_memories = num_memories.repeat_interleave( input_turns_cnt, dim=0) # type: ignore if memory_decoder_vec is not None: memory_decoder_vec = memory_decoder_vec.repeat_interleave( input_turns_cnt, dim=0) # type: ignore if num_memory_decoder_vecs is not None: num_memory_decoder_vecs = num_memory_decoder_vecs.repeat_interleave( input_turns_cnt, dim=0) # type: ignore n_input = (input_turns_cnt.sum().item() if input_turns_cnt is not None else input.size(0)) # 0a. Classify retrieval type, if necessary generated_memories = [[] for _ in range(int(n_input))] if memory_decoder_vec is not None: generated_memories = self.memory_decoder.generate_memories( memory_decoder_vec, num_memory_decoder_vecs) if self.should_generate_query: assert self.has_query_generator() retrieval_type, search_queries = self.query_generator.classify_retrieval( query_generator_vec, num_memories, generated_memories, skip_search) logging.debug(f'Classify Retrieval: {time.time() - start:.2f}') else: retrieval_type = torch.LongTensor(input.size(0)) search_queries = None # 1. Retrieve top_docs: List[List[Document]] = [[] for _ in range(int(n_input))] doc_scores: List[List[torch.Tensor]] = [[] for _ in range(int(n_input))] # 1a. retrieve from faiss or search search_indices = self.get_retrieval_indices(retrieval_type, RetrievalType.SEARCH) if search_indices.numel() > 0: search_docs, search_doc_scores = self.perform_search( search_queries, query_vec, search_indices) logging.debug(f'Search Complete: {time.time() - start:.2f}') logging.debug(f'search: {search_docs}') if gold_doc_vec is not None: logging.debug(f'num gold docs: {num_gold_docs}') self._fill_docs_and_scores( top_docs, doc_scores, search_indices, search_docs, search_doc_scores, gold_doc_vec, gold_doc_title_vec, num_gold_docs, ) # 1b. memory search memory_indices = self.get_retrieval_indices(retrieval_type, RetrievalType.MEMORY) if memory_indices.numel() > 0: memories, memory_scores = self.access_long_term_memory( query_vec, memory_indices, memory_vec, num_memories, memory_decoder_vec, generated_memories, ) logging.debug(f'Memory Access Complete: {time.time() - start:.2f}') if memories is not None and memory_scores is not None: self._fill_docs_and_scores(top_docs, doc_scores, memory_indices, memories, memory_scores) # 1c. no search no_search_indices = self.get_retrieval_indices(retrieval_type, RetrievalType.NONE) if no_search_indices.numel() > 0: dummy_docs, dummy_scores = self.dummy_retriever.retrieve( query_vec[no_search_indices] # type: ignore ) logging.debug('no search') self._fill_docs_and_scores(top_docs, doc_scores, no_search_indices, dummy_docs, dummy_scores) # 2. Expand the input if input_turns_cnt is not None: input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore input_lengths = input_lengths.repeat_interleave( input_turns_cnt, dim=0) # type: ignore # Filtering empty doc_scores added due to dynamic batching (if used) doc_scores = [[s for s in ds if s is not None] for ds in doc_scores if ds] top_doc_scores = torch.stack( [torch.cat([s_i for s_i in scores_i]) for scores_i in doc_scores]) expanded_input = self.concat_docs_and_input(input, input_lengths, top_docs, top_doc_scores.size(1)) return expanded_input, top_docs, top_doc_scores
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, target_lengths: Optional[torch.LongTensor], positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: """ Override FidModel.encoder to pack all the documents into one input example. :param input: 2D [bsz, seqlen] input to the encoder :param input_lengths: 1D [bsz] lengths of each input item :param query_vec: 2D [bsz*n_turns, seqlen] input for the retriever :param input_turns_cnt: 1D [bsz] number of dialogue turns for each input example :param input_lengths: 1D [bsz] lengths of each target item (for each input item) :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores): encoder_out: *concatenated* encoded representations of context/document pairs encoder_mask: new mask for enc_out input_turns_cnt: pass along the input turns count for the decoder top_docs: List of top Documents for each batch example top_doc_scores: scores for each retrieved document. """ enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder( self, input, input_lengths, query_vec, input_turns_cnt, positions, segments ) # type: ignore seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0) if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore doc_starts = (enc_out == self.pad_idx).sum(dim=1) # left padded doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts # if no padding, docs are assumed to be min doc length long doc_lens[doc_lens.le(0)] = self.min_doc_len new_enc_out = enc_out.clone() # BEFORE: # [ # pad...doc_0 / in_0 # pad...doc_1 / in_0 # ... # pad...doc_n / in_m # ] total_length_i = 0 for i, doc_len in enumerate(doc_lens): if i % n_docs == 0: total_length_i = 0 # max doc length is determined by how much space we have after subtracting input length input_and_target = input_lengths[i // n_docs] + ( target_lengths[i // n_docs] if target_lengths is not None else 0 ) max_doc_len = torch.div( (seq_len - input_and_target), n_docs, rounding_mode='floor' ) max_doc_len = max(max_doc_len, self.min_doc_len) # type: ignore doc_len = min(doc_len, max_doc_len) total_length_i += doc_len if i % n_docs == n_docs - 1: # keep the actual input context when processing the last doc. clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i ) if target_lengths is not None: clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i - target_lengths ) pad_end = seq_len - clamped_input_length - doc_len else: pad_end = seq_len - doc_len # we simply move the portion of the doc we want to keep to the end of the tensor, # and mask out everything else. mask[i, :pad_end] = False new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[ i, doc_starts[i] : doc_starts[i] + doc_len ] new_out, new_mask = concat_enc_outs( input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False ) # After: # [ # doc_0 / doc_1 / doc_2 ... in_0 # doc_0 / doc_1 / doc_2 ... in_m # ] return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores
def compute_loss( self, criterion: torch.nn.Module, scores: torch.Tensor, preds: torch.LongTensor, enc_state: Tuple[Any, ...], label_vec: torch.LongTensor, ) -> Tuple[torch.Tensor, List[int], torch.Tensor, torch.Tensor]: """ Compute Loss for Rag Turn. RAG Turn Doc-Then-Turn computes loss with a normal NLL Loss (everything is marginalized beforehand) RAG Turn Doc-Only computes loss for each input turn; this loss can be weighted with a discount factor, applying less weight to prior turns (only for backpropagation purposes). :param criterion: torch criterion module :param scores: model scores :param preds: model "predicions" of tokens :param enc_state: encoder states :param label_vec: target tokens :return (loss, metric_loss, correct_tokens, target_tokens): loss: the loss through which we backpropagate metric_loss: loss we use for metrics correct_tokens: correct predictions from the model target_tokens: the ground truth tokens. """ if scores.size(1) != label_vec.size(1): # ignore start scores = scores[:, 1:, :] preds = preds[:, 1:] # type: ignore input_turns_cnt = enc_state[2] real_bsz = label_vec.size(0) resize_label = real_bsz != scores.size(0) if resize_label: assert self.turn_marginalize == 'doc_only' label_vec = label_vec.repeat_interleave(input_turns_cnt, dim=0) # type: ignore # compute loss score_view = scores.reshape(-1, scores.size(-1)) loss = criterion(score_view, label_vec.view(-1)) loss = loss.view(scores.shape[:-1]).sum(dim=1) metric_loss = loss.tolist() if resize_label: assert self.turn_marginalize == 'doc_only' loss = sum_across_turns(loss, input_turns_cnt, discount=self.discount_factor) metric_loss = sum_across_turns(loss, input_turns_cnt).tolist() # compute metric counters notnull = label_vec.ne(self.null_idx) target_tokens = metric_target_tokens = notnull.long().sum(dim=-1) correct = metric_correct = ((label_vec == preds) * notnull).sum(dim=-1) if resize_label: metric_target_tokens = sum_across_turns(target_tokens, input_turns_cnt) metric_correct = sum_across_turns(correct, input_turns_cnt) loss = loss.sum() loss /= target_tokens.sum() # average loss per token return loss, metric_loss, metric_correct, metric_target_tokens