def set_input_turn_cnt_vec(self, observation: Message, model: RagModel, query_str: str) -> Message: """ Compute the number of turns of input, and set the vec accordingly. :param observation: observation in which to set the vec :param model: model provided for access to retriever tokenizer :param query_str: the query string for computation of the input turns. :return observation: return the observation with the input turn vec set appropriately. """ delimiter = model.get_retriever_delimiter() split_text_raw = query_str.split(delimiter) split_text: List[str] = [] if self.n_turns > 1 and len(split_text_raw) > self.n_turns: end_off = self.n_turns - 1 split_text = [delimiter.join(split_text_raw[:-end_off]) ] + split_text_raw[-end_off:] else: split_text = split_text_raw input_turns_cnt = torch.LongTensor([len(split_text)]) query_vecs = [model.tokenize_query(q) for q in split_text] # Override query vec observation.force_set('query_vec', query_vecs) observation['input_turn_cnt_vec'] = input_turns_cnt return observation
def build_regret_model(self) -> RagModel: """ Build and return regret RagModel. Assume dictionary is the same. """ model_file = self.opt['regret_model_file'] if model_file: assert os.path.exists( model_file), 'specify correct path for --regret-model-file' regret_opt = Opt.load(f'{model_file}.opt') regret_opt['n_docs'] = self.opt[ 'n_docs'] # Urgent that this is the same # add keys that were not in this model when originally trained regret_opt.update( {k: v for k, v in self.opt.items() if k not in regret_opt}) retriever_shared = None if all([ regret_opt[k] == self.opt[k] for k in [ 'rag_retriever_type', 'path_to_index', 'path_to_dpr_passages', ] ]): logging.warning( 'Sharing retrievers between model and regret model!') retriever_shared = self.model.encoder.retriever.share() model = RagModel(regret_opt, self.dict, retriever_shared=retriever_shared) with PathManager.open(self.opt['regret_model_file'], 'rb') as f: states = torch.load( f, map_location=lambda cpu, _: cpu, pickle_module=parlai.utils.pickle, ) assert 'model' in states model.load_state_dict(states['model']) if self.model_parallel: ph = PipelineHelper() ph.check_compatibility(self.opt) self.regret_model = ph.make_parallel(self.regret_model) else: self.regret_model.cuda() if self.fp16: self.regret_model = self.regret_model.half() sync_parameters(self.regret_model) train_params = trainable_parameters(self.regret_model) total_params = total_parameters(self.regret_model) logging.info( f"Total regret parameters: {total_params:,d} ({train_params:,d} trainable)" ) else: model = self.model return model
def augment_batch_for_generation(self, batch: Batch, model: RagModel) -> Batch: """ Augment batch for generation. For RAG Sequence, we retrieve prior to generation, as we do not consider the document probabilities until after generating all of the beams. :param batch: batch to augment :param model: model to possibly help with augmenting :return batch: return batch with text vec swapped out. """ (expanded_input, _, doc_scores) = model.retrieve_and_concat( batch.text_vec, batch.text_vec.ne(self.null_idx).sum(1), batch.query_vec, batch.input_turn_cnt_vec, ) doc_log_probs = F.log_softmax(doc_scores, dim=1) batch.src_text_vec = batch.text_vec batch.text_vec = expanded_input batch.doc_log_probs = doc_log_probs batch.batchsize = batch.text_vec.size(0) return batch
def thorough_generation( cls, hyps: List[torch.LongTensor], new_input: torch.LongTensor, null_idx: int, model: RagModel, ) -> List[Tuple[torch.LongTensor, torch.Tensor]]: """ Apply RAG-sequence thorough generation for a single batch item. Recomputes model scores with given hypotheses, sorts accordingly. :param hyps: list of candidate hypotheses :param new_input: input for the model :return sorted_hyps: return list of (hyp, score) tuples, sorted by their score. """ # deduplicate, exclude BOS Token hyps = list({str(h.tolist()): h[1:] for h in hyps}.values()) # type: ignore new_input = new_input.repeat(len(hyps), 1) # type: ignore new_ys, _ = padded_tensor( hyps, fp16friendly=new_input.size(1) % FP16_PAD_SIZE == 0, pad_idx=null_idx ) new_ys = new_ys.to(new_input.device) scores, *_ = model.seq2seq_forward_pass(new_input, new_ys) loss = cls._rag_sequence_loss( new_ys.unsqueeze(1).unsqueeze(-1), scores.unsqueeze(1), null_idx ) # type: ignore sorted_by_score = [ (hyps[idx], loss[idx]) for idx in loss.sort()[-1] ] # sort ascending return sorted_by_score
def build_rag_model(opt: Opt, dictionary: DictionaryAgent) -> RagModel: return RagModel(opt, dictionary)
def encoder( self, input: torch.LongTensor, input_lengths: torch.LongTensor, query_vec: torch.LongTensor, input_turns_cnt: torch.LongTensor, target_lengths: Optional[torch.LongTensor], positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, ) -> Tuple[ torch.Tensor, torch.BoolTensor, Optional[torch.LongTensor], Optional[List[List[Document]]], Optional[torch.Tensor], ]: """ Override FidModel.encoder to pack all the documents into one input example. :param input: 2D [bsz, seqlen] input to the encoder :param input_lengths: 1D [bsz] lengths of each input item :param query_vec: 2D [bsz*n_turns, seqlen] input for the retriever :param input_turns_cnt: 1D [bsz] number of dialogue turns for each input example :param input_lengths: 1D [bsz] lengths of each target item (for each input item) :return (encoder_out, encoder_mask, input_turns_cnt, top_docs, top_doc_scores): encoder_out: *concatenated* encoded representations of context/document pairs encoder_mask: new mask for enc_out input_turns_cnt: pass along the input turns count for the decoder top_docs: List of top Documents for each batch example top_doc_scores: scores for each retrieved document. """ enc_out, mask, input_turns_cnt, top_docs, top_doc_scores = RagModel.encoder( self, input, input_lengths, query_vec, input_turns_cnt, positions, segments ) # type: ignore seq_len, n_docs = enc_out.size(1), enc_out.size(0) // input.size(0) if input_turns_cnt is not None: # Input Turns is a tensor of dim [bsz] input = input.repeat_interleave(input_turns_cnt, dim=0) # type: ignore doc_starts = (enc_out == self.pad_idx).sum(dim=1) # left padded doc_lens = (seq_len - input_lengths.repeat_interleave(n_docs)) - doc_starts # if no padding, docs are assumed to be min doc length long doc_lens[doc_lens.le(0)] = self.min_doc_len new_enc_out = enc_out.clone() # BEFORE: # [ # pad...doc_0 / in_0 # pad...doc_1 / in_0 # ... # pad...doc_n / in_m # ] total_length_i = 0 for i, doc_len in enumerate(doc_lens): if i % n_docs == 0: total_length_i = 0 # max doc length is determined by how much space we have after subtracting input length input_and_target = input_lengths[i // n_docs] + ( target_lengths[i // n_docs] if target_lengths is not None else 0 ) max_doc_len = torch.div( (seq_len - input_and_target), n_docs, rounding_mode='floor' ) max_doc_len = max(max_doc_len, self.min_doc_len) # type: ignore doc_len = min(doc_len, max_doc_len) total_length_i += doc_len if i % n_docs == n_docs - 1: # keep the actual input context when processing the last doc. clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i ) if target_lengths is not None: clamped_input_length = input_lengths[i // n_docs].clamp( max=self.truncate - total_length_i - target_lengths ) pad_end = seq_len - clamped_input_length - doc_len else: pad_end = seq_len - doc_len # we simply move the portion of the doc we want to keep to the end of the tensor, # and mask out everything else. mask[i, :pad_end] = False new_enc_out[i, pad_end : pad_end + doc_len] = enc_out[ i, doc_starts[i] : doc_starts[i] + doc_len ] new_out, new_mask = concat_enc_outs( input, new_enc_out.unsqueeze(-1), mask, 1, self.pad_idx, right_padded=False ) # After: # [ # doc_0 / doc_1 / doc_2 ... in_0 # doc_0 / doc_1 / doc_2 ... in_m # ] return new_out, new_mask, input_turns_cnt, top_docs, top_doc_scores