Пример #1
0
def get_scores(model, y_test, delta_test, time_grid, surv_residual = False, cens_residual = False):
    n = y_test.shape[0]
    x_train, target = model.training_data
    y_train, delta_train = target

    # compute residual from training data
    exp_residual_train = np.nan_to_num(np.exp(np.log(y_train) - model.predict(x_train).reshape(-1)))
    exp_residual_test = np.nan_to_num(np.exp(np.log(y_test) - model.predict(x_test).reshape(-1)))

    # compute exp(-theta) from test data to evaluate accelerating component
    exp_predict_neg_test = np.nan_to_num(np.exp(-model.predict(x_test)).reshape(-1))

    naf_base = NelsonAalenFitter().fit(y_train, event_observed = delta_train)
    kmf_cens = KaplanMeierFitter().fit(y_train, event_observed = 1 - delta_train)
    
    if cens_residual == True:
        cens_test = kmf_cens.survival_function_at_times(exp_residual_test)
    elif cens_residual == False:
        cens_test = kmf_cens.survival_function_at_times(y_test)

    bss = []
    nblls = []
    for t in time_grid:
        bs, nbll = get_score(n, t, y_test, delta_test, naf_base, kmf_cens, cens_test, exp_predict_neg_test, surv_residual, cens_residual, model)
        bss.append(bs)
        nblls.append(-nbll)

    return (np.array(bss), np.array(nblls))
Пример #2
0
def integrated_brier_score(time_true: np.ndarray, time_pred: np.ndarray,
                           event_observed: np.ndarray,
                           time_bins: np.ndarray) -> float:
    r"""Compute the integrated Brier score for a predicted survival function.

    The integrated Brier score is defined as the mean squared error between
    the true and predicted survival functions at time t, integrated over all
    timepoints.

    Parameters
    ----------
    time_true : np.ndarray, shape=(n_samples,)
        The true time to event or censoring for each sample.
    time_pred : np.ndarray, shape=(n_samples, n_time_bins)
        The predicted survival probabilities for each sample in each time bin.
    event_observed : np.ndarray, shape=(n_samples,)
        The event indicator for each sample (1 = event, 0 = censoring).
    time_bins : np.ndarray, shape=(n_time_bins,)
        The time bins for which the survival function was computed.

    Returns
    -------
    float
        The integrated Brier score of the predictions.

    Notes
    -----
    This function uses the definition from [1]_ with inverse probability
    of censoring weighting (IPCW) to correct for censored observations. The weights
    are computed using the Kaplan-Meier estimate of the censoring distribution.

    References
    ----------
    .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher, ‘Assessment
       and comparison of prognostic classification schemes for survival data’,
       Statistics in Medicine, vol. 18, no. 17‐18, pp. 2529–2545, Sep. 1999.
    """

    # compute weights for inverse probability of censoring weighting (IPCW)
    censoring_km = KaplanMeierFitter()
    censoring_km.fit(time_true, 1 - event_observed)
    weights_event = censoring_km.survival_function_at_times(
        time_true).values.reshape(-1, 1)
    weights_no_event = censoring_km.survival_function_at_times(
        time_bins).values.reshape(1, -1)

    # scores for subjects with event before time t for each time bin
    had_event = (time_true[:, np.newaxis] <=
                 time_bins) & event_observed[:, np.newaxis]
    scores_event = np.where(had_event, (0 - time_pred)**2 / weights_event, 0)
    # scores for subjects with no event and no censoring before time t for each time bin
    scores_no_event = np.where((time_true[:, np.newaxis] > time_bins),
                               (1 - time_pred)**2 / weights_no_event, 0)

    scores = np.mean(scores_event + scores_no_event, axis=0)

    # integrate over all time bins
    score = np.trapz(scores, time_bins) / time_bins.max()
    return score
Пример #3
0
def signedByMedianSurvival(
    T: pd.Series,
    E: pd.Series,
    mask: pd.Series,
    *,
    alternative_mask: Union[None, pd.Series] = None,
    timeline: Union[None, Sequence] = None,
) -> int:

    """
    Decide if group defined by mask has a better (+1) or worse (-1) prognosis.
    Since typically, mask defines low expression, +1 means accelerating disease.

    Parameters
    ----------
    T
        A series of survival times.
    E
        A series of events, where 1 is the event (death).
    mask
        A Pandas mask for the main group.
    alternative_mask
        The second group is the negation of the first mask
        by default. This parameter sets a custom mask.
    timeline
        A series of time points (days in TCGA) when to sample survival probs.

    Returns
    -------
    Sign of the survival benefit for the grouping mask.
    """

    if alternative_mask is None:
        alternative_mask = ~mask
    if timeline is None:
        timeline = [0, 1500, 3000, 4500, 6000, 7500, 9000]
    kmf1 = KaplanMeierFitter()
    kmf1.fit(
        T[mask], E[mask],
    )
    kmf2 = KaplanMeierFitter()
    kmf2.fit(
        T[alternative_mask], E[alternative_mask],
    )
    if np.trapz(kmf1.survival_function_at_times(timeline)) > np.trapz(
        kmf2.survival_function_at_times(timeline)
    ):
        return 1
    else:
        return -1
Пример #4
0
def show_survival_curve(df,
                        t_col,
                        y_col,
                        max_time=None,
                        weight=None,
                        save_file=None):
    plt.figure(figsize=(8, 6))
    plt.rcParams["font.size"] = 14
    colors = ['blue', 'red', 'magenta']

    tr_uniq = np.sort(df[t_col].astype(int).unique())
    max_time = df[y_col].max() if max_time is None else max_time
    time = df[y_col].values
    event = np.where(df[y_col] < max_time, 1, 0)
    verbose_days = [
        0,
        int((max_time - 1) / 3),
        int((max_time - 1) * 2 / 3),
        int(max_time) - 1
    ]

    for d in verbose_days:
        plt.text(d,
                 0.6,
                 f'RR({d}day)',
                 horizontalalignment='center',
                 verticalalignment='center')

    curve_list = []
    elapsed_days = np.array([i for i in range(int(max_time))])
    kmf = KaplanMeierFitter()
    for i, tr in enumerate(tr_uniq):
        t_idx = (df[t_col] == tr)
        if weight is None:
            kmf.fit(time[t_idx], event[t_idx], label=f'tr={tr}')
        else:
            kmf.fit(time[t_idx],
                    event[t_idx],
                    label=f'tr={tr}',
                    weights=weight[t_idx])
        curve_list.append(kmf.survival_function_at_times(elapsed_days))
        ax = kmf.plot(c=colors[i])
        for d in verbose_days:
            surv_prob = kmf.survival_function_at_times(d).values[0]
            ax = plt.scatter(d, surv_prob, marker='o', c=colors[i])
            ax = plt.text(d,
                          0.6 - 0.02 * (i + 1),
                          f'{surv_prob :.3f}',
                          c=colors[i],
                          horizontalalignment='center',
                          verticalalignment='center')

    plt.xlim(-3, int(max_time) + 3)
    plt.ylim(0.5, 1.05)
    plt.xlabel('Followed days (elapsed days)')
    plt.ylabel('Survival probability (retention rate)')
    plt.legend(loc='best')
    plt.grid()
    plt.tight_layout()
    if save_file is not None:
        plt.savefig(save_file)
    plt.show()

    return (np.array(curve_list[1]) - np.array(curve_list[0])).reshape(-1)
Пример #5
0
def index_of_survival(request: HttpRequest, all_parameter: str):
    """
    response = {
        data = 
    }
    """
    mm = all_parameter.split("&")

    st = mm[0].split("=")[1]
    if ("," in mm[1].split("=")[1]):
        ct = mm[1].split("=")[1].split(",")
    else:
        ct = [mm[1].split("=")[1]]

    b = API.DatabaseAPI("tcga")
    my_dict_b = b.query_collection_obs()
    my_df_b = pd.DataFrame(my_dict_b)
    select_part = my_df_b.loc[my_df_b["primary_disease"].isin(ct), :]

    if len(ct) > 8:
        response = {
            "error":
            "Too many datasets. You can select no more than eight datasets."
        }
        return JsonResponse(response)

    ref = mm[5].split("=")[1]
    if ("," in mm[2].split("=")[1]):
        cell = mm[2].split("=")[1].split(",")
    else:
        cell = [mm[2].split("=")[1]]
    up = mm[3].split("=")[1]
    dn = mm[4].split("=")[1]

    select = select_part["primary_disease"].tolist()
    if ref == "EPIC":
        columns_list = ["EPIC_cellFractions." + i for i in cell]
        ref = API.DatabaseAPI("ref")
    elif ref == "LM":
        columns_list = ["LM_" + i for i in cell]
        ref = API.DatabaseAPI("LM_ref")
    elif ref == "QS":
        columns_list = ["QS_" + i for i in cell]
        ref = API.DatabaseAPI("QS_ref")
    else:
        response = {"error": "reference error"}
        return JsonResponse(response)
    cellID = select_part["cellID"].tolist()
    my_df_d = select_part.loc[:, columns_list]
    my_df_d.index = cellID
    my_df_t = my_df_d.T

    genes = ref.query_collection_var()["geneSymbol"]

    gg = ref.query_collection_gene_X_var_by_obs(genes)
    gg = pd.DataFrame(gg)

    gg.columns = ref.query_collection_obs()["celltype"]
    gg_mean = pd.DataFrame(gg.T.mean(axis=1))
    gg_mean = gg_mean.loc[cell, :]
    expression = my_df_t.multiply(gg_mean.values)
    expression_t = expression.T
    expression_t = pd.DataFrame(expression_t.sum(axis=1), columns=["sum"])
    expression_t = expression_t.sort_values(by=["sum"], ascending=False)
    number = expression_t.shape[0]
    number1 = int(number / 100 * (100 - int(up)))
    number2 = int(number / 100 * (100 - int(dn)))
    samples = expression_t.index.tolist()
    sample = []
    for each in samples:
        names = each.split(".")
        sample.append(names[0] + "." + names[1] + "." + names[2])

    up_sample = sample[:number1]
    dn_sample = sample[number2 + 1:]

    matches = {"Dead": 1, "Alive": 0, "-": 0}
    a = API.DatabaseAPI("survival")

    #up part
    my_dict_a = a.query_collection_obs()
    my_df_a = pd.DataFrame(my_dict_a)
    my_df_a = my_df_a.loc[my_df_a["sample"].isin(up_sample), :]
    OSEVENT = my_df_a["OSEVENT"].tolist()
    E = [matches[i] for i in OSEVENT]
    if st == "OS":
        T = my_df_a["OSDAY"].tolist()
    else:
        T = my_df_a["RFSDAY"].tolist()

    E_end_up = [E[i] for i in range(len(T)) if T[i] != "-"]
    T_end = [T[i] for i in range(len(T)) if T[i] != "-"]
    T_end = list(map(float, T_end))
    T_end_up = list(map(lambda x: round(x / 30, 2), T_end))
    kmf = KaplanMeierFitter()
    kmf.fit(T_end_up, E_end_up)

    sf = kmf.survival_function_.T
    xa = sf.columns.tolist()
    y1a = list(map(lambda x: round(x, 3), sf.values[0].tolist()))
    ci = kmf.confidence_interval_survival_function_.T.values
    y2a = list(map(lambda x: round(x, 3), ci[1].tolist()))
    y3a = list(map(lambda x: round(x, 3), ci[0].tolist()))
    xca = [T_end_up[i] for i in range(len(T_end_up)) if E_end_up[i] == 0]
    xca = list(map(float, xca))
    yca = list(
        map(lambda x: round(x, 3),
            kmf.survival_function_at_times(xca).tolist()))

    #dn part
    my_dict_a = a.query_collection_obs()
    my_df_a = pd.DataFrame(my_dict_a)
    my_df_a = my_df_a.loc[my_df_a["sample"].isin(dn_sample), :]

    OSEVENT = my_df_a["OSEVENT"].tolist()
    E = [matches[i] for i in OSEVENT]
    if st == "OS":
        T = my_df_a["OSDAY"].tolist()
    else:
        T = my_df_a["RFSDAY"].tolist()

    E_end_dn = [E[i] for i in range(len(T)) if T[i] != "-"]
    T_end = [T[i] for i in range(len(T)) if T[i] != "-"]
    T_end = list(map(float, T_end))
    T_end_dn = list(map(lambda x: round(x / 30, 2), T_end))
    kmf = KaplanMeierFitter()
    kmf.fit(T_end_dn, E_end_dn)

    sf = kmf.survival_function_.T
    xb = sf.columns.tolist()
    y1b = list(map(lambda x: round(x, 3), sf.values[0].tolist()))
    ci = kmf.confidence_interval_survival_function_.T.values
    y2b = list(map(lambda x: round(x, 3), ci[1].tolist()))
    y3b = list(map(lambda x: round(x, 3), ci[0].tolist()))
    xcb = [T_end_dn[i] for i in range(len(T_end_dn)) if E_end_dn[i] == 0]
    xcb = list(map(float, xcb))
    ycb = list(
        map(lambda x: round(x, 3),
            kmf.survival_function_at_times(xcb).tolist()))

    results = logrank_test(T_end_up,
                           T_end_dn,
                           event_observed_A=E_end_up,
                           event_observed_B=E_end_dn)
    pValues1 = float(results.summary["p"].values)

    dfA = pd.DataFrame({'E': E_end_up, 'T': T_end_up, 'groupA': 1})
    dfB = pd.DataFrame({'E': E_end_dn, 'T': T_end_dn, 'groupA': 0})
    df = pd.concat([dfA, dfB])
    cph = CoxPHFitter().fit(df, 'T', 'E')
    pValues2 = float(cph.summary["p"].values)

    response = {
        "data": [{
            "pValues1": pValues1,
            "pValues2": pValues2
        }, {
            "line": {
                "dash": "solid",
                "color": "red",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xa,
            "y": y1a,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "line": {
                "dash": "dash",
                "color": "red",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xa,
            "y": y2a,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "line": {
                "dash": "dash",
                "color": "red",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xa,
            "y": y3a,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "mode": "markers",
            "name": "",
            "text": "",
            "type": "scatter",
            "x": xca,
            "y": yca,
            "xaxis": "x1",
            "yaxis": "y1",
            "marker": {
                "size": 10,
                "color": "black",
                "symbol": "cross-thin-open",
                "opacity": 1,
                "sizeref": 1,
                "sizemode": "area"
            },
            "showlegend": False
        }, {
            "line": {
                "dash": "solid",
                "color": "blue",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xb,
            "y": y1b,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "line": {
                "dash": "dash",
                "color": "blue",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xb,
            "y": y2b,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "line": {
                "dash": "dash",
                "color": "blue",
                "shape": "hv",
                "width": 2
            },
            "mode": "lines",
            "name": "",
            "type": "scatter",
            "x": xb,
            "y": y3b,
            "xaxis": "x1",
            "yaxis": "y1",
            "showlegend": False
        }, {
            "mode": "markers",
            "name": "",
            "text": "",
            "type": "scatter",
            "x": xcb,
            "y": ycb,
            "xaxis": "x1",
            "yaxis": "y1",
            "marker": {
                "size": 10,
                "color": "black",
                "symbol": "cross-thin-open",
                "opacity": 1,
                "sizeref": 1,
                "sizemode": "area"
            },
            "showlegend": False
        }]
    }
    return JsonResponse(response)
Пример #6
0
S2 = data[data.Stage_group == 2]
km2 = KM()
km2.fit(S2.loc[:, "Time"], event_observed=S2.loc[:, 'Event'], label='Stage IV')

ax = km1.plot(ci_show=False)
km2.plot(ax=ax, ci_show=False)
plt.xlabel('time')
plt.ylabel('Survival probability estimate')
plt.savefig('two_km_curves', dpi=300)

# Let's compare the survival functions at 90, 180, 270, and 360 days

# In[37]:

survivals = pd.DataFrame([90, 180, 270, 360], columns=['time'])
survivals.loc[:, 'Group 1'] = km1.survival_function_at_times(
    survivals['time']).values
survivals.loc[:, 'Group 2'] = km2.survival_function_at_times(
    survivals['time']).values

# In[38]:

survivals

# This makes clear the difference in survival between the Stage III and IV cancer groups in the dataset.

# <a name='5-1'></a>
# ## 5.1 Bonus: Log-Rank Test
#
# To say whether there is a statistical difference between the survival curves we can run the log-rank test. This test tells us the probability that we could observe this data if the two curves were the same. The derivation of the log-rank test is somewhat complicated, but luckily `lifelines` has a simple function to compute it.
#
# Run the next cell to compute a p-value using `lifelines.statistics.logrank_test`.
Пример #7
0
def filter_survival(filter_id):

    data_filtered = filtering(filter_id)
    data_filtered = data_filtered[data_filtered[filter_id['cell_full']] != "missing"]

    # Get the groups for ntiles and run the Kaplan Meier fitter for each of them
    # If the group_sizes are provided, use the binning function, otherwise the general ntiles
    if filter_id['group_sizes'] != None:
        data_filtered['rank'] = binning(data_filtered.sort_values(by=filter_id['cell_full']), filter_id['cell_full'], filter_id['group_sizes'])
    else:
        data_filtered['rank'] = ntiles(data_filtered[filter_id['cell_full']], filter_id['num_groups'])
    points = []
    # OBS: checking the number of groups after filtering
    num_groups = len(uniq(data_filtered['rank']))
    if num_groups < 2:
        raise ValueError('Number of groups must be at least two.')
    points_dfs = []
    alive_dfs = []
    for g in range(num_groups):
        kmf = KaplanMeierFitter()
        data = data_filtered[lambda row: row['rank'] == g+1]
        kmf.fit(
            data['T'],
            data['E'],
            label='Kaplan_Meier',
        )
        df = pd.concat([
            kmf.survival_function_,
            kmf.confidence_interval_survival_function_,
        ], axis=1)
        df['group'] = g+1
        points_dfs += [df]

        alive_df = kmf.survival_function_at_times(data[data['E']==False]['T']).to_frame().reset_index()
        alive_df['group'] = g+1
        alive_dfs += [alive_df]

    # Curate points and alive points
    points_df = pd.concat(points_dfs).reset_index().rename(columns={
        'index': 'time',
        'Kaplan_Meier': 'fit',
        'Kaplan_Meier_lower_0.95': 'lower',
        'Kaplan_Meier_upper_0.95': 'upper',
    })
    points = points_df.to_dict(orient='records')

    alive_points_df = pd.concat(alive_dfs).rename(columns={
        'index': 'time',
        'Kaplan_Meier': 'fit',
        'group': 'group',
    })
    alive_points = alive_points_df.to_dict(orient='records')

    # Run multivarate analysis
    log_rank = multivariate_logrank_test(data_filtered['T'], data_filtered['rank'], data_filtered['E'])
    log = {
        'test_statistic_logrank': log_rank.summary['test_statistic'][0],
        'p_logrank': log_rank.summary['p'][0]
    }

    # Run cox regression
    cph = CoxPHFitter()
    cph.fit(data_filtered[['rank', 'T', 'E']], 'T', event_col='E')
    cox = {
        'coef': cph.summary['exp(coef)'][0],
        'lower': cph.summary['exp(coef) lower 95%'][0],
        'upper': cph.summary['exp(coef) upper 95%'][0],
        'p': cph.summary['p'][0]
    }

    # Replace infinite values with max or min probabilities
    if cox['upper'] == float('inf'):
        cox['upper'] = 1.0
    if cox['lower'] == float('-inf'):
        cox['lower'] = 0.0
    return {'points': points, 'log_rank': log, 'cox_regression': cox, 'live_points': alive_points}
Пример #8
0
class DGPSurv(gp.parameterized.Parameterized):
    def __init__(self,
                 X,
                 T,
                 c,
                 prediction_horizon,
                 layer_dim=30,
                 num_causes=1,
                 num_inducing=100,
                 calibration_fraction=0.5,
                 calibrate=False):

        super(DGPSurv, self).__init__()

        self.prediction_horizon = prediction_horizon
        self.calibrate = calibrate

        # Refine inputs
        inclusion_criteria = (T >= self.prediction_horizon) | (c != 0)
        X_ = np.array(X)[inclusion_criteria, :]
        c_ = np.array(
            ((np.array(c)[inclusion_criteria] != 0) &
             (np.array(T)[inclusion_criteria] < self.prediction_horizon)) *
            np.array(c)[inclusion_criteria])

        # Set all model attributes
        self.minmax_ = StandardScaler()

        self.X = torch.tensor(np.array(self.minmax_.fit_transform(X_))).float()
        self.T = torch.tensor(np.array(T).astype(float) / 365).float()
        self.c = torch.tensor(np.array(c_).astype(float)).float()

        self.num_inducing = min(num_inducing, self.X.shape[0])
        self.num_causes = num_causes + 1
        self.num_dim = self.X.shape[1]

        self.Xu = torch.from_numpy(
            kmeans2(self.X.numpy(), self.num_inducing, minit='points')[0])

        # handle erroneous settings for the model's parameters

        try:

            self.layer_dim = layer_dim

            if self.layer_dim < 2:
                raise ValueError(
                    "Bad inputs: number of intermediate dimensions must be greater than 2."
                )

        except ValueError as ve:
            print(ve)

        # computes the weight for mean function of the first layer using a PCA transformation
        _, _, V = np.linalg.svd(self.X.numpy(), full_matrices=False)
        W = torch.from_numpy(V[:self.layer_dim, :])

        mean_fn = LinearT(self.num_dim, self.layer_dim)
        mean_fn.linear.weight.data = W
        mean_fn.linear.weight.requires_grad_(False)

        self.mean_fn = mean_fn

        # Initialize the first DGP layer

        linear = torch.nn.Linear(self.num_dim, 20)
        pyro_linear_fn = lambda x: pyro.module("linear", linear)(x)
        kernel = gp.kernels.Matern32(input_dim=self.num_dim,
                                     lengthscale=torch.tensor(1.))
        warped_kernel = gp.kernels.Warping(kernel, pyro_linear_fn)

        self.layer_0 = gp.models.VariationalSparseGP(
            self.X,
            None,
            gp.kernels.Matern52(self.num_dim,
                                variance=torch.tensor(1.),
                                lengthscale=torch.ones(self.num_dim)),
            #warped_kernel,
            Xu=self.Xu,
            likelihood=None,
            mean_function=self.mean_fn,
            latent_shape=torch.Size([self.layer_dim]))

        h = self.mean_fn(self.X).t()
        hu = self.mean_fn(self.Xu).t()

        self.layer_1 = gp.models.VariationalSparseGP(
            h,
            self.c,
            gp.kernels.Matern52(self.layer_dim,
                                variance=torch.tensor(1.),
                                lengthscale=torch.tensor(1.)),
            Xu=hu,
            likelihood=gp.likelihoods.MultiClass(num_classes=self.num_causes),
            latent_shape=torch.Size([self.num_causes]))

        #self.layer_0.u_scale_tril = self.layer_0.u_scale_tril * 1e-5
        #self.layer_0.set_constraint("u_scale_tril", torch.distributions.constraints.lower_cholesky)

        if self.calibrate:

            self.kmf = KaplanMeierFitter()

            self.kmf.fit(T, event_observed=c)

            self.offset_probability = self.kmf.survival_function_at_times(
                times=[self.prediction_horizon])._values[0]
            self.calibration_fraction = calibration_fraction

    @autoname.name_count
    def model(self, X, c):

        self.layer_0.set_data(X, None)

        h_loc, h_var = self.layer_0.model()
        h = dist.Normal(h_loc, h_var.sqrt())()

        self.layer_1.set_data(h.t(), c)
        self.layer_1.model()

    @autoname.name_count
    def guide(self, X, c):

        self.layer_0.guide()
        self.layer_1.guide()

    # make prediction
    def forward(self, X_new):

        # because prediction is stochastic (due to Monte Carlo sample of hidden layer),
        # we make 100 prediction and take the most common one (as in [4])

        pred = []
        num_MC_samples = 100

        for _ in range(num_MC_samples):

            h_loc, h_var = self.layer_0(X_new)
            h = dist.Normal(h_loc, h_var.sqrt())()

            f_loc, f_var = self.layer_1(h.t())

            pred.append(f_loc)  # change for multiclass

        return torch.stack(pred).mode(dim=0)[0]

    def train(self,
              num_epochs=5,
              num_iters=60,
              batch_size=1000,
              learning_rate=0.01):

        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        loss_fn = infer.TraceMeanField_ELBO().differentiable_loss

        self.loss = []

        for i in range(num_epochs):

            self.loss.append(
                self.train_update(optimizer, loss_fn, batch_size, num_iters,
                                  i))

        self.loss = np.array(self.loss).reshape((-1, 1))

        if self.calibrate:

            print("Calibrating the trained model...")

            calibration_indexes = np.random.choice(
                list(range(self.X.shape[0])),
                int(np.ceil(self.calibration_fraction * self.X.shape[0])))
            y_uncalibrated = self.predict_survival(
                self.X[calibration_indexes, :].detach().numpy(), calibrate=1)
            y_raw = np.log((1 - y_uncalibrated) / y_uncalibrated)

            self.calibration_constant = sigmoid_calibrate_survival_predictions(
                self, y_raw)

            print("Done training!")

        else:

            self.calibration_constant = 1

    def train_update(self, optimizer, loss_fn, batch_size, num_iters, epoch):

        losses = []

        for _ in range(num_iters):

            batch_indexes = np.random.choice(list(range(self.X.shape[0])),
                                             batch_size)

            features_ = self.X[batch_indexes, :]
            event_censor = self.c[batch_indexes]

            features_ = features_.reshape(-1, self.X.shape[1])

            optimizer.zero_grad()

            loss = loss_fn(self.model, self.guide, features_, event_censor)

            losses.append(loss)

            loss.backward()
            optimizer.step()

        print("Train Epoch: {:2d} \t[Iteration: {:2d}] \tLoss: {:.6f}".format(
            epoch, _, loss))

        return losses

    def predict_survival(self, X_new, calibrate=None):

        s_preds = []
        y_pred = []

        index = 0
        base_size = 1000
        predictor_size = np.min((X_new.shape[0], base_size))
        num_batches_ = int(np.ceil(X_new.shape[0] / predictor_size))

        if calibrate == None:

            calibration_factor = self.calibration_constant

        else:

            calibration_factor = calibrate

        for u in range(num_batches_):

            if (u == (num_batches_ -
                      1)) and (np.mod(X_new.shape[0], predictor_size) > 0):

                X_curr = np.array(X_new)[index:, :]

            else:

                X_curr = np.array(X_new)[index:index + predictor_size, :]

            X_new_numpy = self.minmax_.transform(X_curr)
            X_new_ = torch.tensor(X_new_numpy).float()

            f_output = self(X_new_).detach().numpy()

            if u == 0:
                s_preds = f_output
            else:
                s_preds = np.hstack((s_preds, f_output))

            index += predictor_size

        for v in range(self.num_causes):

            y_pred.append(
                output_layer(constant=calibration_factor, y=s_preds[v, :]))

        y_pred = ((1 - np.array(y_pred)) / np.sum(
            (1 - np.array(y_pred)), axis=0))

        return y_pred