def predict_dependency( self, data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet], variable: str, values: Iterable, mode: str = "dataframe", target="decoder", show_progress_bar: bool = False, **kwargs, ) -> Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]: """ Predict partial dependency. Args: data (Union[DataLoader, pd.DataFrame, TimeSeriesDataSet]): data variable (str): variable which to modify values (Iterable): array of values to probe mode (str, optional): Output mode. Defaults to "dataframe". Either * "series": values are average prediction and index are probed values * "dataframe": columns are as obtained by the `dataset.x_to_index()` method, prediction (which is the mean prediction over the time horizon), normalized_prediction (which are predictions devided by the prediction for the first probed value) the variable name for the probed values * "raw": outputs a tensor of shape len(values) x prediction_shape target: Defines which values are overwritten for making a prediction. Same as in :py:meth:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet.set_overwrite_values`. Defaults to "decoder". show_progress_bar: if to show progress bar. Defaults to False. **kwargs: additional kwargs to :py:meth:`~predict` method Returns: Union[np.ndarray, torch.Tensor, pd.Series, pd.DataFrame]: output """ values = np.asarray(values) if isinstance(data, pd.DataFrame): # convert to dataframe data = TimeSeriesDataSet.from_parameters(self.dataset_parameters, data, predict=True) elif isinstance(data, DataLoader): data = data.dataset results = [] progress_bar = tqdm(desc="Predict", unit=" batches", total=len(values), disable=not show_progress_bar) for idx, value in enumerate(values): # set values data.set_overwrite_values(variable=variable, values=value, target=target) # predict kwargs.setdefault("mode", "prediction") if idx == 0 and mode == "dataframe": # need index for returning as dataframe res, index = self.predict(data, return_index=True, **kwargs) results.append(res) else: results.append(self.predict(data, **kwargs)) # increment progress progress_bar.update() data.reset_overwrite_values( ) # reset overwrite values to avoid side-effect # results to one tensor results = torch.stack(results, dim=0) # convert results to requested output format if mode == "series": results = results[:, ~torch.isnan(results[0])].mean( 1) # average samples and prediction horizon results = pd.Series(results, index=values) elif mode == "dataframe": # take mean over time is_nan = torch.isnan(results) results[is_nan] = 0 results = results.sum(-1) / (~is_nan).float().sum(-1) # create dataframe dependencies = (index.iloc[np.tile(np.arange( len(index)), len(values))].reset_index(drop=True).assign( prediction=results.flatten())) dependencies[variable] = values.repeat(len(data)) first_prediction = dependencies.groupby( data.group_ids, observed=True).prediction.transform("first") dependencies["normalized_prediction"] = dependencies[ "prediction"] / first_prediction dependencies["id"] = dependencies.groupby(data.group_ids, observed=True).ngroup() results = dependencies elif mode == "raw": pass else: raise ValueError( f"mode {mode} is unknown - see documentation for available modes" ) return results
def predict( self, data: Union[DataLoader, pd.DataFrame, TimeSeriesDataSet], mode: Union[str, Tuple[str, str]] = "prediction", return_index: bool = False, return_decoder_lengths: bool = False, batch_size: int = 64, num_workers: int = 0, fast_dev_run: bool = False, show_progress_bar: bool = False, return_x: bool = False, ): """ predict dataloader Args: dataloader: dataloader, dataframe or dataset mode: one of "prediction", "quantiles" or "raw", or tuple ``("raw", output_name)`` where output_name is a name in the dictionary returned by ``forward()`` return_index: if to return the prediction index return_decoder_lengths: if to return decoder_lengths batch_size: batch size for dataloader - only used if data is not a dataloader is passed num_workers: number of workers for dataloader - only used if data is not a dataloader is passed fast_dev_run: if to only return results of first batch show_progress_bar: if to show progress bar. Defaults to False. return_x: if to return network inputs Returns: output, x, index, decoder_lengths: some elements might not be present depending on what is configured to be returned """ # convert to dataloader if isinstance(data, pd.DataFrame): data = TimeSeriesDataSet.from_parameters(self.dataset_parameters, data, predict=True) if isinstance(data, TimeSeriesDataSet): dataloader = data.to_dataloader(batch_size=batch_size, train=False, num_workers=num_workers) else: dataloader = data # ensure passed dataloader is correct assert isinstance( dataloader.dataset, TimeSeriesDataSet ), "dataset behind dataloader mut be TimeSeriesDataSet" # prepare model self.eval() # no dropout, etc. no gradients # run predictions output = [] decode_lenghts = [] x_list = [] index = [] progress_bar = tqdm(desc="Predict", unit=" batches", total=len(dataloader), disable=not show_progress_bar) with torch.no_grad(): for x, _ in dataloader: # move data to appropriate device for name in x.keys(): if x[name].device != self.device: x[name].to(self.device) # make prediction out = self(x) # raw output is dictionary out["prediction"] = self.transform_output(out) lengths = x["decoder_lengths"] if return_decoder_lengths: decode_lenghts.append(lengths) nan_mask = self._get_mask(out["prediction"].size(1), lengths) if isinstance(mode, (tuple, list)): if mode[0] == "raw": out = out[mode[1]] else: raise ValueError( f"If a tuple is specified, the first element must be 'raw' - got {mode[0]} instead" ) elif mode == "prediction": out = self.loss.to_prediction(out["prediction"]) # mask non-predictions out = out.masked_fill(nan_mask, torch.tensor(float("nan"))) elif mode == "quantiles": out = self.loss.to_quantiles(out["prediction"]) # mask non-predictions out = out.masked_fill(nan_mask.unsqueeze(-1), torch.tensor(float("nan"))) elif mode == "raw": pass else: raise ValueError( f"Unknown mode {mode} - see docs for valid arguments") output.append(out) if return_x: x_list.append(x) if return_index: index.append(dataloader.dataset.x_to_index(x)) progress_bar.update() if fast_dev_run: break # concatenate if isinstance(mode, (tuple, list)) or mode != "raw": output = torch.cat(output, dim=0) elif mode == "raw": output_cat = {} for name in output[0].keys(): output_cat[name] = torch.cat([out[name] for out in output], dim=0) output = output_cat # generate output if return_x or return_index or return_decoder_lengths: output = [output] if return_x: x_cat = {} for name in x_list[0].keys(): x_cat[name] = torch.cat([x[name] for x in x_list], dim=0) x_cat = x_cat output.append(x_cat) if return_index: output.append(pd.concat(index, axis=0, ignore_index=True)) if return_decoder_lengths: output.append(torch.cat(decode_lenghts, dim=0)) return output