Exemple #1
0
def fit_cox(
    train_df: Union[pd.DataFrame, str],
    covariates: List[str],
    test_df: Union[pd.DataFrame, str] = None,
    strata: List[str] = None,
    plot: bool = False,
    process_dir: str = None,
):
    if isinstance(train_df, str):
        train_df = pd.read_csv(train_df)
    if isinstance(test_df, str):
        test_df = pd.read_csv(test_df)
    cphf = CoxPHFitter()
    included_cols = ["duration", "event"] + list(covariates)
    print(train_df.columns)
    cphf.fit(
        train_df[included_cols],
        duration_col="duration",
        event_col="event",
        strata=strata,
    )

    results = {
        "log_likelihood":
        cphf.log_likelihood_,
        "concordance_index":
        cphf.concordance_index_,
        "log_likelihood_ratio_test_pvalue":
        cphf.log_likelihood_ratio_test().p_value,
    }

    if test_df is not None:
        results["test_log_likelihood"] = cphf.score(
            test_df[included_cols], scoring_method="log_likelihood")
        results["test_concordance_index"] = cphf.score(
            test_df[included_cols], scoring_method="concordance_index")

    if plot and process_dir is not None:
        plt.figure(figsize=(5, 10))
        cphf.plot()
        plt.savefig(os.path.join(process_dir, "hazard_plot.pdf"))

    if process_dir is not None:
        cphf.summary.to_csv(os.path.join(process_dir, "summary.csv"))
        save_dict_to_json(os.path.join(process_dir, "results.json"), results)

    return results, cphf
Exemple #2
0
def cox_fit(df, prior, later, nexposed, nindivs, lagged_hr_cut_year,
            is_sex_specific, res_writer, error_writer):
    """Fit the data for the Cox regression and write output to result file"""
    logger.info(f"Fitting data to Cox model  (lag: {lagged_hr_cut_year})")
    # First try with a somewhat big step_size to go fast, retry later
    # with a lower step_size to help with convergence.
    step_size = 1.0

    cph = CoxPHFitter()
    cox_fit_success = False  # keep track of when we need to write out the results

    # Set covariates depending on outcome being sex-specific
    cols = ["duration", "outcome", "pred_prior", "birth_year", "SMOKER", "BMI"]
    if not is_sex_specific:
        cols.append("SEX")
    df = df.loc[:, cols]

    # Compute the median duration for those with the prior->outcome association
    median_duration = df.loc[df.pred_prior & df.outcome, "duration"].median()

    # Set default values in case of error
    pred_coef = np.nan
    pred_se = np.nan
    pred_hr = np.nan
    pred_ci_lower = np.nan
    pred_ci_upper = np.nan
    pred_pval = np.nan
    pred_zval = np.nan

    year_coef = np.nan
    year_se = np.nan
    year_hr = np.nan
    year_ci_lower = np.nan
    year_ci_upper = np.nan
    year_pval = np.nan
    year_zval = np.nan

    smoker_coef = np.nan
    smoker_se = np.nan
    smoker_hr = np.nan
    smoker_ci_lower = np.nan
    smoker_ci_upper = np.nan
    smoker_pval = np.nan
    smoker_zval = np.nan

    bmi_coef = np.nan
    bmi_se = np.nan
    bmi_hr = np.nan
    bmi_ci_lower = np.nan
    bmi_ci_upper = np.nan
    bmi_pval = np.nan
    bmi_zval = np.nan

    sex_coef = np.nan
    sex_se = np.nan
    sex_hr = np.nan
    sex_ci_lower = np.nan
    sex_ci_upper = np.nan
    sex_pval = np.nan
    sex_zval = np.nan

    nsubjects = np.nan
    nevents = np.nan
    partial_log_likelihood = np.nan
    concordance = np.nan
    log_likelihood_ratio_test = np.nan
    log_likelihood_ndf = np.nan
    log_likelihood_pval = np.nan

    try:
        cph.fit(
            df,
            duration_col="duration",
            event_col="outcome",
            show_progress=False,
            step_size=step_size,
        )
    except (ConvergenceError, Warning):
        logger.debug(
            f"Failed to fit Cox model for pair ({prior}, {later}) with step_size={step_size}, retrying with lower step_size"
        )

        # Retry with lower step_size to help with convergence
        step_size = 0.1
        try:
            cph.fit(
                df,
                duration_col="duration",
                event_col="outcome",
                show_progress=False,
                step_size=step_size,
            )
        except (ConvergenceError, Warning) as e:
            logger.warning(
                f"Failed to fit Cox model for pair ({prior}, {later}) after lowering step_size to {step_size} to fit Cox model"
            )
            error_writer.writerow(
                [prior, later, lagged_hr_cut_year,
                 type(e), e])
        else:
            logger.debug(
                f"Success when retrying with lower step_size={step_size}")
            cox_fit_success = True
    else:
        cox_fit_success = True

    if cox_fit_success:
        # Save results
        pred_coef = cph.params_["pred_prior"]
        pred_se = cph.standard_errors_["pred_prior"]
        pred_hr = np.exp(pred_coef)
        pred_ci_lower = np.exp(pred_coef - 1.96 * pred_se)
        pred_ci_upper = np.exp(pred_coef + 1.96 * pred_se)
        pred_pval = cph.summary.p["pred_prior"]
        pred_zval = cph.summary.z["pred_prior"]

        year_coef = cph.params_["birth_year"]
        year_se = cph.standard_errors_["birth_year"]
        year_hr = np.exp(year_coef)
        year_ci_lower = np.exp(year_coef - 1.96 * year_se)
        year_ci_upper = np.exp(year_coef + 1.96 * year_se)
        year_pval = cph.summary.p["birth_year"]
        year_zval = cph.summary.z["birth_year"]

        smoker_coef = cph.params_["SMOKER"]
        smoker_se = cph.standard_errors_["SMOKER"]
        smoker_hr = np.exp(smoker_coef)
        smoker_ci_lower = np.exp(smoker_coef - 1.96 * smoker_se)
        smoker_ci_upper = np.exp(smoker_coef + 1.96 * smoker_se)
        smoker_pval = cph.summary.p["SMOKER"]
        smoker_zval = cph.summary.z["SMOKER"]

        bmi_coef = cph.params_["BMI"]
        bmi_se = cph.standard_errors_["BMI"]
        bmi_hr = np.exp(bmi_coef)
        bmi_ci_lower = np.exp(bmi_coef - 1.96 * bmi_se)
        bmi_ci_upper = np.exp(bmi_coef + 1.96 * bmi_se)
        bmi_pval = cph.summary.p["BMI"]
        bmi_zval = cph.summary.z["BMI"]

        if not is_sex_specific:
            sex_coef = cph.params_["SEX"]
            sex_se = cph.standard_errors_["SEX"]
            sex_hr = np.exp(sex_coef)
            sex_ci_lower = np.exp(sex_coef - 1.96 * sex_se)
            sex_ci_upper = np.exp(sex_coef + 1.96 * sex_se)
            sex_pval = cph.summary.p["SEX"]
            sex_zval = cph.summary.z["SEX"]

        nsubjects = cph._n_examples
        nevents = cph.event_observed.sum()
        partial_log_likelihood = cph._log_likelihood
        concordance = cph.score_
        with np.errstate(invalid="ignore", divide="ignore"):
            sr = cph.log_likelihood_ratio_test()
            log_likelihood_ratio_test = sr.test_statistic
            log_likelihood_ndf = sr.degrees_freedom
            log_likelihood_pval = sr.p_value

    res_writer.writerow([
        prior,
        later,
        nexposed,
        nindivs,
        lagged_hr_cut_year,
        median_duration,
        pred_coef,
        pred_se,
        pred_hr,
        pred_ci_lower,
        pred_ci_upper,
        pred_pval,
        pred_zval,
        year_coef,
        year_se,
        year_hr,
        year_ci_lower,
        year_ci_upper,
        year_pval,
        year_zval,
        sex_coef,
        sex_se,
        sex_hr,
        sex_ci_lower,
        sex_ci_upper,
        sex_pval,
        sex_zval,
        smoker_coef,
        smoker_se,
        smoker_hr,
        smoker_ci_lower,
        smoker_ci_upper,
        smoker_pval,
        smoker_zval,
        bmi_coef,
        bmi_se,
        bmi_hr,
        bmi_ci_lower,
        bmi_ci_upper,
        bmi_pval,
        bmi_zval,
        nsubjects,
        nevents,
        partial_log_likelihood,
        concordance,
        log_likelihood_ratio_test,
        log_likelihood_ndf,
        log_likelihood_pval,
        step_size,
    ])
Exemple #3
0
def coxph_from_python(values,
                      isdead,
                      nbdays,
                      do_KM_plot=False,
                      png_path='./',
                      metadata_mat=None,
                      dichotomize_afterward=False,
                      fig_name='KM_plot.pdf',
                      penalizer=0.01,
                      l1_ratio=0.0,
                      isfactor=False):
    """
    """
    values = np.asarray(values)
    isdead = np.asarray(isdead)
    nbdays = np.asarray(nbdays)

    if isfactor:
        values = np.asarray(values).astype("str")

    if metadata_mat is not None:
        frame = {"values": values, "isdead": isdead, "nbdays": nbdays}

        for key in metadata_mat:
            frame[key] = metadata_mat[key]

        frame = pd.DataFrame(frame)

    else:
        frame = pd.DataFrame({
            "values": values,
            "isdead": isdead,
            "nbdays": nbdays
        })
        penalizer = 0.0

    cph = CoxPHFitter(penalizer=penalizer, l1_ratio=l1_ratio)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                cph.fit(frame, "nbdays", "isdead")

        except Exception:
            return np.nan

    pvalue = cph.log_likelihood_ratio_test().p_value
    cindex = cph.concordance_index_

    if do_KM_plot:
        fig, ax = plt.subplots(figsize=(10, 10))

        kaplan = KaplanMeierFitter()

        for label in set(values):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                kaplan.fit(
                    #values[values==label],
                    nbdays[values == label],
                    event_observed=isdead[values == label],
                    label='cluster nb. {0}'.format(label))

            kaplan.plot(ax=ax, ci_alpha=0.15)

        ax.set_xlabel('time unit')
        ax.set_title('pval.: {0: .1e} CI: {1: .2f}'.format(pvalue, cindex),
                     fontsize=16,
                     fontweight='bold')

        figname = "{0}/{1}.pdf".format(
            png_path,
            fig_name.replace('.pdf', '').replace('.png', ''))

        fig.savefig(figname)
        print('Figure saved in: {0}'.format(figname))

    return pvalue
Exemple #4
0
def make_figure(df, pa):
    df_ls = df.copy()

    durations = df_ls[pa["xvals"]]
    event_observed = df_ls[pa["yvals"]]

    km = KaplanMeierFitter()  ## instantiate the class to create an object

    pl = None
    fig = plt.figure(frameon=False,
                     figsize=(float(pa["fig_width"]), float(pa["fig_height"])))

    ## Fit the data into the model

    if str(pa["groups_value"]) == "None":
        km.fit(durations, event_observed, label='Kaplan Meier Estimate')

        df_survival = km.survival_function_
        df_conf = km.confidence_interval_
        df_event = km.event_table

        df = pd.merge(df_survival,
                      df_conf,
                      how='left',
                      left_index=True,
                      right_index=True)
        df = pd.merge(df,
                      df_event,
                      how='left',
                      left_index=True,
                      right_index=True)

        df['time'] = df.index.tolist()
        df = df.reset_index(drop=True)
        df = df[[
            "time", "at_risk", "removed", "observed", "censored", "entrance",
            "Kaplan Meier Estimate", "Kaplan Meier Estimate_lower_0.95",
            "Kaplan Meier Estimate_upper_0.95"
        ]]

        pa_ = {}
        for arg in [
                "Conf_Interval", "show_censors", "ci_legend", "ci_force_lines",
                "left_axis", "right_axis", "upper_axis", "lower_axis",
                "tick_left_axis", "tick_right_axis", "tick_upper_axis",
                "tick_lower_axis"
        ]:
            if pa[arg] in ["off", ".off"]:
                pa_[arg] = False
            else:
                pa_[arg] = True

        if str(pa["markerc_write"]) != "":
            pa_["marker_fc"] = pa["markerc_write"]
        else:
            pa_["marker_fc"] = pa["markerc"]

        if str(pa["edgecolor_write"]) != "":
            pa_["marker_ec"] = pa["edgecolor_write"]
        else:
            pa_["marker_ec"] = pa["edgecolor"]

        if str(pa["grid_color_text"]) != "":
            pa_["grid_color_write"] = pa["grid_color_text"]
        else:
            pa_["grid_color_write"] = pa["grid_color_value"]

        pl=km.plot(show_censors=pa_["show_censors"], \
                censor_styles={"marker":marker_dict[pa["censor_marker_value"]], "markersize":float(pa["censor_marker_size_val"]), "markeredgecolor":pa_["marker_ec"], "markerfacecolor":pa_["marker_fc"], "alpha":float(pa["marker_alpha"])}, \
               ci_alpha=float(pa["ci_alpha"]), \
               ci_force_lines=pa_["ci_force_lines"], \
               ci_show=pa_["Conf_Interval"], \
               ci_legend=pa_["ci_legend"], \
               linestyle=pa["linestyle_value"], \
               linewidth=float(pa["linewidth_write"]), \
               color=pa["line_color_value"])

        pl.spines['right'].set_visible(pa_["right_axis"])
        pl.spines['top'].set_visible(pa_["upper_axis"])
        pl.spines['left'].set_visible(pa_["left_axis"])
        pl.spines['bottom'].set_visible(pa_["lower_axis"])

        pl.spines['right'].set_linewidth(pa["axis_line_width"])
        pl.spines['left'].set_linewidth(pa["axis_line_width"])
        pl.spines['top'].set_linewidth(pa["axis_line_width"])
        pl.spines['bottom'].set_linewidth(pa["axis_line_width"])

        pl.tick_params(axis="both",
                       direction=pa["ticks_direction_value"],
                       length=float(pa["ticks_length"]))

        pl.tick_params(axis='x',
                       which='both',
                       bottom=pa_["tick_lower_axis"],
                       top=pa_["tick_upper_axis"],
                       labelbottom=pa_["lower_axis"],
                       labelrotation=float(pa["xticks_rotation"]),
                       labelsize=float(pa["xticks_fontsize"]))

        pl.tick_params(axis='y',
                       which='both',
                       left=pa_["tick_left_axis"],
                       right=pa_["tick_right_axis"],
                       labelleft=pa_["left_axis"],
                       labelrotation=float(pa["yticks_rotation"]),
                       labelsize=float(pa["yticks_fontsize"]))

        if str(pa["grid_value"]) != "None":
            pl.grid(True,
                    which='both',
                    axis=pa["grid_value"],
                    color=pa_["grid_color_write"],
                    linewidth=float(pa["grid_linewidth"]))

        if str(pa["x_lower_limit"]) != "" and str(pa["x_upper_limit"]) != "":
            pl.set_xlim(float(pa["x_lower_limit"]), float(pa["x_upper_limit"]))
        if str(pa["y_lower_limit"]) != "" and str(pa["y_upper_limit"]) != "":
            pl.set_ylim(float(pa["y_lower_limit"]), float(pa["y_upper_limit"]))

        pl.set_title(pa["title"], fontdict={'fontsize': float(pa['titles'])})
        pl.set_xlabel(pa["xlabel"],
                      fontdict={'fontsize': float(pa['xlabels'])})
        pl.set_ylabel(pa["ylabel"],
                      fontdict={'fontsize': float(pa['ylabels'])})

        return df, pl

    elif str(pa["groups_value"]) != "None":

        df_long = pd.DataFrame(
            columns=['day', 'status', str(pa["groups_value"])])

        for row in range(0, len(df_ls)):

            if int(df_ls.loc[row, pa["yvals"]]) >= 1:
                dead = int(df_ls.loc[row, pa["yvals"]])
                #print(dead)
                for i in range(0, dead):
                    #print(i)
                    df_long = df_long.append(
                        {
                            'day':
                            int(df_ls.loc[row, pa["xvals"]]),
                            'status':
                            1,
                            str(pa["groups_value"]):
                            str(df_ls.loc[row, pa["groups_value"]])
                        },
                        ignore_index=True)
                    i = i + 1

            elif int(df_ls.loc[row, pa["censors_val"]]) >= 1:
                censored = int(df_ls.loc[row, pa["censors_val"]])
                #print(censored)
                for c in range(0, censored):
                    #print(c)
                    df_long = df_long.append(
                        {
                            'day':
                            int(df_ls.loc[row, pa["xvals"]]),
                            'status':
                            0,
                            str(pa["groups_value"]):
                            str(df_ls.loc[row, pa["groups_value"]])
                        },
                        ignore_index=True)
                    c = c + 1

        df_dummy = pd.get_dummies(df_long,
                                  drop_first=True,
                                  columns=[pa["groups_value"]])

        results = logrank_test(df_dummy.loc[df_dummy['status'] == 1,
                                            'day'].tolist(),
                               df_dummy.loc[df_dummy['status'] == 0,
                                            'day'].tolist(),
                               df_dummy.loc[df_dummy['status'] == 1,
                                            'status'].tolist(),
                               df_dummy.loc[df_dummy['status'] == 0,
                                            'status'].tolist(),
                               alpha=.99)

        cph = CoxPHFitter()
        cph.fit(df_dummy, duration_col='day', event_col='status')

        cph_coeff = cph.summary
        cph_coeff = cph_coeff.reset_index()

        df_info = {}
        df_info['model'] = 'lifelines.CoxPHFitter'
        df_info['duration col'] = cph.duration_col
        df_info['event col'] = cph.event_col
        df_info['baseline estimation'] = 'breslow'
        df_info['number of observations'] = cph._n_examples
        df_info['number of events observed'] = len(
            df_dummy.loc[df_dummy['status'] == 1, ])
        df_info['partial log-likelihood'] = cph.log_likelihood_
        df_info['Concordance'] = cph.concordance_index_
        df_info['Partial AIC'] = cph.AIC_partial_
        df_info['log-likelihood ratio test'] = cph.log_likelihood_ratio_test(
        ).test_statistic
        df_info[
            'P.value(log-likelihood ratio test)'] = cph.log_likelihood_ratio_test(
            ).p_value
        df_info['log rank test'] = results.test_statistic
        df_info['P.value(log rank test)'] = results.p_value

        cph_stats = pd.DataFrame(df_info.items())
        cph_stats = cph_stats.rename(columns={0: 'Statistic', 1: 'Value'})
        #cph_stats

        tmp = []

        for cond in pa["list_of_groups"]:
            df_tmp = df_ls.loc[df_ls[pa["groups_value"]] == cond]

            km.fit(df_tmp[pa["xvals"]], df_tmp[pa["yvals"]], label=cond)

            df_survival = km.survival_function_
            df_conf = km.confidence_interval_
            df_event = km.event_table

            df = pd.merge(df_survival,
                          df_conf,
                          how='left',
                          left_index=True,
                          right_index=True)
            df = pd.merge(df,
                          df_event,
                          how='left',
                          left_index=True,
                          right_index=True)

            df['time'] = df.index.tolist()
            df = df.reset_index(drop=True)
            df = df.rename(
                columns={
                    "at_risk": cond + "_at_risk",
                    "removed": cond + "_removed",
                    "observed": cond + "_observed",
                    "censored": cond + "_censored",
                    "entrance": cond + "_entrance",
                    cond: cond + "_KMestimate"
                })

            df = df[[
                "time", cond + "_at_risk", cond + "_removed",
                cond + "_observed", cond + "_censored", cond + "_entrance",
                cond + "_KMestimate", cond + "_lower_0.95",
                cond + "_upper_0.95"
            ]]
            tmp.append(df)

            df = reduce(lambda df1, df2: pd.merge(df1, df2, on='time'), tmp)

            PA_ = [g for g in pa["groups_settings"] if g["name"] == cond][0]

            if str(PA_["linecolor_write"]) != "":
                linecolor = PA_["linecolor_write"]
            else:
                linecolor = PA_["line_color_value"]

            if str(PA_["linestyle_write"]) != "":
                linestyle = PA_["linestyle_write"]
            else:
                linestyle = PA_["linestyle_value"]

            if str(PA_["markerc_write"]) != "":
                markerColor = PA_["markerc_write"]
            else:
                markerColor = PA_["markerc"]

            if str(PA_["edgecolor_write"]) != "":
                edgeColor = PA_["edgecolor_write"]
            else:
                edgeColor = PA_["edgecolor"]

            if PA_["show_censors"] in ["off", ".off"]:
                showCensors = False
            else:
                showCensors = True

            if PA_["Conf_Interval"] in ["off", ".off"]:
                ConfidenceInterval = False
            else:
                ConfidenceInterval = True

            if PA_["ci_legend"] in ["off", ".off"]:
                CI_legend = False
            else:
                CI_legend = True

            if PA_["ci_force_lines"] in ["off", ".off"]:
                CI_lines = False
            else:
                CI_lines = True

            linewidth = PA_["linewidth_write"]
            edgeLineWidth = PA_["edge_linewidth"]
            markerSize = PA_["censor_marker_size_val"]

            markerAlpha = PA_["marker_alpha"]
            CI_alpha = PA_["ci_alpha"]
            markerVal = PA_["censor_marker_value"]

            pa_ = {}
            for arg in [
                    "left_axis", "right_axis", "upper_axis", "lower_axis",
                    "tick_left_axis", "tick_right_axis", "tick_upper_axis",
                    "tick_lower_axis"
            ]:
                if pa[arg] in ["off", ".off"]:
                    pa_[arg] = False
                else:
                    pa_[arg] = True

            if str(pa["grid_color_text"]) != "":
                pa_["grid_color_write"] = pa["grid_color_text"]
            else:
                pa_["grid_color_write"] = pa["grid_color_value"]

            pl=km.plot(show_censors=showCensors, \
                censor_styles={"marker":marker_dict[markerVal], "markersize":float(markerSize), "markeredgecolor":edgeColor, "markerfacecolor":markerColor, "alpha":float(markerAlpha), "mew":float(edgeLineWidth)}, \
                ci_alpha=float(CI_alpha), \
                ci_force_lines=CI_lines, \
                ci_show=ConfidenceInterval, \
                ci_legend=CI_legend, \
                linestyle=linestyle, \
                linewidth=float(linewidth), \
                color=linecolor)

            pl.spines['right'].set_visible(pa_["right_axis"])
            pl.spines['top'].set_visible(pa_["upper_axis"])
            pl.spines['left'].set_visible(pa_["left_axis"])
            pl.spines['bottom'].set_visible(pa_["lower_axis"])

            pl.spines['right'].set_linewidth(pa["axis_line_width"])
            pl.spines['left'].set_linewidth(pa["axis_line_width"])
            pl.spines['top'].set_linewidth(pa["axis_line_width"])
            pl.spines['bottom'].set_linewidth(pa["axis_line_width"])

            pl.tick_params(axis="both",
                           direction=pa["ticks_direction_value"],
                           length=float(pa["ticks_length"]))

            pl.tick_params(axis='x',
                           which='both',
                           bottom=pa_["tick_lower_axis"],
                           top=pa_["tick_upper_axis"],
                           labelbottom=pa_["lower_axis"],
                           labelrotation=float(pa["xticks_rotation"]),
                           labelsize=float(pa["xticks_fontsize"]))

            pl.tick_params(axis='y',
                           which='both',
                           left=pa_["tick_left_axis"],
                           right=pa_["tick_right_axis"],
                           labelleft=pa_["left_axis"],
                           labelrotation=float(pa["yticks_rotation"]),
                           labelsize=float(pa["yticks_fontsize"]))

            if str(pa["grid_value"]) != "None":
                pl.grid(True,
                        which='both',
                        axis=pa["grid_value"],
                        color=pa_["grid_color_write"],
                        linewidth=float(pa["grid_linewidth"]))

            if str(pa["x_lower_limit"]) != "" and str(
                    pa["x_upper_limit"]) != "":
                pl.set_xlim(float(pa["x_lower_limit"]),
                            float(pa["x_upper_limit"]))
            if str(pa["y_lower_limit"]) != "" and str(
                    pa["y_upper_limit"]) != "":
                pl.set_ylim(float(pa["y_lower_limit"]),
                            float(pa["y_upper_limit"]))

            pl.set_title(pa["title"],
                         fontdict={'fontsize': float(pa['titles'])})
            pl.set_xlabel(pa["xlabel"],
                          fontdict={'fontsize': float(pa['xlabels'])})
            pl.set_ylabel(pa["ylabel"],
                          fontdict={'fontsize': float(pa['ylabels'])})

        return df, pl, cph_coeff, cph_stats