def predict(
        self,
        X,
        time_bins=None,
        return_ci=False,
        ci_width=0.683,
        return_interval_probs=False,
    ):
        """
        Make queries to nearest neighbor search index build on the transformed XGBoost space.
        Compute a Kaplan-Meier estimator for each neighbor-set. Predict the KM estimators.

        Args:
            X (pd.DataFrame): Dataframe with samples to generate predictions

            time_bins (np.array): Specified time windows to use when making survival predictions

            return_ci (Bool): Whether to return confidence intervals via the Exponential Greenwood formula

            ci_width (Float): Width of confidence interval

            return_interval_probs (Bool): Boolean indicating if interval probabilities are
                supposed to be returned. If False the cumulative survival is returned.


        Returns:
            (pd.DataFrame): A dataframe of survival probabilities
            for all times (columns), from a time_bins array, for all samples of X
            (rows). If return_interval_probs is True, the interval probabilities are returned
            instead of the cumulative survival probabilities.

            upper_ci (np.array): Upper confidence interval for the survival
            probability values

            lower_ci (np.array): Lower confidence interval for the survival
            probability values
        """

        # converting to xgb format
        d_matrix = xgb.DMatrix(X)

        # getting leaves and extracting neighbors
        leaves = self.bst.predict(d_matrix, pred_leaf=True)

        if self.radius:
            assert self.radius > 0, "Radius must be positive"

            neighs, _ = self.tree.query_radius(
                leaves, r=self.radius, return_distance=True
            )

            number_of_neighbors = np.array([len(neigh) for neigh in neighs])

            if np.argwhere(number_of_neighbors == 1).shape[0] > 0:
                # If there is at least one sample without neighbors apart from itself
                # a warning is raised suggesting a radius increase
                warnings.warn(
                    "Warning: Some samples don't have neighbors apart from itself. Increase the radius",
                    RuntimeWarning,
                )
        else:
            _, neighs = self.tree.query(leaves, k=self.n_neighbors)

        # gathering times and events/censors for neighbor sets
        T_neighs = self.T_train[neighs]
        E_neighs = self.E_train[neighs]

        # vectorized (very fast!) implementation of Kaplan Meier curves
        if time_bins is None:
            time_bins = self.time_bins

        # calculating z-score from width
        z = st.norm.ppf(0.5 + ci_width / 2)

        preds_df, upper_ci, lower_ci = calculate_kaplan_vectorized(
            T_neighs, E_neighs, time_bins, z
        )

        if return_ci and return_interval_probs:
            raise ValueError(
                "Confidence intervals for interval probabilities is not supported. Choose between return_ci and return_interval_probs."
            )

        if return_interval_probs:
            preds_df = calculate_interval_failures(preds_df)
            return preds_df

        if return_ci:
            return preds_df, upper_ci, lower_ci

        return preds_df
    def fit(
        self,
        X,
        y,
        persist_train=True,
        index_id=None,
        time_bins=None,
        ci_width=0.683,
        **xgb_kwargs,
    ):
        """
        Fit a single decision tree using xgboost. For each leaf in the tree,
        build a Kaplan-Meier estimator.

        !!! Note
            * Differently from `XGBSEKaplanNeighbors`, in `XGBSEKaplanTree`, the width of
            the confidence interval (`ci_width`) must be specified at fit time.

        Args:

            X ([pd.DataFrame, np.array]): Design matrix to fit XGBoost model

            y (structured array(numpy.bool_, numpy.number)): Binary event indicator as first field,
                and time of event or time of censoring as second field.

            persist_train (Bool): Whether or not to persist training data to use explainability
                through prototypes

            index_id (pd.Index): User defined index if intended to use explainability
                through prototypes

            time_bins (np.array): Specified time windows to use when making survival predictions

            ci_width (Float): Width of confidence interval

        Returns:
            XGBSEKaplanTree: Trained instance of XGBSEKaplanTree
        """

        E_train, T_train = convert_y(y)
        if time_bins is None:
            time_bins = get_time_bins(T_train, E_train)
        self.time_bins = time_bins

        # converting data to xgb format
        dtrain = convert_data_to_xgb_format(X, y, self.xgb_params["objective"])

        # training XGB
        self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=1, **xgb_kwargs)
        self.feature_importances_ = self.bst.get_score()

        # getting leaves
        leaves = self.bst.predict(dtrain, pred_leaf=True)

        # organizing elements per leaf
        leaf_neighs = (
            pd.DataFrame({"leaf": leaves})
            .groupby("leaf")
            .apply(lambda x: list(x.index))
        )

        # getting T and E for each leaf
        T_leaves = _align_leaf_target(leaf_neighs, T_train)
        E_leaves = _align_leaf_target(leaf_neighs, E_train)

        # calculating z-score from width
        z = st.norm.ppf(0.5 + ci_width / 2)

        # vectorized (very fast!) implementation of Kaplan Meier curves
        (
            self._train_survival,
            self._train_upper_ci,
            self._train_lower_ci,
        ) = calculate_kaplan_vectorized(T_leaves, E_leaves, time_bins, z)

        # adding leaf indexes
        self._train_survival = self._train_survival.set_index(leaf_neighs.index)
        self._train_upper_ci = self._train_upper_ci.set_index(leaf_neighs.index)
        self._train_lower_ci = self._train_lower_ci.set_index(leaf_neighs.index)

        if persist_train:
            self.persist_train = True
            if index_id is None:
                index_id = X.index.copy()
            self.tree = BallTree(leaves.reshape(-1, 1), metric="hamming", leaf_size=40)
        self.index_id = index_id

        return self
    T_valid,
    E_train,
    E_test,
    E_valid,
    y_train,
    y_test,
    y_valid,
    features,
) = get_data()

# generating Kaplan Meier for all tests

time_bins = get_time_bins(T_train, E_train, 100)

mean, high, low = calculate_kaplan_vectorized(T_train.values.reshape(1, -1),
                                              E_train.values.reshape(1, -1),
                                              time_bins)

km_survival = pd.concat([mean] * len(y_train))
km_survival = km_survival.reset_index(drop=True)

# generating xgbse predictions for all tests

xgbse_model = XGBSEDebiasedBCE()

xgbse_model.fit(
    X_train,
    y_train,
    num_boost_round=1000,
    validation_data=(X_valid, y_valid),
    early_stopping_rounds=10,