示例#1
0
    def _init_frequency_factors_and_slice_timesteps(self):
        for i, freq in enumerate(self._frequencies):
            frequency_factor = get_frequency_factor(self._frequencies[0], freq)
            if frequency_factor != int(frequency_factor):
                raise ValueError(
                    'Frequencies must be multiples of the lowest frequency.')
            self._frequency_factors[freq] = int(frequency_factor)

            if i > 0:
                prev_frequency_factor = get_frequency_factor(
                    self._frequencies[i - 1], freq)
                if self.cfg.predict_last_n[freq] % prev_frequency_factor != 0:
                    raise ValueError(
                        'At all frequencies, predict_last_n must align with the steps of the next-lower frequency.'
                    )

                if self.cfg.seq_length[freq] > self.cfg.seq_length[
                        self._frequencies[i - 1]] * prev_frequency_factor:
                    raise ValueError(
                        'Higher frequencies must have shorter input sequences than lower frequencies.'
                    )

                # we want to pass the state of the day _before_ the next higher frequency starts,
                # because e.g. the mean of a day is stored at the same date at 00:00 in the morning.
                slice_timestep = self.cfg.seq_length[
                    freq] / prev_frequency_factor
                if slice_timestep != int(slice_timestep):
                    raise ValueError(
                        'At all frequencies, seq_length must align with the next-lower frequency steps.'
                    )
                self._slice_timesteps[freq] = int(slice_timestep)

                # in theory, the following conditions would be possible, but they would make the implementation
                # quite complex and are probably hardly ever useful.
                if self.cfg.predict_last_n[self._frequencies[
                        i - 1]] < self.cfg.predict_last_n[
                            freq] / prev_frequency_factor:
                    raise NotImplementedError(
                        'Lower frequencies cannot have smaller predict_last_n values than higher ones.'
                    )

        if any(self.cfg.predict_last_n[f] / self._frequency_factors[f] >
               self._slice_timesteps[self._frequencies[-1]] /
               self._frequency_factors[self._frequencies[-2]]
               for f in self._frequencies):
            raise NotImplementedError(
                'predict_last_n cannot be larger than sequence length of highest frequency.'
            )
示例#2
0
    def _randomize_freq(self, x_d: torch.Tensor, low_frequency: str,
                        high_frequency: str) -> torch.Tensor:
        """Randomize the frequency of the  input sequence. """
        frequency_factor = int(
            get_frequency_factor(low_frequency, high_frequency))
        possible_aggregate_steps = list(
            filter(lambda n: frequency_factor % n == 0,
                   range(1, frequency_factor + 1)))

        t = 0
        max_t = x_d.shape[0] / frequency_factor
        x_d_randomized = []
        while t < max_t:
            highfreq_slice = x_d[t * frequency_factor:(t + 1) *
                                 frequency_factor]

            # aggregate to a random frequency between low and high
            random_aggregate_steps = np.random.choice(possible_aggregate_steps)
            if highfreq_slice.shape[0] % random_aggregate_steps == 0:
                randfreq_slice = highfreq_slice.view(
                    -1, random_aggregate_steps, highfreq_slice.shape[1],
                    highfreq_slice.shape[2]).mean(dim=1)
                # update the frequency indicators.
                randfreq_slice[:, :,
                               -1] = random_aggregate_steps / self._frequency_factors[
                                   high_frequency]
            else:
                # do not randomize last slice if it doesn't align with aggregation steps
                randfreq_slice = highfreq_slice
            x_d_randomized.append(randfreq_slice)

            t += 1

        return torch.cat(x_d_randomized, dim=0)
    def forward(self, prediction: Dict[str, torch.Tensor],
                ground_truth: Dict[str, torch.Tensor], *args) -> torch.Tensor:
        """Calculate the sum of mean squared deviations between adjacent predicted frequencies.

        Parameters
        ----------
        prediction : Dict[str, torch.Tensor]
            Dictionary containing ``y_hat_{frequency}`` for each frequency.
        ground_truth : Dict[str, torch.Tensor]
            Dictionary continaing ``y_{frequency}`` for each frequency.

        Returns
        -------
        torch.Tensor
            The sum of mean squared deviations for each pair of adjacent frequencies.
        """

        loss = 0
        for idx, freq in enumerate(self._frequencies):
            if idx == 0:
                continue
            frequency_factor = int(
                get_frequency_factor(self._frequencies[idx - 1], freq))
            freq_pred = prediction[f'y_hat_{freq}']
            mean_freq_pred = freq_pred.view(
                freq_pred.shape[0], freq_pred.shape[1] // frequency_factor,
                frequency_factor, -1).mean(dim=2)
            lower_freq_pred = prediction[
                f'y_hat_{self._frequencies[idx - 1]}'][:, -mean_freq_pred.
                                                       shape[1]:]
            loss = loss + torch.mean((lower_freq_pred - mean_freq_pred)**2)

        return loss
def test_get_frequency_factor():
    """Test the logic that calculates the ratio between two frequencies. """
    assert get_frequency_factor('1H', '1H') == 1
    assert get_frequency_factor('1A', '1Y') == 1
    assert get_frequency_factor('1Y', '4Q') == 1
    assert get_frequency_factor('1H', '1D') == 1 / 24
    assert get_frequency_factor('1D', '1H') == 24
    assert get_frequency_factor('2D', '12H') == 4
    assert get_frequency_factor('1W', '1D') == 7
    assert get_frequency_factor('1W-MON', '1D') == 7
    assert get_frequency_factor('1Y', '1M') == 12
    assert get_frequency_factor('0D', '0H') == 1

    pytest.raises(ValueError, get_frequency_factor, '1YS', '1M')  # year-start vs. month-end
    pytest.raises(ValueError, get_frequency_factor, '1Q', '1W')  # quarter vs. week
    pytest.raises(ValueError, get_frequency_factor, '1XYZ', '1D')  # not a frequency
    pytest.raises(ValueError, get_frequency_factor, '1Y', '1D')  # disallowed because to_timedelta('1Y') is deprecated
    pytest.raises(ValueError, get_frequency_factor, '1M', '1D')  # disallowed because to_timedelta('1M') is deprecated
    pytest.raises(NotImplementedError, get_frequency_factor, '-1D', '1H')  # we should never need negative frequencies
示例#5
0
 def _init_frequency_factors_and_slice_timesteps(self):
     for idx, freq in enumerate(self._frequencies):
         if idx < len(self._frequencies) - 1:
             frequency_factor = get_frequency_factor(
                 freq, self._frequencies[idx + 1])
             if frequency_factor != int(frequency_factor):
                 raise ValueError(
                     'Adjacent frequencies must be multiples of each other.'
                 )
             self._frequency_factors.append(int(frequency_factor))
             # we want to pass the state of the day _before_ the next higher frequency starts,
             # because e.g. the mean of a day is stored at the same date at 00:00 in the morning.
             slice_timestep = int(
                 self._seq_lengths[self._frequencies[idx + 1]] /
                 self._frequency_factors[idx])
             self._slice_timestep[freq] = slice_timestep
示例#6
0
def _create_ensemble(results_files: List[Path], frequencies: List[str],
                     config: Config) -> dict:
    """Averages the predictions of the passed runs and re-calculates metrics. """
    lowest_freq = sort_frequencies(frequencies)[0]
    ensemble_sum = defaultdict(dict)
    target_vars = config.target_variables

    print('Loading results for each run.')
    for run in tqdm(results_files):
        run_results = pickle.load(open(run, 'rb'))
        for basin, basin_results in run_results.items():
            for freq in frequencies:
                freq_results = basin_results[freq]['xr']

                # sum up the predictions of all basins
                if freq not in ensemble_sum[basin]:
                    ensemble_sum[basin][freq] = freq_results
                else:
                    for target_var in target_vars:
                        ensemble_sum[basin][freq][
                            f'{target_var}_sim'] += freq_results[
                                f'{target_var}_sim']

    # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency
    print('Combining results and calculating metrics.')
    ensemble = defaultdict(lambda: defaultdict(dict))
    for basin in tqdm(ensemble_sum.keys()):
        for freq in frequencies:
            ensemble_xr = ensemble_sum[basin][freq]

            # combine date and time to a single index to calculate metrics
            # create datetime range at the current frequency, removing time steps that are not being predicted
            frequency_factor = int(get_frequency_factor(lowest_freq, freq))
            # make sure the last day is fully contained in the range
            freq_date_range = pd.date_range(start=ensemble_xr.coords['date'].values[0],
                                            end=ensemble_xr.coords['date'].values[-1] \
                                                + pd.Timedelta(days=1, seconds=-1),
                                            freq=freq)
            mask = np.ones(frequency_factor).astype(bool)
            mask[:-len(ensemble_xr.coords['time_step'])] = False
            freq_date_range = freq_date_range[np.tile(
                mask, len(ensemble_xr.coords['date']))]

            ensemble_xr = ensemble_xr.isel(
                time_step=slice(-frequency_factor, None)).stack(
                    datetime=['date', 'time_step'])
            ensemble_xr['datetime'] = freq_date_range
            for target_var in target_vars:
                # average predictions
                ensemble_xr[f'{target_var}_sim'] = ensemble_xr[
                    f'{target_var}_sim'] / len(results_files)

                # clip predictions to zero
                sim = ensemble_xr[f'{target_var}_sim']
                if target_var in config.clip_targets_to_zero:
                    sim = xr.where(sim < 0, 0, sim)

                # calculate metrics
                metrics = config.metrics if isinstance(
                    config.metrics, list) else config.metrics[target_var]
                if 'all' in metrics:
                    metrics = get_available_metrics()
                try:
                    ensemble_metrics = calculate_metrics(
                        ensemble_xr[f'{target_var}_obs'],
                        sim,
                        metrics=metrics,
                        resolution=freq)
                except AllNaNError as err:
                    msg = f'Basin {basin} ' \
                        + (f'{target_var} ' if len(target_vars) > 1 else '') \
                        + (f'{freq} ' if len(frequencies) > 1 else '') \
                        + str(err)
                    print(msg)
                    ensemble_metrics = {metric: np.nan for metric in metrics}

                # add variable identifier to metrics if needed
                if len(target_vars) > 1:
                    ensemble_metrics = {
                        f'{target_var}_{key}': val
                        for key, val in ensemble_metrics.items()
                    }
                for metric, val in ensemble_metrics.items():
                    ensemble[basin][freq][f'{metric}_{freq}'] = val

            ensemble[basin][freq]['xr'] = ensemble_xr

    return dict(ensemble)
示例#7
0
def mean_peak_timing(obs: DataArray,
                     sim: DataArray,
                     window: int = None,
                     resolution: str = '1D',
                     datetime_coord: str = None) -> float:
    """Mean difference in peak flow timing.
    
    Uses scipy.find_peaks to find peaks in the observed time series. Starting with all observed peaks, those with a
    prominence of less than the standard deviation of the observed time series are discarded. Next, the lowest peaks
    are subsequently discarded until all remaining peaks have a distance of at least 100 steps. Finally, the
    corresponding peaks in the simulated time series are searched in a window of size `window` on either side of the
    observed peaks and the absolute time differences between observed and simulated peaks is calculated.
    The final metric is the mean absolute time difference across all peaks. For more details, see Appendix of [#]_
    
    Parameters
    ----------
    obs : DataArray
        Observed time series.
    sim : DataArray
        Simulated time series.
    window : int, optional
        Size of window to consider on each side of the observed peak for finding the simulated peak. That is, the total
        window length to find the peak in the simulations is :math:`2 * \\text{window} + 1` centered at the observed
        peak. The default depends on the temporal resolution, e.g. for a resolution of '1D', a window of 3 is used and 
        for a resolution of '1H' the the window size is 12.
    resolution : str, optional
        Temporal resolution of the time series in pandas format, e.g. '1D' for daily and '1H' for hourly.
    datetime_coord : str, optional
        Name of datetime coordinate. Tried to infer automatically if not specified.
        

    Returns
    -------
    float
        Mean peak time difference.

    References
    ----------
    .. [#] Kratzert, F., Klotz, D., Hochreiter, S., and Nearing, G. S.: A note on leveraging synergy in multiple 
        meteorological datasets with deep learning for rainfall-runoff modeling, Hydrol. Earth Syst. Sci. Discuss., 
        https://doi.org/10.5194/hess-2020-221, in review, 2020. 
    """
    # verify inputs
    _validate_inputs(obs, sim)

    # get time series with only valid observations (scipy's find_peaks doesn't guarantee correctness with NaNs)
    obs, sim = _mask_valid(obs, sim)

    # heuristic to get indices of peaks and their corresponding height.
    peaks, _ = signal.find_peaks(obs.values,
                                 distance=100,
                                 prominence=np.std(obs.values))

    # infer name of datetime index
    if datetime_coord is None:
        datetime_coord = utils.infer_datetime_coord(obs)

    if window is None:
        # infer a reasonable window size
        window = max(int(utils.get_frequency_factor('12H', resolution)), 3)

    # evaluate timing
    timing_errors = []
    for idx in peaks:
        # skip peaks at the start and end of the sequence and peaks around missing observations
        # (NaNs that were removed in obs & sim would result in windows that span too much time).
        if (idx - window < 0) or (idx + window >= len(obs)) or (pd.date_range(
                obs[idx - window][datetime_coord].values,
                obs[idx + window][datetime_coord].values,
                freq=resolution).size != 2 * window + 1):
            continue

        # check if the value at idx is a peak (both neighbors must be smaller)
        if (sim[idx] > sim[idx - 1]) and (sim[idx] > sim[idx + 1]):
            peak_sim = sim[idx]
        else:
            # define peak around idx as the max value inside of the window
            values = sim[idx - window:idx + window + 1]
            peak_sim = values[values.argmax()]

        # get xarray object of qobs peak, for getting the date and calculating the datetime offset
        peak_obs = obs[idx]

        # calculate the time difference between the peaks
        delta = peak_obs.coords[datetime_coord] - peak_sim.coords[
            datetime_coord]

        timing_error = np.abs(delta.values / pd.to_timedelta(resolution))

        timing_errors.append(timing_error)

    return np.mean(timing_errors) if len(timing_errors) > 0 else np.nan
    def _create_lookup_table(self, xr: xarray.Dataset):
        lookup = []
        if not self._disable_pbar:
            LOGGER.info("Create lookup table and convert to pytorch tensor")

        # list to collect basins ids of basins without a single training sample
        basins_without_samples = []
        basin_coordinates = xr["basin"].values.tolist()
        for basin in tqdm(basin_coordinates, file=sys.stdout, disable=self._disable_pbar):

            # store data of each frequency as numpy array of shape [time steps, features]
            x_d, x_s, y = {}, {}, {}

            # keys: frequencies, values: array mapping each lowest-frequency
            # sample to its corresponding sample in this frequency
            frequency_maps = {}
            lowest_freq = utils.sort_frequencies(self.frequencies)[0]

            # converting from xarray to pandas DataFrame because resampling is much faster in pandas.
            df_native = xr.sel(basin=basin).to_dataframe()
            for freq in self.frequencies:
                # make sure that possible mass inputs are sorted to the beginning of the dynamic feature list
                if isinstance(self.cfg.dynamic_inputs, list):
                    dynamic_cols = self.cfg.mass_inputs + self.cfg.dynamic_inputs
                else:
                    dynamic_cols = self.cfg.mass_inputs + self.cfg.dynamic_inputs[freq]

                df_resampled = df_native[dynamic_cols + self.cfg.target_variables +
                                         self.cfg.evolving_attributes].resample(freq).mean()
                x_d[freq] = df_resampled[dynamic_cols].values
                y[freq] = df_resampled[self.cfg.target_variables].values
                if self.cfg.evolving_attributes:
                    x_s[freq] = df_resampled[self.cfg.evolving_attributes].values

                # number of frequency steps in one lowest-frequency step
                frequency_factor = int(utils.get_frequency_factor(lowest_freq, freq))
                # array position i is the last entry of this frequency that belongs to the lowest-frequency sample i.
                frequency_maps[freq] = np.arange(len(df_resampled) // frequency_factor) \
                                       * frequency_factor + (frequency_factor - 1)

            # store first date of sequence to be able to restore dates during inference
            if not self.is_train:
                self.period_starts[basin] = pd.to_datetime(xr.sel(basin=basin)["date"].values[0])

            # we can ignore the deprecation warning about lists because we don't use the passed lists
            # after the validate_samples call. The alternative numba.typed.Lists is still experimental.
            with warnings.catch_warnings():
                warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

                # checks inputs and outputs for each sequence. valid: flag = 1, invalid: flag = 0
                # manually unroll the dicts into lists to make sure the order of frequencies is consistent.
                # during inference, we want all samples with sufficient history (even if input is NaN), so
                # we pass x_d, x_s, y as None.
                flag = validate_samples(x_d=[x_d[freq] for freq in self.frequencies] if self.is_train else None,
                                        x_s=[x_s[freq] for freq in self.frequencies] if self.is_train and x_s else None,
                                        y=[y[freq] for freq in self.frequencies] if self.is_train else None,
                                        frequency_maps=[frequency_maps[freq] for freq in self.frequencies],
                                        seq_length=self.seq_len,
                                        predict_last_n=self._predict_last_n)
            valid_samples = np.argwhere(flag == 1)
            for f in valid_samples:
                # store pointer to basin and the sample's index in each frequency
                lookup.append((basin, [frequency_maps[freq][int(f)] for freq in self.frequencies]))

            # only store data if this basin has at least one valid sample in the given period
            if valid_samples.size > 0:
                self.x_d[basin] = {freq: torch.from_numpy(_x_d.astype(np.float32)) for freq, _x_d in x_d.items()}
                self.y[basin] = {freq: torch.from_numpy(_y.astype(np.float32)) for freq, _y in y.items()}
                if x_s:
                    self.x_s[basin] = {freq: torch.from_numpy(_x_s.astype(np.float32)) for freq, _x_s in x_s.items()}
            else:
                basins_without_samples.append(basin)

        if basins_without_samples:
            LOGGER.info(
                f"These basins do not have a single valid sample in the {self.period} period: {basins_without_samples}")
        self.lookup_table = {i: elem for i, elem in enumerate(lookup)}
        self.num_samples = len(self.lookup_table)
示例#9
0
    def evaluate(self,
                 epoch: int = None,
                 save_results: bool = True,
                 metrics: Union[list, dict] = [],
                 model: torch.nn.Module = None,
                 experiment_logger: Logger = None) -> dict:
        """Evaluate the model.
        
        Parameters
        ----------
        epoch : int, optional
            Define a specific epoch to evaluate. By default, the weights of the last epoch are used.
        save_results : bool, optional
            If True, stores the evaluation results in the run directory. By default, True.
        metrics : Union[list, dict], optional
            List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics
        model : torch.nn.Module, optional
            If a model is passed, this is used for validation.
        experiment_logger : Logger, optional
            Logger can be passed during training to log metrics

        Returns
        -------
        dict
            A dictionary containing one xarray per basin with the evaluation results.
        """
        if model is None:
            if self.init_model:
                self._load_weights(epoch=epoch)
                model = self.model
            else:
                raise RuntimeError(
                    "No model was initialized for the evaluation")

        # during validation, depending on settings, only evaluate on a random subset of basins
        basins = self.basins
        if self.period == "validation":
            if len(basins) > self.cfg.validate_n_random_basins:
                random.shuffle(basins)
                basins = basins[:self.cfg.validate_n_random_basins]

        # force model to train-mode when doing mc-dropout evaluation
        if self.cfg.mc_dropout:
            model.train()
        else:
            model.eval()

        results = defaultdict(dict)

        pbar = tqdm(basins, file=sys.stdout)
        pbar.set_description('# Validation' if self.period ==
                             "validation" else "# Evaluation")

        for basin in pbar:

            if self.cfg.cache_validation_data and basin in self.cached_datasets.keys(
            ):
                ds = self.cached_datasets[basin]
            else:
                try:
                    ds = self._get_dataset(basin)
                except NoTrainDataError as error:
                    # skip basin
                    continue
                if self.cfg.cache_validation_data and self.period == "validation":
                    self.cached_datasets[basin] = ds

            loader = DataLoader(ds,
                                batch_size=self.cfg.batch_size,
                                num_workers=0)

            y_hat, y, loss = self._evaluate(model, loader, ds.frequencies)

            # log loss of this basin plus number of samples in the logger to compute epoch aggregates later
            if experiment_logger is not None:
                experiment_logger.log_step(loss=(loss, len(loader)))

            predict_last_n = self.cfg.predict_last_n
            seq_length = self.cfg.seq_length
            # if predict_last_n/seq_length are int, there's only one frequency
            if isinstance(predict_last_n, int):
                predict_last_n = {ds.frequencies[0]: predict_last_n}
            if isinstance(seq_length, int):
                seq_length = {ds.frequencies[0]: seq_length}
            lowest_freq = sort_frequencies(ds.frequencies)[0]
            for freq in ds.frequencies:
                if predict_last_n[freq] == 0:
                    continue  # this frequency is not being predicted
                results[basin][freq] = {}

                # rescale observations
                feature_scaler = self.scaler["xarray_feature_scale"][
                    self.cfg.target_variables].to_array().values
                feature_center = self.scaler["xarray_feature_center"][
                    self.cfg.target_variables].to_array().values
                y_freq = y[freq] * feature_scaler + feature_center
                # rescale predictions
                if y_hat[freq].ndim == 3 or (len(feature_scaler) == 1):
                    y_hat_freq = y_hat[freq] * feature_scaler + feature_center
                elif y_hat[freq].ndim == 4:
                    # if y_hat has 4 dim and we have multiple features we expand the dimensions for scaling
                    feature_scaler = np.expand_dims(feature_scaler, (0, 1, 3))
                    feature_center = np.expand_dims(feature_center, (0, 1, 3))
                    y_hat_freq = y_hat[freq] * feature_scaler + feature_center
                else:
                    raise RuntimeError(
                        f"Simulations have {y_hat[freq].ndim} dimension. Only 3 and 4 are supported."
                    )

                # create xarray
                data = self._create_xarray(y_hat_freq, y_freq)

                # get warmup-offsets across all frequencies
                offsets = {
                    freq: ds.get_period_start(basin) +
                    (seq_length[freq] - 1) * to_offset(freq)
                    for freq in ds.frequencies
                }
                max_offset_freq = max(offsets, key=offsets.get)
                start_date = offsets[max_offset_freq]

                # determine the end of the first sequence (first target in sequence-to-one)
                # we use the end_date stored in the dataset, which also covers issues with per-basin different periods
                end_date = ds.dates[basin]["end_dates"][0] + pd.Timedelta(
                    days=1, seconds=-1)

                # date range at the lowest frequency
                date_range = pd.date_range(start=start_date,
                                           end=end_date,
                                           freq=lowest_freq)
                if len(date_range) != data[
                        f"{self.cfg.target_variables[0]}_obs"][1].shape[0]:
                    raise ValueError(
                        "Evaluation date range does not match generated predictions."
                    )

                # freq_range are the steps of the current frequency at each lowest-frequency step
                frequency_factor = int(get_frequency_factor(lowest_freq, freq))
                freq_range = list(
                    range(frequency_factor - predict_last_n[freq],
                          frequency_factor))

                # create datetime range at the current frequency
                freq_date_range = pd.date_range(start=start_date,
                                                end=end_date,
                                                freq=freq)
                # remove datetime steps that are not being predicted from the datetime range
                mask = np.ones(frequency_factor).astype(bool)
                mask[:-predict_last_n[freq]] = False
                freq_date_range = freq_date_range[np.tile(
                    mask, len(date_range))]

                xr = xarray.Dataset(data_vars=data,
                                    coords={
                                        'date': date_range,
                                        'time_step': freq_range
                                    })
                results[basin][freq]['xr'] = xr

                # only warn once per freq
                if frequency_factor < predict_last_n[freq] and basin == basins[
                        0]:
                    tqdm.write(
                        f'Metrics for {freq} are calculated over last {frequency_factor} elements only. '
                        f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.'
                    )

                if metrics:
                    for target_variable in self.cfg.target_variables:
                        # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H]
                        obs = xr.isel(time_step=slice(-frequency_factor, None)) \
                            .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"]
                        obs['datetime'] = freq_date_range
                        # check if there are observations for this period
                        if not all(obs.isnull()):
                            sim = xr.isel(time_step=slice(-frequency_factor, None)) \
                                .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"]
                            sim['datetime'] = freq_date_range

                            # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero'
                            if target_variable in self.cfg.clip_targets_to_zero:
                                sim = xarray.where(sim < 0, 0, sim)

                            if 'samples' in sim.dims:
                                sim = sim.mean(dim='samples')

                            var_metrics = metrics if isinstance(
                                metrics, list) else metrics[target_variable]
                            if 'all' in var_metrics:
                                var_metrics = get_available_metrics()
                            try:
                                values = calculate_metrics(obs,
                                                           sim,
                                                           metrics=var_metrics,
                                                           resolution=freq)
                            except AllNaNError as err:
                                msg = f'Basin {basin} ' \
                                    + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \
                                    + (f'{freq} ' if len(ds.frequencies) > 1 else '') \
                                    + str(err)
                                LOGGER.warning(msg)
                                values = {
                                    metric: np.nan
                                    for metric in var_metrics
                                }

                            # add variable identifier to metrics if needed
                            if len(self.cfg.target_variables) > 1:
                                values = {
                                    f"{target_variable}_{key}": val
                                    for key, val in values.items()
                                }
                            # add frequency identifier to metrics if needed
                            if len(ds.frequencies) > 1:
                                values = {
                                    f"{key}_{freq}": val
                                    for key, val in values.items()
                                }
                            if experiment_logger is not None:
                                experiment_logger.log_step(**values)
                            for k, v in values.items():
                                results[basin][freq][k] = v

        # convert default dict back to normal Python dict to avoid unexpected behavior when trying to access
        # a non-existing basin
        results = dict(results)

        if (self.period == "validation") and (self.cfg.log_n_figures > 0) and (
                experiment_logger is not None):
            self._create_and_log_figures(results, experiment_logger, epoch)

        if save_results:
            self._save_results(results, epoch)

        return results