예제 #1
0
 def handle(self) -> None:
     # Load archive
     archive = load_archive(Path(self.argument("archive")))
     vocab = archive.model.vocab
     if self.option("log-model-info"):
         # Log model config
         logger.info("Config: {}".format(
             json.dumps(archive.config.as_flat_dict(),
                        indent=2,
                        ensure_ascii=False)))
         # Log model metrics
         log_metrics("Trained model", archive.metrics)
     # Parse options
     num_samples = int(self.option("num-samples"))
     items = self.parse_items()
     # Prepare data for the model
     dataset_reader_params = archive.config.get("dataset_reader")
     dataset_reader_params["sample_masking"] = False
     dataset_reader = DatasetReader.from_params(**dataset_reader_params)
     collate_batch = CollateBatch.by_name(dataset_reader_params.get("type"))
     input_dict = collate_batch(
         Batch([
             vocab.encode(dataset_reader.item_to_instance(item))
             for item in items
         ])).as_dict()
     # Set posterior samples
     archive.model.set_samples(samples=1)
     # Run it
     output_dict = archive.model.interpolate(input_dict["src_tokens"],
                                             samples=num_samples,
                                             random=self.option("random"))
     # Make it readable
     samples = archive.model.make_output_human_readable(output_dict)
     print(samples)
예제 #2
0
 def evaluate(
     self,
     dataloader: DataIterator,
     desc="Validation",
     info: Dict[str, Union[float, int, str]] = None,
 ) -> Dict[str, float]:
     self._pytorch_model.eval()
     metrics = self._fit(dataloader, is_train=False)
     # Calculate mutual info
     metrics["mutual-info"] = self.calc_mutual_info(dataloader)
     # Add samples from the prior if needed
     if self._sampling_parameters is not None:
         metrics["samples"] = self._construct_samples_dataframe()
     # Log metrics only on master with run_on_rank_zero decorator
     training_util.log_metrics(mode_str=desc, info=info, metrics=metrics)
     # Pop samples as we do not want save them in the checkpoint
     metrics.pop("samples")
     return metrics
예제 #3
0
 def train(
     self,
     train_dataloader: DataIterator,
     validation_dataloader: DataIterator,
 ) -> Dict[str, float]:
     for epoch in range(self._epochs):
         # Train
         self._pytorch_model.train()
         logger.info("Training")
         train_metrics = self._fit(train_dataloader)
         # Log metrics only on master with run_on_rank_zero decorator
         training_util.log_metrics(
             mode_str="Training",
             info={"epoch": epoch},
             metrics=train_metrics,
         )
         # Validation
         logger.info("Validation")
         validation_metrics = self.evaluate(validation_dataloader, info={"epoch": epoch})
         if self._metric_patience:
             self._metric_patience(validation_metrics)
         # Save model state only on master
         if self._is_master:
             self._save_checkpoint(
                 validation_metrics,
                 is_best_so_far=self._metric_patience.improved if self._metric_patience else True,
                 save_dict=self._get_save_dict(**validation_metrics),
             )
         # Wait for master process to save new checkpoint
         if self._distributed:
             dist.barrier()
         if self._metric_patience.should_stop if self._metric_patience else False:
             logger.success("Patience reached. Stop training.")
             logger.info(
                 "Best metrics: {}".format(
                     json.dumps(self._metric_patience.best_metrics, ensure_ascii=False, indent=2)
                 )
             )
             break
     return self._metric_patience.best_metrics if self._metric_patience else validation_metrics
예제 #4
0
 def handle(self) -> None:
     # Load archive
     archive = load_archive(Path(self.argument("archive")))
     if self.option("log-model-info"):
         # Log model config
         logger.info("Config: {}".format(
             json.dumps(archive.config.as_flat_dict(),
                        indent=2,
                        ensure_ascii=False)))
         # Log model metrics
         log_metrics("Trained model", archive.metrics)
     num_samples = int(self.option("num-samples"))
     lengths = self.parse_lengths()
     samples, samples_log_prob = archive.model.sample(num_samples, lengths)
     # TODO: Make better output by truncating <eos> tokens
     samples = archive.model.make_output_human_readable(samples)
     df_dict = {"texts": [], "log_probs": []}
     for sample, log_prob in zip(samples["texts"],
                                 samples_log_prob.tolist()):
         df_dict["texts"].extend(sample)
         df_dict["log_probs"].extend([log_prob] * len(sample))
     df = pd.DataFrame(df_dict)
     print([x for i, x in enumerate(list(df.texts))])
예제 #5
0
 def train(
     self,
     train_dataloader: DataIterator,
     validation_dataloader: DataIterator,
 ) -> Dict[str, float]:
     mi_not_improved = 0
     for epoch in range(self._epochs):
         # Train
         self._pytorch_model.train()
         logger.info("Training")
         train_metrics = self._fit(train_dataloader)
         # Log metrics only on master with run_on_rank_zero decorator
         training_util.log_metrics(
             mode_str="Training",
             info={
                 "epoch": epoch,
                 "aggressive": self._aggressive
             },
             metrics=train_metrics,
         )
         # Validation
         logger.info("Validation")
         validation_metrics = self.evaluate(validation_dataloader,
                                            info={
                                                "epoch": epoch,
                                                "aggressive":
                                                self._aggressive
                                            })
         # Check mutual info to finish aggressive training if needed
         if self._aggressive and self._model.is_kl_used:
             mi_not_improved += 1
             # 5 is an expected number of aggressive epochs based on experiments from the paper
             if mi_not_improved == 5:
                 self._aggressive = False
                 logger.info("Stop aggressive burning.")
         if self._metric_patience:
             self._metric_patience(validation_metrics)
         # Save model state only on master
         if self._is_master:
             self._save_checkpoint(
                 validation_metrics,
                 is_best_so_far=self._metric_patience.improved
                 if self._metric_patience else True,
                 save_dict={
                     "model": self._model.state_dict(),
                     "encoder_optimizer":
                     self._encoder_optimizer.state_dict(),
                     "decoder_optimizer":
                     self._decoder_optimizer.state_dict(),
                     "encoder_scheduler":
                     self._encoder_scheduler.state_dict(),
                     "decoder_scheduler":
                     self._decoder_scheduler.state_dict(),
                     **validation_metrics
                 },
             )
         # Wait for master process to save new checkpoint
         if self._distributed:
             dist.barrier()
         if self._metric_patience.should_stop if self._metric_patience else False:
             logger.success("Patience reached. Stop training.")
             logger.info("Best metrics: {}".format(
                 json.dumps(self._metric_patience.best_metrics,
                            ensure_ascii=False,
                            indent=2)))
             break
     return self._metric_patience.best_metrics if self._metric_patience else validation_metrics