예제 #1
0
def test_parser():
    """Method to test if the config parser can load the config file correctly"""
    config_name = "sample_config"
    config = get_config(config_name)
    set_logger(config)
    write_message_logs("torch version = {}".format(torch.__version__))
    assert config.general.id == config_name
예제 #2
0
    def save(
        self,
        epoch: int,
        optimizers: Optional[List[torch.optim.Optimizer]],
        is_best_model: bool = False,
    ) -> None:
        """Method to persist the model.
        Note this method is not well tested"""

        model_config = self.config.model
        if len(self.model_config_key) == 0:
            model_name = model_config.name
        else:
            model_name = model_config[self.model_config_key]["name"]

        # Updating the information about the epoch
        ## Check if the epoch_state is already saved on the file system
        if not os.path.exists(model_config.save_dir):
            os.makedirs(model_config.save_dir)
        epoch_state_path = os.path.join(model_config.save_dir,
                                        "epoch_state.tar")

        if os.path.exists(epoch_state_path):
            epoch_state = torch.load(epoch_state_path)
        else:
            epoch_state = {"best": epoch}
        epoch_state["current"] = epoch
        if is_best_model:
            epoch_state["best"] = epoch
        torch.save(epoch_state, epoch_state_path)

        state = {
            "metadata": {
                "epoch": epoch,
                "is_best_model": False,
            },
            "model": {
                "weights": self.weights,
                "weight_names": self.weight_names
            },
            "optimizers": [{
                "state_dict": optimizer.state_dict()
            } for optimizer in optimizers],
            "random_state": {
                "np": np.random.get_state(),
                "python": random.getstate(),
                "pytorch": torch.get_rng_state(),
            },
        }
        path = os.path.join(model_config.save_dir,
                            "{}_epoch_{}.tar".format(model_name, epoch))
        if is_best_model:
            state["metadata"]["is_best_model"] = True
        torch.save(state, path)
        write_message_logs("saved experiment to path = {}".format(path))
예제 #3
0
 def get_model_params(self):
     """Method to get the model params"""
     model_parameters = list(
         filter(lambda p: p.requires_grad, self.parameters()))
     params = sum([np.prod(p.size()) for p in model_parameters])
     if params == 0:
         # get params from weights
         model_parameters = self.weights
         params = sum([np.prod(p.size()) for p in model_parameters])
     write_message_logs("Total number of params = " + str(params))
     return model_parameters
예제 #4
0
    def load(
        self,
        epoch: int,
        should_load_optimizers: bool = True,
        optimizers=Optional[List[optim.Optimizer]],
        schedulers=Optional[List[optim.lr_scheduler.ReduceLROnPlateau]],
    ) -> None:
        """Public method to load the model"""
        model_config = self.config.model
        model_config = self.config.model
        if len(self.model_config_key) == 0:
            model_name = model_config.name
        else:
            model_name = model_config[self.model_config_key]["name"]
        path = os.path.join(model_config.save_dir,
                            "{}_epoch_{}.tar".format(model_name, epoch))
        if not os.path.exists(path):
            raise FileNotFoundError("Loading path {} not found!".format(path))
        write_message_logs("Loading model from path {}".format(path))
        if str(self.config.general.device) == "cuda":
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path,
                                    map_location=lambda storage, loc: storage)
        load_random_state(checkpoint["random_state"])
        self.weights = checkpoint["model"]["weights"]
        self.weight_names = checkpoint["model"]["weight_names"]

        if should_load_optimizers:
            if optimizers is None:
                optimizers = self.get_optimizers()
            for optim_index, optimizer in enumerate(optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers"][optim_index]["state_dict"])

            key = "schedulers"
            if key in checkpoint:
                for scheduler_index, scheduler in enumerate(schedulers):
                    scheduler.load_state_dict(
                        checkpoint[key][scheduler_index]["state_dict"])
        return optimizers, schedulers
예제 #5
0
    def run(self):
        """Method to run the task"""

        write_message_logs("Starting Experiment at {}".format(
            time.asctime(time.localtime(time.time()))))
        write_config_log(self.config)
        write_message_logs("torch version = {}".format(torch.__version__))

        if not self.config.general.is_meta:
            self.train_data = self.initialize_data(mode="train")
            self.valid_data = self.initialize_data(mode="valid")
            self.test_data = self.initialize_data(mode="test")
            self.experiment = MultitaskExperiment(
                config=self.config,
                model=self.model,
                data=[self.train_data, self.valid_data, self.test_data],
                logbook=self.logbook,
            )
        else:
            raise NotImplementedError("NA")
        self.experiment.load_model()
        self.experiment.run()
예제 #6
0
파일: logbook.py 프로젝트: vargeus/GraphLog
 def write_message_logs(self, message):
     """Write message logs"""
     fs_log.write_message_logs(message, experiment_id=self._experiment_id)