Beispiel #1
0
 def predict(self,
             X: Union[torch.Tensor, SliceDict],
             type: str = 'mean',
             *args,
             **kwargs) -> np.ndarray:
     """
     Return an attribute of the distribution (by default the mean) as a numpy array.
     """
     X = to_tensor(X, device=self.device, dtype=self.module_dtype_)
     y_out = []
     for params in self.forward_iter(X, training=False):
         batch_size = len(params[0])
         distribution_kwargs = dict(
             zip(self.distribution_param_names_, params))
         dist = self.distribution(**distribution_kwargs)
         yp = getattr(dist, type)
         if callable(yp):
             yp = yp(*args, **kwargs)
         yp = to_numpy(yp)
         if yp.shape[0] != batch_size:
             raise RuntimeError(
                 f"`{self.distribution.__name__}.{type}` produced a tensor whose leading dim is {yp.shape[0]}, "
                 f"expected {batch_size}.")
         y_out.append(yp)
     y_out = np.concatenate(y_out, 0)
     return y_out
Beispiel #2
0
    def estimate_laplace_params(self, X, y, **fit_params):
        means = torch.cat(
            [param.data.view(-1) for param in self.module_.parameters()])

        # get loss, hessian:
        y_pred = self.infer(X, **fit_params)
        y_true = to_tensor(y, device=self.device, dtype=self.module_dtype_)
        loss = self.get_loss(y_pred, y_true, reduction='sum')
        hess = hessian(output=loss,
                       inputs=list(self.module_.parameters()),
                       allow_unused=True,
                       progress=False)

        # create mvnorm for laplace approx:
        try:
            self.laplace_params_ = torch.distributions.MultivariateNormal(
                means, covariance_matrix=torch.inverse(hess))
            self.converged_ = True
        except RuntimeError as e:
            if 'lapack' in str(e) or 'cholesky' in str(e):
                warn(
                    "Model failed to converge; `laplace_params` cannot be estimated"
                )
                fake_cov = (2 * means.abs().max() * torch.eye(len(hess)))**2
                self.laplace_params_ = torch.distributions.MultivariateNormal(
                    means, covariance_matrix=fake_cov)
                self.converged_ = False
            else:
                raise e
Beispiel #3
0
    def predict_dataframe(self, dataframe: 'DataFrame',
                          preprocessor: Union[ColumnTransformer, Sequence],
                          type: str, x: torch.Tensor):
        """
        Experimental.
        """
        from pandas import DataFrame

        # check x, broadcast:
        x = to_tensor(x, device=self.device, dtype=self.module_dtype_)
        if len(x.shape) == 2:
            assert x.shape[1] == 1
        elif len(x.shape) == 1:
            x = x[:, None]
        else:
            raise RuntimeError("Expected `x` to be 1D")

        # generate dist-param predictions:
        assert len(dataframe.index) == len(set(dataframe.index))
        with torch.no_grad():
            X = to_tensor(preprocessor.transform(dataframe),
                          device=self.device,
                          dtype=self.module_dtype_)
            params = self.infer(X)
            distribution_kwargs = dict(
                zip(self.distribution_param_names_, params))
            dist = self.distribution(**distribution_kwargs)

            # plug x into method:
            pred_broadcasted = getattr(dist, type)(x)

        # flatten, joinable along original index:
        index_broadcasted = to_tensor(dataframe.index.values,
                                      device=self.device)[None, :].repeat(
                                          len(x), 1)
        x_broadcasted = x.repeat(1, len(dataframe.index))
        return DataFrame({
            'index': index_broadcasted.view(-1).numpy(),
            'x': x_broadcasted.view(-1).numpy(),
            type: pred_broadcasted.view(-1).numpy()
        })
Beispiel #4
0
 def get_loss(self,
              y_pred: torch.Tensor,
              y_true: torch.Tensor,
              X: Optional[torch.Tensor] = None,
              training: bool = False,
              **kwargs):
     y_true = to_tensor(y_true,
                        device=self.device,
                        dtype=self.module_dtype_)
     neg_log_lik = self.criterion_(y_pred, y_true, **kwargs)
     penalty = self.criterion_.get_penalty(y_true=y_true,
                                           module=self.module_)
     return neg_log_lik + penalty
Beispiel #5
0
    def partial_fit(self,
                    X,
                    y=None,
                    classes=None,
                    input_feature_names: Optional[Sequence[str]] = None,
                    **fit_params):
        # infer number of input features if appropriate:
        if self.module_input_feature_names_ is None:
            self.module_input_feature_names_ = self._infer_input_feature_names(
                X, input_feature_names)

        # convert y to the right dtype:
        y = to_tensor(y, device=self.device, dtype=self.module_dtype_)

        return super().partial_fit(X=X, y=y, classes=classes, **fit_params)
Beispiel #6
0
 def infer(self, x: Union[torch.Tensor, SliceDict], **fit_params):
     x = to_tensor(x, device=self.device, dtype=self.module_dtype_)
     return super().infer(x=x, **fit_params)
Beispiel #7
0
    def km_summary(self,
                   dataframe: 'DataFrame',
                   preprocessor: Union[ColumnTransformer, Sequence],
                   time_colname: str,
                   censor_colname: str,
                   start_time_colname: Optional[str] = None) -> 'DataFrame':
        """
        :param dataframe: A pandas DataFrame, or a DataFrameGroupBy (i.e., the result of calling `df.groupby([...])`).
        :param preprocessor: Either a sklearn ColumnTransformer that takes the dataframe and returns X, or a list of
        column-names (such that `X = dataframe.loc[:,preprocessor].values`)
        :param time_colname: The column-name in the dataframe for time-to-event.
        :param censor_colname: The column-name in the dataframe for the censoring indicator.
        :param start_time_colname: Optional, the column-name in the dataframe for start-times (for left-truncation).
        :return: A DataFrame with kaplan-meier estimates.
        """
        try:
            from pandas.core.groupby.generic import DataFrameGroupBy
        except ImportError as e:
            raise ImportError("Must install pandas for `km_summary`") from e

        if isinstance(dataframe, DataFrameGroupBy):
            df_applied = dataframe.apply(self.km_summary,
                                         preprocessor=preprocessor,
                                         time_colname=time_colname,
                                         censor_colname=censor_colname,
                                         start_time_colname=start_time_colname)
            index_idx = [i for i, _ in enumerate(df_applied.index.names)]
            return df_applied.reset_index(level=index_idx[:-1], drop=False)
        else:

            # preprocess X:
            if hasattr(preprocessor, 'transform'):
                X = preprocessor.transform(dataframe)
            else:
                X = dataframe.loc[:, preprocessor].values
            X = to_tensor(X, device=self.device, dtype=self.module_dtype_)

            # km estimate:
            df_km = km_summary(time=dataframe[time_colname].values,
                               is_upper_cens=dataframe[censor_colname],
                               lower_trunc=dataframe[start_time_colname]
                               if start_time_colname else None)

            # generate predicted params, transpose as inputs to distribution:
            with torch.no_grad():
                y_preds = self.infer(X)
                kwargs = {
                    k: y_true[None, :]
                    for k, y_true in zip(self.distribution_param_names_,
                                         y_preds)
                }
                distribution = self.distribution(**kwargs)

            # get unique times in distribution-friendly format:
            y = df_km.loc[:, ['time']].values
            if self.scale_y:
                y = self.y_scaler_.transform(y)
            y = to_tensor(y, device=self.device, dtype=self.module_dtype_)

            # b/c dist-kwargs transposed, broadcasting logic means we get array with dims: (times, dataframe_rows)
            observed = y[:, [0]]
            surv = 1. - distribution.cdf(observed)
            if start_time_colname:
                # TODO: either figure out if taking the average is valid, or emit a warning
                min_ltrunc = np.full_like(
                    observed, fill_value=dataframe[start_time_colname].min())
                if self.scale_y:
                    min_ltrunc = self.y_scaler_.transform(min_ltrunc)
                min_ltrunc = to_tensor(min_ltrunc,
                                       device=self.device,
                                       dtype=self.module_dtype_)
                surv /= (1. - distribution.cdf(min_ltrunc))
            # this is then reduced, collapsing across dataframe rows, so that we get a mean estimate for this dataset
            df_km['model_estimate'] = torch.mean(surv, dim=1)
            return df_km