Beispiel #1
0
    def test_kmf_add_at_risk_counts_with_single_row_multi_groups(
            self, block, kmf):
        T = np.random.exponential(10, size=(100))
        E = np.random.binomial(1, 0.8, size=(100))
        kmf_test = KaplanMeierFitter().fit(T, E, label="test")

        T = np.random.exponential(15, size=(1000))
        E = np.random.binomial(1, 0.6, size=(1000))
        kmf_con = KaplanMeierFitter().fit(T, E, label="con")

        fig = self.plt.figure()
        ax = fig.subplots(1, 1)

        kmf_test.plot(ax=ax)
        kmf_con.plot(ax=ax)

        ax.set_ylim([0.0, 1.1])
        ax.set_xlim([0.0, 100])
        ax.set_xlabel("Days")
        ax.set_ylabel("Survival probability")

        add_at_risk_counts(kmf_test,
                           kmf_con,
                           ax=ax,
                           rows_to_show=["At risk"],
                           ypos=-0.4)
        self.plt.title(
            "test_kmf_add_at_risk_counts_with_single_row_multi_groups")
        self.plt.tight_layout()
        self.plt.show(block=block)
def do_KM_analysis(durations, groups, events, group_labels, xlabel=None):
    fitters = list()
    ax_list = list()
    sns.set(palette = "colorblind", font_scale = 1.35, rc = {"figure.figsize": (8, 6), "axes.facecolor": ".92"})
    
    for i, cl in enumerate(sorted(set(groups))):
        kmf = KaplanMeierFitter()
        kmf.fit(durations[groups == cl], events[groups == cl], label=group_labels[i])
        fitters.append(kmf)
        if i == 0:
            ax_list.append(kmf.plot(ci_show=False))
        elif i == len(group_labels)-1:
            kmf.plot(ax=ax_list[-1], ci_show=False)
        else:
            ax_list.append(kmf.plot(ax=ax_list[-1], ci_show=False))
        
    add_at_risk_counts(*fitters, labels=group_labels)
    ax_list[-1].set_ylim(0,1.1)
    if xlabel is not None:
        ax_list[-1].set_xlabel(xlabel)

    multi = multivariate_logrank_test(durations, groups, events)
    ax_list[-1].text(0.1, 0.01, 'P-value=%.3f'% multi.p_value)
    
    if len(set(groups)) > 2:
      pair = pairwise_logrank_test(durations, groups, events)
      pair.print_summary()
    
    plt.show()
    
    return kmf
Beispiel #3
0
    def plot_kaplan_meier(self, column, value):
        """[plot Kaplan meier survival plots of cleaned METABRIC clinical data]

        Args:
            column ([string]): [column in METABRIC data corresponding to a patient attribute, such as her2 receptor
            status]
            value ([string or integer]): [value of column that is a point of comparision. ie column:her2_recepter value:'negative']
        Plots values in column vs != values in column
        """
        kmf = KaplanMeierFitter()
        treatment_df = self.data[self.data[column] == value]
        not_treatment_df = self.data[self.data[column] != value]
        treatment_months = treatment_df.overall_survival_months
        not_treatment_months = not_treatment_df.overall_survival_months

        kmf.fit(treatment_months,
                event_observed=treatment_df['death_from_cancer'],
                label=value)
        ax = kmf.plot()

        kmf2 = KaplanMeierFitter()
        kmf2.fit(not_treatment_months,
                 event_observed=not_treatment_df['death_from_cancer'],
                 label=f'not {value}')
        ax = kmf2.plot(ax=ax)
        add_at_risk_counts(kmf, kmf2, ax=ax)
        ax.set_ylim([0.0, 1.0])
        ax.set_xlabel('Timeline (Months)')
        ax.set_title(f'Kaplan Meier plot in months of {column} variable')
        # plt.figure(dpi=350)
        plt.tight_layout()
        plt.show()
Beispiel #4
0
def plot_km_survf(data, t_col="t", e_col="e",datatype='train_data'):
    """
    Plot KM survival function curves.

    Parameters
    ----------
    data: pandas.DataFrame
        Survival data to plot.
    t_col: str
        Column name in data indicating time.
    e_col: str
        Column name in data indicating events or status.
    """
    from lifelines import KaplanMeierFitter
    from lifelines.plotting import add_at_risk_counts
    fig, ax = plt.subplots(figsize=(6, 4))
    kmfh = KaplanMeierFitter()
    kmfh.fit(data[t_col], event_observed=data[e_col], label="KM Survival Curve")
    kmfh.survival_function_.plot(ax=ax)
    plt.ylim(0, 1.01)
    plt.xlabel("Time")
    plt.ylabel("Probalities")
    plt.legend(loc="best")
    add_at_risk_counts(kmfh, ax=ax)
    # plt.show()
    if datatype=='train_data':
        plt.savefig('/home/kyro_zhang/ZQX/train_data.png')
    else if datatype=='test_data':
        plt.savefig('/home/kyro_zhang/ZQX/test_data.png')
    else:
        plt.savefig('/home/kyro_zhang/ZQX/predict_data.png')
Beispiel #5
0
    def test_at_risk_looks_right_when_scales_are_magnitudes_of_order_larger_single_attribute(self, block):

        T1 = list(map(lambda v: v.right, pd.cut(np.arange(32000), 100, retbins=False)))
        T2 = list(map(lambda v: v.right, pd.cut(np.arange(9000), 100, retbins=False)))
        T3 = list(map(lambda v: v.right, pd.cut(np.arange(900), 100, retbins=False)))
        T4 = list(map(lambda v: v.right, pd.cut(np.arange(90), 100, retbins=False)))
        T5 = list(map(lambda v: v.right, pd.cut(np.arange(9), 100, retbins=False)))

        kmf1 = KaplanMeierFitter().fit(T1, label="Category A")
        kmf2 = KaplanMeierFitter().fit(T2, label="Category")
        kmf3 = KaplanMeierFitter().fit(T3, label="CatB")
        kmf4 = KaplanMeierFitter().fit(T4, label="Categ")
        kmf5 = KaplanMeierFitter().fit(T5, label="Categowdary B")

        ax = kmf1.plot()
        ax = kmf2.plot(ax=ax)
        ax = kmf3.plot(ax=ax)
        ax = kmf4.plot(ax=ax)
        ax = kmf5.plot(ax=ax)

        add_at_risk_counts(kmf1, kmf2, kmf3, kmf4, kmf5, ax=ax, rows_to_show=["At risk"])

        self.plt.title("test_at_risk_looks_right_when_scales_are_magnitudes_of_order_larger")
        self.plt.tight_layout()
        self.plt.show(block=block)
Beispiel #6
0
def plot_km_survf(data, t_col="t", e_col="e", save_file=''):
    """
    Plot KM survival function curves.
    Parameters
    ----------
    data: pandas.DataFrame
        Survival data to plot.
    t_col: str
        Column name in data indicating time.
    e_col: str
        Column name in data indicating events or status.
    save_model: string
            Path for saving model.
    """
    from lifelines import KaplanMeierFitter
    from lifelines.plotting import add_at_risk_counts

    f = plt.figure()
    fig, ax = plt.subplots(figsize=(6, 4))
    kmfh = KaplanMeierFitter()
    kmfh.fit(data[t_col],
             event_observed=data[e_col],
             label="KM Survival Curve")
    kmfh.survival_function_.plot(ax=ax)
    plt.ylim(0, 1.01)
    plt.xlabel("Time")
    plt.ylabel("Probalities")
    plt.legend(loc="best")
    add_at_risk_counts(kmfh, ax=ax)
    #plt.show()
    f.savefig(save_file + '.pdf', bbox_inches='tight')
Beispiel #7
0
def plot_riskGroups(data_groups, event_col, duration_col, labels=[], plot_join=False, 
                    xlabel="Survival time (Month)", ylabel="Survival Rate", legend="Risk Groups",
                    title="Survival function of Risk groups", save_fig_as=""):
    """Plot survival curve for different risk groups.

    Parameters
    ----------
    data_groups : list(`pandas.DataFame`) 
        list of DataFame[['E', 'T']], risk groups from lowest to highest.
    event_col : str
        column in DataFame indicating events.
    duration_col : atr
        column in DataFame indicating durations.
    labels : list(str), default []
        One text label for one group.
    plot_join : bool, default False
        Is plotting for two adjacent risk group, default False.
    save_fig_as : str
        Name of file for saving in local.

    Returns
    -------
    None
        Plot figure of each risk-groups.

    Examples
    --------
    >>> plot_riskGroups(df_list, "E", "T", labels=["Low", "Mid", "High"])
    """
    # init labels
    N_groups = len(data_groups)
    if len(labels) == 0:
        for i in range(N_groups):
            labels.append(str(i+1))
    # Plot
    fig, ax = plt.subplots(figsize=(8, 6))
    kmfit_groups = []
    for i in range(N_groups):
        kmfh = KaplanMeierFitter()
        sub_group = data_groups[i]
        kmfh.fit(sub_group[duration_col], event_observed=sub_group[event_col], label=labels[i])
        kmfh.survival_function_.plot(ax=ax)
        kmfit_groups.append(kmfh)
    # Plot two group (i, i + 1)
    if plot_join:
        for i in range(N_groups - 1):
            kmfh = KaplanMeierFitter()
            sub_group = pd.concat([data_groups[i], data_groups[i+1]], axis=0)
            kmfh.fit(sub_group[duration_col], event_observed=sub_group[event_col], label=labels[i]+'&'+labels[i+1])
            kmfh.survival_function_.plot(ax=ax)
            kmfit_groups.append(kmfh)
    plt.ylim(0, 1.01)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend(loc="best", title=legend)
    add_at_risk_counts(*kmfit_groups, ax=ax)
    plt.show()
    if save_fig_as != "":
        fig.savefig(save_fig_as, format='png', dpi=600, bbox_inches='tight')
Beispiel #8
0
    def test_kmf_add_at_risk_counts_with_specific_rows(self, block, kmf):
        T = np.random.exponential(10, size=(100))
        E = np.random.binomial(1, 0.8, size=(100))
        kmf.fit(T, E)

        fig = self.plt.figure()
        ax = fig.subplots(1, 1)
        kmf.plot(ax=ax)
        add_at_risk_counts(kmf, ax=ax, rows_to_show=["Censored", "At risk"])
        self.plt.tight_layout()
        self.plt.title("test_kmf_add_at_risk_counts_with_specific_rows")
        self.plt.show(block=block)
Beispiel #9
0
def plot_km(data, T_col='T', E_col='E'):

    fig, ax = plt.subplots(figsize=(6, 4))
    kmf = KaplanMeierFitter()
    kmf.fit(data[T_col], event_observed=data[E_col], label='KM Curve')
    kmf.survival_function_.plot(ax=ax)
    plt.ylim(0, 1.01)
    plt.xlabel('Time')
    plt.ylabel('$\hat{S}(t)$')
    plt.legend(loc='best')
    add_at_risk_counts(kmf, ax=ax)
    plt.show()
Beispiel #10
0
    def test_kmf_add_at_risk_counts_with_subplot(self, block, kmf):
        data1 = np.random.exponential(10, size=(100))
        kmf.fit(data1)

        fig = self.plt.figure()
        axes = fig.subplots(1, 2)
        kmf.plot(ax=axes[0])
        add_at_risk_counts(kmf, ax=axes[0])
        kmf.plot(ax=axes[1])

        self.plt.title("test_kmf_add_at_risk_counts_with_subplot")
        self.plt.show(block=block)
def stratifiedSurvival(t,
                       eventTime,
                       eventIndicator=None,
                       followupTime=None,
                       group=None):

    import matplotlib.pyplot as plt
    import lifelines as lf
    from lifelines.plotting import add_at_risk_counts
    import pandas as pd
    import copy

    tm = t[eventTime].copy()

    if (group is None):
        grp = pd.Series('Population', index=t.index)
    else:
        grp = t[group]

    if (eventIndicator is None):
        ev = ~t[eventTime].isnull()
        tm[tm.isnull()] = t.loc[tm.isnull(), followupTime]

    ######### Kaplan Meier curves stratified by sex
    kl = list()
    kmf = lf.KaplanMeierFitter()
    fig, ax = plt.subplots()
    for g in set(grp):
        kmf.fit(tm[grp == g], ev[grp == g], label=g)
        kmf.plot(ax=ax)
        kl.append(copy.deepcopy(kmf))

    add_at_risk_counts(*kl, ax=ax)

    plt.legend(loc='lower left')
    plt.ylim([0, 1])
    plt.xlabel('Time (years)')
    plt.ylabel('Survival')
    plt.title('Kaplan-Meier survival curve')
    def estimate_kaplan_meier(self):
        labels = self.survival_label[
            'label']  # 将data_label的DataFrame格式转化为Series格式
        sfs = {}
        # 画生存曲线图
        # plt.figure(1)
        ax = plt.subplot()
        fitter = []

        for label in sorted(labels.unique()):
            data_label_index = list(
                set(labels[labels == label].index)
                & set(self.survival_label.index))
            kmf = KaplanMeierFitter()
            kmf.fit(self.survival_label.loc[data_label_index][
                self.duration_column],
                    self.survival_label.loc[data_label_index][
                        self.observed_column],
                    label=label)
            # 将每一个训练的kmf放入fitter中存储,用于画出每个标签的对应的时间的生存人数
            fitter.append(kmf)

            sfs[label] = kmf.survival_function_  # 得到每个标签的生存率
            self.median_survival_time[label] = kmf.median_

            ax = kmf.plot(ax=ax)  # 画生存曲线图

        # 画对应时间的生存人数
        add_at_risk_counts(*fitter)
        # 计算log_rank值看分组的生存差异是否显著
        self.test_statistic, self.p_value = multivariate_logrank_test(
            self.survival_label, labels)
        if self.p_value == 0:
            self.p_value = '< 0.0001'
            p_transform = True
        else:
            self.p_value = str(self.p_value)
            p_transform = False
        # 输出所有组的生存率
        self.survival_rate_result = pd.concat(
            [sfs[k] for k in list(sorted(labels.unique()))],
            axis=1).interpolate()
        if len(self.CI) > 0:
            # 在图中显示log_rank中p值
            if p_transform == False:
                ax.text(0.35,
                        0.8,
                        'log_rank p=%s' % self.p_value,
                        transform=ax.transAxes,
                        va='top',
                        fontsize=12)
                ax.text(0.35,
                        0.9,
                        "HR=%.3f(95%% CI:%.3f-%.3f)" %
                        (self.HR, self.CI[0], self.CI[1]),
                        transform=ax.transAxes,
                        va='top',
                        fontsize=12)
            else:
                ax.text(0.35,
                        0.8,
                        'log_rank p %s' % self.p_value,
                        transform=ax.transAxes,
                        va='top',
                        fontsize=12)
                ax.text(0.35,
                        0.9,
                        "HR=%.3f(95%% CI:%.3f-%.3f)" %
                        (self.HR, self.CI[0], self.CI[1]),
                        transform=ax.transAxes,
                        va='top',
                        fontsize=12)
        else:
            # 在图中显示log_rank中p值
            ax.text(0.35,
                    0.8,
                    'log_rank p=%s' % self.p_value,
                    transform=ax.transAxes,
                    va='top',
                    fontsize=12)
        plt.title('Full Data')
        print("Median survival time of data: %s" % self.median_survival_time)
        plt.show()
Beispiel #13
0
# #### lifelines

# %%
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts

ax = plt.subplot(111)
kmfs = []
for status in dat['ER Status'].unique():
    mask = dat['ER Status']==status
    kmf = KaplanMeierFitter()
    kmf.fit(dat["OS Time"][mask], event_observed = dat['OS event'][mask], label = status)
    ax = kmf.plot_survival_function(ax=ax,ci_show=False)
    kmfs.append(kmf)

add_at_risk_counts(*kmfs, ax=ax) # *kmfs expands to the elements of kmfs
plt.tight_layout()

# %% [markdown]
# Note that `lifelines` and `scikit-survival` leverage `matplotlib` and so the layers are having to be drawn as part of a for loop. This is typical `matplotlib` behavior.
#
# The `kaplanmeier` function basically packages up the `lifeline` Kaplan-Meier curve and makes some choices for you while creating the graph.

# %% [markdown]
# ## Question 1b

# %%
ax = plt.subplot(111)
kmfs = []
for status in sorted(dat['Node'].unique()):
    mask = dat['Node']==status
Beispiel #14
0
def kaplan_meier(
        file,
        model=None,
        cohorts=["UKDP"],
        event_type="biochemicalRecurrence",
        event_time="bcrTime",
        figsize=(9, 6),
):
    if isinstance(cohorts, str):
        cohorts = [cohorts]
    if model is None:
        md = mc_model()
    else:
        md = model

    valid = np.logical_and(
        ~md.pheno.loc[:, [event_type, event_time]].isna().any(axis=1),
        md.pheno["blacklisted"] == 0,
    )
    chs = pd.Series(cohorts).str.upper()
    ind = np.logical_and(md.pheno["CohortAbb"].str.upper().isin(chs), valid)
    mpheno = md.pheno.loc[ind, :].copy()

    if file.endswith(".tsv"):
        # this is a score file
        score_df = pd.read_csv(file, delimiter="\t", index_col="ID")
        score = score_df.loc[mpheno.index, "score"]
        ind[ind] = ~score.isna()
        score = score[~score.isna()].values
        mpheno = md.pheno.loc[ind, :].copy()
    else:
        pars = get_params(file)
        if "logHR" not in pars["means"]:
            raise TypeError(
                "The parameter in the file do not seem to contain hazard "
                "prediction.")
        expressions = np.concatenate(
            [pars["means"]["x_t"][ind, :], pars["means"]["x_f"][ind, :]],
            axis=1,
        )
        score = np.dot(expressions, pars["means"]["logHR"])[:, 0]

    event = md.pheno.loc[ind, event_type].values
    time = md.pheno.loc[ind, event_time].values / 365.25

    # Grouping
    threshold = np.median(score)
    grouping = score > threshold
    g1 = grouping
    g2 = ~grouping

    # Kaplan Mayer Plot
    kmfh = KaplanMeierFitter()
    kmfh.fit(time[g1], event[g1], label="High Hazard")
    figure = kmfh.plot(figsize=figsize)
    kmfl = KaplanMeierFitter()
    kmfl.fit(time[g2], event[g2], label="Low Hazard")
    figure = kmfl.plot(ax=figure)
    plt.xlabel("years")
    add_at_risk_counts(kmfh, kmfl, ax=figure)

    # Cox Regression
    mpheno["score"] = score
    cph = CoxPHFitter()
    cph.fit(mpheno,
            duration_col=event_time,
            event_col=event_type,
            formula="score")

    # logrank test
    logr = statistics.logrank_test(
        mpheno.loc[g1, event_time],
        mpheno.loc[g2, event_time],
        mpheno.loc[g1, event_type],
        mpheno.loc[g2, event_type],
    )

    print("Cohorts: {}, event: {}, time: {}".format(cohorts, event_type,
                                                    event_time))
    print("Concordance: {:.2%}".format(cph.concordance_index_))
    print("Cox p-value: {}".format(cph.summary.loc["score", "p"]))
    print("Logrank p-value: {}".format(logr.p_value))

    return figure, cph, logr
Beispiel #15
0
def plot(out,
         fontsize=12,
         savepath='',
         width=10,
         height=6,
         cmap='Set1',
         cii_alpha=0.05,
         cii_lines='dense',
         methodtype='lifeline',
         title='Survival function'):
    '''
    

    Parameters
    ----------
    out :       [dict] Dictionary derived from the fit function.

    fontsize :  [INT],  Font size for the graph
                default is 12.
    
    savepath:   [STRING], Path to store the figure

    width:      [INT], Width of the figure
                10 (default)
                
    height:     [INT], Width of the figure
                6 (default)

    cmap:       [STRING], Specify your own colors for each class-label or use a colormap:  https://matplotlib.org/examples/color/colormaps_reference.html
                [(1, 0, 0),(0, 0, 1),(..)]
                'Set1'       (default)     
                'Set2'       Discrete colors
                'Pastel1'    Discrete colors
                'Paired'     Discrete colors
                'rainbow'
                'bwr'        Blue-white-red
                'binary' or 'binary_r'
                'seismic'    Blue-white-red 
                'Blues'      white-to-blue
                'Reds'       white-to-red

    cii_alpha:  [FLOAT], Confidence interval (works only when methodtype='lifelines')
                0.05 (default)
                
    cii_lines:  [STRING], Confidence lines (works only when methodtype='lifelines')
                'lifelines' (default)
                'custom'

    methodtype:  [STRING], Implementation type
                 'dense'   (dense/filled lines)
                 'line' 
                  None  (no lines)

    Returns
    -------
    None.

    '''
    KMcoord = {}
    Param = {}
    Param['width'] = width
    Param['height'] = height
    Param['fontsize'] = fontsize
    Param['savepath'] = savepath
    labx = out['labx']

    # Combine data and gather class labels
    data = np.vstack((out['time_event'], out['censoring'])).T

    # Make colors and legend-names for class-labels
    [class_colors, classlabel] = make_class_color_names(data,
                                                        out['labx'],
                                                        out['uilabx'],
                                                        cmap=cmap)

    if methodtype == 'lifeline':
        # Init
        kmf_all = []

        # Startup figure
        fig = plt.figure(figsize=(Param['width'], Param['height']))
        ax = fig.add_subplot(111)
        #        ax.grid(True)
        #        ax.ylabel('Percentage survival')
        if out['logrank'] != []:
            plt.title('%s, Logrank Test P-Value = %.5f' %
                      (title, out['logrank_P']))

        # Compute KM survival coordinates per class
        if cii_lines == 'dense':
            cii_lines = False
        if cii_lines == 'line':
            cii_lines = True
        if cii_lines == '' or cii_lines == None or cii_alpha == None:
            cii_lines = False
            cii_alpha = 0

        for i in range(0, len(out['uilabx'])):
            kmf = KaplanMeierFitter()
            idx = np.where(labx == out['uilabx'][i])[0]
            # Fit
            kmf.fit(out['time_event'][idx],
                    event_observed=out['censoring'][idx],
                    label=classlabel[i],
                    ci_labels=None,
                    alpha=1 - cii_alpha)
            # Plot
            kmf.plot(ax=ax,
                     ci_force_lines=cii_lines,
                     color=class_colors[i],
                     show_censors=True)
            # Store
            kmf_all.append(
                kmf.fit(out['time_event'][idx],
                        event_observed=out['censoring'][idx],
                        label=classlabel[i],
                        ci_labels=None,
                        alpha=1 - cii_alpha))

        if len(kmf_all) == 1:
            add_at_risk_counts(kmf_all[0], ax=ax)
        elif len(kmf_all) == 2:
            add_at_risk_counts(kmf_all[0], kmf_all[1], ax=ax)
        elif len(kmf_all) == 3:
            add_at_risk_counts(kmf_all[0], kmf_all[1], kmf_all[2], ax=ax)
        elif len(kmf_all) == 4:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               ax=ax)
        elif len(kmf_all) == 5:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               ax=ax)
        elif len(kmf_all) == 6:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               kmf_all[5],
                               ax=ax)
        elif len(kmf_all) == 7:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               kmf_all[5],
                               kmf_all[6],
                               ax=ax)
        elif len(kmf_all) == 8:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               kmf_all[5],
                               kmf_all[6],
                               kmf_all[7],
                               ax=ax)
        elif len(kmf_all) == 9:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               kmf_all[5],
                               kmf_all[6],
                               kmf_all[7],
                               kmf_all[8],
                               ax=ax)
        elif len(kmf_all) == 10:
            add_at_risk_counts(kmf_all[0],
                               kmf_all[1],
                               kmf_all[2],
                               kmf_all[3],
                               kmf_all[4],
                               kmf_all[5],
                               kmf_all[6],
                               kmf_all[7],
                               kmf_all[8],
                               kmf_all[9],
                               ax=ax)
        else:
            print('[KM] Maximum of 10 classes is reached.')

        ax.tick_params(axis='x',
                       length=15,
                       width=1,
                       direction='out',
                       labelsize=Param['fontsize'])
        ax.tick_params(axis='y',
                       length=15,
                       width=1,
                       direction='out',
                       labelsize=Param['fontsize'])
        ax.spines['bottom'].set_position(['outward', Param['fontsize']])
        ax.spines['left'].set_position(['outward', Param['fontsize']])
        #    ax.rc('font', size= Param['fontsize'])   # controls default text sizes
        #    ax.rc('axes',  labelsize = Param['fontsize'])  # fontsize of the x and y labels

        if Param['savepath'] != '':
            savefig(fig, Param['savepath'])

    if methodtype == 'custom':
        # Compute KM survival coordinates per class
        for i in range(0, len(out['uilabx'])):
            idx = np.where(labx == out['uilabx'][i])[0]
            tmpdata = data[idx, :].tolist()
            KMcoord[i] = compute_coord(tmpdata)

        # Plot KM survival lines
        plotkm(KMcoord,
               classlabel,
               cmap=class_colors,
               width=Param['width'],
               height=Param['height'],
               fontsize=Param['fontsize'])
Beispiel #16
0
              label='Train set')
ax = kmf_train.plot_survival_function(show_censors=True,
                                      ci_force_lines=False,
                                      ci_show=False,
                                      ax=ax)

kmf_test.fit(test['PFS'],
             event_observed=test['disease_progress'],
             timeline=t,
             label='Test set')
ax = kmf_test.plot_survival_function(show_censors=True,
                                     ci_force_lines=False,
                                     ci_show=False,
                                     ax=ax)

add_at_risk_counts(kmf_train, kmf_test, ax=ax, fontsize=12)

ax.set_xlabel('Time (months)', fontsize=12, fontweight='bold')
ax.set_ylabel('Survival probability', fontsize=12, fontweight='bold')
ax.legend(fontsize=12)

ax.text(10, 0.75, 'p=0.85', fontsize=12, fontweight='bold')

# In[27]:

print(kmf_train.median_survival_time_)
print(kmf_test.median_survival_time_)

# In[28]:

print(median_survival_times(kmf_train.confidence_interval_))
def plot_category_kmf(survival_df,
                      cat_col,
                      cat_list=None,
                      labels=None,
                      title='',
                      xlabel='',
                      ax=None,
                      q=None,
                      SZ=14,
                      weights=None,
                      xticks=None,
                      return_data=False):
    if (q is not None):
        survival_df[cat_col + '_cat'] = pd.qcut(survival_df[cat_col],
                                                q,
                                                precision=1)
        cat_col = cat_col + '_cat'
        cat_list = survival_df[cat_col].cat.categories
        if labels is None:
            labels = [str(x) for x in cat_list]
        categories = list(zip(cat_list, labels))
    else:
        if labels is not None:
            categories = list(zip(cat_list, labels))
        else:
            categories = list(zip(cat_list, cat_list.astype(str)))

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 5))

    kmfs = []
    survival_plot_dfs = []
    survival_ci_dfs = []
    for cat, label in categories:
        idx = (survival_df[cat_col] == cat)
        #label = '{} (N={})'.format(label, len(survival_df[idx]))
        cat_kmf = get_kmf(survival_df, idx, label, weights)
        kmfs.append(cat_kmf)
        cat_kmf.plot(ax=ax)
        if return_data:
            survival_plot_dfs.append(cat_kmf.survival_function_)
            survival_ci_dfs.append(
                cat_kmf.confidence_interval_survival_function_)

        ax.set_title(title, size=SZ + 2)
        ax.set_xlabel(xlabel, size=SZ)
        ax.set_ylabel('Survival Probability', size=SZ)

    if xticks is None:
        add_at_risk_counts(*kmfs, ax=ax)
    else:
        ax.set_xticks(xticks)
        add_at_risk_counts(*kmfs, ax=ax)

    if return_data:
        survival_plot_df = reduce(
            lambda df1, df2: pd.merge(df1, df2, on='timeline'),
            survival_plot_dfs)
        return survival_plot_df, survival_ci_dfs

    return None, None
Beispiel #18
0
def plot(out,
         fontsize=12,
         savepath='',
         width=10,
         height=6,
         cmap='Set1',
         cii_alpha=0.05,
         cii_lines='dense',
         methodtype='lifeline',
         title='Survival function',
         full_ylim=False,
         y_percentage=False):
    """Make plot.

    Parameters
    ----------
    out : dict
        Results from the fit function.
    fontsize : int, optional
        Font size for the graph. The default is 12.
    savepath : String, optional
        Path to store the figure. The default is ''.
    width : int, optional
        Width of the figure. The default is 10.
    height : int, optional
        height of the figure. The default is 6.
    cmap : String, optional
        Specify your own colors for each class-label or use a colormap:  https://matplotlib.org/examples/color/colormaps_reference.html. The default is 'Set1'.
        [(1, 0, 0),(0, 0, 1),(..)]
        'Set1'       (default)
        'Set2'       Discrete colors
        'Pastel1'    Discrete colors
        'Paired'     Discrete colors
        'rainbow'
        'bwr'        Blue-white-red
        'binary' or 'binary_r'
        'seismic'    Blue-white-red
        'Blues'      white-to-blue
        'Reds'       white-to-red
    cii_alpha : float, optional
        Confidence interval (works only when methodtype='lifelines'). The default is 0.05.
    cii_lines : String, optional
        Confidence lines (works only when methodtype='lifelines'). The default is 'dense'.
        'lifelines' (default)
        'custom'
    methodtype : String, optional
        Implementation type. The default is 'lifeline'.
        'dense'   (dense/filled lines)
        'line'
         None  (no lines)
    title : TYPE, optional
        DESCRIPTION. The default is 'Survival function'.

    Returns
    -------
    None.

    """
    KMcoord = {}
    Param = {}
    Param['width'] = width
    Param['height'] = height
    Param['fontsize'] = fontsize
    Param['savepath'] = savepath
    labx = out['labx']

    # Combine data and gather class labels
    data = np.vstack((out['time_event'], out['censoring'])).T

    # Make colors and legend-names for class-labels
    [class_colors, classlabel] = make_class_color_names(data,
                                                        out['labx'],
                                                        out['uilabx'],
                                                        cmap=cmap)

    if methodtype == 'lifeline':
        # Init
        kmf_all = []

        # Startup figure
        fig = plt.figure(figsize=(Param['width'], Param['height']))
        ax = fig.add_subplot(111)
        if full_ylim:
            ax.set_ylim([0.0, 1.05])
        if y_percentage:
            ax.yaxis.set_major_formatter(PercentFormatter(1.0))
        if out['logrank'] != []:
            plt.title('%s, Logrank Test P-Value = %.5f' %
                      (title, out['logrank_P']))

        # Compute KM survival coordinates per class
        if cii_lines == 'dense':
            cii_lines = False
        if cii_lines == 'line':
            cii_lines = True
        if cii_lines == '' or cii_lines == None or cii_alpha == None:
            cii_lines = False
            cii_alpha = 0

        for i in range(0, len(out['uilabx'])):
            kmf = KaplanMeierFitter()
            idx = np.where(labx == out['uilabx'][i])[0]
            # Fit
            kmf.fit(out['time_event'][idx],
                    event_observed=out['censoring'][idx],
                    label=classlabel[i],
                    ci_labels=None,
                    alpha=(1 - cii_alpha))
            # Plot
            kmf.plot(ax=ax,
                     ci_force_lines=cii_lines,
                     color=class_colors[i],
                     show_censors=True)
            # Store
            kmf_all.append(
                kmf.fit(out['time_event'][idx],
                        event_observed=out['censoring'][idx],
                        label=classlabel[i],
                        ci_labels=None,
                        alpha=(1 - cii_alpha)))

        add_at_risk_counts(*kmf_all, ax=ax)

        ax.tick_params(axis='x',
                       length=15,
                       width=1,
                       direction='out',
                       labelsize=Param['fontsize'])
        ax.tick_params(axis='y',
                       length=15,
                       width=1,
                       direction='out',
                       labelsize=Param['fontsize'])
        ax.spines['bottom'].set_position(['outward', Param['fontsize']])
        ax.spines['left'].set_position(['outward', Param['fontsize']])
        #    ax.rc('font', size= Param['fontsize'])   # controls default text sizes
        #    ax.rc('axes',  labelsize = Param['fontsize'])  # fontsize of the x and y labels

        if Param['savepath'] != '':
            savefig(fig, Param['savepath'])

    if methodtype == 'custom':
        # Compute KM survival coordinates per class
        for i in range(0, len(out['uilabx'])):
            idx = np.where(labx == out['uilabx'][i])[0]
            tmpdata = data[idx, :].tolist()
            KMcoord[i] = compute_coord(tmpdata)

        # Plot KM survival lines
        plotkm(KMcoord,
               classlabel,
               cmap=class_colors,
               width=Param['width'],
               height=Param['height'],
               fontsize=Param['fontsize'])
Beispiel #19
0
T1 = df.query('group==1')['time']
T2 = df.query('group==2')['time']

C1 = df.query('group==1')['Died']
C2 = df.query('group==2')['Died']

fig, ax = plt.subplots()
kmf1 = KaplanMeierFitter()
kmf1.fit(T1, C1,label='Group 1',)
kmf2 = KaplanMeierFitter()
kmf2.fit(T2, C2,label = 'Group 2',)

ax = kmf1.plot_survival_function(ax = ax,  ci_show=False)
ax = kmf2.plot_survival_function(ax = ax, 
    label = 'Group 2', ci_show=False)
add_at_risk_counts(kmf1, kmf2, ax=ax)

plt.tight_layout();
# %% [markdown]
# ## Kaplan-Meier curve using _lifelines_
#
# We can also format the axes to show percentages rather than proportions. 
#
# We'll demo this with a single KM curve
# %%
import matplotlib.ticker as mtick

ax = kmf.plot()
ax.set_xlabel('days')
ax.set_ylabel('Probability of survival')
ax.get_legend().remove()
df = df[pd.notnull(df[survival_col])]


tx = df['history_of_neoadjuvant_treatment']=='Yes'
ax = plt.subplot(111)

kmf1 = KaplanMeierFitter(alpha=0.95)
kmf1.fit(durations=df.ix[tx, survival_col], event_observed=df.ix[tx, censor_col], label=['Tx==Yes'])
kmf1.plot(ax=ax, show_censors=True,  ci_show=False)


kmf2 = KaplanMeierFitter(alpha=0.95)
kmf2.fit(durations=df.ix[~tx, survival_col], event_observed=df.ix[~tx, censor_col], label=['Tx==No'])
kmf2.plot(ax=ax, show_censors=True,  ci_show=False )

add_at_risk_counts(kmf1, kmf2, ax=ax)
plt.title ('Acute myeloid leukemia survival analysis with Tx and without Tx')
plt.xlabel(survival_col)
plt.savefig('km.png')

results = logrank_test(df.ix[tx, survival_col], df.ix[~tx, survival_col], df.ix[tx, censor_col], df.ix[~tx, censor_col], alpha=.99 )
results.print_summary()

cox = CoxPHFitter(normalize=False)
df_age = df[[survival_col, censor_col, 'age_at_initial_pathologic_diagnosis']]
df_age = df_age[pd.notnull(df_age['age_at_initial_pathologic_diagnosis'])]
cox = cox.fit(df_age, survival_col, event_col=censor_col, include_likelihood=True)
cox.print_summary()

scores = k_fold_cross_validation(cox, df_age, survival_col, event_col=censor_col, k=10)
print scores