Beispiel #1
0
 def _generative_step(self, batch: dict) -> dict:
     pad_token_id = self.tokenizer.pad_token_id
     source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(
         batch, pad_token_id)
     t0 = time.time()
     generated_ids = self.model.generate(
         input_ids=source_ids,
         attention_mask=source_mask,
         use_cache=True,
         decoder_start_token_id=self.decoder_start_token_id,
     )
     gen_time = (time.time() - t0) / source_ids.shape[0]
     preds = self.ids_to_clean_text(generated_ids)
     target = self.ids_to_clean_text(y)
     loss_tensors = self._step(batch)
     base_metrics = {
         name: loss
         for name, loss in zip(self.loss_names, loss_tensors)
     }
     rouge: Dict = self.calc_generative_metrics(preds, target)
     summ_len = np.mean(lmap(len, generated_ids))
     base_metrics.update(gen_time=gen_time,
                         summ_len=summ_len,
                         preds=preds,
                         target=target,
                         **rouge)
     return base_metrics
Beispiel #2
0
 def get_dataset(self, type_path) -> SummarizationDataset:
     n_obs = self.n_obs[type_path]
     dataset = SummarizationDataset(self.tokenizer,
                                    type_path=type_path,
                                    n_obs=n_obs,
                                    **self.dataset_kwargs)
     return dataset
Beispiel #3
0
    def test_step(self, batch, batch_idx):
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(
            batch, pad_token_id)
        # NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
        generated_ids = self.model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            num_beams=1,
            max_length=80,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True,
        )
        preds = [
            self.tokenizer.decode(g,
                                  skip_special_tokens=True,
                                  clean_up_tokenization_spaces=True)
            for g in generated_ids
        ]
        target = [
            self.tokenizer.decode(t,
                                  skip_special_tokens=True,
                                  clean_up_tokenization_spaces=True) for t in y
        ]
        loss = self._step(batch)

        return {"val_loss": loss, "preds": preds, "target": target}
Beispiel #4
0
 def val_dataloader(self):
     val_dataset = SummarizationDataset(
         self.tokenizer,
         data_dir=self.hparams.data_dir,
         type_path="val",
         block_size=self.hparams.max_seq_length)
     return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size)
Beispiel #5
0
 def get_dataloader(self, type_path: str, batch_size: int) -> DataLoader:
     dataset = SummarizationDataset(self.tokenizer,
                                    type_path=type_path,
                                    **self.dataset_kwargs)
     dataloader = DataLoader(dataset,
                             batch_size=batch_size,
                             collate_fn=dataset.collate_fn)
     return dataloader
Beispiel #6
0
 def get_dataset(self, type_path) -> SummarizationDataset:
     n_obs = self.n_obs[type_path]
     max_target_length = self.target_lens[type_path]
     dataset = SummarizationDataset(
         self.tokenizer,
         type_path=type_path,
         n_obs=n_obs,
         max_target_length=max_target_length,
         **self.dataset_kwargs,
     )
     return dataset
Beispiel #7
0
 def get_dataloader(self,
                    type_path: str,
                    batch_size: int,
                    shuffle: bool = False) -> DataLoader:
     dataset = SummarizationDataset(self.tokenizer,
                                    type_path=type_path,
                                    **self.dataset_kwargs)
     dataloader = DataLoader(dataset,
                             batch_size=batch_size,
                             collate_fn=dataset.collate_fn,
                             shuffle=shuffle,
                             num_workers=8)
     return dataloader
    def _generative_entailment_step(self, batch: dict) -> Tensor:
        """
        Decodes the output and compute the enatailment loss against reference, from the current training step.
        :param batch:
        :return:
        """
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(
            batch, pad_token_id)
        t0 = time.time()
        generated_ids = self.model.generate(
            input_ids=source_ids,
            attention_mask=source_mask,
            use_cache=True,
            decoder_start_token_id=self.decoder_start_token_id,
        )
        gen_time = (time.time() - t0) / source_ids.shape[0]
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)

        entailment_input = [
            InputExample(text_a=target[idx],
                         text_b=pred,
                         guid="",
                         label="entailment") for idx, pred in enumerate(preds)
        ]
        entailment_features = glue_convert_examples_to_features(
            entailment_input,
            tokenizer=self.entailment_tokenizer,
            label_list=['contradiction', 'neutral', 'entailment'],
            output_mode="classification")
        all_input_ids = torch.tensor(
            [f.input_ids for f in entailment_features], dtype=torch.long)
        all_attention_mask = torch.tensor(
            [f.attention_mask for f in entailment_features], dtype=torch.long)
        all_labels = torch.tensor([f.label for f in entailment_features],
                                  dtype=torch.long)

        all_input_ids = all_input_ids.to('cuda')
        all_attention_mask = all_attention_mask.to('cuda')
        all_labels = all_labels.to('cuda')

        with torch.no_grad():
            entailment_output = self.entailment_model(
                input_ids=all_input_ids,
                attention_mask=all_attention_mask,
                labels=all_labels)
            entailment_loss = entailment_output[0].detach()

        return entailment_loss
Beispiel #9
0
    def _generative_step(self, batch):
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
        # TODO(SS): task specific params

        t0 = time.time()
        generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
        gen_time = time.time() - t0
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(y)
        loss_tensors = self._step(batch)
        base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
        rouge: Dict = calculate_rouge(preds, target)
        summ_len = np.mean(lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, summ_len=summ_len, preds=preds, target=target, **rouge)
        return base_metrics
Beispiel #10
0
 def train_dataloader(self):
     train_dataset = SummarizationDataset(
         self.tokenizer,
         data_dir=self.hparams.data_dir,
         type_path="train",
         block_size=self.hparams.max_seq_length)
     dataloader = DataLoader(train_dataset,
                             batch_size=self.hparams.train_batch_size)
     t_total = (
         (len(dataloader.dataset) //
          (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))) //
         self.hparams.gradient_accumulation_steps *
         float(self.hparams.num_train_epochs))
     scheduler = get_linear_schedule_with_warmup(
         self.opt,
         num_warmup_steps=self.hparams.warmup_steps,
         num_training_steps=t_total)
     self.lr_scheduler = scheduler
     return dataloader
Beispiel #11
0
 def get_dataloader(self, type_path, batch_size, shuffle = False) -> DataLoader:
     dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
     dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
     return dataloader