def KM_median(array, upper_lim_flags, left_censor=True, return_type='percentile'): kmf = KaplanMeierFitter() if upper_lim_flags is not None: if left_censor == True: kmf.fit_left_censoring(array, upper_lim_flags) else: kmf.fit(array, event_observed=upper_lim_flags) #right censoring else: kmf.fit(array, upper_lim_flags) median = median_survival_times(kmf.survival_function_) if return_type == 'percentile': upper_perc = kmf.percentile(0.25) lower_perc = kmf.percentile(0.75) print( f'median and 1st/3rd quartiles: {median}, {lower_perc}, {upper_perc}' ) return median, upper_perc, lower_perc elif return_type == 'ci': median_ci = median_survival_times(kmf.confidence_interval_).values print(f'median and CI: {median}, {median_ci}') return median, median_ci[0][0], median_ci[0][1] elif return_type == 'median': return median
def test_kmf_left_censorship_plots(self, block): kmf = KaplanMeierFitter() lcd_dataset = load_lcd() alluvial_fan = lcd_dataset.loc[lcd_dataset["group"] == "alluvial_fan"] basin_trough = lcd_dataset.loc[lcd_dataset["group"] == "basin_trough"] kmf.fit_left_censoring(alluvial_fan["T"], alluvial_fan["E"], label="alluvial_fan") ax = kmf.plot() kmf.fit_left_censoring(basin_trough["T"], basin_trough["E"], label="basin_trough") ax = kmf.plot(ax=ax) self.plt.title("test_kmf_left_censorship_plots") self.plt.show(block=block) return
def estimate_survival(df, plot=True, censoring='right', fig_path=None): if censoring not in {'right', 'left'}: raise ValueError(f"unknown fit type: {censoring}," f" use one of {{'left', 'right'}}") kmf = KaplanMeierFitter(alpha=1.0) # disable confidence interval if plot: fig = plt.figure(figsize=(20, 10)) ax = fig.add_subplot(111) medians = {} for system in sorted(df.domain_system.unique()): if censoring == 'right': kmf.fit( df.loc[df['domain_system'] == system].time, df.loc[df['domain_system'] == system].spotted, label=system, ) elif censoring == 'left': kmf.fit_left_censoring( df.loc[df['domain_system'] == system].time, ~df.loc[df['domain_system'] == system].spotted, label=system, ) else: raise ValueError(f"unknown fit type: {censoring}," f" use one of {{'left', 'right'}}") if plot: kmf.plot_survival_function(ax=ax) medians[system] = kmf.median_survival_time_ if plot: plt.ylim(0.0, 1.0) plt.xlabel("Turns") plt.ylabel("Survival Probability") plt.title("Estimated Survival Function of different systems") save_path = fig_path or "survival.png" print(f'saving plot of estimated survival functions to: {save_path}') plt.savefig(save_path) return medians
def plot_KM(arrays, labels, upper_lim_flags, savepath, left_censor=True, cdf=False, plot_quantity='Mdust', noerr_inds=[]): kmf = KaplanMeierFitter() fig = plt.figure(figsize=(10, 10)) ax = plt.axes() colors = [ 'tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:purple', 'gray', 'brown' ] for ind in range(len(arrays)): col = colors[ind] print(labels[ind]) if upper_lim_flags[ind] is not None: if left_censor == True: kmf.fit_left_censoring(arrays[ind], upper_lim_flags[ind], label=labels[ind]) else: kmf.fit(arrays[ind], event_observed=upper_lim_flags[ind], label=labels[ind]) #right censoring else: kmf.fit(arrays[ind], upper_lim_flags[ind], label=labels[ind]) #kmf.confidence_interval_survival_function_.plot(ax=ax) #kmf.survival_function_.plot(ax=ax) if cdf == True: kmf.plot(ax=ax) else: if ind > 0: prev_prob = prob prev_size = size size = np.array( kmf.survival_function_[labels[ind]].axes).flatten()[1:] prob = kmf.survival_function_[labels[ind]].values[1:] lower = np.array(kmf.confidence_interval_survival_function_[ f'{labels[ind]}_lower_0.95'].values[1:]) upper = np.array(kmf.confidence_interval_survival_function_[ f'{labels[ind]}_upper_0.95'].values[1:]) ax.plot(size, prob, label=labels[ind], color=col) if ind not in noerr_inds: ax.fill_between(size, lower, upper, color=col, alpha=0.25) if ind == noerr_inds[1]: interp_new = interp1d(size, prob) interp_prob = interp_new(prev_size) ax.fill_between(prev_size, prev_prob, interp_prob, color=col, alpha=0.25) #ax.fill(np.concatenate((prev_size,size)), np.concatenate((prev_prob,prob)), color=col, alpha=0.25) plt.legend(fontsize=16) ax.set_xscale('log') ax.tick_params(axis='both', which='major', labelsize=14) if plot_quantity == 'Mdust': ax.set_xlim(0.03, 2000) ax.set_xlabel(r'$M_{dust}$ ($M_{\oplus}$)', fontsize=18) ax.set_ylabel(r'$P \geq M_{dust}$', fontsize=18) if plot_quantity == 'Rdisk': ax.set_xlim(3, 200) ax.set_xlabel(r'$R_{disk}$ (AU)', fontsize=18) ax.set_ylabel(r'$P \geq R_{disk}$', fontsize=18) plt.savefig(savepath) print(f'saved figure at {savepath}')