Example #1
0
class BaseTrainer(object):
    """Default class to train a model.

    Parameters
    ----------
    cfg : Config
        The run configuration.
    """
    def __init__(self, cfg: Config):
        super(BaseTrainer, self).__init__()
        self.cfg = cfg
        self.model = None
        self.optimizer = None
        self.loss_obj = None
        self.experiment_logger = None
        self.loader = None
        self.validator = None
        self.noise_sampler_y = None
        self._target_mean = None
        self._target_std = None
        self._scaler = {}
        self._allow_subsequent_nan_losses = cfg.allow_subsequent_nan_losses

        # load train basin list and add number of basins to the config
        self.basins = load_basin_file(cfg.train_basin_file)
        self.cfg.number_of_basins = len(self.basins)

        # check at which epoch the training starts
        self._epoch = self._get_start_epoch_number()

        self._create_folder_structure()
        setup_logging(str(self.cfg.run_dir / "output.log"))
        LOGGER.info(f"### Folder structure created at {self.cfg.run_dir}")

        if self.cfg.is_continue_training:
            LOGGER.info(
                f"### Continue training of run stored in {self.cfg.base_run_dir}"
            )

        if self.cfg.is_finetuning:
            LOGGER.info(
                f"### Start finetuning with pretrained model stored in {self.cfg.base_run_dir}"
            )

        LOGGER.info(f"### Run configurations for {self.cfg.experiment_name}")
        for key, val in self.cfg.as_dict().items():
            LOGGER.info(f"{key}: {val}")

        self._set_random_seeds()
        self._set_device()

    def _get_dataset(self) -> BaseDataset:
        return get_dataset(cfg=self.cfg,
                           period="train",
                           is_train=True,
                           scaler=self._scaler)

    def _get_model(self) -> torch.nn.Module:
        return get_model(cfg=self.cfg)

    def _get_optimizer(self) -> torch.optim.Optimizer:
        return get_optimizer(model=self.model, cfg=self.cfg)

    def _get_loss_obj(self) -> loss.BaseLoss:
        return get_loss_obj(cfg=self.cfg)

    def _set_regularization(self):
        self.loss_obj.set_regularization_terms(
            get_regularization_obj(cfg=self.cfg))

    def _get_tester(self) -> BaseTester:
        return get_tester(cfg=self.cfg,
                          run_dir=self.cfg.run_dir,
                          period="validation",
                          init_model=False)

    def _get_data_loader(self, ds: BaseDataset) -> torch.utils.data.DataLoader:
        return DataLoader(ds,
                          batch_size=self.cfg.batch_size,
                          shuffle=True,
                          num_workers=self.cfg.num_workers)

    def _freeze_model_parts(self):
        # freeze all model weights
        for param in self.model.parameters():
            param.requires_grad = False

        unresolved_modules = []

        # unfreeze parameters specified in config as tuneable parameters
        if isinstance(self.cfg.finetune_modules, list):
            for module_part in self.cfg.finetune_modules:
                if module_part in self.model.module_parts:
                    module = getattr(self.model, module_part)
                    for param in module.parameters():
                        param.requires_grad = True
                else:
                    unresolved_modules.append(module_part)
        else:
            # if it was no list, it has to be a dictionary
            for module_group, module_parts in self.cfg.finetune_modules.items(
            ):
                if module_group in self.model.module_parts:
                    if isinstance(module_parts, str):
                        module_parts = [module_parts]
                    for module_part in module_parts:
                        module = getattr(self.model, module_group)[module_part]
                        for param in module.parameters():
                            param.requires_grad = True
                else:
                    unresolved_modules.append(module_group)
        if unresolved_modules:
            LOGGER.warning(
                f"Could not resolve the following module parts for finetuning: {unresolved_modules}"
            )

    def initialize_training(self):
        """Initialize the training class.

        This method will load the model, initialize loss, regularization, optimizer, dataset and dataloader,
        tensorboard logging, and Tester class.
        If called in a ``continue_training`` context, this model will also restore the model and optimizer state.
        """
        self.model = self._get_model().to(self.device)
        if self.cfg.checkpoint_path is not None:
            LOGGER.info(
                f"Starting training from Checkpoint {self.cfg.checkpoint_path}"
            )
            self.model.load_state_dict(
                torch.load(str(self.cfg.checkpoint_path),
                           map_location=self.device))
        elif self.cfg.checkpoint_path is None and self.cfg.is_finetuning:
            # the default for finetuning is the last model state
            checkpoint_path = [
                x for x in sorted(
                    list(self.cfg.base_run_dir.glob('model_epoch*.pt')))
            ][-1]
            LOGGER.info(f"Starting training from checkpoint {checkpoint_path}")
            self.model.load_state_dict(
                torch.load(str(checkpoint_path), map_location=self.device))

        # freeze model parts and load scaler from pre-trained model
        if self.cfg.is_finetuning:
            self._freeze_model_parts()

            with open(
                    self.cfg.base_run_dir / "train_data" /
                    "train_data_scaler.p", "rb") as fp:
                self._scaler = pickle.load(fp)

        self.optimizer = self._get_optimizer()
        self.loss_obj = self._get_loss_obj().to(self.device)

        # Add possible regularization terms to the loss function.
        self._set_regularization()

        # restore optimizer and model state if training is continued
        if self.cfg.is_continue_training:
            self._restore_training_state()

        ds = self._get_dataset()
        if len(ds) == 0:
            raise ValueError("Dataset contains no samples.")
        self.loader = self._get_data_loader(ds=ds)

        self.experiment_logger = Logger(cfg=self.cfg)
        if self.cfg.log_tensorboard:
            self.experiment_logger.start_tb()

        if self.cfg.is_continue_training:
            # set epoch and iteration step counter to continue from the selected checkpoint
            self.experiment_logger.epoch = self._epoch
            self.experiment_logger.update = len(self.loader) * self._epoch

        if self.cfg.validate_every is not None:
            if self.cfg.validate_n_random_basins < 1:
                warn_msg = [
                    f"Validation set to validate every {self.cfg.validate_every} epoch(s), but ",
                    "'validate_n_random_basins' not set or set to zero. Will validate on the entire validation set."
                ]
                LOGGER.warning("".join(warn_msg))
                self.cfg.validate_n_random_basins = self.cfg.number_of_basins
            self.validator = self._get_tester()

        if self.cfg.target_noise_std is not None:
            self.noise_sampler_y = torch.distributions.Normal(
                loc=0, scale=self.cfg.target_noise_std)
            self._target_mean = torch.from_numpy(
                ds.scaler["xarray_feature_center"]
                [self.cfg.target_variables].to_array().values).to(self.device)
            self._target_std = torch.from_numpy(
                ds.scaler["xarray_feature_scale"]
                [self.cfg.target_variables].to_array().values).to(self.device)

    def train_and_validate(self):
        """Train and validate the model.

        Train the model for the number of epochs specified in the run configuration, and perform validation after every
        ``validate_every`` epochs. Model and optimizer state are saved after every ``save_weights_every`` epochs.
        """
        for epoch in range(self._epoch + 1, self._epoch + self.cfg.epochs + 1):
            if epoch in self.cfg.learning_rate.keys():
                LOGGER.info(
                    f"Setting learning rate to {self.cfg.learning_rate[epoch]}"
                )
                for param_group in self.optimizer.param_groups:
                    param_group["lr"] = self.cfg.learning_rate[epoch]

            self._train_epoch(epoch=epoch)
            avg_loss = self.experiment_logger.summarise()
            LOGGER.info(f"Epoch {epoch} average loss: {avg_loss}")

            if epoch % self.cfg.save_weights_every == 0:
                self._save_weights_and_optimizer(epoch)

            if (self.validator
                    is not None) and (epoch % self.cfg.validate_every == 0):
                self.validator.evaluate(
                    epoch=epoch,
                    save_results=self.cfg.save_validation_results,
                    metrics=self.cfg.metrics,
                    model=self.model,
                    experiment_logger=self.experiment_logger.valid())

                valid_metrics = self.experiment_logger.summarise()
                if valid_metrics:
                    print_msg = f" -- Median validation metrics:"
                    print_msg += ", ".join(
                        f"{key}: {val:.5f}"
                        for key, val in valid_metrics.items())
                    LOGGER.info(print_msg)

        # make sure to close tensorboard to avoid losing the last epoch
        if self.cfg.log_tensorboard:
            self.experiment_logger.stop_tb()

    def _get_start_epoch_number(self):
        if self.cfg.is_continue_training:
            if self.cfg.continue_from_epoch is not None:
                epoch = self.cfg.continue_from_epoch
            else:
                weight_path = [
                    x for x in sorted(
                        list(self.cfg.run_dir.glob('model_epoch*.pt')))
                ][-1]
                epoch = weight_path.name[-6:-3]
        else:
            epoch = 0
        return int(epoch)

    def _restore_training_state(self):
        if self.cfg.continue_from_epoch is not None:
            epoch = f"{self.cfg.continue_from_epoch:03d}"
            weight_path = self.cfg.base_run_dir / f"model_epoch{epoch}.pt"
        else:
            weight_path = [
                x for x in sorted(
                    list(self.cfg.base_run_dir.glob('model_epoch*.pt')))
            ][-1]
            epoch = weight_path.name[-6:-3]

        optimizer_path = self.cfg.base_run_dir / f"optimizer_state_epoch{epoch}.pt"

        LOGGER.info(f"Continue training from epoch {int(epoch)}")
        self.model.load_state_dict(
            torch.load(weight_path, map_location=self.device))
        self.optimizer.load_state_dict(
            torch.load(str(optimizer_path), map_location=self.device))

    def _save_weights_and_optimizer(self, epoch: int):
        weight_path = self.cfg.run_dir / f"model_epoch{epoch:03d}.pt"
        torch.save(self.model.state_dict(), str(weight_path))

        optimizer_path = self.cfg.run_dir / f"optimizer_state_epoch{epoch:03d}.pt"
        torch.save(self.optimizer.state_dict(), str(optimizer_path))

    def _train_epoch(self, epoch: int):
        self.model.train()
        self.experiment_logger.train()

        # process bar handle
        pbar = tqdm(self.loader, file=sys.stdout)
        pbar.set_description(f'# Epoch {epoch}')

        # Iterate in batches over training set
        nan_count = 0
        for data in pbar:

            for key in data.keys():
                data[key] = data[key].to(self.device)

            # apply possible subclass pre-processing
            data = self._pre_model_hook(data)

            # get predictions
            predictions = self.model(data)

            if self.noise_sampler_y is not None:
                for key in filter(lambda k: 'y' in k, data.keys()):
                    noise = self.noise_sampler_y.sample(data[key].shape)
                    # make sure we add near-zero noise to originally near-zero targets
                    data[key] += (data[key] + self._target_mean /
                                  self._target_std) * noise.to(self.device)

            loss = self.loss_obj(predictions, data)

            # early stop training if loss is NaN
            if torch.isnan(loss):
                nan_count += 1
                if nan_count > self._allow_subsequent_nan_losses:
                    raise RuntimeError(
                        f"Loss was NaN for {nan_count} times in a row. Stopped training."
                    )
                LOGGER.warning(
                    f"Loss is Nan; ignoring step. (#{nan_count}/{self._allow_subsequent_nan_losses})"
                )
            else:
                nan_count = 0

                # delete old gradients
                self.optimizer.zero_grad()

                # get gradients
                loss.backward()

                if self.cfg.clip_gradient_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.cfg.clip_gradient_norm)

                # update weights
                self.optimizer.step()

            pbar.set_postfix_str(f"Loss: {loss.item():.4f}")

            self.experiment_logger.log_step(loss=loss.item())

    def _set_random_seeds(self):
        if self.cfg.seed is None:
            self.cfg.seed = int(np.random.uniform(low=0, high=1e6))

        # fix random seeds for various packages
        random.seed(self.cfg.seed)
        np.random.seed(self.cfg.seed)
        torch.cuda.manual_seed(self.cfg.seed)
        torch.manual_seed(self.cfg.seed)

    def _set_device(self):
        if self.cfg.device is not None:
            if self.cfg.device.startswith("cuda"):
                gpu_id = int(self.cfg.device.split(':')[-1])
                if gpu_id > torch.cuda.device_count():
                    raise RuntimeError(
                        f"This machine does not have GPU #{gpu_id} ")
                else:
                    self.device = torch.device(self.cfg.device)
            else:
                self.device = torch.device("cpu")
        else:
            self.device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
        LOGGER.info(f"### Device {self.device} will be used for training")

    def _create_folder_structure(self):
        # create as subdirectory within run directory of base run
        if self.cfg.is_continue_training:
            folder_name = f"continue_training_from_epoch{self._epoch:03d}"

            # store dir of base run for easier access in weight loading
            self.cfg.base_run_dir = self.cfg.run_dir
            self.cfg.run_dir = self.cfg.run_dir / folder_name

        # create as new folder structure
        else:
            now = datetime.now()
            day = f"{now.day}".zfill(2)
            month = f"{now.month}".zfill(2)
            hour = f"{now.hour}".zfill(2)
            minute = f"{now.minute}".zfill(2)
            second = f"{now.second}".zfill(2)
            run_name = f'{self.cfg.experiment_name}_{day}{month}_{hour}{minute}{second}'

            # if no directory for the runs is specified, a 'runs' folder will be created in the current working dir
            if self.cfg.run_dir is None:
                self.cfg.run_dir = Path().cwd() / "runs" / run_name
            else:
                self.cfg.run_dir = self.cfg.run_dir / run_name

        # create folder + necessary subfolder
        if not self.cfg.run_dir.is_dir():
            self.cfg.train_dir = self.cfg.run_dir / "train_data"
            self.cfg.train_dir.mkdir(parents=True)
        else:
            raise RuntimeError(
                f"There is already a folder at {self.cfg.run_dir}")
        if self.cfg.log_n_figures is not None:
            self.cfg.img_log_dir = self.cfg.run_dir / "img_log"
            self.cfg.img_log_dir.mkdir(parents=True)

    def _pre_model_hook(
            self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        return data
Example #2
0
    def evaluate(self,
                 epoch: int = None,
                 save_results: bool = True,
                 metrics: Union[list, dict] = [],
                 model: torch.nn.Module = None,
                 experiment_logger: Logger = None) -> dict:
        """Evaluate the model.
        
        Parameters
        ----------
        epoch : int, optional
            Define a specific epoch to evaluate. By default, the weights of the last epoch are used.
        save_results : bool, optional
            If True, stores the evaluation results in the run directory. By default, True.
        metrics : Union[list, dict], optional
            List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics
        model : torch.nn.Module, optional
            If a model is passed, this is used for validation.
        experiment_logger : Logger, optional
            Logger can be passed during training to log metrics

        Returns
        -------
        dict
            A dictionary containing one xarray per basin with the evaluation results.
        """
        if model is None:
            if self.init_model:
                self._load_weights(epoch=epoch)
                model = self.model
            else:
                raise RuntimeError(
                    "No model was initialized for the evaluation")

        # during validation, depending on settings, only evaluate on a random subset of basins
        basins = self.basins
        if self.period == "validation":
            if len(basins) > self.cfg.validate_n_random_basins:
                random.shuffle(basins)
                basins = basins[:self.cfg.validate_n_random_basins]

        # force model to train-mode when doing mc-dropout evaluation
        if self.cfg.mc_dropout:
            model.train()
        else:
            model.eval()

        results = defaultdict(dict)

        pbar = tqdm(basins, file=sys.stdout)
        pbar.set_description('# Validation' if self.period ==
                             "validation" else "# Evaluation")

        for basin in pbar:

            if self.cfg.cache_validation_data and basin in self.cached_datasets.keys(
            ):
                ds = self.cached_datasets[basin]
            else:
                try:
                    ds = self._get_dataset(basin)
                except NoTrainDataError as error:
                    # skip basin
                    continue
                if self.cfg.cache_validation_data and self.period == "validation":
                    self.cached_datasets[basin] = ds

            loader = DataLoader(ds,
                                batch_size=self.cfg.batch_size,
                                num_workers=0)

            y_hat, y = self._evaluate(model, loader, ds.frequencies)

            predict_last_n = self.cfg.predict_last_n
            seq_length = self.cfg.seq_length
            # if predict_last_n/seq_length are int, there's only one frequency
            if isinstance(predict_last_n, int):
                predict_last_n = {ds.frequencies[0]: predict_last_n}
            if isinstance(seq_length, int):
                seq_length = {ds.frequencies[0]: seq_length}
            lowest_freq = sort_frequencies(ds.frequencies)[0]
            for freq in ds.frequencies:
                if predict_last_n[freq] == 0:
                    continue  # this frequency is not being predicted
                results[basin][freq] = {}

                # rescale predictions
                y_hat_freq = \
                    y_hat[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values
                y_freq = y[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values

                # create xarray
                data = self._create_xarray(y_hat_freq, y_freq)

                # get maximum warmup-offset across all frequencies
                offsets = {
                    freq: (seq_length[freq] - predict_last_n[freq]) *
                    pd.to_timedelta(freq)
                    for freq in ds.frequencies
                }
                max_offset_freq = max(offsets, key=offsets.get)
                start_date = ds.get_period_start(
                    basin) + offsets[max_offset_freq]

                # determine the end of the first sequence (first target in sequence-to-one)
                # we use the end_date stored in the dataset, which also covers issues with per-basin different periods
                end_date = ds.dates[basin]["end_dates"][0] \
                    + pd.Timedelta(days=1, seconds=-1) \
                    - pd.to_timedelta(max_offset_freq) * (predict_last_n[max_offset_freq] - 1)
                date_range = pd.date_range(start=start_date,
                                           end=end_date,
                                           freq=lowest_freq)
                if len(date_range) != data[
                        f"{self.cfg.target_variables[0]}_obs"][1].shape[0]:
                    raise ValueError(
                        "Evaluation date range does not match generated predictions."
                    )

                frequency_factor = pd.to_timedelta(
                    lowest_freq) // pd.to_timedelta(freq)
                freq_range = pd.timedelta_range(end=(frequency_factor - 1) *
                                                pd.to_timedelta(freq),
                                                periods=predict_last_n[freq],
                                                freq=freq)

                xr = xarray.Dataset(data_vars=data,
                                    coords={
                                        'date': date_range,
                                        'time_step': freq_range
                                    })
                results[basin][freq]['xr'] = xr

                # only warn once per freq
                if frequency_factor < predict_last_n[freq] and basin == basins[
                        0]:
                    tqdm.write(
                        f'Metrics for {freq} are calculated over last {frequency_factor} elements only. '
                        f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.'
                    )

                if metrics:
                    for target_variable in self.cfg.target_variables:
                        # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H]
                        obs = xr.isel(time_step=slice(-frequency_factor, None)) \
                            .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"]
                        obs['datetime'] = obs.coords['date'] + obs.coords[
                            'time_step']
                        # check if there are observations for this period
                        if not all(obs.isnull()):
                            sim = xr.isel(time_step=slice(-frequency_factor, None)) \
                                .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"]
                            sim['datetime'] = sim.coords['date'] + sim.coords[
                                'time_step']

                            # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero'
                            if target_variable in self.cfg.clip_targets_to_zero:
                                sim = xarray.where(sim < 0, 0, sim)

                            if 'samples' in sim.dims:
                                sim = sim.mean(dim='samples')

                            var_metrics = metrics if isinstance(
                                metrics, list) else metrics[target_variable]
                            if 'all' in var_metrics:
                                var_metrics = get_available_metrics()
                            try:
                                values = calculate_metrics(obs,
                                                           sim,
                                                           metrics=var_metrics,
                                                           resolution=freq)
                            except AllNaNError as err:
                                msg = f'Basin {basin} ' \
                                    + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \
                                    + (f'{freq} ' if len(ds.frequencies) > 1 else '') \
                                    + str(err)
                                LOGGER.warning(msg)
                                values = {
                                    metric: np.nan
                                    for metric in var_metrics
                                }

                            # add variable identifier to metrics if needed
                            if len(self.cfg.target_variables) > 1:
                                values = {
                                    f"{target_variable}_{key}": val
                                    for key, val in values.items()
                                }
                            # add frequency identifier to metrics if needed
                            if len(ds.frequencies) > 1:
                                values = {
                                    f"{key}_{freq}": val
                                    for key, val in values.items()
                                }
                            if experiment_logger is not None:
                                experiment_logger.log_step(**values)
                            for k, v in values.items():
                                results[basin][freq][k] = v

        if (self.period == "validation") and (self.cfg.log_n_figures > 0) and (
                experiment_logger is not None):
            self._create_and_log_figures(results, experiment_logger, epoch)

        if save_results:
            self._save_results(results, epoch)

        return results