Пример #1
0
def label_gen(configs):
    '''Creation of base labels directly defined on the imputed data / endpoints'''

    label_key = configs["label_key"]
    split_key = configs["split_key"]
    lhours = configs["lhours"]
    rhours = configs["rhours"]
    data_mode = configs["data_mode"]

    if data_mode == "reduced":
        dim_reduced_data = True
    else:
        dim_reduced_data = False

    if configs["verbose"]:
        print("Creating label: {} [{},{}] for reduced data: {}".format(
            label_key, lhours, rhours, dim_reduced_data),
              flush=True)

    if configs["dataset"] == "bern":
        label_base_dir = configs["bern_output_base_dir"]
        endpoint_base_dir = configs["bern_endpoints_dir"]
        imputed_base_dir = configs["bern_imputed_dir"]
    elif configs["dataset"] == "mimic":
        label_base_dir = configs["mimic_output_base_dir"]
        endpoint_base_dir = configs["mimic_endpoints_dir"]
        imputed_base_dir = configs["mimic_imputed_dir"]

    if dim_reduced_data:
        base_dir = os.path.join(label_base_dir, "reduced", split_key,
                                label_key,
                                "{}To{}Hours".format(lhours, rhours))
    else:
        base_dir = os.path.join(label_base_dir, split_key, label_key,
                                "{}To{}Hours".format(lhours, rhours))

    try:
        if not configs["debug_mode"]:
            mlhc_fs.create_dir_if_not_exist(base_dir, recursive=True)
    except:
        print(
            "WARNING: Race condition when creating directory from different jobs..."
        )

    data_split = mlhc_io.load_pickle(
        configs["temporal_data_split_binary"])[split_key]

    if configs["dataset"] == "bern":
        all_pids = data_split["train"] + data_split["val"] + data_split["test"]
    elif configs["dataset"] == "mimic":
        all_pids = list(
            map(
                int,
                mlhc_io.read_list_from_file(
                    configs["mimic_all_pid_list_path"])))

    if configs["verbose"]:
        print("Number of patient IDs: {}".format(len(all_pids), flush=True))

    if configs["dataset"] == "bern":
        batch_map = mlhc_io.load_pickle(
            configs["bern_pid_batch_map_binary"])["chunk_to_pids"]
    elif configs["dataset"] == "mimic":
        batch_map = mlhc_io.load_pickle(
            configs["mimic_pid_batch_map_binary"])["chunk_to_pids"]

    batch_idx = configs["batch_idx"]

    if not configs["debug_mode"]:
        mlhc_fs.delete_if_exist(
            os.path.join(base_dir, "batch_{}.h5".format(batch_idx)))

    pids_batch = batch_map[batch_idx]
    selected_pids = list(set(pids_batch).intersection(all_pids))
    n_skipped_patients = 0
    first_write = True

    if label_key == "Deterioration":
        tf_model = bern_tf_labels.DeteriorationLabel(lhours, rhours)
    elif label_key == "WorseState":
        tf_model = bern_tf_labels.WorseStateLabel(lhours, rhours)
    elif label_key == "WorseStateSoft":
        tf_model = bern_tf_labels.WorseStateSoftLabel(lhours, rhours)
    elif label_key == "AllLabels":
        tf_model = bern_tf_labels.AllLabel(lhours,
                                           rhours,
                                           dataset=configs["dataset"])
    else:
        print("ERROR: Invalid label requested...", flush=True)
        sys.exit(1)

    print("Number of selected PIDs: {}".format(len(selected_pids)), flush=True)

    for pidx, pid in enumerate(selected_pids):

        if dim_reduced_data:
            patient_path = os.path.join(imputed_base_dir, "reduced", split_key,
                                        "batch_{}.h5".format(batch_idx))
            cand_files = glob.glob(
                os.path.join(endpoint_base_dir, "reduced",
                             "reduced_endpoints_{}_*.h5".format(batch_idx)))
            assert (len(cand_files) == 1)
            endpoint_path = cand_files[0]
            output_dir = os.path.join(label_base_dir, "reduced", split_key,
                                      label_key,
                                      "{}To{}Hours".format(lhours, rhours))
        else:
            patient_path = os.path.join(imputed_base_dir, split_key,
                                        "batch_{}.h5".format(batch_idx))
            cand_files = glob.glob(
                os.path.join(endpoint_base_dir,
                             "endpoints_{}_*.h5".format(batch_idx)))
            assert (len(cand_files) == 1)
            endpoint_path = cand_files[0]
            output_dir = os.path.join(label_base_dir, split_key, label_key,
                                      "{}To{}Hours".format(lhours, rhours))

        if not os.path.exists(patient_path):
            print(
                "WARNING: Patient {} does not exists, skipping...".format(pid),
                flush=True)
            n_skipped_patients += 1
            continue

        try:
            df_endpoint = pd.read_hdf(endpoint_path,
                                      mode='r',
                                      where="PatientID={}".format(pid))
        except:
            print(
                "WARNING: Issue while reading endpoints of patient {}".format(
                    pid),
                flush=True)
            n_skipped_patients += 1
            continue

        df_pat = pd.read_hdf(patient_path,
                             mode='r',
                             where="PatientID={}".format(pid))

        if df_pat.shape[0] == 0 or df_endpoint.shape[0] == 0:
            print(
                "WARNING: Empty endpoints or empty imputed data in patient {}".
                format(pid),
                flush=True)
            n_skipped_patients += 1
            continue

        if not is_df_sorted(df_endpoint, "Datetime"):
            df_endpoint = df_endpoint.sort_values(by="Datetime",
                                                  kind="mergesort")

        df_label = tf_model.transform(df_pat, df_endpoint, pid=pid)

        if df_label is None:
            print(
                "WARNING: Label could not be created for PID: {}".format(pid),
                flush=True)
            n_skipped_patients += 1
            continue

        assert (df_label.shape[0] == df_pat.shape[0])
        output_path = os.path.join(output_dir, "batch_{}.h5".format(batch_idx))

        if first_write:
            append_mode = False
            open_mode = 'w'
        else:
            append_mode = True
            open_mode = 'a'

        if not configs["debug_mode"]:
            df_label.to_hdf(output_path,
                            configs["label_dset_id"],
                            complevel=configs["hdf_comp_level"],
                            complib=configs["hdf_comp_alg"],
                            format="table",
                            append=append_mode,
                            mode=open_mode,
                            data_columns=["PatientID"])

        gc.collect()
        first_write = False

        if (pidx + 1) % 100 == 0 and configs["verbose"]:
            print("Progress for batch {}: {:.2f} %".format(
                batch_idx, (pidx + 1) / len(selected_pids) * 100),
                  flush=True)
            print("Number of skipped patients: {}".format(n_skipped_patients))
Пример #2
0
def interpolated_mimic_hirid(configs):
    static_cols_without_encode=["Age","Height","Emergency"]
    static_cols_one_hot_encode=["Surgical","APACHEPatGroup"]
    static_cols_one_hot_encode_str=["Sex"]
    str_to_int_dict={"M": 0, "F": 1, "U": 2}    
    
    random.seed(configs["random_state"])
    np_rand.seed(configs["random_state"])                
    held_out=configs["val_type"]
    dim_reduced_str=configs["data_mode"]
    task_key=configs["task_key"]
    left_hours=configs["lhours"]
    right_hours=configs["rhours"]
    val_type=configs["val_type"]
    assert(dim_reduced_str in ["reduced","non_reduced"])

    feat_order=None

    if dim_reduced_str=="reduced":
        dim_reduced_data=True
    else:
        dim_reduced_data=False    

    batch_map=mlhc_io.load_pickle(configs["mimic_pid_map_path"])["pid_to_chunk"]
    n_skipped_patients=0
    scores_dict={}
    labels_dict={}

    cal_scores_dict={}
    cal_labels_dict={}

    hirid_ml_model,hirid_col_desc,hirid_split_key=("lightgbm", "shap_top20_variables_MIMIC","held_out")
    hirid_model_dir=os.path.join(configs["predictions_dir"],"reduced",hirid_split_key,"{}_{}_{}_{}_{}".format(task_key, left_hours, right_hours, hirid_col_desc, hirid_ml_model))
    hirid_model_dir=hirid_model_dir+"_full"

    with open(os.path.join(hirid_model_dir,"best_model.pickle"),'rb') as fp:
        hirid_model=pickle.load(fp)        

    hirid_feat_order=list(hirid_model._Booster.feature_name())
    
    all_labels=[("lightgbm", "shap_top20_variables_MIMIConly_random_0","random_0"),("lightgbm", "shap_top20_variables_MIMIConly_random_1","random_1"),("lightgbm", "shap_top20_variables_MIMIConly_random_2","random_2"),
                ("lightgbm", "shap_top20_variables_MIMIConly_random_3","random_3"),("lightgbm", "shap_top20_variables_MIMIConly_random_4","random_4")]

    for mimic_ml_model, mimic_col_desc,mimic_split_key in all_labels:
        configs["split_key"]=mimic_split_key
        print("Analyzing model ({},{},{})".format(mimic_ml_model,mimic_col_desc, mimic_split_key),flush=True)
        mimic_data_split=mlhc_io.load_pickle(configs["mimic_split_path"])[mimic_split_key]        
        pred_pids=mimic_data_split[val_type]
        
        print("Number of test PIDs: {}".format(len(pred_pids)),flush=True)
        
        mimic_model_dir=os.path.join(configs["predictions_dir"],"reduced",hirid_split_key,"{}_{}_{}_{}_{}".format(task_key, left_hours, right_hours, mimic_col_desc, mimic_ml_model))
        
        feat_dir=os.path.join(configs["mimic_ml_input_dir"],"reduced",hirid_split_key,"AllLabels_0.0_8.0","X")
        labels_dir=os.path.join(configs["mimic_ml_input_dir"],"reduced",hirid_split_key,"AllLabels_0.0_8.0","y")
        impute_dir=os.path.join(configs["mimic_imputed_dir"], "reduced",hirid_split_key)
        mimic_model_dir=mimic_model_dir+"_full"

        with open(os.path.join(mimic_model_dir,"best_model.pickle"),'rb') as fp:
            mimic_model=pickle.load(fp)            

        mimic_feat_order=list(mimic_model._Booster.feature_name())
        assert(hirid_feat_order==mimic_feat_order)

        cum_pred_scores=[]
        cum_labels=[]

        cum_pred_scores_valid=[]
        cum_labels_valid=[]

        cum_pred_scores_retrain=[]
        cum_labels_retrain=[]

        df_shapelet_path=os.path.join(configs["mimic_shapelets_path"])

        n_valid_count=0

        skip_reason_key=skip_reason_ns_bef=skip_reason_ns_after=skip_reason_shapelet=0

        if configs["val_type"]=="val" or configs["full_explore_mode"]:
            ip_coeff=configs["ip_coeff"]
        else:
            val_results=glob.glob(os.path.join(configs["result_dir"],"result_val_*.tsv"))
            val_dict={}
            for rpath in sorted(val_results):
                ip_coeff_val=float(rpath.split("/")[-1].split("_")[-1][:-4])
                with open(rpath,'r') as fp:
                    csv_fp=csv.reader(fp)
                    next(csv_fp)
                    for split,auroc,auprc in csv_fp:
                        if not split==mimic_split_key:
                            continue
                        val_dict[ip_coeff_val]=float(auprc)
            ip_coeff=max(val_dict,key=val_dict.get)
            print("Best IP coeff on val set: {}".format(ip_coeff),flush=True)
        
        for pidx,pid in enumerate(pred_pids):

            if (pidx+1)%100==0 and configs["verbose"]:
                print("{}/{}, KEY: {}, NS BEF: {}, NS AFT: {}, SHAPELET: {}".format(pidx+1,len(pred_pids), skip_reason_key, skip_reason_ns_bef,skip_reason_ns_after, skip_reason_shapelet),flush=True)

            if pidx>=100 and configs["debug_mode"]:
                break
                
            batch_pat=batch_map[pid]

            try:
                pat_df=pd.read_hdf(os.path.join(feat_dir,"batch_{}.h5".format(batch_pat)), "/{}".format(pid), mode='r')
                pat_label_df=pd.read_hdf(os.path.join(labels_dir,"batch_{}.h5".format(batch_pat)), "/{}".format(pid),mode='r')
                assert(pat_df.shape[0]==pat_label_df.shape[0])
                df_feat_valid=pat_df[pat_df["SampleStatus_WorseStateFromZero0.0To8.0Hours"]=="VALID"]
                df_label_valid=pat_label_df[pat_label_df["SampleStatus_WorseStateFromZero0.0To8.0Hours"]=="VALID"]
                assert(df_feat_valid.shape[0]==df_label_valid.shape[0])
                
            except KeyError:
                skip_reason_key+=1
                continue

            if df_feat_valid.shape[0]==0:
                skip_reason_ns_bef+=1
                continue

            shapelet_df=pd.read_hdf(df_shapelet_path, '/{}'.format(pid), mode='r')
            shapelet_df["AbsDatetime"]=pd.to_datetime(shapelet_df["AbsDatetime"])
            special_cols=["AbsDatetime","PatientID"]
            shapelet_cols=list(filter(lambda col: "_dist-set" in col, sorted(shapelet_df.columns.values.tolist())))
            shapelet_df=shapelet_df[special_cols+shapelet_cols]

            if shapelet_df.shape[0]==0:
                skip_reason_shapelet+=1
                continue            

            df_merged=pd.merge(df_feat_valid,shapelet_df,on=["AbsDatetime","PatientID"])
            df_feat_valid=df_merged
            pat_label_df_orig_cols=sorted(df_label_valid.columns.values.tolist())
            df_label_valid=pd.merge(df_label_valid,shapelet_df,on=["AbsDatetime","PatientID"])
            df_label_valid=df_label_valid[pat_label_df_orig_cols]

            if df_feat_valid.shape[0]==0:
                skip_reason_ns_after+=1
                continue
            
            all_feat_cols=sorted(df_feat_valid.columns.values.tolist())
            sel_feat_cols=list(filter(lambda col: "Patient" not in col, all_feat_cols))
            X_df=df_feat_valid[sel_feat_cols]
            
            true_labels=df_label_valid["Label_WorseStateFromZero0.0To8.0Hours"]
            assert(true_labels.shape[0]==X_df.shape[0])
            X_feats=X_df[hirid_feat_order]
            X_full_collect=[X_feats]
            X_full=np.concatenate(X_full_collect,axis=1)
            
            pred_scores_mimic=mimic_model.predict_proba(X_full)[:,1]
            pred_scores_hirid=hirid_model.predict_proba(X_full)[:,1]
            
            pred_scores_ip=ip_coeff*pred_scores_hirid+(1-ip_coeff)*pred_scores_mimic

            df_out_dict={}
            abs_dt=pat_df["AbsDatetime"]
            rel_dt=pat_df["RelDatetime"]
            pred_ip_vect=mlhc_array.empty_nan(abs_dt.size)
            pred_ip_vect[pat_df["SampleStatus_WorseStateFromZero0.0To8.0Hours"]=="VALID"]=pred_scores_ip
            pred_mimic_vect=mlhc_array.empty_nan(abs_dt.size)
            pred_mimic_vect[pat_df["SampleStatus_WorseStateFromZero0.0To8.0Hours"]=="VALID"]=pred_scores_mimic
            pred_hirid_vect=mlhc_array.empty_nan(abs_dt.size)
            pred_hirid_vect[pat_df["SampleStatus_WorseStateFromZero0.0To8.0Hours"]=="VALID"]=pred_scores_hirid
            pid_vect=mlhc_array.value_empty(abs_dt.size,pid)
            y_vect=np.array(pat_label_df["Label_WorseStateFromZero0.0To8.0Hours"])
            df_out_dict["PatientID"]=pid_vect
            df_out_dict["PredScoreInterpolated"]=pred_ip_vect
            df_out_dict["PredScoreHiRiD"]=pred_hirid_vect
            df_out_dict["PredScoreMIMIC"]=pred_mimic_vect
            df_out_dict["TrueLabel"]=y_vect
            df_out_dict["AbsDatetime"]=abs_dt
            df_out_dict["RelDatetime"]=rel_dt
            df_out=pd.DataFrame(df_out_dict)
            out_dir=os.path.join(configs["result_dir"],"full_{}_set_results".format(configs["val_type"]),str(ip_coeff),mimic_split_key)
            mlhc_fs.create_dir_if_not_exist(out_dir,recursive=True)
            df_out_path=os.path.join(configs["result_dir"],"full_{}_set_results".format(configs["val_type"]),str(ip_coeff),mimic_split_key,"batch_{}.h5".format(batch_pat))

            if configs["write_output"]:
                df_out.to_hdf(df_out_path,"/p{}".format(pid),complevel=5,complib="blosc:lz4",fletcher32=True)
            
            cum_pred_scores.append(pred_scores_ip)
            cum_labels.append(true_labels)

            cum_pred_scores_valid.append(pred_scores_hirid)
            cum_labels_valid.append(true_labels)

            cum_pred_scores_retrain.append(pred_scores_mimic)
            cum_labels_retrain.append(true_labels)

            n_valid_count+=1

        scores_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"interpolated")]=np.concatenate(cum_pred_scores)
        labels_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"interpolated")]=np.concatenate(cum_labels)

        scores_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"valid")]=np.concatenate(cum_pred_scores_valid)
        labels_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"valid")]=np.concatenate(cum_labels_valid)

        scores_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"retrain")]=np.concatenate(cum_pred_scores_retrain)
        labels_dict[(mimic_ml_model,mimic_col_desc,mimic_split_key,"retrain")]=np.concatenate(cum_labels_retrain)    

        print("Number of processed prediction set PIDs: {}/{}".format(n_valid_count,len(pred_pids)),flush=True)

    if configs["plot_type"]=="NONE":
        sys.exit(0)
        
    if configs["val_type"]=="test":
        fpath=os.path.join(configs["result_dir"],"result_{}.tsv".format(configs["val_type"]))
    else:
        fpath=os.path.join(configs["result_dir"],"result_{}_{}.tsv".format(configs["val_type"], configs["ip_coeff"]))

    color_dict={"interpolated": "C0", "valid": "C1", "retrain": "C2"}
    name_dict={"interpolated": "Interpolated", "valid": "MIMICval", "retrain": "MIMICretrain"}
        
    with open(fpath,'w') as fp:
        csv_fp=csv.writer(fp)
        csv_fp.writerow(["split","auroc","auprc","model_key"])

        for model_key in ["interpolated","valid","retrain"]:
        
            all_aurocs=[]
            all_auprcs=[]

            fpr_grid=None
            tprs=[]

            recall_grid=None
            precs=[]            

            for split in ["random_0", "random_1", "random_2", "random_3", "random_4"]:
                labels=labels_dict[("lightgbm", "shap_top20_variables_MIMIConly_{}".format(split),split,model_key)]
                scores=scores_dict[("lightgbm", "shap_top20_variables_MIMIConly_{}".format(split),split,model_key)]

                split_prevalence=np.sum(labels==1.0)/labels.size
                prevalence_bern=configs["target_prevalence_bern"]
                local_correct_factor=(1-prevalence_bern)*split_prevalence/ (prevalence_bern*(1-split_prevalence) )
                fpr_split,tpr_split,_=skmetrics.roc_curve(labels,scores)
                precs_split,recalls_split,_=score_metrics(labels,scores,correct_factor=local_correct_factor)

                if fpr_grid is None:
                    fpr_grid=fpr_split

                if recall_grid is None:
                    recall_grid=recalls_split

                tprs.append(scipy.interp(fpr_grid,fpr_split,tpr_split))
                precs.append(scipy.interp(recall_grid,recalls_split[::-1],precs_split[::-1]))

                auroc=skmetrics.roc_auc_score(labels,scores)
                auprc=skmetrics.auc(recalls_split,precs_split)

                all_aurocs.append(auroc)
                all_auprcs.append(auprc)

                csv_fp.writerow([split,str(auroc), str(auprc),model_key])

            mean_tprs=np.mean(tprs,axis=0)
            std_tprs=np.std(tprs,axis=0)
            tprs_lower=np.maximum(mean_tprs-std_tprs,0)
            tprs_upper=np.minimum(mean_tprs+std_tprs,1)

            mean_precs=np.mean(precs,axis=0)
            std_precs=np.std(precs,axis=0)
            precs_lower=np.maximum(mean_precs-std_precs,0)
            precs_upper=np.minimum(mean_precs+std_precs,1)

            if configs["plot_type"]=="roc":
                plt.plot(fpr_grid,mean_tprs,color=color_dict[model_key],label="{}, AUROC: {:.3f} ({:.3f})".format(name_dict[model_key],np.mean(all_aurocs),np.std(all_aurocs)))
                plt.fill_between(fpr_grid,tprs_lower,tprs_upper,color=color_dict[model_key],alpha=0.2)
            else:
                plt.plot(recall_grid,mean_precs,color=color_dict[model_key],label="{}: AUPRC: {:.3f} ({:.3f})".format(name_dict[model_key],np.mean(all_auprcs),np.std(all_auprcs)))
                plt.fill_between(recall_grid,precs_lower,precs_upper,color=color_dict[model_key],alpha=0.2)

        if configs["plot_type"]=="roc":

            aux_curves_orig=mlhc_io.load_pickle(configs["aux_curves_path"])

            auroc_held_out=skmetrics.auc(aux_curves_orig["bern_fpr"],aux_curves_orig["bern_tpr"])
            auroc_t1=skmetrics.auc(aux_curves_orig["bern_fpr_t1"],aux_curves_orig["bern_tpr_t1"])
            auroc_t2=skmetrics.auc(aux_curves_orig["bern_fpr_t2"],aux_curves_orig["bern_tpr_t2"])
            auroc_t3=skmetrics.auc(aux_curves_orig["bern_fpr_t3"],aux_curves_orig["bern_tpr_t3"])
            auroc_t4=skmetrics.auc(aux_curves_orig["bern_fpr_t4"],aux_curves_orig["bern_tpr_t4"])
            auroc_t5=skmetrics.auc(aux_curves_orig["bern_fpr_t5"],aux_curves_orig["bern_tpr_t5"])
            std_aurocs=np.std([auroc_t1,auroc_t2,auroc_t3,auroc_t4,auroc_t5])

            ip_tprs=[]
            ip_tprs.append(scipy.interp(aux_curves_orig["bern_fpr"], aux_curves_orig["bern_fpr_t1"], aux_curves_orig["bern_tpr_t1"]))
            ip_tprs.append(scipy.interp(aux_curves_orig["bern_fpr"], aux_curves_orig["bern_fpr_t2"], aux_curves_orig["bern_tpr_t2"]))
            ip_tprs.append(scipy.interp(aux_curves_orig["bern_fpr"], aux_curves_orig["bern_fpr_t3"], aux_curves_orig["bern_tpr_t3"]))
            ip_tprs.append(scipy.interp(aux_curves_orig["bern_fpr"], aux_curves_orig["bern_fpr_t4"], aux_curves_orig["bern_tpr_t4"]))
            ip_tprs.append(scipy.interp(aux_curves_orig["bern_fpr"], aux_curves_orig["bern_fpr_t5"], aux_curves_orig["bern_tpr_t5"]))
            std_tprs=np.std(ip_tprs,axis=0)
            tprs_upper=np.minimum(aux_curves_orig["bern_tpr"]+std_tprs,1)
            tprs_lower=np.maximum(aux_curves_orig["bern_tpr"]-std_tprs,0)
            
            plt.plot(aux_curves_orig["bern_fpr"],aux_curves_orig["bern_tpr"],color="C4",label="Original HiRID, AUROC: {:.3f} ({:.3f})".format(auroc_held_out,std_aurocs))
            plt.fill_between(aux_curves_orig["bern_fpr"],tprs_lower,tprs_upper,color="C4",alpha=0.2)
            
            plt.plot([0, 1], [0, 1], color='grey', lw=0.5, linestyle='--',rasterized=True)
            ax=plt.gca()
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.set_aspect(1.0)
            ax.grid(which="both", lw=0.5)
            plt.xlabel('1 - Specificity')
            plt.ylabel('Sensitivity')
            plt.ylim([0.0, 1.0])
            plt.xlim([0.0, 1.0])
            plt.legend(loc="lower right")
            plt.title("Interpolated score performance")
            plt.tight_layout()
            plt.savefig(os.path.join(configs["plot_dir"],"interpolated_score_roc.pdf"),bbox_inches="tight",dpi=1200,transparent=True)
            plt.savefig(os.path.join(configs["plot_dir"],"interpolated_score_roc.png"),bbox_inches="tight")
            plt.clf()

        elif configs["plot_type"]=="prc":

            aux_curves_orig=mlhc_io.load_pickle(configs["aux_curves_path_pr"])

            auprc_held_out=skmetrics.auc(aux_curves_orig["bern_recalls"],aux_curves_orig["bern_precs"])
            auprc_t1=skmetrics.auc(aux_curves_orig["bern_recalls_t1"],aux_curves_orig["bern_precs_t1"])
            auprc_t2=skmetrics.auc(aux_curves_orig["bern_recalls_t2"],aux_curves_orig["bern_precs_t2"])
            auprc_t3=skmetrics.auc(aux_curves_orig["bern_recalls_t3"],aux_curves_orig["bern_precs_t3"])
            auprc_t4=skmetrics.auc(aux_curves_orig["bern_recalls_t4"],aux_curves_orig["bern_precs_t4"])
            auprc_t5=skmetrics.auc(aux_curves_orig["bern_recalls_t5"],aux_curves_orig["bern_precs_t5"])
            std_auprcs=np.std([auprc_t1,auprc_t2,auprc_t3,auprc_t4,auprc_t5])

            ip_precs=[]
            ip_precs.append(scipy.interp(aux_curves_orig["bern_recalls"], aux_curves_orig["bern_recalls_t1"][::-1], aux_curves_orig["bern_precs_t1"][::-1]))
            ip_precs.append(scipy.interp(aux_curves_orig["bern_recalls"], aux_curves_orig["bern_recalls_t2"][::-1], aux_curves_orig["bern_precs_t2"][::-1]))
            ip_precs.append(scipy.interp(aux_curves_orig["bern_recalls"], aux_curves_orig["bern_recalls_t3"][::-1], aux_curves_orig["bern_precs_t3"][::-1]))
            ip_precs.append(scipy.interp(aux_curves_orig["bern_recalls"], aux_curves_orig["bern_recalls_t4"][::-1], aux_curves_orig["bern_precs_t4"][::-1]))
            ip_precs.append(scipy.interp(aux_curves_orig["bern_recalls"], aux_curves_orig["bern_recalls_t5"][::-1], aux_curves_orig["bern_precs_t5"][::-1]))
            std_precs=np.std(ip_precs,axis=0)
            precs_upper=np.minimum(aux_curves_orig["bern_precs"]+std_precs,1)
            precs_lower=np.maximum(aux_curves_orig["bern_precs"]-std_precs,0)
            
            plt.plot(aux_curves_orig["bern_recalls"],aux_curves_orig["bern_precs"],color="C4",label="Original HiRID, AUROC: {:.3f} ({:.3f})".format(auprc_held_out,std_auprcs))
            plt.fill_between(aux_curves_orig["bern_recalls"],precs_lower,precs_upper,color="C4",alpha=0.2)
            
            ax=plt.gca()
            ax.set_aspect(1.0)
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.grid(which="both", lw=0.5)
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.ylim([0.0, 1.0])
            plt.xlim([0.0, 1.0])
            plt.title('Interpolated score performance')  
            plt.legend(loc="upper right")
            plt.tight_layout()
            plt.savefig(os.path.join(configs["plot_dir"],"interpolated_score_prc.pdf"),bbox_inches="tight",dpi=1200,transparent=True)
            plt.savefig(os.path.join(configs["plot_dir"],"interpolated_score_prc.png"),bbox_inches="tight")
            plt.clf()

        else:
            print("No plot is produced...",flush=True)
Пример #3
0
def cluster_ml_input(configs):
    ''' Computes ML features in the current configuration on all possible label+imputed data configurations'''

    job_index = 0
    subprocess.call(["source activate default_py36"], shell=True)
    mem_in_mbytes = configs["mbytes_per_job"]
    n_cpu_cores = 1
    n_compute_hours = configs["hours_per_job"]
    is_dry_run = configs["dry_run"]
    bad_hosts = ["lm-a2-003", "lm-a2-004"]

    if configs["dataset"] == "bern":
        features_output_dir = configs["bern_features_dir"]
    elif configs["dataset"] == "mimic":
        features_output_dir = configs["mimic_features_dir"]
    else:
        print("ERROR: Invalid data-set specified..")
        sys.exit(1)

    if not is_dry_run and not configs["preserve_logs"]:
        print("Deleting previous log files...")
        for logf in os.listdir(configs["log_dir"]):
            os.remove(os.path.join(configs["log_dir"], logf))

    for reduce_config in configs["DATA_SCHEMAS"]:

        for split_key in configs["SPLIT_SCHEMAS"]:

            if reduce_config == "reduced":
                split_dir = os.path.join(features_output_dir, "reduced",
                                         split_key)
            else:
                split_dir = os.path.join(features_output_dir, split_key)

            if not is_dry_run:
                mlhc_fs.create_dir_if_not_exist(split_dir)

            for label_key in ["AllLabels"]:

                for lhours, rhours in configs["LABEL_SCHEMAS"]:

                    output_base_key = "{}_{}_{}".format(
                        label_key, float(lhours), float(rhours))
                    ml_output_dir = os.path.join(split_dir, output_base_key)

                    if not is_dry_run:
                        mlhc_fs.create_dir_if_not_exist(ml_output_dir)

                    X_output_dir = os.path.join(ml_output_dir, "X")
                    y_output_dir = os.path.join(ml_output_dir, "y")

                    if not is_dry_run:
                        mlhc_fs.create_dir_if_not_exist(X_output_dir)
                        mlhc_fs.create_dir_if_not_exist(y_output_dir)

                    for batch_idx in range(50):

                        print(
                            "Create features for split {} with reduced data: {}, label: {} [{},{}], batch: {}"
                            .format(split_key, reduce_config, label_key,
                                    lhours, rhours, batch_idx))
                        job_name = "featgen_{}_{}_{}_{}_{}_{}_{}".format(
                            split_key, reduce_config, label_key, lhours,
                            rhours, batch_idx,
                            features_output_dir.split("/")[-1])
                        log_result_file = os.path.join(
                            configs["log_dir"],
                            "{}_RESULT.txt".format(job_name))
                        mlhc_fs.delete_if_exist(log_result_file)
                        cmd_line = " ".join([
                            "bsub", "-R",
                            "rusage[mem={}]".format(mem_in_mbytes), "-n",
                            "{}".format(n_cpu_cores), "-r", "-W",
                            "{}:00".format(n_compute_hours), " ".join([
                                '-R "select[hname!=\'{}\']"'.format(bad_host)
                                for bad_host in bad_hosts
                            ]), "-J", "{}".format(job_name), "-o",
                            log_result_file, "python3",
                            configs["compute_script_path"],
                            "--run_mode CLUSTER", "--dataset {}".format(
                                configs["dataset"]),
                            "--missing_values_mode {}".format(
                                configs["missing_values_mode"]),
                            "--split_key {}".format(split_key),
                            "--data_mode {}".format(reduce_config),
                            "--label_key {}".format(label_key),
                            "--lhours {}".format(lhours),
                            "--rhours {}".format(rhours),
                            "--batch_idx {}".format(batch_idx)
                        ])
                        assert (" rm " not in cmd_line)
                        job_index += 1

                        if configs["dry_run"]:
                            print("CMD: {}".format(cmd_line))
                        else:
                            subprocess.call([cmd_line], shell=True)

                            if configs["debug_mode"]:
                                sys.exit(0)

    print("Generated {} jobs...".format(job_index))
Пример #4
0
def cluster_labels(configs):
    ''' Computes labels for all possible impute data and label type combinations'''
    compute_script_path = configs["compute_script_path"]
    job_index = 0
    subprocess.call(["source activate default_py36"], shell=True)
    mem_in_mbytes = configs["compute_mem"]
    n_cpu_cores = configs["compute_n_cores"]
    n_compute_hours = configs["compute_n_hours"]
    bad_hosts = ["lm-a2-003", "lm-a2-004"]

    if configs["dataset"] == "bern":
        label_base_path = configs["bern_output_label_path"]
    elif configs["dataset"] == "mimic":
        label_base_path = configs["mimic_output_label_path"]

    is_dry_run = configs["dry_run"]

    if not is_dry_run and not configs["preserve_logs"]:
        print("Deleting previous log-files...")
        for logf in os.listdir(configs["log_dir"]):
            os.remove(os.path.join(configs["log_dir"], logf))

    for reduce_config in configs["DATA_MODES"]:
        for split_key in configs["SPLIT_MODES"]:

            if reduce_config == "reduced":
                split_base_dir = os.path.join(label_base_path, "reduced",
                                              split_key)
            else:
                split_base_dir = os.path.join(label_base_path, split_key)

            mlhc_fs.create_dir_if_not_exist(split_base_dir)

            for label_key in ["AllLabels"]:

                if reduce_config == "reduced":
                    label_base_dir = os.path.join(label_base_path, "reduced",
                                                  split_key, label_key)
                else:
                    label_base_dir = os.path.join(label_base_path, split_key,
                                                  label_key)

                mlhc_fs.create_dir_if_not_exist(label_base_dir)

                for lhours, rhours in configs["LABEL_SCHEMAS"]:

                    for batch_idx in range(50):

                        print(
                            "Create label patient data for split {} with reduced data: {}, label: {} [{},{}], batch {}"
                            .format(split_key, reduce_config, label_key,
                                    lhours, rhours, batch_idx))
                        job_name = "labelgen_{}_{}_{}_{}_{}_{}".format(
                            split_key, reduce_config, label_key, lhours,
                            rhours, batch_idx)
                        log_result_file = os.path.join(
                            configs["log_dir"],
                            "{}_RESULT.txt".format(job_name))
                        mlhc_fs.delete_if_exist(log_result_file)
                        cmd_line = " ".join([
                            "bsub", "-R",
                            "rusage[mem={}]".format(mem_in_mbytes), "-n",
                            "{}".format(n_cpu_cores), "-r", "-W",
                            "{}:00".format(n_compute_hours), " ".join([
                                '-R "select[hname!=\'{}\']"'.format(bad_host)
                                for bad_host in bad_hosts
                            ]), "-J", "{}".format(job_name), "-o",
                            log_result_file, "python3", compute_script_path,
                            "--run_mode CLUSTER",
                            "--split_key {}".format(split_key),
                            "--data_mode {}".format(reduce_config),
                            "--label_key {}".format(label_key),
                            "--lhours {}".format(lhours),
                            "--dataset {}".format(configs["dataset"]),
                            "--rhours {}".format(rhours),
                            "--batch_idx {}".format(batch_idx)
                        ])
                        assert (" rm " not in cmd_line)
                        job_index += 1

                        if is_dry_run:
                            print("CMD: {}".format(cmd_line))
                        else:
                            subprocess.call([cmd_line], shell=True)

                            if configs["debug_mode"]:
                                sys.exit(0)

    print("Number of generated jobs: {}".format(job_index))
Пример #5
0
def cluster_learning_serial(configs):
    LABEL_SCHEMAS=[(0,8)] 
    ALL_TASKS=["WorseStateFromZero"]
    random.seed(configs["random_seed"])
    job_index=0
    subprocess.call(["source activate default_py36"],shell=True)
    mem_in_mbytes=configs["mbytes_per_job"]
    n_cpu_cores=configs["num_cpu_cores"]
    n_compute_hours=configs["hours_per_job"]
    is_dry_run=configs["dry_run"]
    ml_model=configs["ml_model"]
    col_desc=configs["col_desc"]
    bad_hosts=["lm-a2-003","lm-a2-004"]

    if not is_dry_run and not configs["preserve_logs"]:
        print("Deleting previous log files...")
        for logf in os.listdir(configs["log_dir"]):
            os.remove(os.path.join(configs["log_dir"], logf))

    for reduce_config in configs["DATA_CONFIGS"]:
        for split_key in configs["SPLIT_CONFIGS"]:

            if reduce_config=="reduced":
                split_dir=os.path.join(configs["pred_dir"], "reduced",split_key)
            else:
                split_dir=os.path.join(configs["pred_dir"],split_key)

            if not is_dry_run:
                mlhc_fs.create_dir_if_not_exist(split_dir)

            for task_key in ALL_TASKS:
                for lhours,rhours in LABEL_SCHEMAS:

                    if configs["xinrui_subsample"]:
                        output_base_key="{}_{}_{}_{}_{}_xinrui".format(task_key, float(lhours), float(rhours), col_desc, ml_model)
                    else:
                        output_base_key="{}_{}_{}_{}_{}_full".format(task_key, float(lhours), float(rhours), col_desc, ml_model)

                    pred_output_dir=os.path.join(split_dir,output_base_key)

                    if not is_dry_run:
                        mlhc_fs.create_dir_if_not_exist(pred_output_dir)

                    print("Fit ML model for split {} with reduced data: {}, task: {} [{},{}], ML model: {}".format(split_key,reduce_config,task_key,
                                                                                                                   lhours,rhours, ml_model))

                    job_name="mlfit_{}_{}_{}_{}_{}_{}_{}".format(configs["col_desc"],split_key,reduce_config,task_key,lhours,rhours,ml_model)
                    log_result_file=os.path.join(configs["log_dir"],"{}_RESULT.txt".format(job_name))
                    mlhc_fs.delete_if_exist(log_result_file)

                    if ml_model=="lightgbm":
                        cmd_line=" ".join(["bsub", "-R" , "span[hosts=1]", "-R", "rusage[mem={}]".format(mem_in_mbytes), 
                                           "-n", "{}".format(n_cpu_cores), "-r",
                                           "-W", "{}:00".format(n_compute_hours), 
                                           "-J","{}".format(job_name), "-o", log_result_file,
                                           " ".join(['-R "select[hname!=\'{}\']"'.format(bad_host) for bad_host in bad_hosts]),                                           
                                           "python3", configs["compute_script_path"], "--run_mode CLUSTER",
                                           "--ml_model {}".format(ml_model), "--split_key {}".format(split_key), "--data_mode {}".format(reduce_config),
                                           "--special_development_split {}".format(configs["special_development_split"]),
                                           "--column_set {}".format(configs["col_desc"]), ("--add_shapelets" if configs["add_shapelets"] else ""),
                                           ("--negative_subsampling" if configs["negative_subsampling"] else ""), ("--use_catboost" if configs["use_catboost"] else ""),
                                           ("--50percent_sample_train" if configs["50percent_sample_train"] else ""),
                                           ("--20percent_sample_train" if configs["20percent_sample_train"] else ""),
                                           ("--10percent_sample_train" if configs["10percent_sample_train"] else ""),
                                           ("--5percent_sample_train" if configs["5percent_sample_train"] else ""),
                                           ("--1percent_sample_train" if configs["1percent_sample_train"] else ""),
                                           ("--0.1percent_sample_train" if configs["0.1percent_sample_train"] else ""),
                                           ("--decision_tree_mode" if configs["decision_tree_mode"] else ""),
                                           ("--logreg_mode" if configs["logreg_mode"] else ""),
                                           ("--mlp_mode" if configs["mlp_mode"] else ""), 
                                           ("--decision_tree_baseline" if configs["decision_tree_baseline"] else ""),
                                           ("--1percent_sample" if configs["1percent_sample"] else ""),
                                           "--dataset {}".format(configs["dataset"]),
                                           "--mimic_split_key {}".format(configs["mimic_split_key"]),
                                           ("--special_year {}".format(configs["special_year"]) if configs["special_year"] is not None else ""),
                                           "--special_test_set {}".format(configs["special_test_set"]),
                                           "--task_key {}".format(task_key), "--lhours {}".format(lhours), 
                                           "--rhours {}".format(rhours), "--ml_model {}".format(ml_model)])

                    elif ml_model=="logreg":
                        cmd_line=" ".join(["bsub", "-R" , "span[hosts=1]", "-R", "rusage[mem={}]".format(mem_in_mbytes), "-R", "rusage[ngpus_excl_t=1]",
                                           "-n", "{}".format(n_cpu_cores), "-r",
                                           "-W", "{}:00".format(n_compute_hours), 
                                           "-J","{}".format(job_name), "-o", log_result_file, "python3", configs["compute_script_path"], "--run_mode CLUSTER",
                                           "--ml_model {}".format(ml_model), "--split_key {}".format(split_key), "--data_mode {}".format(reduce_config),
                                           "--column_set {}".format(configs["col_desc"]), "--logreg_alpha {}".format(best_hps["alpha"]),
                                           "--task_key {}".format(task_key), "--lhours {}".format(lhours), 
                                           "--rhours {}".format(rhours)])

                    assert(" rm " not in cmd_line)
                    job_index+=1

                    if configs["dry_run"]:
                        print("CMD: {}".format(cmd_line))
                    else:
                        subprocess.call([cmd_line], shell=True)

                        if configs["debug_mode"]:
                            sys.exit(0)
                        

    print("Generated {} jobs...".format(job_index))
Пример #6
0
def save_ml_input(configs):
    split_key=configs["split_key"]
    reduced_data_str=configs["data_mode"]
    label_key=configs["label_key"]
    lhours=configs["lhours"]
    rhours=configs["rhours"]
    batch_idx=configs["batch_idx"]

    assert(reduced_data_str in ["reduced", "non_reduced"])
    
    if reduced_data_str=="reduced":
        dim_reduced_data=True
    else:
        dim_reduced_data=False

    print("Job SPLIT: {}, REDUCED?: {}, LABEL: {}, INTERVAL: [{},{}], BATCH: {}".format(split_key, dim_reduced_data, label_key, lhours, rhours, batch_idx),flush=True)

    output_base_key="{}_{}_{}".format(label_key, float(lhours), float(rhours))

    if configs["dataset"]=="bern":
        imputed_base_path=configs["bern_imputed_path"]
        label_base_path=configs["bern_label_path"]
        output_base_path=configs["bern_output_features_path"]
    elif configs["dataset"]=="mimic":
        imputed_base_path=configs["mimic_imputed_path"]
        label_base_path=configs["mimic_label_path"]
        output_base_path=configs["mimic_output_features_path"]
    else:
        print("ERROR: Invalid data-set specified")
        sys.exit(1)

    if dim_reduced_data:
        fmat_dir=os.path.join(imputed_base_path,"reduced",split_key)
        lmat_path=os.path.join(label_base_path,"reduced", split_key, label_key,"{:.1f}To{:.1f}Hours".format(int(lhours),int(rhours)))
        ml_output_dir=os.path.join(output_base_path,"reduced",split_key,output_base_key)
        var_encoding_dict=mlhc_io.load_pickle(configs["meta_varenc_map_path"])
        var_parameter_dict=mlhc_io.load_pickle(os.path.join(configs["meta_varprop_map_path"],"interval_median_{}.pickle".format(split_key)))
        pharma_dict=np.load(configs["pharma_acting_period_map_path"],allow_pickle=True).item()
    else:
        fmat_dir=os.path.join(imputed_base_path,split_key)
        lmat_path=os.path.join(label_base_path,split_key, label_key, "{:.1f}To{:.1f}Hours".format(int(lhours),int(rhours)))
        ml_output_dir=os.path.join(output_base_path,split_key,output_base_key)
        var_encoding_dict=mlhc_io.load_pickle(configs["varenc_map_path"])
        var_parameter_dict=mlhc_io.load_pickle(os.path.join(configs["varprop_map_path"], "interval_median_{}.pickle".format(split_key)))

    data_split=mlhc_io.load_pickle(configs["temporal_split_path"])[split_key]

    if configs["dataset"]=="bern":
        all_pids=data_split["train"]+data_split["val"]+data_split["test"]
    elif configs["dataset"]=="mimic":
        all_pids=list(map(int, mlhc_io.read_list_from_file(configs["mimic_all_pid_list_path"])))

    if configs["verbose"]:
        print("Number of patient IDs: {}".format(len(all_pids),flush=True))

    if configs["dataset"]=="bern":
        batch_map=mlhc_io.load_pickle(configs["bern_pid_map_path"])["chunk_to_pids"]
    elif configs["dataset"]=="mimic":
        batch_map=mlhc_io.load_pickle(configs["mimic_pid_map_path"])["chunk_to_pids"]

    pids_batch=batch_map[batch_idx]
    selected_pids=list(set(pids_batch).intersection(all_pids))
    print("Number of selected PIDs for this batch: {}".format(len(selected_pids)),flush=True)
    batch_path=os.path.join(fmat_dir,"batch_{}.h5".format(batch_idx))
    n_skipped_patients=0

    if not os.path.exists(batch_path):
        print("WARNING: No input data for batch, skipping...",flush=True)
        print("Generated path: {}".format(batch_path))
        return 0

    if configs["missing_values_mode"]=="finite":
        tf_model=bern_tf_features.Features(dim_reduced_data=dim_reduced_data, 
                                           impute_grid_unit=configs["impute_grid_period_secs"],
                                           dataset=configs["dataset"])
    elif configs["missing_values_mode"]=="missing":
        tf_model=bern_tf_features_nan.FeaturesWithMissingVals(dim_reduced_data=dim_reduced_data, 
                                                              impute_grid_unit=configs["impute_grid_period_secs"],
                                                              dataset=configs["dataset"])
    else:
        print("ERROR: Invalid missing value mode specified...")
        sys.exit(1)
        
    tf_model.set_varencoding_dict(var_encoding_dict)
    tf_model.set_varparameters_dict(var_parameter_dict)
    tf_model.set_pharma_dict(pharma_dict)
    X_output_dir=os.path.join(ml_output_dir,"X")
    y_output_dir=os.path.join(ml_output_dir,"y")
    
    try:
        if not configs["debug_mode"]:
            mlhc_fs.create_dir_if_not_exist(X_output_dir, recursive=True)
            mlhc_fs.create_dir_if_not_exist(y_output_dir, recursive=True)
    except:
        print("WARNING: Race condition when creating directory from different jobs...")

    lab_path=os.path.join(lmat_path,"batch_{}.h5".format(batch_idx))

    if not os.path.exists(lab_path):
        print("WARNING: No input label data for batch, skipping...",flush=True)
        return 0

    first_write=True
    t_begin=timeit.default_timer()

    for pidx,pid in enumerate(selected_pids):
        df_pat=pd.read_hdf(batch_path,mode='r',where="PatientID={}".format(pid))
        df_label_pat=pd.read_hdf(lab_path,mode='r',where="PatientID={}".format(pid))

        if df_pat.shape[0]==0 or df_label_pat.shape[0]==0:
            n_skipped_patients+=1
            continue

        assert(df_pat.shape[0]==df_label_pat.shape[0])
        assert(is_df_sorted(df_pat,"AbsDatetime"))
        assert(is_df_sorted(df_label_pat,"AbsDatetime"))
        df_X, df_y=tf_model.transform(df_pat, df_label_pat, pid=pid)

        if df_X is None or df_y is None:
            print("WARNING: Features could not be generated for PID: {}".format(pid), flush=True)
            n_skipped_patients+=1
            continue

        assert(df_X.shape[0]==df_y.shape[0])
        assert(df_X.shape[0]==df_pat.shape[0])

        if first_write:
            open_mode='w'
        else:
            open_mode='a'

        if not configs["debug_mode"]:
            df_X.to_hdf(os.path.join(X_output_dir,"batch_{}.h5".format(batch_idx)),"/{}".format(pid),format="fixed",append=False,mode=open_mode,
                        complevel=configs["hdf_comp_level"],complib=configs["hdf_comp_alg"])
            df_y.to_hdf(os.path.join(y_output_dir,"batch_{}.h5".format(batch_idx)),"/{}".format(pid),format="fixed",append=False,mode=open_mode,
                        complevel=configs["hdf_comp_level"],complib=configs["hdf_comp_alg"])

        first_write=False

        print("Job {}: {:.2f} %".format(batch_idx, (pidx+1)/len(selected_pids)*100),flush=True)
        t_current=timeit.default_timer()
        tpp=(t_current-t_begin)/(pidx+1)
        eta_minutes=(len(selected_pids)-(pidx+1))*tpp/60.0
        print("ETA [minutes]: {:.2f}".format(eta_minutes),flush=True)
        print("Number of skipped patients: {}".format(n_skipped_patients),flush=True)

    return 0