class BaseTrainer(object): """Default class to train a model. Parameters ---------- cfg : Config The run configuration. """ def __init__(self, cfg: Config): super(BaseTrainer, self).__init__() self.cfg = cfg self.model = None self.optimizer = None self.loss_obj = None self.experiment_logger = None self.loader = None self.validator = None self.noise_sampler_y = None self._target_mean = None self._target_std = None self._scaler = {} self._allow_subsequent_nan_losses = cfg.allow_subsequent_nan_losses # load train basin list and add number of basins to the config self.basins = load_basin_file(cfg.train_basin_file) self.cfg.number_of_basins = len(self.basins) # check at which epoch the training starts self._epoch = self._get_start_epoch_number() self._create_folder_structure() setup_logging(str(self.cfg.run_dir / "output.log")) LOGGER.info(f"### Folder structure created at {self.cfg.run_dir}") if self.cfg.is_continue_training: LOGGER.info( f"### Continue training of run stored in {self.cfg.base_run_dir}" ) if self.cfg.is_finetuning: LOGGER.info( f"### Start finetuning with pretrained model stored in {self.cfg.base_run_dir}" ) LOGGER.info(f"### Run configurations for {self.cfg.experiment_name}") for key, val in self.cfg.as_dict().items(): LOGGER.info(f"{key}: {val}") self._set_random_seeds() self._set_device() def _get_dataset(self) -> BaseDataset: return get_dataset(cfg=self.cfg, period="train", is_train=True, scaler=self._scaler) def _get_model(self) -> torch.nn.Module: return get_model(cfg=self.cfg) def _get_optimizer(self) -> torch.optim.Optimizer: return get_optimizer(model=self.model, cfg=self.cfg) def _get_loss_obj(self) -> loss.BaseLoss: return get_loss_obj(cfg=self.cfg) def _set_regularization(self): self.loss_obj.set_regularization_terms( get_regularization_obj(cfg=self.cfg)) def _get_tester(self) -> BaseTester: return get_tester(cfg=self.cfg, run_dir=self.cfg.run_dir, period="validation", init_model=False) def _get_data_loader(self, ds: BaseDataset) -> torch.utils.data.DataLoader: return DataLoader(ds, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers) def _freeze_model_parts(self): # freeze all model weights for param in self.model.parameters(): param.requires_grad = False unresolved_modules = [] # unfreeze parameters specified in config as tuneable parameters if isinstance(self.cfg.finetune_modules, list): for module_part in self.cfg.finetune_modules: if module_part in self.model.module_parts: module = getattr(self.model, module_part) for param in module.parameters(): param.requires_grad = True else: unresolved_modules.append(module_part) else: # if it was no list, it has to be a dictionary for module_group, module_parts in self.cfg.finetune_modules.items( ): if module_group in self.model.module_parts: if isinstance(module_parts, str): module_parts = [module_parts] for module_part in module_parts: module = getattr(self.model, module_group)[module_part] for param in module.parameters(): param.requires_grad = True else: unresolved_modules.append(module_group) if unresolved_modules: LOGGER.warning( f"Could not resolve the following module parts for finetuning: {unresolved_modules}" ) def initialize_training(self): """Initialize the training class. This method will load the model, initialize loss, regularization, optimizer, dataset and dataloader, tensorboard logging, and Tester class. If called in a ``continue_training`` context, this model will also restore the model and optimizer state. """ self.model = self._get_model().to(self.device) if self.cfg.checkpoint_path is not None: LOGGER.info( f"Starting training from Checkpoint {self.cfg.checkpoint_path}" ) self.model.load_state_dict( torch.load(str(self.cfg.checkpoint_path), map_location=self.device)) elif self.cfg.checkpoint_path is None and self.cfg.is_finetuning: # the default for finetuning is the last model state checkpoint_path = [ x for x in sorted( list(self.cfg.base_run_dir.glob('model_epoch*.pt'))) ][-1] LOGGER.info(f"Starting training from checkpoint {checkpoint_path}") self.model.load_state_dict( torch.load(str(checkpoint_path), map_location=self.device)) # freeze model parts and load scaler from pre-trained model if self.cfg.is_finetuning: self._freeze_model_parts() with open( self.cfg.base_run_dir / "train_data" / "train_data_scaler.p", "rb") as fp: self._scaler = pickle.load(fp) self.optimizer = self._get_optimizer() self.loss_obj = self._get_loss_obj().to(self.device) # Add possible regularization terms to the loss function. self._set_regularization() # restore optimizer and model state if training is continued if self.cfg.is_continue_training: self._restore_training_state() ds = self._get_dataset() if len(ds) == 0: raise ValueError("Dataset contains no samples.") self.loader = self._get_data_loader(ds=ds) self.experiment_logger = Logger(cfg=self.cfg) if self.cfg.log_tensorboard: self.experiment_logger.start_tb() if self.cfg.is_continue_training: # set epoch and iteration step counter to continue from the selected checkpoint self.experiment_logger.epoch = self._epoch self.experiment_logger.update = len(self.loader) * self._epoch if self.cfg.validate_every is not None: if self.cfg.validate_n_random_basins < 1: warn_msg = [ f"Validation set to validate every {self.cfg.validate_every} epoch(s), but ", "'validate_n_random_basins' not set or set to zero. Will validate on the entire validation set." ] LOGGER.warning("".join(warn_msg)) self.cfg.validate_n_random_basins = self.cfg.number_of_basins self.validator = self._get_tester() if self.cfg.target_noise_std is not None: self.noise_sampler_y = torch.distributions.Normal( loc=0, scale=self.cfg.target_noise_std) self._target_mean = torch.from_numpy( ds.scaler["xarray_feature_center"] [self.cfg.target_variables].to_array().values).to(self.device) self._target_std = torch.from_numpy( ds.scaler["xarray_feature_scale"] [self.cfg.target_variables].to_array().values).to(self.device) def train_and_validate(self): """Train and validate the model. Train the model for the number of epochs specified in the run configuration, and perform validation after every ``validate_every`` epochs. Model and optimizer state are saved after every ``save_weights_every`` epochs. """ for epoch in range(self._epoch + 1, self._epoch + self.cfg.epochs + 1): if epoch in self.cfg.learning_rate.keys(): LOGGER.info( f"Setting learning rate to {self.cfg.learning_rate[epoch]}" ) for param_group in self.optimizer.param_groups: param_group["lr"] = self.cfg.learning_rate[epoch] self._train_epoch(epoch=epoch) avg_loss = self.experiment_logger.summarise() LOGGER.info(f"Epoch {epoch} average loss: {avg_loss}") if epoch % self.cfg.save_weights_every == 0: self._save_weights_and_optimizer(epoch) if (self.validator is not None) and (epoch % self.cfg.validate_every == 0): self.validator.evaluate( epoch=epoch, save_results=self.cfg.save_validation_results, metrics=self.cfg.metrics, model=self.model, experiment_logger=self.experiment_logger.valid()) valid_metrics = self.experiment_logger.summarise() if valid_metrics: print_msg = f" -- Median validation metrics:" print_msg += ", ".join( f"{key}: {val:.5f}" for key, val in valid_metrics.items()) LOGGER.info(print_msg) # make sure to close tensorboard to avoid losing the last epoch if self.cfg.log_tensorboard: self.experiment_logger.stop_tb() def _get_start_epoch_number(self): if self.cfg.is_continue_training: if self.cfg.continue_from_epoch is not None: epoch = self.cfg.continue_from_epoch else: weight_path = [ x for x in sorted( list(self.cfg.run_dir.glob('model_epoch*.pt'))) ][-1] epoch = weight_path.name[-6:-3] else: epoch = 0 return int(epoch) def _restore_training_state(self): if self.cfg.continue_from_epoch is not None: epoch = f"{self.cfg.continue_from_epoch:03d}" weight_path = self.cfg.base_run_dir / f"model_epoch{epoch}.pt" else: weight_path = [ x for x in sorted( list(self.cfg.base_run_dir.glob('model_epoch*.pt'))) ][-1] epoch = weight_path.name[-6:-3] optimizer_path = self.cfg.base_run_dir / f"optimizer_state_epoch{epoch}.pt" LOGGER.info(f"Continue training from epoch {int(epoch)}") self.model.load_state_dict( torch.load(weight_path, map_location=self.device)) self.optimizer.load_state_dict( torch.load(str(optimizer_path), map_location=self.device)) def _save_weights_and_optimizer(self, epoch: int): weight_path = self.cfg.run_dir / f"model_epoch{epoch:03d}.pt" torch.save(self.model.state_dict(), str(weight_path)) optimizer_path = self.cfg.run_dir / f"optimizer_state_epoch{epoch:03d}.pt" torch.save(self.optimizer.state_dict(), str(optimizer_path)) def _train_epoch(self, epoch: int): self.model.train() self.experiment_logger.train() # process bar handle pbar = tqdm(self.loader, file=sys.stdout) pbar.set_description(f'# Epoch {epoch}') # Iterate in batches over training set nan_count = 0 for data in pbar: for key in data.keys(): data[key] = data[key].to(self.device) # apply possible subclass pre-processing data = self._pre_model_hook(data) # get predictions predictions = self.model(data) if self.noise_sampler_y is not None: for key in filter(lambda k: 'y' in k, data.keys()): noise = self.noise_sampler_y.sample(data[key].shape) # make sure we add near-zero noise to originally near-zero targets data[key] += (data[key] + self._target_mean / self._target_std) * noise.to(self.device) loss = self.loss_obj(predictions, data) # early stop training if loss is NaN if torch.isnan(loss): nan_count += 1 if nan_count > self._allow_subsequent_nan_losses: raise RuntimeError( f"Loss was NaN for {nan_count} times in a row. Stopped training." ) LOGGER.warning( f"Loss is Nan; ignoring step. (#{nan_count}/{self._allow_subsequent_nan_losses})" ) else: nan_count = 0 # delete old gradients self.optimizer.zero_grad() # get gradients loss.backward() if self.cfg.clip_gradient_norm is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_gradient_norm) # update weights self.optimizer.step() pbar.set_postfix_str(f"Loss: {loss.item():.4f}") self.experiment_logger.log_step(loss=loss.item()) def _set_random_seeds(self): if self.cfg.seed is None: self.cfg.seed = int(np.random.uniform(low=0, high=1e6)) # fix random seeds for various packages random.seed(self.cfg.seed) np.random.seed(self.cfg.seed) torch.cuda.manual_seed(self.cfg.seed) torch.manual_seed(self.cfg.seed) def _set_device(self): if self.cfg.device is not None: if self.cfg.device.startswith("cuda"): gpu_id = int(self.cfg.device.split(':')[-1]) if gpu_id > torch.cuda.device_count(): raise RuntimeError( f"This machine does not have GPU #{gpu_id} ") else: self.device = torch.device(self.cfg.device) else: self.device = torch.device("cpu") else: self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") LOGGER.info(f"### Device {self.device} will be used for training") def _create_folder_structure(self): # create as subdirectory within run directory of base run if self.cfg.is_continue_training: folder_name = f"continue_training_from_epoch{self._epoch:03d}" # store dir of base run for easier access in weight loading self.cfg.base_run_dir = self.cfg.run_dir self.cfg.run_dir = self.cfg.run_dir / folder_name # create as new folder structure else: now = datetime.now() day = f"{now.day}".zfill(2) month = f"{now.month}".zfill(2) hour = f"{now.hour}".zfill(2) minute = f"{now.minute}".zfill(2) second = f"{now.second}".zfill(2) run_name = f'{self.cfg.experiment_name}_{day}{month}_{hour}{minute}{second}' # if no directory for the runs is specified, a 'runs' folder will be created in the current working dir if self.cfg.run_dir is None: self.cfg.run_dir = Path().cwd() / "runs" / run_name else: self.cfg.run_dir = self.cfg.run_dir / run_name # create folder + necessary subfolder if not self.cfg.run_dir.is_dir(): self.cfg.train_dir = self.cfg.run_dir / "train_data" self.cfg.train_dir.mkdir(parents=True) else: raise RuntimeError( f"There is already a folder at {self.cfg.run_dir}") if self.cfg.log_n_figures is not None: self.cfg.img_log_dir = self.cfg.run_dir / "img_log" self.cfg.img_log_dir.mkdir(parents=True) def _pre_model_hook( self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return data
def evaluate(self, epoch: int = None, save_results: bool = True, metrics: Union[list, dict] = [], model: torch.nn.Module = None, experiment_logger: Logger = None) -> dict: """Evaluate the model. Parameters ---------- epoch : int, optional Define a specific epoch to evaluate. By default, the weights of the last epoch are used. save_results : bool, optional If True, stores the evaluation results in the run directory. By default, True. metrics : Union[list, dict], optional List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics model : torch.nn.Module, optional If a model is passed, this is used for validation. experiment_logger : Logger, optional Logger can be passed during training to log metrics Returns ------- dict A dictionary containing one xarray per basin with the evaluation results. """ if model is None: if self.init_model: self._load_weights(epoch=epoch) model = self.model else: raise RuntimeError( "No model was initialized for the evaluation") # during validation, depending on settings, only evaluate on a random subset of basins basins = self.basins if self.period == "validation": if len(basins) > self.cfg.validate_n_random_basins: random.shuffle(basins) basins = basins[:self.cfg.validate_n_random_basins] # force model to train-mode when doing mc-dropout evaluation if self.cfg.mc_dropout: model.train() else: model.eval() results = defaultdict(dict) pbar = tqdm(basins, file=sys.stdout) pbar.set_description('# Validation' if self.period == "validation" else "# Evaluation") for basin in pbar: if self.cfg.cache_validation_data and basin in self.cached_datasets.keys( ): ds = self.cached_datasets[basin] else: try: ds = self._get_dataset(basin) except NoTrainDataError as error: # skip basin continue if self.cfg.cache_validation_data and self.period == "validation": self.cached_datasets[basin] = ds loader = DataLoader(ds, batch_size=self.cfg.batch_size, num_workers=0) y_hat, y = self._evaluate(model, loader, ds.frequencies) predict_last_n = self.cfg.predict_last_n seq_length = self.cfg.seq_length # if predict_last_n/seq_length are int, there's only one frequency if isinstance(predict_last_n, int): predict_last_n = {ds.frequencies[0]: predict_last_n} if isinstance(seq_length, int): seq_length = {ds.frequencies[0]: seq_length} lowest_freq = sort_frequencies(ds.frequencies)[0] for freq in ds.frequencies: if predict_last_n[freq] == 0: continue # this frequency is not being predicted results[basin][freq] = {} # rescale predictions y_hat_freq = \ y_hat[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \ + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values y_freq = y[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \ + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values # create xarray data = self._create_xarray(y_hat_freq, y_freq) # get maximum warmup-offset across all frequencies offsets = { freq: (seq_length[freq] - predict_last_n[freq]) * pd.to_timedelta(freq) for freq in ds.frequencies } max_offset_freq = max(offsets, key=offsets.get) start_date = ds.get_period_start( basin) + offsets[max_offset_freq] # determine the end of the first sequence (first target in sequence-to-one) # we use the end_date stored in the dataset, which also covers issues with per-basin different periods end_date = ds.dates[basin]["end_dates"][0] \ + pd.Timedelta(days=1, seconds=-1) \ - pd.to_timedelta(max_offset_freq) * (predict_last_n[max_offset_freq] - 1) date_range = pd.date_range(start=start_date, end=end_date, freq=lowest_freq) if len(date_range) != data[ f"{self.cfg.target_variables[0]}_obs"][1].shape[0]: raise ValueError( "Evaluation date range does not match generated predictions." ) frequency_factor = pd.to_timedelta( lowest_freq) // pd.to_timedelta(freq) freq_range = pd.timedelta_range(end=(frequency_factor - 1) * pd.to_timedelta(freq), periods=predict_last_n[freq], freq=freq) xr = xarray.Dataset(data_vars=data, coords={ 'date': date_range, 'time_step': freq_range }) results[basin][freq]['xr'] = xr # only warn once per freq if frequency_factor < predict_last_n[freq] and basin == basins[ 0]: tqdm.write( f'Metrics for {freq} are calculated over last {frequency_factor} elements only. ' f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.' ) if metrics: for target_variable in self.cfg.target_variables: # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H] obs = xr.isel(time_step=slice(-frequency_factor, None)) \ .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"] obs['datetime'] = obs.coords['date'] + obs.coords[ 'time_step'] # check if there are observations for this period if not all(obs.isnull()): sim = xr.isel(time_step=slice(-frequency_factor, None)) \ .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"] sim['datetime'] = sim.coords['date'] + sim.coords[ 'time_step'] # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero' if target_variable in self.cfg.clip_targets_to_zero: sim = xarray.where(sim < 0, 0, sim) if 'samples' in sim.dims: sim = sim.mean(dim='samples') var_metrics = metrics if isinstance( metrics, list) else metrics[target_variable] if 'all' in var_metrics: var_metrics = get_available_metrics() try: values = calculate_metrics(obs, sim, metrics=var_metrics, resolution=freq) except AllNaNError as err: msg = f'Basin {basin} ' \ + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \ + (f'{freq} ' if len(ds.frequencies) > 1 else '') \ + str(err) LOGGER.warning(msg) values = { metric: np.nan for metric in var_metrics } # add variable identifier to metrics if needed if len(self.cfg.target_variables) > 1: values = { f"{target_variable}_{key}": val for key, val in values.items() } # add frequency identifier to metrics if needed if len(ds.frequencies) > 1: values = { f"{key}_{freq}": val for key, val in values.items() } if experiment_logger is not None: experiment_logger.log_step(**values) for k, v in values.items(): results[basin][freq][k] = v if (self.period == "validation") and (self.cfg.log_n_figures > 0) and ( experiment_logger is not None): self._create_and_log_figures(results, experiment_logger, epoch) if save_results: self._save_results(results, epoch) return results