def test_parser(): """Method to test if the config parser can load the config file correctly""" config_name = "sample_config" config = get_config(config_name) set_logger(config) write_message_logs("torch version = {}".format(torch.__version__)) assert config.general.id == config_name
def save( self, epoch: int, optimizers: Optional[List[torch.optim.Optimizer]], is_best_model: bool = False, ) -> None: """Method to persist the model. Note this method is not well tested""" model_config = self.config.model if len(self.model_config_key) == 0: model_name = model_config.name else: model_name = model_config[self.model_config_key]["name"] # Updating the information about the epoch ## Check if the epoch_state is already saved on the file system if not os.path.exists(model_config.save_dir): os.makedirs(model_config.save_dir) epoch_state_path = os.path.join(model_config.save_dir, "epoch_state.tar") if os.path.exists(epoch_state_path): epoch_state = torch.load(epoch_state_path) else: epoch_state = {"best": epoch} epoch_state["current"] = epoch if is_best_model: epoch_state["best"] = epoch torch.save(epoch_state, epoch_state_path) state = { "metadata": { "epoch": epoch, "is_best_model": False, }, "model": { "weights": self.weights, "weight_names": self.weight_names }, "optimizers": [{ "state_dict": optimizer.state_dict() } for optimizer in optimizers], "random_state": { "np": np.random.get_state(), "python": random.getstate(), "pytorch": torch.get_rng_state(), }, } path = os.path.join(model_config.save_dir, "{}_epoch_{}.tar".format(model_name, epoch)) if is_best_model: state["metadata"]["is_best_model"] = True torch.save(state, path) write_message_logs("saved experiment to path = {}".format(path))
def get_model_params(self): """Method to get the model params""" model_parameters = list( filter(lambda p: p.requires_grad, self.parameters())) params = sum([np.prod(p.size()) for p in model_parameters]) if params == 0: # get params from weights model_parameters = self.weights params = sum([np.prod(p.size()) for p in model_parameters]) write_message_logs("Total number of params = " + str(params)) return model_parameters
def load( self, epoch: int, should_load_optimizers: bool = True, optimizers=Optional[List[optim.Optimizer]], schedulers=Optional[List[optim.lr_scheduler.ReduceLROnPlateau]], ) -> None: """Public method to load the model""" model_config = self.config.model model_config = self.config.model if len(self.model_config_key) == 0: model_name = model_config.name else: model_name = model_config[self.model_config_key]["name"] path = os.path.join(model_config.save_dir, "{}_epoch_{}.tar".format(model_name, epoch)) if not os.path.exists(path): raise FileNotFoundError("Loading path {} not found!".format(path)) write_message_logs("Loading model from path {}".format(path)) if str(self.config.general.device) == "cuda": checkpoint = torch.load(path) else: checkpoint = torch.load(path, map_location=lambda storage, loc: storage) load_random_state(checkpoint["random_state"]) self.weights = checkpoint["model"]["weights"] self.weight_names = checkpoint["model"]["weight_names"] if should_load_optimizers: if optimizers is None: optimizers = self.get_optimizers() for optim_index, optimizer in enumerate(optimizers): optimizer.load_state_dict( checkpoint["optimizers"][optim_index]["state_dict"]) key = "schedulers" if key in checkpoint: for scheduler_index, scheduler in enumerate(schedulers): scheduler.load_state_dict( checkpoint[key][scheduler_index]["state_dict"]) return optimizers, schedulers
def run(self): """Method to run the task""" write_message_logs("Starting Experiment at {}".format( time.asctime(time.localtime(time.time())))) write_config_log(self.config) write_message_logs("torch version = {}".format(torch.__version__)) if not self.config.general.is_meta: self.train_data = self.initialize_data(mode="train") self.valid_data = self.initialize_data(mode="valid") self.test_data = self.initialize_data(mode="test") self.experiment = MultitaskExperiment( config=self.config, model=self.model, data=[self.train_data, self.valid_data, self.test_data], logbook=self.logbook, ) else: raise NotImplementedError("NA") self.experiment.load_model() self.experiment.run()
def write_message_logs(self, message): """Write message logs""" fs_log.write_message_logs(message, experiment_id=self._experiment_id)