Ejemplo n.º 1
0
def test_sort_frequencies():
    """Test the sorting of frequencies. """
    assert sort_frequencies(['1D', '1H', '2D', '3H']) == ['2D', '1D', '3H', '1H']
    assert sort_frequencies(['1M', '1Y']) == ['1Y', '1M']
    assert sort_frequencies(['1D', '48H']) == ['48H', '1D']
    assert sort_frequencies(['1D']) == ['1D']
    assert sort_frequencies([]) == []

    pytest.raises(ValueError, sort_frequencies, ['1D', '1XYZ'])  # not a frequency
Ejemplo n.º 2
0
def _create_ensemble(results_files: List[Path], frequencies: List[str], config: Config) -> dict:
    """Averages the predictions of the passed runs and re-calculates metrics. """
    lowest_freq = sort_frequencies(frequencies)[0]
    ensemble_sum = defaultdict(dict)
    target_vars = config.target_variables

    print('Loading results for each run.')
    for run in tqdm(results_files):
        run_results = pickle.load(open(run, 'rb'))
        for basin, basin_results in run_results.items():
            for freq in frequencies:
                freq_results = basin_results[freq]['xr']

                # sum up the predictions of all basins
                if freq not in ensemble_sum[basin]:
                    ensemble_sum[basin][freq] = freq_results
                else:
                    for target_var in target_vars:
                        ensemble_sum[basin][freq][f'{target_var}_sim'] += freq_results[f'{target_var}_sim']

    # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency
    print('Combining results and calculating metrics.')
    ensemble = defaultdict(lambda: defaultdict(dict))
    for basin in tqdm(ensemble_sum.keys()):
        for freq in frequencies:
            ensemble[basin][freq]['xr'] = ensemble_sum[basin][freq]

            # combine date and time to a single index to calculate metrics
            frequency_factor = pd.to_timedelta(lowest_freq) // pd.to_timedelta(freq)
            ensemble[basin][freq]['xr'] = ensemble[basin][freq]['xr'].isel(
                time_step=slice(-frequency_factor, None)).stack(datetime=['date', 'time_step'])
            ensemble[basin][freq]['xr']['datetime'] = [
                c[0] + c[1] for c in ensemble[basin][freq]['xr'].coords['datetime'].values
            ]

            for target_var in target_vars:
                # average predictions
                ensemble[basin][freq]['xr'][
                    f'{target_var}_sim'] = ensemble[basin][freq]['xr'][f'{target_var}_sim'] / len(results_files)

                # clip predictions to zero
                sim = ensemble[basin][freq]['xr'][f'{target_var}_sim']
                if target_var in config.clip_targets_to_zero:
                    sim = xr.where(sim < 0, 0, sim)

                # calculate metrics
                ensemble_metrics = calculate_metrics(
                    ensemble[basin][freq]['xr'][f'{target_var}_obs'],
                    sim,
                    metrics=config.metrics if isinstance(config.metrics, list) else config.metrics[target_var],
                    resolution=freq)
                # add variable identifier to metrics if needed
                if len(target_vars) > 1:
                    ensemble_metrics = {f'{target_var}_{key}': val for key, val in ensemble_metrics.items()}
                for metric, val in ensemble_metrics.items():
                    ensemble[basin][freq][f'{metric}_{freq}'] = val

    return dict(ensemble)
Ejemplo n.º 3
0
    def __init__(self, cfg: Config):
        super(TiedFrequencyMSERegularization, self).__init__(cfg)
        self._frequencies = sort_frequencies([
            f for f in cfg.use_frequencies
            if cfg.predict_last_n[f] > 0 and f not in cfg.no_loss_frequencies
        ])

        if len(self._frequencies) < 2:
            raise ValueError(
                "TiedFrequencyMSERegularization needs at least two frequencies."
            )
Ejemplo n.º 4
0
    def __init__(self, cfg: Config):
        super(ODELSTM, self).__init__(cfg=cfg)
        if len(cfg.use_frequencies) < 2:
            raise ValueError('ODELSTM needs at least two frequencies.')
        if isinstance(cfg.dynamic_inputs, dict) or isinstance(
                cfg.hidden_size, dict):
            raise ValueError(
                'ODELSTM does not support per-frequency input variables or hidden sizes.'
            )

        # Note: be aware that frequency_factors and slice_timesteps have a slightly different meaning here vs. in
        # MTSLSTM. Here, the frequency_factor is relative to the _lowest_ (not the next-lower) frequency.
        # slice_timesteps[freq] is the input step (counting backwards) in the next-*lower* frequency from where on input
        # data at frequency freq is available.
        self._frequency_factors = {}
        self._slice_timesteps = {}
        self._frequencies = sort_frequencies(cfg.use_frequencies)
        self._init_frequency_factors_and_slice_timesteps()

        # start to count the number of inputs
        self.input_size = len(cfg.dynamic_inputs + cfg.static_attributes +
                              cfg.hydroatlas_attributes +
                              cfg.evolving_attributes)

        if cfg.use_basin_id_encoding:
            self.input_size += cfg.number_of_basins
        if cfg.head.lower() == 'umal':
            self.input_size += 1

        self.lstm_cell = _LSTMCell(self.input_size, self.cfg.hidden_size,
                                   cfg.initial_forget_bias)
        self.ode_cell = _ODERNNCell(self.cfg.hidden_size,
                                    self.cfg.hidden_size,
                                    num_unfolds=self.cfg.ode_num_unfolds,
                                    method=self.cfg.ode_method)
        self.dropout = nn.Dropout(p=cfg.output_dropout)
        self.head = get_head(cfg=cfg,
                             n_in=self.cfg.hidden_size,
                             n_out=self.output_size)
Ejemplo n.º 5
0
    def __init__(self, cfg: Config):
        super(MTSLSTM, self).__init__(cfg=cfg)
        self.lstms = None
        self.transfer_fcs = None
        self.heads = None
        self.dropout = None

        self._slice_timestep = {}
        self._frequency_factors = []

        self._seq_lengths = cfg.seq_length
        self._is_shared_mtslstm = self.cfg.shared_mtslstm  # default: a distinct LSTM per timescale
        self._transfer_mtslstm_states = self.cfg.transfer_mtslstm_states  # default: linear transfer layer
        transfer_modes = [None, "None", "identity", "linear"]
        if self._transfer_mtslstm_states["h"] not in transfer_modes \
                or self._transfer_mtslstm_states["c"] not in transfer_modes:
            raise ValueError(
                f"MTS-LSTM supports state transfer modes {transfer_modes}")

        if len(cfg.use_frequencies) < 2:
            raise ValueError("MTS-LSTM expects more than one input frequency")
        self._frequencies = sort_frequencies(cfg.use_frequencies)

        # start to count the number of inputs
        input_sizes = len(cfg.static_attributes + cfg.hydroatlas_attributes +
                          cfg.evolving_attributes)

        # if is_shared_mtslstm, the LSTM gets an additional frequency flag as input.
        if self._is_shared_mtslstm:
            input_sizes += len(self._frequencies)

        if cfg.use_basin_id_encoding:
            input_sizes += cfg.number_of_basins
        if cfg.head.lower() == "umal":
            input_sizes += 1

        if isinstance(cfg.dynamic_inputs, list):
            input_sizes = {
                freq: input_sizes + len(cfg.dynamic_inputs)
                for freq in self._frequencies
            }
        else:
            if self._is_shared_mtslstm:
                raise ValueError(
                    f'Different inputs not allowed if shared_mtslstm is used.')
            input_sizes = {
                freq: input_sizes + len(cfg.dynamic_inputs[freq])
                for freq in self._frequencies
            }

        if not isinstance(cfg.hidden_size, dict):
            LOGGER.info(
                "No specific hidden size for frequencies are specified. Same hidden size is used for all."
            )
            self._hidden_size = {
                freq: cfg.hidden_size
                for freq in self._frequencies
            }
        else:
            self._hidden_size = cfg.hidden_size

        if (self._is_shared_mtslstm
            or self._transfer_mtslstm_states["h"] == "identity"
            or self._transfer_mtslstm_states["c"] == "identity") \
                and any(size != self._hidden_size[self._frequencies[0]] for size in self._hidden_size.values()):
            raise ValueError(
                "All hidden sizes must be equal if shared_mtslstm is used or state transfer=identity."
            )

        # create layer depending on selected frequencies
        self._init_modules(input_sizes)
        self._reset_parameters()

        # frequency factors are needed to determine the time step of information transfer
        self._init_frequency_factors_and_slice_timesteps()
Ejemplo n.º 6
0
    def evaluate(self,
                 epoch: int = None,
                 save_results: bool = True,
                 metrics: Union[list, dict] = [],
                 model: torch.nn.Module = None,
                 experiment_logger: Logger = None) -> dict:
        """Evaluate the model.
        
        Parameters
        ----------
        epoch : int, optional
            Define a specific epoch to evaluate. By default, the weights of the last epoch are used.
        save_results : bool, optional
            If True, stores the evaluation results in the run directory. By default, True.
        metrics : Union[list, dict], optional
            List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics
        model : torch.nn.Module, optional
            If a model is passed, this is used for validation.
        experiment_logger : Logger, optional
            Logger can be passed during training to log metrics

        Returns
        -------
        dict
            A dictionary containing one xarray per basin with the evaluation results.
        """
        if model is None:
            if self.init_model:
                self._load_weights(epoch=epoch)
                model = self.model
            else:
                raise RuntimeError(
                    "No model was initialized for the evaluation")

        # during validation, depending on settings, only evaluate on a random subset of basins
        basins = self.basins
        if self.period == "validation":
            if len(basins) > self.cfg.validate_n_random_basins:
                random.shuffle(basins)
                basins = basins[:self.cfg.validate_n_random_basins]

        # force model to train-mode when doing mc-dropout evaluation
        if self.cfg.mc_dropout:
            model.train()
        else:
            model.eval()

        results = defaultdict(dict)

        pbar = tqdm(basins, file=sys.stdout)
        pbar.set_description('# Validation' if self.period ==
                             "validation" else "# Evaluation")

        for basin in pbar:

            if self.cfg.cache_validation_data and basin in self.cached_datasets.keys(
            ):
                ds = self.cached_datasets[basin]
            else:
                try:
                    ds = self._get_dataset(basin)
                except NoTrainDataError as error:
                    # skip basin
                    continue
                if self.cfg.cache_validation_data and self.period == "validation":
                    self.cached_datasets[basin] = ds

            loader = DataLoader(ds,
                                batch_size=self.cfg.batch_size,
                                num_workers=0)

            y_hat, y = self._evaluate(model, loader, ds.frequencies)

            predict_last_n = self.cfg.predict_last_n
            seq_length = self.cfg.seq_length
            # if predict_last_n/seq_length are int, there's only one frequency
            if isinstance(predict_last_n, int):
                predict_last_n = {ds.frequencies[0]: predict_last_n}
            if isinstance(seq_length, int):
                seq_length = {ds.frequencies[0]: seq_length}
            lowest_freq = sort_frequencies(ds.frequencies)[0]
            for freq in ds.frequencies:
                if predict_last_n[freq] == 0:
                    continue  # this frequency is not being predicted
                results[basin][freq] = {}

                # rescale predictions
                y_hat_freq = \
                    y_hat[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values
                y_freq = y[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \
                    + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values

                # create xarray
                data = self._create_xarray(y_hat_freq, y_freq)

                # get maximum warmup-offset across all frequencies
                offsets = {
                    freq: (seq_length[freq] - predict_last_n[freq]) *
                    pd.to_timedelta(freq)
                    for freq in ds.frequencies
                }
                max_offset_freq = max(offsets, key=offsets.get)
                start_date = ds.get_period_start(
                    basin) + offsets[max_offset_freq]

                # determine the end of the first sequence (first target in sequence-to-one)
                # we use the end_date stored in the dataset, which also covers issues with per-basin different periods
                end_date = ds.dates[basin]["end_dates"][0] \
                    + pd.Timedelta(days=1, seconds=-1) \
                    - pd.to_timedelta(max_offset_freq) * (predict_last_n[max_offset_freq] - 1)
                date_range = pd.date_range(start=start_date,
                                           end=end_date,
                                           freq=lowest_freq)
                if len(date_range) != data[
                        f"{self.cfg.target_variables[0]}_obs"][1].shape[0]:
                    raise ValueError(
                        "Evaluation date range does not match generated predictions."
                    )

                frequency_factor = pd.to_timedelta(
                    lowest_freq) // pd.to_timedelta(freq)
                freq_range = pd.timedelta_range(end=(frequency_factor - 1) *
                                                pd.to_timedelta(freq),
                                                periods=predict_last_n[freq],
                                                freq=freq)

                xr = xarray.Dataset(data_vars=data,
                                    coords={
                                        'date': date_range,
                                        'time_step': freq_range
                                    })
                results[basin][freq]['xr'] = xr

                # only warn once per freq
                if frequency_factor < predict_last_n[freq] and basin == basins[
                        0]:
                    tqdm.write(
                        f'Metrics for {freq} are calculated over last {frequency_factor} elements only. '
                        f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.'
                    )

                if metrics:
                    for target_variable in self.cfg.target_variables:
                        # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H]
                        obs = xr.isel(time_step=slice(-frequency_factor, None)) \
                            .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"]
                        obs['datetime'] = obs.coords['date'] + obs.coords[
                            'time_step']
                        # check if there are observations for this period
                        if not all(obs.isnull()):
                            sim = xr.isel(time_step=slice(-frequency_factor, None)) \
                                .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"]
                            sim['datetime'] = sim.coords['date'] + sim.coords[
                                'time_step']

                            # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero'
                            if target_variable in self.cfg.clip_targets_to_zero:
                                sim = xarray.where(sim < 0, 0, sim)

                            if 'samples' in sim.dims:
                                sim = sim.mean(dim='samples')

                            var_metrics = metrics if isinstance(
                                metrics, list) else metrics[target_variable]
                            if 'all' in var_metrics:
                                var_metrics = get_available_metrics()
                            try:
                                values = calculate_metrics(obs,
                                                           sim,
                                                           metrics=var_metrics,
                                                           resolution=freq)
                            except AllNaNError as err:
                                msg = f'Basin {basin} ' \
                                    + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \
                                    + (f'{freq} ' if len(ds.frequencies) > 1 else '') \
                                    + str(err)
                                LOGGER.warning(msg)
                                values = {
                                    metric: np.nan
                                    for metric in var_metrics
                                }

                            # add variable identifier to metrics if needed
                            if len(self.cfg.target_variables) > 1:
                                values = {
                                    f"{target_variable}_{key}": val
                                    for key, val in values.items()
                                }
                            # add frequency identifier to metrics if needed
                            if len(ds.frequencies) > 1:
                                values = {
                                    f"{key}_{freq}": val
                                    for key, val in values.items()
                                }
                            if experiment_logger is not None:
                                experiment_logger.log_step(**values)
                            for k, v in values.items():
                                results[basin][freq][k] = v

        if (self.period == "validation") and (self.cfg.log_n_figures > 0) and (
                experiment_logger is not None):
            self._create_and_log_figures(results, experiment_logger, epoch)

        if save_results:
            self._save_results(results, epoch)

        return results
Ejemplo n.º 7
0
def _create_ensemble(results_files: List[Path], frequencies: List[str],
                     config: Config) -> dict:
    """Averages the predictions of the passed runs and re-calculates metrics. """
    lowest_freq = sort_frequencies(frequencies)[0]
    ensemble_sum = defaultdict(dict)
    target_vars = config.target_variables

    print('Loading results for each run.')
    for run in tqdm(results_files):
        run_results = pickle.load(open(run, 'rb'))
        for basin, basin_results in run_results.items():
            for freq in frequencies:
                freq_results = basin_results[freq]['xr']

                # sum up the predictions of all basins
                if freq not in ensemble_sum[basin]:
                    ensemble_sum[basin][freq] = freq_results
                else:
                    for target_var in target_vars:
                        ensemble_sum[basin][freq][
                            f'{target_var}_sim'] += freq_results[
                                f'{target_var}_sim']

    # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency
    print('Combining results and calculating metrics.')
    ensemble = defaultdict(lambda: defaultdict(dict))
    for basin in tqdm(ensemble_sum.keys()):
        for freq in frequencies:
            ensemble_xr = ensemble_sum[basin][freq]

            # combine date and time to a single index to calculate metrics
            # create datetime range at the current frequency, removing time steps that are not being predicted
            frequency_factor = int(get_frequency_factor(lowest_freq, freq))
            # make sure the last day is fully contained in the range
            freq_date_range = pd.date_range(start=ensemble_xr.coords['date'].values[0],
                                            end=ensemble_xr.coords['date'].values[-1] \
                                                + pd.Timedelta(days=1, seconds=-1),
                                            freq=freq)
            mask = np.ones(frequency_factor).astype(bool)
            mask[:-len(ensemble_xr.coords['time_step'])] = False
            freq_date_range = freq_date_range[np.tile(
                mask, len(ensemble_xr.coords['date']))]

            ensemble_xr = ensemble_xr.isel(
                time_step=slice(-frequency_factor, None)).stack(
                    datetime=['date', 'time_step'])
            ensemble_xr['datetime'] = freq_date_range
            for target_var in target_vars:
                # average predictions
                ensemble_xr[f'{target_var}_sim'] = ensemble_xr[
                    f'{target_var}_sim'] / len(results_files)

                # clip predictions to zero
                sim = ensemble_xr[f'{target_var}_sim']
                if target_var in config.clip_targets_to_zero:
                    sim = xr.where(sim < 0, 0, sim)

                # calculate metrics
                metrics = config.metrics if isinstance(
                    config.metrics, list) else config.metrics[target_var]
                if 'all' in metrics:
                    metrics = get_available_metrics()
                try:
                    ensemble_metrics = calculate_metrics(
                        ensemble_xr[f'{target_var}_obs'],
                        sim,
                        metrics=metrics,
                        resolution=freq)
                except AllNaNError as err:
                    msg = f'Basin {basin} ' \
                        + (f'{target_var} ' if len(target_vars) > 1 else '') \
                        + (f'{freq} ' if len(frequencies) > 1 else '') \
                        + str(err)
                    print(msg)
                    ensemble_metrics = {metric: np.nan for metric in metrics}

                # add variable identifier to metrics if needed
                if len(target_vars) > 1:
                    ensemble_metrics = {
                        f'{target_var}_{key}': val
                        for key, val in ensemble_metrics.items()
                    }
                for metric, val in ensemble_metrics.items():
                    ensemble[basin][freq][f'{metric}_{freq}'] = val

            ensemble[basin][freq]['xr'] = ensemble_xr

    return dict(ensemble)
Ejemplo n.º 8
0
    def _create_lookup_table(self, xr: xarray.Dataset):
        lookup = []
        if not self._disable_pbar:
            LOGGER.info("Create lookup table and convert to pytorch tensor")

        # list to collect basins ids of basins without a single training sample
        basins_without_samples = []
        basin_coordinates = xr["basin"].values.tolist()
        for basin in tqdm(basin_coordinates,
                          file=sys.stdout,
                          disable=self._disable_pbar):

            # store data of each frequency as numpy array of shape [time steps, features]
            x_d, x_s, y = {}, {}, {}

            # keys: frequencies, values: array mapping each lowest-frequency
            # sample to its corresponding sample in this frequency
            frequency_maps = {}
            lowest_freq = utils.sort_frequencies(self.frequencies)[0]

            # converting from xarray to pandas DataFrame because resampling is much faster in pandas.
            df_native = xr.sel(basin=basin).to_dataframe()
            for freq in self.frequencies:
                if isinstance(self.cfg.dynamic_inputs, list):
                    dynamic_cols = self.cfg.dynamic_inputs
                else:
                    dynamic_cols = self.cfg.dynamic_inputs[freq]

                df_resampled = df_native[
                    dynamic_cols + self.cfg.target_variables +
                    self.cfg.evolving_attributes].resample(freq).mean()
                x_d[freq] = df_resampled[dynamic_cols].values
                y[freq] = df_resampled[self.cfg.target_variables].values
                if self.cfg.evolving_attributes:
                    x_s[freq] = df_resampled[
                        self.cfg.evolving_attributes].values

                # number of frequency steps in one lowest-frequency step
                frequency_factor = pd.to_timedelta(
                    lowest_freq) // pd.to_timedelta(freq)
                # array position i is the last entry of this frequency that belongs to the lowest-frequency sample i.
                frequency_maps[freq] = np.arange(len(df_resampled) // frequency_factor) \
                                       * frequency_factor + (frequency_factor - 1)

            # store first date of sequence to be able to restore dates during inference
            if not self.is_train:
                self.period_starts[basin] = pd.to_datetime(
                    xr.sel(basin=basin)["date"].values[0])

            # we can ignore the deprecation warning about lists because we don't use the passed lists
            # after the validate_samples call. The alternative numba.typed.Lists is still experimental.
            with warnings.catch_warnings():
                warnings.simplefilter('ignore',
                                      category=NumbaPendingDeprecationWarning)

                # checks inputs and outputs for each sequence. valid: flag = 1, invalid: flag = 0
                # manually unroll the dicts into lists to make sure the order of frequencies is consistent.
                # during inference, we want all samples with sufficient history (even if input is NaN), so
                # we pass x_d, x_s, y as None.
                flag = validate_samples(
                    x_d=[x_d[freq] for freq in self.frequencies]
                    if self.is_train else None,
                    x_s=[x_s[freq] for freq in self.frequencies]
                    if self.is_train and x_s else None,
                    y=[y[freq] for freq in self.frequencies]
                    if self.is_train else None,
                    frequency_maps=[
                        frequency_maps[freq] for freq in self.frequencies
                    ],
                    seq_length=self.seq_len,
                    predict_last_n=self._predict_last_n)
            valid_samples = np.argwhere(flag == 1)
            for f in valid_samples:
                # store pointer to basin and the sample's index in each frequency
                lookup.append((basin, [
                    frequency_maps[freq][int(f)] for freq in self.frequencies
                ]))

            # only store data if this basin has at least one valid sample in the given period
            if valid_samples.size > 0:
                self.x_d[basin] = {
                    freq: torch.from_numpy(_x_d.astype(np.float32))
                    for freq, _x_d in x_d.items()
                }
                self.y[basin] = {
                    freq: torch.from_numpy(_y.astype(np.float32))
                    for freq, _y in y.items()
                }
                if x_s:
                    self.x_s[basin] = {
                        freq: torch.from_numpy(_x_s.astype(np.float32))
                        for freq, _x_s in x_s.items()
                    }
            else:
                basins_without_samples.append(basin)

        if basins_without_samples:
            LOGGER.info(
                f"These basins do not have a single valid sample in the {self.period} period: {basins_without_samples}"
            )
        self.lookup_table = {i: elem for i, elem in enumerate(lookup)}
        self.num_samples = len(self.lookup_table)