示例#1
0
def bootstrap(config_id):
    config_dict = get_config(config_id=config_id)
    print(config_dict.log)
    set_logger(config_dict)
    write_message_logs("Starting Experiment at {}".format(
        time.asctime(time.localtime(time.time()))))
    write_message_logs("torch version = {}".format(torch.__version__))
    write_config_log(config_dict)
    set_seed(seed=config_dict.general.seed)
    return config_dict
 def load_model(self, optimizers):
     '''Method to load the model'''
     model_config = self.config.model
     path = model_config.load_path
     write_message_logs("Loading model from path {}".format(path))
     if (self.config.device == "cuda"):
         checkpoint = torch.load(path)
     else:
         checkpoint = torch.load(path,
                                 map_location=lambda storage, loc: storage)
     epochs = checkpoint["epochs"]
     self._load_metadata(checkpoint)
     self._load_model_params(checkpoint["state_dict"])
     for optim_index, optimizer in enumerate(optimizers):
         # optimizer.load_state_dict(checkpoint[OPTIMIZERS][optim_index]())
         optimizer.load_state_dict(checkpoint["optimizers"][optim_index])
     return optimizers, epochs
 def save_model(self, epochs=-1, optimizers=None, is_best_model=False):
     '''Method to persist the model'''
     model_config = self.config.model
     state = {
         "epochs": epochs + 1,
         "state_dict": self.state_dict(),
         "optimizers": [optimizer.state_dict() for optimizer in optimizers],
         "np_random_state": np.random.get_state(),
         "python_random_state": random.getstate(),
         "pytorch_random_state": torch.get_rng_state()
     }
     if is_best_model:
         path = os.path.join(model_config.save_dir, "best_model.tar")
     else:
         path = os.path.join(
             model_config.save_dir, "model_epoch_" + str(epochs + 1) +
             "_timestamp_" + str(int(time())) + ".tar")
     torch.save(state, path)
     write_message_logs("saved model to path = {}".format(path))
def _run_epochs(experiment):
    validation_metrics_dict = experiment.validation_metrics_dict
    metric_to_perform_early_stopping = experiment.metric_to_perform_early_stopping
    config = experiment.config
    for key in validation_metrics_dict:
        validation_metrics_dict[key].reset()
    while experiment.epoch_index < config.model.num_epochs:
        _run_one_epoch_all_modes(experiment)
        for scheduler in experiment.schedulers:
            if (config.model.scheduler_type == "exp"):
                scheduler.step()
            elif (config.model.scheduler_type == "plateau"):
                scheduler.step(
                    validation_metrics_dict[metric_to_perform_early_stopping].
                    current_value)
        if (config.model.persist_per_epoch > 0):
            if (experiment.epoch_index % config.model.persist_per_epoch == 0):
                experiment.model.save_model(epochs=experiment.epoch_index,
                                            optimizers=experiment.optimizers)

        if (config.model.persist_best_model):
            if (validation_metrics_dict[metric_to_perform_early_stopping].
                    is_best_so_far()):
                experiment.model.save_model(epochs=experiment.epoch_index,
                                            optimizers=experiment.optimizers,
                                            is_best_model=True)

        if (validation_metrics_dict[metric_to_perform_early_stopping].
                should_stop_early()):
            best_epoch_index = experiment.epoch_index - validation_metrics_dict[
                metric_to_perform_early_stopping].time_span
            write_metadata_logs(best_epoch_index=best_epoch_index)
            write_message_logs("Early stopping after running {} epochs".format(
                experiment.epoch_index))
            write_message_logs(
                "Best performing model corresponds to epoch id {}".format(
                    best_epoch_index))
            for key, value in validation_metrics_dict.items():
                write_message_logs(
                    "{} of the best performing model = {}".format(
                        key, value.get_best_so_far()))
            break
        experiment.epoch_index += 1
    else:
        best_epoch_index = experiment.epoch_index - validation_metrics_dict[
            metric_to_perform_early_stopping].counter
        write_metadata_logs(best_epoch_index=best_epoch_index)
        write_message_logs("Early stopping after running {} epochs".format(
            experiment.epoch_index))
        write_message_logs(
            "Best performing model corresponds to epoch id {}".format(
                best_epoch_index))
        for key, value in validation_metrics_dict.items():
            write_message_logs("{} of the best performing model = {}".format(
                key, value.get_best_so_far()))
 def get_model_params(self):
     model_parameters = list(
         filter(lambda p: p.requires_grad, self.parameters()))
     params = sum([np.prod(p.size()) for p in model_parameters])
     write_message_logs("Total number of params = " + str(params))
     return model_parameters