def _generative_step(self, batch: dict) -> dict: start_time = time.time() batch = BatchEncoding(batch).to(device=self.model.device) generated_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], do_deduplication=False, # rag specific parameter use_cache=True, min_length=1, max_length=self.target_lens["val"], ) gen_time = (time.time() - start_time) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) # print(preds,target) loss_tensors = self._step(batch) base_metrics = { name: loss for name, loss in zip(self.loss_names, loss_tensors) } gen_metrics: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics) return base_metrics
def _generative_step(self, batch: dict) -> dict: # batch['decoder_input_ids'] = batch['labels'] # del batch['labels'] start_time = time.time() batch = BatchEncoding(batch).to(device=self.model.device) generated_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=True, min_length=1, max_length=self.target_lens["val"], ) gen_time = (time.time() - start_time) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) label = batch["decoder_input_ids"].cpu().detach().numpy() label = np.where(label != -100, label, self.tokenizer.pad_token_id) target: List[str] = self.ids_to_clean_text(label) loss_tensors = self._step(batch) base_metrics = { name: loss for name, loss in zip(self.loss_names, loss_tensors) } gen_metrics: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics) return base_metrics
def ids_to_clean_text(self, generated_ids: List[int]): gen_text = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return lmap(str.strip, gen_text)