def run(self) -> Type[torch.nn.Module]: """ Trains a neural network, prints the statistics, saves the final model weights. """ for e in range(self.training_cycles): if self.full_epoch: self.step_vanilla() else: self.step(e) if self.swa: self.save_running_weights(e) if self.perturb_weights: self.weight_perturbation(e) if e == 0 or (e + 1) % self.print_loss == 0: self.print_statistics(e) # Save final model weights self.save_model() if not self.full_epoch: self.eval_model() if self.swa: #if not self.full_epoch: print("Performing stochastic weights averaging...") self.net.load_state_dict(average_weights(self.recent_weights)) self.eval_model() if self.plot_training_history: plot_losses(self.train_loss, self.test_loss) return self.net
def run(self) -> Type[torch.nn.Module]: """ Trains a neural network, prints the statistics, saves the final model weights. One can also pass kwargs for utils.datatransform class to perform the data augmentation "on-the-fly" """ for e in range(self.training_cycles): if self.full_epoch: self.step_full() else: self.step(e) if self.swa: self.save_running_weights(e) if self.perturb_weights: self.weight_perturbation(e) if any([ e == 0, (e + 1) % self.print_loss == 0, e == self.training_cycles - 1 ]): self.print_statistics(e) self.save_model(self.filename + "_metadict_final") if not self.full_epoch: self.eval_model() if self.swa: print("Performing stochastic weights averaging...") self.net.load_state_dict(average_weights(self.running_weights)) self.eval_model() if self.plot_training_history: plot_losses(self.loss_acc["train_loss"], self.loss_acc["test_loss"]) return self.net
def train_ensemble_from_scratch(self) -> ensemble_out: """ Trains ensemble of models starting every time from scratch with different initialization (for both weights and batches shuffling) """ print("Training ensemble models:") for i in range(self.n_models): print("Ensemble model {}".format(i + 1)) trainer_i = self.train_baseline(seed=i + 1, batch_seed=i + 1) self.ensemble_state_dict[i] = trainer_i.net.state_dict() self.save_ensemble_metadict(trainer_i.meta_state_dict) averaged_weights = average_weights(self.ensemble_state_dict) trainer_i.net.load_state_dict(averaged_weights) return self.ensemble_state_dict, trainer_i.net
def train_from_baseline(self, basemodel: Union[OrderedDict, Type[torch.nn.Module]], **kwargs: Dict) -> ensemble_out: """ Trains ensemble of models starting each time from baseline weights Args: basemodel (pytorch object): Baseline model or baseline weights **kwargs: Updates kwargs from the ensemble class initialization (can be useful for iterative training) """ if len(kwargs) != 0: for k, v in kwargs.items(): self.kdict[k] = v if isinstance(basemodel, OrderedDict): initial_model_state_dict = copy.deepcopy(basemodel) else: initial_model_state_dict = copy.deepcopy(basemodel.state_dict()) n_models = kwargs.get("n_models") if n_models is not None: self.n_models = n_models if "print_loss" not in self.kdict.keys(): self.kdict["print_loss"] = 10 filename = kwargs.get("filename") training_cycles_ensemble = kwargs.get("training_cycles_ensemble") if training_cycles_ensemble is not None: self.iter_ensemble = training_cycles_ensemble if filename is not None: self.filename = filename print('Training ensemble models:') for i in range(self.n_models): print('Ensemble model', i + 1) trainer_i = trainer(self.X_train, self.y_train, self.X_test, self.y_test, self.iter_ensemble, self.model_type, batch_seed=i + 1, plot_training_history=False, **self.kdict) self.update_weights(trainer_i.net.state_dict().values(), initial_model_state_dict.values()) trained_model_i = trainer_i.run() self.ensemble_state_dict[i] = trained_model_i.state_dict() self.save_ensemble_metadict(trainer_i.meta_state_dict) averaged_weights = average_weights(self.ensemble_state_dict) trainer_i.net.load_state_dict(averaged_weights) return self.ensemble_state_dict, trainer_i.net