def augment_batch_for_generation(self, batch: Batch, model: RagModel) -> Batch: """ Augment batch for generation. For RAG Sequence, we retrieve prior to generation, as we do not consider the document probabilities until after generating all of the beams. :param batch: batch to augment :param model: model to possibly help with augmenting :return batch: return batch with text vec swapped out. """ (expanded_input, _, doc_scores) = model.retrieve_and_concat( batch.text_vec, batch.text_vec.ne(self.null_idx).sum(1), batch.query_vec, batch.input_turn_cnt_vec, ) doc_log_probs = F.log_softmax(doc_scores, dim=1) batch.src_text_vec = batch.text_vec batch.text_vec = expanded_input batch.doc_log_probs = doc_log_probs batch.batchsize = batch.text_vec.size(0) return batch
def augment_batch_for_generation(self, batch: Batch, model: RagModel) -> Batch: """ Augment batch for doc_only turn marginalization. src_text_vec and input_turns_cnt are each used during beam re-ranking; setting batch.batchsize lets this interact nicely with TGA._generate. :param batch: batch to augment :param model: model to possibly help with augmenting :return batch: return batch with appropriate augmentations. """ if self.turn_marginalize == 'doc_only': input_turns_cnt = batch.input_turn_cnt_vec batch.batchsize = input_turns_cnt.sum().item() batch.src_text_vec = batch.text_vec batch.input_turns_cnt = input_turns_cnt return batch