Esempio n. 1
0
def _create_ensemble(results_files: List[Path], frequencies: List[str], config: Config) -> dict:
    """Averages the predictions of the passed runs and re-calculates metrics. """
    lowest_freq = sort_frequencies(frequencies)[0]
    ensemble_sum = defaultdict(dict)
    target_vars = config.target_variables

    print('Loading results for each run.')
    for run in tqdm(results_files):
        run_results = pickle.load(open(run, 'rb'))
        for basin, basin_results in run_results.items():
            for freq in frequencies:
                freq_results = basin_results[freq]['xr']

                # sum up the predictions of all basins
                if freq not in ensemble_sum[basin]:
                    ensemble_sum[basin][freq] = freq_results
                else:
                    for target_var in target_vars:
                        ensemble_sum[basin][freq][f'{target_var}_sim'] += freq_results[f'{target_var}_sim']

    # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency
    print('Combining results and calculating metrics.')
    ensemble = defaultdict(lambda: defaultdict(dict))
    for basin in tqdm(ensemble_sum.keys()):
        for freq in frequencies:
            ensemble[basin][freq]['xr'] = ensemble_sum[basin][freq]

            # combine date and time to a single index to calculate metrics
            frequency_factor = pd.to_timedelta(lowest_freq) // pd.to_timedelta(freq)
            ensemble[basin][freq]['xr'] = ensemble[basin][freq]['xr'].isel(
                time_step=slice(-frequency_factor, None)).stack(datetime=['date', 'time_step'])
            ensemble[basin][freq]['xr']['datetime'] = [
                c[0] + c[1] for c in ensemble[basin][freq]['xr'].coords['datetime'].values
            ]

            for target_var in target_vars:
                # average predictions
                ensemble[basin][freq]['xr'][
                    f'{target_var}_sim'] = ensemble[basin][freq]['xr'][f'{target_var}_sim'] / len(results_files)

                # clip predictions to zero
                sim = ensemble[basin][freq]['xr'][f'{target_var}_sim']
                if target_var in config.clip_targets_to_zero:
                    sim = xr.where(sim < 0, 0, sim)

                # calculate metrics
                ensemble_metrics = calculate_metrics(
                    ensemble[basin][freq]['xr'][f'{target_var}_obs'],
                    sim,
                    metrics=config.metrics if isinstance(config.metrics, list) else config.metrics[target_var],
                    resolution=freq)
                # add variable identifier to metrics if needed
                if len(target_vars) > 1:
                    ensemble_metrics = {f'{target_var}_{key}': val for key, val in ensemble_metrics.items()}
                for metric, val in ensemble_metrics.items():
                    ensemble[basin][freq][f'{metric}_{freq}'] = val

    return dict(ensemble)
Esempio n. 2
0
    def evaluate(self,
                 epoch: int = None,
                 save_results: bool = True,
                 metrics: Union[list, dict] = [],
                 model: torch.nn.Module = None,
                 experiment_logger: Logger = None) -> dict:
        """Evaluate the model.
        
        Parameters
        ----------
        epoch : int, optional
            Define a specific epoch to evaluate. By default, the weights of the last epoch are used.
        save_results : bool, optional
            If True, stores the evaluation results in the run directory. By default, True.
        metrics : Union[list, dict], optional
            List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics
        model : torch.nn.Module, optional
            If a model is passed, this is used for validation.
        experiment_logger : Logger, optional
            Logger can be passed during training to log metrics

        Returns
        -------
        dict
            A dictionary containing one xarray per basin with the evaluation results.
        """
        if model is None:
            if self.init_model:
                self._load_weights(epoch=epoch)
                model = self.model
            else:
                raise RuntimeError(
                    "No model was initialized for the evaluation")

        # during validation, depending on settings, only evaluate on a random subset of basins
        basins = self.basins
        if self.period == "validation":
            if len(basins) > self.cfg.validate_n_random_basins:
                random.shuffle(basins)
                basins = basins[:self.cfg.validate_n_random_basins]

        # force model to train-mode when doing mc-dropout evaluation
        if self.cfg.mc_dropout:
            model.train()
        else:
            model.eval()

        results = defaultdict(dict)

        pbar = tqdm(basins, file=sys.stdout)
        pbar.set_description('# Validation' if self.period ==
                             "validation" else "# Evaluation")

        for basin in pbar:

            if self.cfg.cache_validation_data and basin in self.cached_datasets.keys(
            ):
                ds = self.cached_datasets[basin]
            else:
                try:
                    ds = self._get_dataset(basin)
                except NoTrainDataError as error:
                    # skip basin
                    continue
                if self.cfg.cache_validation_data and self.period == "validation":
                    self.cached_datasets[basin] = ds

            loader = DataLoader(ds,
                                batch_size=self.cfg.batch_size,
                                num_workers=0)

            y_hat, y = self._evaluate(model, loader, ds.frequencies)

            predict_last_n = self.cfg.predict_last_n
            seq_length = self.cfg.seq_length
            # if predict_last_n/seq_length are int, there's only one frequency
            if isinstance(predict_last_n, int):
                predict_last_n = {ds.frequencies[0]: predict_last_n}
            if isinstance(seq_length, int):
                seq_length = {ds.frequencies[0]: seq_length}
            lowest_freq = sort_frequencies(ds.frequencies)[0]
            for freq in ds.frequencies:
                if predict_last_n[freq] == 0:
                    continue  # this frequency is not being predicted
                results[basin][freq] = {}

                # rescale predictions
                y_hat_freq = \
                    y_hat[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values
                y_freq = y[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values

                # create xarray
                data = self._create_xarray(y_hat_freq, y_freq)

                # get maximum warmup-offset across all frequencies
                offsets = {
                    freq: (seq_length[freq] - predict_last_n[freq]) *
                    pd.to_timedelta(freq)
                    for freq in ds.frequencies
                }
                max_offset_freq = max(offsets, key=offsets.get)
                start_date = ds.get_period_start(
                    basin) + offsets[max_offset_freq]

                # determine the end of the first sequence (first target in sequence-to-one)
                # we use the end_date stored in the dataset, which also covers issues with per-basin different periods
                end_date = ds.dates[basin]["end_dates"][0] \
                    + pd.Timedelta(days=1, seconds=-1) \
                    - pd.to_timedelta(max_offset_freq) * (predict_last_n[max_offset_freq] - 1)
                date_range = pd.date_range(start=start_date,
                                           end=end_date,
                                           freq=lowest_freq)
                if len(date_range) != data[
                        f"{self.cfg.target_variables[0]}_obs"][1].shape[0]:
                    raise ValueError(
                        "Evaluation date range does not match generated predictions."
                    )

                frequency_factor = pd.to_timedelta(
                    lowest_freq) // pd.to_timedelta(freq)
                freq_range = pd.timedelta_range(end=(frequency_factor - 1) *
                                                pd.to_timedelta(freq),
                                                periods=predict_last_n[freq],
                                                freq=freq)

                xr = xarray.Dataset(data_vars=data,
                                    coords={
                                        'date': date_range,
                                        'time_step': freq_range
                                    })
                results[basin][freq]['xr'] = xr

                # only warn once per freq
                if frequency_factor < predict_last_n[freq] and basin == basins[
                        0]:
                    tqdm.write(
                        f'Metrics for {freq} are calculated over last {frequency_factor} elements only. '
                        f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.'
                    )

                if metrics:
                    for target_variable in self.cfg.target_variables:
                        # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H]
                        obs = xr.isel(time_step=slice(-frequency_factor, None)) \
                            .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"]
                        obs['datetime'] = obs.coords['date'] + obs.coords[
                            'time_step']
                        # check if there are observations for this period
                        if not all(obs.isnull()):
                            sim = xr.isel(time_step=slice(-frequency_factor, None)) \
                                .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"]
                            sim['datetime'] = sim.coords['date'] + sim.coords[
                                'time_step']

                            # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero'
                            if target_variable in self.cfg.clip_targets_to_zero:
                                sim = xarray.where(sim < 0, 0, sim)

                            if 'samples' in sim.dims:
                                sim = sim.mean(dim='samples')

                            var_metrics = metrics if isinstance(
                                metrics, list) else metrics[target_variable]
                            if 'all' in var_metrics:
                                var_metrics = get_available_metrics()
                            try:
                                values = calculate_metrics(obs,
                                                           sim,
                                                           metrics=var_metrics,
                                                           resolution=freq)
                            except AllNaNError as err:
                                msg = f'Basin {basin} ' \
                                    + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \
                                    + (f'{freq} ' if len(ds.frequencies) > 1 else '') \
                                    + str(err)
                                LOGGER.warning(msg)
                                values = {
                                    metric: np.nan
                                    for metric in var_metrics
                                }

                            # add variable identifier to metrics if needed
                            if len(self.cfg.target_variables) > 1:
                                values = {
                                    f"{target_variable}_{key}": val
                                    for key, val in values.items()
                                }
                            # add frequency identifier to metrics if needed
                            if len(ds.frequencies) > 1:
                                values = {
                                    f"{key}_{freq}": val
                                    for key, val in values.items()
                                }
                            if experiment_logger is not None:
                                experiment_logger.log_step(**values)
                            for k, v in values.items():
                                results[basin][freq][k] = v

        if (self.period == "validation") and (self.cfg.log_n_figures > 0) and (
                experiment_logger is not None):
            self._create_and_log_figures(results, experiment_logger, epoch)

        if save_results:
            self._save_results(results, epoch)

        return results
Esempio n. 3
0
def _create_ensemble(results_files: List[Path], frequencies: List[str],
                     config: Config) -> dict:
    """Averages the predictions of the passed runs and re-calculates metrics. """
    lowest_freq = sort_frequencies(frequencies)[0]
    ensemble_sum = defaultdict(dict)
    target_vars = config.target_variables

    print('Loading results for each run.')
    for run in tqdm(results_files):
        run_results = pickle.load(open(run, 'rb'))
        for basin, basin_results in run_results.items():
            for freq in frequencies:
                freq_results = basin_results[freq]['xr']

                # sum up the predictions of all basins
                if freq not in ensemble_sum[basin]:
                    ensemble_sum[basin][freq] = freq_results
                else:
                    for target_var in target_vars:
                        ensemble_sum[basin][freq][
                            f'{target_var}_sim'] += freq_results[
                                f'{target_var}_sim']

    # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency
    print('Combining results and calculating metrics.')
    ensemble = defaultdict(lambda: defaultdict(dict))
    for basin in tqdm(ensemble_sum.keys()):
        for freq in frequencies:
            ensemble_xr = ensemble_sum[basin][freq]

            # combine date and time to a single index to calculate metrics
            # create datetime range at the current frequency, removing time steps that are not being predicted
            frequency_factor = int(get_frequency_factor(lowest_freq, freq))
            # make sure the last day is fully contained in the range
            freq_date_range = pd.date_range(start=ensemble_xr.coords['date'].values[0],
                                            end=ensemble_xr.coords['date'].values[-1] \
                                                + pd.Timedelta(days=1, seconds=-1),
                                            freq=freq)
            mask = np.ones(frequency_factor).astype(bool)
            mask[:-len(ensemble_xr.coords['time_step'])] = False
            freq_date_range = freq_date_range[np.tile(
                mask, len(ensemble_xr.coords['date']))]

            ensemble_xr = ensemble_xr.isel(
                time_step=slice(-frequency_factor, None)).stack(
                    datetime=['date', 'time_step'])
            ensemble_xr['datetime'] = freq_date_range
            for target_var in target_vars:
                # average predictions
                ensemble_xr[f'{target_var}_sim'] = ensemble_xr[
                    f'{target_var}_sim'] / len(results_files)

                # clip predictions to zero
                sim = ensemble_xr[f'{target_var}_sim']
                if target_var in config.clip_targets_to_zero:
                    sim = xr.where(sim < 0, 0, sim)

                # calculate metrics
                metrics = config.metrics if isinstance(
                    config.metrics, list) else config.metrics[target_var]
                if 'all' in metrics:
                    metrics = get_available_metrics()
                try:
                    ensemble_metrics = calculate_metrics(
                        ensemble_xr[f'{target_var}_obs'],
                        sim,
                        metrics=metrics,
                        resolution=freq)
                except AllNaNError as err:
                    msg = f'Basin {basin} ' \
                        + (f'{target_var} ' if len(target_vars) > 1 else '') \
                        + (f'{freq} ' if len(frequencies) > 1 else '') \
                        + str(err)
                    print(msg)
                    ensemble_metrics = {metric: np.nan for metric in metrics}

                # add variable identifier to metrics if needed
                if len(target_vars) > 1:
                    ensemble_metrics = {
                        f'{target_var}_{key}': val
                        for key, val in ensemble_metrics.items()
                    }
                for metric, val in ensemble_metrics.items():
                    ensemble[basin][freq][f'{metric}_{freq}'] = val

            ensemble[basin][freq]['xr'] = ensemble_xr

    return dict(ensemble)