def handle(self) -> None: # Load archive archive = load_archive(Path(self.argument("archive"))) vocab = archive.model.vocab if self.option("log-model-info"): # Log model config logger.info("Config: {}".format( json.dumps(archive.config.as_flat_dict(), indent=2, ensure_ascii=False))) # Log model metrics log_metrics("Trained model", archive.metrics) # Parse options num_samples = int(self.option("num-samples")) items = self.parse_items() # Prepare data for the model dataset_reader_params = archive.config.get("dataset_reader") dataset_reader_params["sample_masking"] = False dataset_reader = DatasetReader.from_params(**dataset_reader_params) collate_batch = CollateBatch.by_name(dataset_reader_params.get("type")) input_dict = collate_batch( Batch([ vocab.encode(dataset_reader.item_to_instance(item)) for item in items ])).as_dict() # Set posterior samples archive.model.set_samples(samples=1) # Run it output_dict = archive.model.interpolate(input_dict["src_tokens"], samples=num_samples, random=self.option("random")) # Make it readable samples = archive.model.make_output_human_readable(output_dict) print(samples)
def evaluate( self, dataloader: DataIterator, desc="Validation", info: Dict[str, Union[float, int, str]] = None, ) -> Dict[str, float]: self._pytorch_model.eval() metrics = self._fit(dataloader, is_train=False) # Calculate mutual info metrics["mutual-info"] = self.calc_mutual_info(dataloader) # Add samples from the prior if needed if self._sampling_parameters is not None: metrics["samples"] = self._construct_samples_dataframe() # Log metrics only on master with run_on_rank_zero decorator training_util.log_metrics(mode_str=desc, info=info, metrics=metrics) # Pop samples as we do not want save them in the checkpoint metrics.pop("samples") return metrics
def train( self, train_dataloader: DataIterator, validation_dataloader: DataIterator, ) -> Dict[str, float]: for epoch in range(self._epochs): # Train self._pytorch_model.train() logger.info("Training") train_metrics = self._fit(train_dataloader) # Log metrics only on master with run_on_rank_zero decorator training_util.log_metrics( mode_str="Training", info={"epoch": epoch}, metrics=train_metrics, ) # Validation logger.info("Validation") validation_metrics = self.evaluate(validation_dataloader, info={"epoch": epoch}) if self._metric_patience: self._metric_patience(validation_metrics) # Save model state only on master if self._is_master: self._save_checkpoint( validation_metrics, is_best_so_far=self._metric_patience.improved if self._metric_patience else True, save_dict=self._get_save_dict(**validation_metrics), ) # Wait for master process to save new checkpoint if self._distributed: dist.barrier() if self._metric_patience.should_stop if self._metric_patience else False: logger.success("Patience reached. Stop training.") logger.info( "Best metrics: {}".format( json.dumps(self._metric_patience.best_metrics, ensure_ascii=False, indent=2) ) ) break return self._metric_patience.best_metrics if self._metric_patience else validation_metrics
def handle(self) -> None: # Load archive archive = load_archive(Path(self.argument("archive"))) if self.option("log-model-info"): # Log model config logger.info("Config: {}".format( json.dumps(archive.config.as_flat_dict(), indent=2, ensure_ascii=False))) # Log model metrics log_metrics("Trained model", archive.metrics) num_samples = int(self.option("num-samples")) lengths = self.parse_lengths() samples, samples_log_prob = archive.model.sample(num_samples, lengths) # TODO: Make better output by truncating <eos> tokens samples = archive.model.make_output_human_readable(samples) df_dict = {"texts": [], "log_probs": []} for sample, log_prob in zip(samples["texts"], samples_log_prob.tolist()): df_dict["texts"].extend(sample) df_dict["log_probs"].extend([log_prob] * len(sample)) df = pd.DataFrame(df_dict) print([x for i, x in enumerate(list(df.texts))])
def train( self, train_dataloader: DataIterator, validation_dataloader: DataIterator, ) -> Dict[str, float]: mi_not_improved = 0 for epoch in range(self._epochs): # Train self._pytorch_model.train() logger.info("Training") train_metrics = self._fit(train_dataloader) # Log metrics only on master with run_on_rank_zero decorator training_util.log_metrics( mode_str="Training", info={ "epoch": epoch, "aggressive": self._aggressive }, metrics=train_metrics, ) # Validation logger.info("Validation") validation_metrics = self.evaluate(validation_dataloader, info={ "epoch": epoch, "aggressive": self._aggressive }) # Check mutual info to finish aggressive training if needed if self._aggressive and self._model.is_kl_used: mi_not_improved += 1 # 5 is an expected number of aggressive epochs based on experiments from the paper if mi_not_improved == 5: self._aggressive = False logger.info("Stop aggressive burning.") if self._metric_patience: self._metric_patience(validation_metrics) # Save model state only on master if self._is_master: self._save_checkpoint( validation_metrics, is_best_so_far=self._metric_patience.improved if self._metric_patience else True, save_dict={ "model": self._model.state_dict(), "encoder_optimizer": self._encoder_optimizer.state_dict(), "decoder_optimizer": self._decoder_optimizer.state_dict(), "encoder_scheduler": self._encoder_scheduler.state_dict(), "decoder_scheduler": self._decoder_scheduler.state_dict(), **validation_metrics }, ) # Wait for master process to save new checkpoint if self._distributed: dist.barrier() if self._metric_patience.should_stop if self._metric_patience else False: logger.success("Patience reached. Stop training.") logger.info("Best metrics: {}".format( json.dumps(self._metric_patience.best_metrics, ensure_ascii=False, indent=2))) break return self._metric_patience.best_metrics if self._metric_patience else validation_metrics