def _create_ensemble(results_files: List[Path], frequencies: List[str], config: Config) -> dict: """Averages the predictions of the passed runs and re-calculates metrics. """ lowest_freq = sort_frequencies(frequencies)[0] ensemble_sum = defaultdict(dict) target_vars = config.target_variables print('Loading results for each run.') for run in tqdm(results_files): run_results = pickle.load(open(run, 'rb')) for basin, basin_results in run_results.items(): for freq in frequencies: freq_results = basin_results[freq]['xr'] # sum up the predictions of all basins if freq not in ensemble_sum[basin]: ensemble_sum[basin][freq] = freq_results else: for target_var in target_vars: ensemble_sum[basin][freq][f'{target_var}_sim'] += freq_results[f'{target_var}_sim'] # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency print('Combining results and calculating metrics.') ensemble = defaultdict(lambda: defaultdict(dict)) for basin in tqdm(ensemble_sum.keys()): for freq in frequencies: ensemble[basin][freq]['xr'] = ensemble_sum[basin][freq] # combine date and time to a single index to calculate metrics frequency_factor = pd.to_timedelta(lowest_freq) // pd.to_timedelta(freq) ensemble[basin][freq]['xr'] = ensemble[basin][freq]['xr'].isel( time_step=slice(-frequency_factor, None)).stack(datetime=['date', 'time_step']) ensemble[basin][freq]['xr']['datetime'] = [ c[0] + c[1] for c in ensemble[basin][freq]['xr'].coords['datetime'].values ] for target_var in target_vars: # average predictions ensemble[basin][freq]['xr'][ f'{target_var}_sim'] = ensemble[basin][freq]['xr'][f'{target_var}_sim'] / len(results_files) # clip predictions to zero sim = ensemble[basin][freq]['xr'][f'{target_var}_sim'] if target_var in config.clip_targets_to_zero: sim = xr.where(sim < 0, 0, sim) # calculate metrics ensemble_metrics = calculate_metrics( ensemble[basin][freq]['xr'][f'{target_var}_obs'], sim, metrics=config.metrics if isinstance(config.metrics, list) else config.metrics[target_var], resolution=freq) # add variable identifier to metrics if needed if len(target_vars) > 1: ensemble_metrics = {f'{target_var}_{key}': val for key, val in ensemble_metrics.items()} for metric, val in ensemble_metrics.items(): ensemble[basin][freq][f'{metric}_{freq}'] = val return dict(ensemble)
def evaluate(self, epoch: int = None, save_results: bool = True, metrics: Union[list, dict] = [], model: torch.nn.Module = None, experiment_logger: Logger = None) -> dict: """Evaluate the model. Parameters ---------- epoch : int, optional Define a specific epoch to evaluate. By default, the weights of the last epoch are used. save_results : bool, optional If True, stores the evaluation results in the run directory. By default, True. metrics : Union[list, dict], optional List of metrics to compute during evaluation. Can also be a dict that specifies per-target metrics model : torch.nn.Module, optional If a model is passed, this is used for validation. experiment_logger : Logger, optional Logger can be passed during training to log metrics Returns ------- dict A dictionary containing one xarray per basin with the evaluation results. """ if model is None: if self.init_model: self._load_weights(epoch=epoch) model = self.model else: raise RuntimeError( "No model was initialized for the evaluation") # during validation, depending on settings, only evaluate on a random subset of basins basins = self.basins if self.period == "validation": if len(basins) > self.cfg.validate_n_random_basins: random.shuffle(basins) basins = basins[:self.cfg.validate_n_random_basins] # force model to train-mode when doing mc-dropout evaluation if self.cfg.mc_dropout: model.train() else: model.eval() results = defaultdict(dict) pbar = tqdm(basins, file=sys.stdout) pbar.set_description('# Validation' if self.period == "validation" else "# Evaluation") for basin in pbar: if self.cfg.cache_validation_data and basin in self.cached_datasets.keys( ): ds = self.cached_datasets[basin] else: try: ds = self._get_dataset(basin) except NoTrainDataError as error: # skip basin continue if self.cfg.cache_validation_data and self.period == "validation": self.cached_datasets[basin] = ds loader = DataLoader(ds, batch_size=self.cfg.batch_size, num_workers=0) y_hat, y = self._evaluate(model, loader, ds.frequencies) predict_last_n = self.cfg.predict_last_n seq_length = self.cfg.seq_length # if predict_last_n/seq_length are int, there's only one frequency if isinstance(predict_last_n, int): predict_last_n = {ds.frequencies[0]: predict_last_n} if isinstance(seq_length, int): seq_length = {ds.frequencies[0]: seq_length} lowest_freq = sort_frequencies(ds.frequencies)[0] for freq in ds.frequencies: if predict_last_n[freq] == 0: continue # this frequency is not being predicted results[basin][freq] = {} # rescale predictions y_hat_freq = \ y_hat[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \ + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values y_freq = y[freq] * self.scaler["xarray_feature_scale"][self.cfg.target_variables].to_array().values \ + self.scaler["xarray_feature_center"][self.cfg.target_variables].to_array().values # create xarray data = self._create_xarray(y_hat_freq, y_freq) # get maximum warmup-offset across all frequencies offsets = { freq: (seq_length[freq] - predict_last_n[freq]) * pd.to_timedelta(freq) for freq in ds.frequencies } max_offset_freq = max(offsets, key=offsets.get) start_date = ds.get_period_start( basin) + offsets[max_offset_freq] # determine the end of the first sequence (first target in sequence-to-one) # we use the end_date stored in the dataset, which also covers issues with per-basin different periods end_date = ds.dates[basin]["end_dates"][0] \ + pd.Timedelta(days=1, seconds=-1) \ - pd.to_timedelta(max_offset_freq) * (predict_last_n[max_offset_freq] - 1) date_range = pd.date_range(start=start_date, end=end_date, freq=lowest_freq) if len(date_range) != data[ f"{self.cfg.target_variables[0]}_obs"][1].shape[0]: raise ValueError( "Evaluation date range does not match generated predictions." ) frequency_factor = pd.to_timedelta( lowest_freq) // pd.to_timedelta(freq) freq_range = pd.timedelta_range(end=(frequency_factor - 1) * pd.to_timedelta(freq), periods=predict_last_n[freq], freq=freq) xr = xarray.Dataset(data_vars=data, coords={ 'date': date_range, 'time_step': freq_range }) results[basin][freq]['xr'] = xr # only warn once per freq if frequency_factor < predict_last_n[freq] and basin == basins[ 0]: tqdm.write( f'Metrics for {freq} are calculated over last {frequency_factor} elements only. ' f'Ignoring {predict_last_n[freq] - frequency_factor} predictions per sequence.' ) if metrics: for target_variable in self.cfg.target_variables: # stack dates and time_steps so we don't just evaluate every 24H when use_frequencies=[1D, 1H] obs = xr.isel(time_step=slice(-frequency_factor, None)) \ .stack(datetime=['date', 'time_step'])[f"{target_variable}_obs"] obs['datetime'] = obs.coords['date'] + obs.coords[ 'time_step'] # check if there are observations for this period if not all(obs.isnull()): sim = xr.isel(time_step=slice(-frequency_factor, None)) \ .stack(datetime=['date', 'time_step'])[f"{target_variable}_sim"] sim['datetime'] = sim.coords['date'] + sim.coords[ 'time_step'] # clip negative predictions to zero, if variable is listed in config 'clip_target_to_zero' if target_variable in self.cfg.clip_targets_to_zero: sim = xarray.where(sim < 0, 0, sim) if 'samples' in sim.dims: sim = sim.mean(dim='samples') var_metrics = metrics if isinstance( metrics, list) else metrics[target_variable] if 'all' in var_metrics: var_metrics = get_available_metrics() try: values = calculate_metrics(obs, sim, metrics=var_metrics, resolution=freq) except AllNaNError as err: msg = f'Basin {basin} ' \ + (f'{target_variable} ' if len(self.cfg.target_variables) > 1 else '') \ + (f'{freq} ' if len(ds.frequencies) > 1 else '') \ + str(err) LOGGER.warning(msg) values = { metric: np.nan for metric in var_metrics } # add variable identifier to metrics if needed if len(self.cfg.target_variables) > 1: values = { f"{target_variable}_{key}": val for key, val in values.items() } # add frequency identifier to metrics if needed if len(ds.frequencies) > 1: values = { f"{key}_{freq}": val for key, val in values.items() } if experiment_logger is not None: experiment_logger.log_step(**values) for k, v in values.items(): results[basin][freq][k] = v if (self.period == "validation") and (self.cfg.log_n_figures > 0) and ( experiment_logger is not None): self._create_and_log_figures(results, experiment_logger, epoch) if save_results: self._save_results(results, epoch) return results
def _create_ensemble(results_files: List[Path], frequencies: List[str], config: Config) -> dict: """Averages the predictions of the passed runs and re-calculates metrics. """ lowest_freq = sort_frequencies(frequencies)[0] ensemble_sum = defaultdict(dict) target_vars = config.target_variables print('Loading results for each run.') for run in tqdm(results_files): run_results = pickle.load(open(run, 'rb')) for basin, basin_results in run_results.items(): for freq in frequencies: freq_results = basin_results[freq]['xr'] # sum up the predictions of all basins if freq not in ensemble_sum[basin]: ensemble_sum[basin][freq] = freq_results else: for target_var in target_vars: ensemble_sum[basin][freq][ f'{target_var}_sim'] += freq_results[ f'{target_var}_sim'] # divide the prediction sum by number of runs to get the mean prediction for each basin and frequency print('Combining results and calculating metrics.') ensemble = defaultdict(lambda: defaultdict(dict)) for basin in tqdm(ensemble_sum.keys()): for freq in frequencies: ensemble_xr = ensemble_sum[basin][freq] # combine date and time to a single index to calculate metrics # create datetime range at the current frequency, removing time steps that are not being predicted frequency_factor = int(get_frequency_factor(lowest_freq, freq)) # make sure the last day is fully contained in the range freq_date_range = pd.date_range(start=ensemble_xr.coords['date'].values[0], end=ensemble_xr.coords['date'].values[-1] \ + pd.Timedelta(days=1, seconds=-1), freq=freq) mask = np.ones(frequency_factor).astype(bool) mask[:-len(ensemble_xr.coords['time_step'])] = False freq_date_range = freq_date_range[np.tile( mask, len(ensemble_xr.coords['date']))] ensemble_xr = ensemble_xr.isel( time_step=slice(-frequency_factor, None)).stack( datetime=['date', 'time_step']) ensemble_xr['datetime'] = freq_date_range for target_var in target_vars: # average predictions ensemble_xr[f'{target_var}_sim'] = ensemble_xr[ f'{target_var}_sim'] / len(results_files) # clip predictions to zero sim = ensemble_xr[f'{target_var}_sim'] if target_var in config.clip_targets_to_zero: sim = xr.where(sim < 0, 0, sim) # calculate metrics metrics = config.metrics if isinstance( config.metrics, list) else config.metrics[target_var] if 'all' in metrics: metrics = get_available_metrics() try: ensemble_metrics = calculate_metrics( ensemble_xr[f'{target_var}_obs'], sim, metrics=metrics, resolution=freq) except AllNaNError as err: msg = f'Basin {basin} ' \ + (f'{target_var} ' if len(target_vars) > 1 else '') \ + (f'{freq} ' if len(frequencies) > 1 else '') \ + str(err) print(msg) ensemble_metrics = {metric: np.nan for metric in metrics} # add variable identifier to metrics if needed if len(target_vars) > 1: ensemble_metrics = { f'{target_var}_{key}': val for key, val in ensemble_metrics.items() } for metric, val in ensemble_metrics.items(): ensemble[basin][freq][f'{metric}_{freq}'] = val ensemble[basin][freq]['xr'] = ensemble_xr return dict(ensemble)