def bootstrap(config_id): config_dict = get_config(config_id=config_id) print(config_dict.log) set_logger(config_dict) write_message_logs("Starting Experiment at {}".format( time.asctime(time.localtime(time.time())))) write_message_logs("torch version = {}".format(torch.__version__)) write_config_log(config_dict) set_seed(seed=config_dict.general.seed) return config_dict
def load_model(self, optimizers): '''Method to load the model''' model_config = self.config.model path = model_config.load_path write_message_logs("Loading model from path {}".format(path)) if (self.config.device == "cuda"): checkpoint = torch.load(path) else: checkpoint = torch.load(path, map_location=lambda storage, loc: storage) epochs = checkpoint["epochs"] self._load_metadata(checkpoint) self._load_model_params(checkpoint["state_dict"]) for optim_index, optimizer in enumerate(optimizers): # optimizer.load_state_dict(checkpoint[OPTIMIZERS][optim_index]()) optimizer.load_state_dict(checkpoint["optimizers"][optim_index]) return optimizers, epochs
def save_model(self, epochs=-1, optimizers=None, is_best_model=False): '''Method to persist the model''' model_config = self.config.model state = { "epochs": epochs + 1, "state_dict": self.state_dict(), "optimizers": [optimizer.state_dict() for optimizer in optimizers], "np_random_state": np.random.get_state(), "python_random_state": random.getstate(), "pytorch_random_state": torch.get_rng_state() } if is_best_model: path = os.path.join(model_config.save_dir, "best_model.tar") else: path = os.path.join( model_config.save_dir, "model_epoch_" + str(epochs + 1) + "_timestamp_" + str(int(time())) + ".tar") torch.save(state, path) write_message_logs("saved model to path = {}".format(path))
def _run_epochs(experiment): validation_metrics_dict = experiment.validation_metrics_dict metric_to_perform_early_stopping = experiment.metric_to_perform_early_stopping config = experiment.config for key in validation_metrics_dict: validation_metrics_dict[key].reset() while experiment.epoch_index < config.model.num_epochs: _run_one_epoch_all_modes(experiment) for scheduler in experiment.schedulers: if (config.model.scheduler_type == "exp"): scheduler.step() elif (config.model.scheduler_type == "plateau"): scheduler.step( validation_metrics_dict[metric_to_perform_early_stopping]. current_value) if (config.model.persist_per_epoch > 0): if (experiment.epoch_index % config.model.persist_per_epoch == 0): experiment.model.save_model(epochs=experiment.epoch_index, optimizers=experiment.optimizers) if (config.model.persist_best_model): if (validation_metrics_dict[metric_to_perform_early_stopping]. is_best_so_far()): experiment.model.save_model(epochs=experiment.epoch_index, optimizers=experiment.optimizers, is_best_model=True) if (validation_metrics_dict[metric_to_perform_early_stopping]. should_stop_early()): best_epoch_index = experiment.epoch_index - validation_metrics_dict[ metric_to_perform_early_stopping].time_span write_metadata_logs(best_epoch_index=best_epoch_index) write_message_logs("Early stopping after running {} epochs".format( experiment.epoch_index)) write_message_logs( "Best performing model corresponds to epoch id {}".format( best_epoch_index)) for key, value in validation_metrics_dict.items(): write_message_logs( "{} of the best performing model = {}".format( key, value.get_best_so_far())) break experiment.epoch_index += 1 else: best_epoch_index = experiment.epoch_index - validation_metrics_dict[ metric_to_perform_early_stopping].counter write_metadata_logs(best_epoch_index=best_epoch_index) write_message_logs("Early stopping after running {} epochs".format( experiment.epoch_index)) write_message_logs( "Best performing model corresponds to epoch id {}".format( best_epoch_index)) for key, value in validation_metrics_dict.items(): write_message_logs("{} of the best performing model = {}".format( key, value.get_best_so_far()))
def get_model_params(self): model_parameters = list( filter(lambda p: p.requires_grad, self.parameters())) params = sum([np.prod(p.size()) for p in model_parameters]) write_message_logs("Total number of params = " + str(params)) return model_parameters