Example #1
0
def KM_median(array,
              upper_lim_flags,
              left_censor=True,
              return_type='percentile'):
    kmf = KaplanMeierFitter()

    if upper_lim_flags is not None:
        if left_censor == True:
            kmf.fit_left_censoring(array, upper_lim_flags)
        else:
            kmf.fit(array, event_observed=upper_lim_flags)  #right censoring
    else:
        kmf.fit(array, upper_lim_flags)

    median = median_survival_times(kmf.survival_function_)

    if return_type == 'percentile':
        upper_perc = kmf.percentile(0.25)
        lower_perc = kmf.percentile(0.75)

        print(
            f'median and 1st/3rd quartiles: {median}, {lower_perc}, {upper_perc}'
        )
        return median, upper_perc, lower_perc

    elif return_type == 'ci':
        median_ci = median_survival_times(kmf.confidence_interval_).values
        print(f'median and CI: {median}, {median_ci}')
        return median, median_ci[0][0], median_ci[0][1]

    elif return_type == 'median':
        return median
Example #2
0
    def test_kmf_left_censorship_plots(self, block):
        kmf = KaplanMeierFitter()
        lcd_dataset = load_lcd()
        alluvial_fan = lcd_dataset.loc[lcd_dataset["group"] == "alluvial_fan"]
        basin_trough = lcd_dataset.loc[lcd_dataset["group"] == "basin_trough"]
        kmf.fit_left_censoring(alluvial_fan["T"], alluvial_fan["E"], label="alluvial_fan")
        ax = kmf.plot()

        kmf.fit_left_censoring(basin_trough["T"], basin_trough["E"], label="basin_trough")
        ax = kmf.plot(ax=ax)
        self.plt.title("test_kmf_left_censorship_plots")
        self.plt.show(block=block)
        return
def estimate_survival(df, plot=True, censoring='right', fig_path=None):

    if censoring not in {'right', 'left'}:
        raise ValueError(f"unknown fit type: {censoring},"
                         f" use one of {{'left', 'right'}}")

    kmf = KaplanMeierFitter(alpha=1.0)  # disable confidence interval

    if plot:
        fig = plt.figure(figsize=(20, 10))
        ax = fig.add_subplot(111)

    medians = {}
    for system in sorted(df.domain_system.unique()):
        if censoring == 'right':
            kmf.fit(
                df.loc[df['domain_system'] == system].time,
                df.loc[df['domain_system'] == system].spotted,
                label=system,
            )
        elif censoring == 'left':
            kmf.fit_left_censoring(
                df.loc[df['domain_system'] == system].time,
                ~df.loc[df['domain_system'] == system].spotted,
                label=system,
            )
        else:
            raise ValueError(f"unknown fit type: {censoring},"
                             f" use one of {{'left', 'right'}}")

        if plot:
            kmf.plot_survival_function(ax=ax)

        medians[system] = kmf.median_survival_time_

    if plot:
        plt.ylim(0.0, 1.0)
        plt.xlabel("Turns")
        plt.ylabel("Survival Probability")
        plt.title("Estimated Survival Function of different systems")
        save_path = fig_path or "survival.png"
        print(f'saving plot of estimated survival functions to: {save_path}')
        plt.savefig(save_path)

    return medians
Example #4
0
def plot_KM(arrays,
            labels,
            upper_lim_flags,
            savepath,
            left_censor=True,
            cdf=False,
            plot_quantity='Mdust',
            noerr_inds=[]):
    kmf = KaplanMeierFitter()

    fig = plt.figure(figsize=(10, 10))
    ax = plt.axes()

    colors = [
        'tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:purple', 'gray',
        'brown'
    ]
    for ind in range(len(arrays)):
        col = colors[ind]
        print(labels[ind])
        if upper_lim_flags[ind] is not None:
            if left_censor == True:
                kmf.fit_left_censoring(arrays[ind],
                                       upper_lim_flags[ind],
                                       label=labels[ind])
            else:
                kmf.fit(arrays[ind],
                        event_observed=upper_lim_flags[ind],
                        label=labels[ind])  #right censoring
        else:
            kmf.fit(arrays[ind], upper_lim_flags[ind], label=labels[ind])

        #kmf.confidence_interval_survival_function_.plot(ax=ax)
        #kmf.survival_function_.plot(ax=ax)
        if cdf == True:
            kmf.plot(ax=ax)
        else:
            if ind > 0:
                prev_prob = prob
                prev_size = size
            size = np.array(
                kmf.survival_function_[labels[ind]].axes).flatten()[1:]
            prob = kmf.survival_function_[labels[ind]].values[1:]

            lower = np.array(kmf.confidence_interval_survival_function_[
                f'{labels[ind]}_lower_0.95'].values[1:])
            upper = np.array(kmf.confidence_interval_survival_function_[
                f'{labels[ind]}_upper_0.95'].values[1:])

            ax.plot(size, prob, label=labels[ind], color=col)
            if ind not in noerr_inds:
                ax.fill_between(size, lower, upper, color=col, alpha=0.25)

            if ind == noerr_inds[1]:
                interp_new = interp1d(size, prob)
                interp_prob = interp_new(prev_size)
                ax.fill_between(prev_size,
                                prev_prob,
                                interp_prob,
                                color=col,
                                alpha=0.25)
                #ax.fill(np.concatenate((prev_size,size)), np.concatenate((prev_prob,prob)), color=col, alpha=0.25)

    plt.legend(fontsize=16)

    ax.set_xscale('log')
    ax.tick_params(axis='both', which='major', labelsize=14)

    if plot_quantity == 'Mdust':
        ax.set_xlim(0.03, 2000)
        ax.set_xlabel(r'$M_{dust}$ ($M_{\oplus}$)', fontsize=18)
        ax.set_ylabel(r'$P \geq M_{dust}$', fontsize=18)

    if plot_quantity == 'Rdisk':
        ax.set_xlim(3, 200)
        ax.set_xlabel(r'$R_{disk}$ (AU)', fontsize=18)
        ax.set_ylabel(r'$P \geq R_{disk}$', fontsize=18)

    plt.savefig(savepath)
    print(f'saved figure at {savepath}')