def _generative_step(self, batch: dict) -> dict: t0 = time.time() generated_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=True, decoder_start_token_id=self.decoder_start_token_id, ) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) loss_tensors = self._step(batch) base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} rouge: Dict = self.calc_generative_metrics(preds, target) summ_len = np.mean(lmap(len, generated_ids)) base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) return base_metrics
def ids_to_clean_text(self, generated_ids: List[int]): gen_text = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return lmap(str.strip, gen_text)