Esempio n. 1
0
def output_sim_data(model, surv, X_train, df_train, X_test, df_test):
    """ Compute the output of the model on the test set
    # Arguments
        model: neural network model trained with final parameters.
        X_train : input variables of the training set
        df_train: training dataset
        X_val : input variables of the validation set
        df_val: validation dataset
    # Returns
        results_test: Uno C-index at median survival time and Integrated Brier Score
    """
    time_grid = np.linspace(np.percentile(df_test['yy'], 10),
                            np.percentile(df_test['yy'], 90), 100)
    median_time = np.percentile(df_test['yy'], 50)
    data_train = skSurv.from_arrays(event=df_train['status'],
                                    time=df_train['yy'])
    data_test = skSurv.from_arrays(event=df_test['status'], time=df_test['yy'])

    c_med = concordance_index_ipcw(
        data_train, data_test,
        np.array(-determine_surv_prob(surv, median_time)), median_time)[0]
    ev = EvalSurv(surv,
                  np.array(df_test['yy']),
                  np.array(df_test['status']),
                  censor_surv='km')
    ibs = ev.integrated_brier_score(time_grid)
    res = pd.DataFrame([c_med, ibs]).T
    res.columns = ['c_median', 'ibs']
    return res
    def evaluate(self, predictions, predictions_interpolated):
        """
        Function responsible for computing evaluation metrics (C-index, Brier score). 
        Also responsible for logging said metrics to MLflow using my Logger module.
        
        """
        self.logger.info("Evaluating Model Strength.. \n")
        targets = dataset_2(self.dataset_filtered, phase='test',
                            targets=True)[0][0]
        events = dataset_2(self.dataset_filtered, phase='test',
                           targets=True)[0][1]
        survs = dataset_2(self.dataset_filtered, phase='test',
                          targets=True)[0][2]
        censored = []
        for surv in events:
            if surv == 1:
                censored.append(False)
            else:
                censored.append(True)
        self.logger.info("Calculating concordance and brier score\n")
        ev = EvalSurv(predictions_interpolated, survs, events, 'km')
        concordance = ev.concordance_td()
        time_grid = np.linspace(0, survs.max())
        integrated_brier_score = ev.integrated_brier_score(time_grid)
        self.logger.info(
            "Concordance: {}, Integrated Brier Score: {}\n".format(
                concordance, integrated_brier_score))

        logger = Logger(self.mlflow_url, "CNN_Predictions")
        logger.log_predictions(predictions, "{}".format(self.mode),
                               concordance, integrated_brier_score,
                               self.params, targets, self.cuts)
        self.logger.info("Logged Evaluation metrics to MLFlow\n")
def metric_fn_par(args):
    predicted_test_times, predicted_survival_functions, observed_y = args

    # predicted survival function dim: n_train * n_test
    obs_test_times = observed_y[:, 0].astype('float32')
    obs_test_events = observed_y[:, 1].astype('float32')
    results = [0, 0, 0, 0, 0]

    ev = EvalSurv(predicted_survival_functions, obs_test_times, obs_test_events, censor_surv='km')
    results[0] = ev.concordance_td('antolini') # concordance_antolini
    results[1] = concordance_index(obs_test_times, predicted_test_times, obs_test_events.astype(np.bool)) # concordance_median

    # we ignore brier scores at the highest test times because it becomes unstable
    time_grid = np.linspace(obs_test_times.min(), obs_test_times.max(), 100)[:80]
    results[2] = ev.integrated_brier_score(time_grid) # integrated_brier

    if sum(obs_test_events) > 0:
        # only noncensored samples are used for rmse/mae calculation
        pred_obs_differences = predicted_test_times[obs_test_events.astype(np.bool)] - obs_test_times[obs_test_events.astype(np.bool)]
        results[3] = np.sqrt(np.mean((pred_obs_differences)**2)) # rmse
        results[4] = np.mean(np.abs(pred_obs_differences)) # mae
    else:
        print("[WARNING] All samples are censored.")
        results[3] = 0
        results[4]  = 0

    return results
Esempio n. 4
0
def Coxnnet_evaluate(model, x1, x2, durations, events):
    #    durations = durations.cpu().numpy()
    #    events = events.cpu().numpy()
    _ = model.compute_baseline_hazards()
    if x2 is not None:
        surv = model.predict_surv_df((x1, x2))
    else:
        surv = model.predict_surv_df(x1)
    ev = EvalSurv(surv, durations, events, censor_surv='km')
    return ev.concordance_td()
def get_metrics(val_data, surv_pred, time_grid):
    ev = EvalSurv(surv_pred, val_data['t'], val_data['y'], censor_surv='km')
    return pd.DataFrame([{
        'dt_c_index':
        ev.concordance_td('antolini'),
        'int_brier_score':
        ev.integrated_brier_score(time_grid),
        'int_nbill':
        ev.integrated_nbll(time_grid)
    }])
Esempio n. 6
0
def test_quality(t_true,
                 y_true,
                 pred,
                 time_grid=np.linspace(0, 300, 30, dtype=np.int),
                 concordance_at_t=None,
                 plot=False):
    # get survival proba for time_grid
    all_surv_time = pd.DataFrame()
    for t in time_grid:
        surv_prob = np.exp(-1 * np.power(t / (pred[:, 0] + 1e-6), pred[:, 1]))
        all_surv_time = pd.concat([all_surv_time, pd.DataFrame(surv_prob).T])
    all_surv_time.index = time_grid

    ev = EvalSurv(surv=all_surv_time,
                  durations=t_true,
                  events=y_true,
                  censor_surv='km')
    dt_c_index = ev.concordance_td('antolini')
    int_brier_score = ev.integrated_brier_score(time_grid)
    int_nbill = ev.integrated_nbll(time_grid)

    if plot:
        fig, ax = plt.subplots(1, 3, figsize=(20, 7))
        d = all_surv_time.sample(5, axis=1).loc[1:]
        obs = d.columns
        for o in obs:
            ax[0].plot(d.index, d[o])
        ax[0].set_xlabel('Time')
        ax[0].set_title("Sample survival curves")
        nb = ev.nbll(time_grid)
        ax[1].plot(time_grid, nb)
        ax[1].set_title('NBLL')
        ax[1].set_xlabel('Time')
        br = ev.brier_score(time_grid)
        ax[2].plot(time_grid, br)
        ax[2].set_title('Brier score')
        ax[2].set_xlabel('Time')
        plt.show()

    if concordance_at_t is not None:
        harell_c_index = concordance_index(
            predicted_scores=all_surv_time.loc[concordance_at_t, :].values,
            event_times=t_true,
            event_observed=y_true)

        return pd.DataFrame([{
            'harell_c_index': harell_c_index,
            'dt_c_index': dt_c_index,
            'int_brier_score': int_brier_score,
            'int_nbill': int_nbill
        }])
    else:
        return pd.DataFrame([{
            'dt_c_index': dt_c_index,
            'int_brier_score': int_brier_score,
            'int_nbill': int_nbill
        }])
Esempio n. 7
0
    def Bootstrap(self, surv, event: list, duration: list):
        np.random.seed(42)  # control reproducibility

        cindex, brier, nbll = [], [], []
        for _ in range(self.bootstrap_n):
            sampled_index = choices(range(surv.shape[1]), k=surv.shape[1])

            sampled_surv = surv.iloc[:, sampled_index]
            sampled_event = [event[i] for i in sampled_index]
            sampled_duration = [duration[i] for i in sampled_index]

            ev = EvalSurv(sampled_surv, np.array(sampled_duration),
                          np.array(sampled_event).astype(int), censor_surv='km')
            time_grid = np.linspace(min(sampled_duration), max(sampled_duration), 100)

            cindex.append(ev.concordance_td('antolini'))
            brier.append(ev.integrated_brier_score(time_grid))
            nbll.append(ev.integrated_nbll(time_grid))

        return cindex, brier, nbll
def run_baseline(runs=10):
    concordance = []
    ibs = []
    
    for i in tqdm(range(runs)):
        df_train,df_test,df_val = load_data("./summaries/survival_data")
        
        x_mapper, labtrans, train, val, x_test, durations_test, events_test, pca = transform_data(
            df_train,df_test,df_val,'LogisticHazard', "standard", cols_standardize, log_columns, num_durations=100)
        x_train, y_train = train

        
        cols = ['PC'+str(i) for i in range(x_train.shape[1])] + ['duration','event']
        pc_col = ['PC'+str(i) for i in range(x_train.shape[1])]
        cox_train = pd.DataFrame(x_train,columns = pc_col)
        cox_test = pd.DataFrame(x_test,columns=pc_col)
        
#        cox_train.loc[:,pc_col] = x_train
        cox_train.loc[:,["duration"]] = y_train[0]
        cox_train.loc[:,'event'] = y_train[1]
#        cox_train = cox_train.drop(columns=[i for i in list(df_train) if i not in cols])
#        cox_test.loc[:,pc_col] = x_test
#        cox_test = cox_test.drop(columns=[i for i in list(df_train) if i not in cols])
        cox_train = cox_train.dropna()
        cox_test = cox_test.dropna()
        cph = CoxPHFitter().fit(cox_train, 'duration', 'event')
#        cph.print_summary()
        surv = cph.predict_survival_function(cox_test)
        ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
        concordance.append(ev.concordance_td('antolini')) 
        time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
        ibs.append(ev.integrated_brier_score(time_grid))
        
        print("Average concordance: %s"%np.mean(concordance))
        print("Average IBS: %s"%np.mean(ibs))
    
    plot_survival(cox_train,
                  pc_col,cph,'./survival/cox',baseline=True)
Esempio n. 9
0
def main():
    parser = setup_parser()
    args = parser.parse_args()

    if args.which_gpu != 'none':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpu

    # save setting
    if not os.path.exists(os.path.join(args.save_path, args.model_name)):
        os.mkdir(os.path.join(args.save_path, args.model_name))

    # label transform
    labtrans = DeepHitSingle.label_transform(args.durations)

    # data reading seeting
    singnal_data_path = args.signal_dataset_path
    table_path = args.table_path
    time_col = 'SurvivalDays'
    event_col = 'Mortality'

    # dataset
    data_pathes, times, events = read_dataset(singnal_data_path, table_path,
                                              time_col, event_col,
                                              args.sample_ratio)

    data_pathes_train, data_pathes_test, times_train, times_test, events_train, events_test = train_test_split(
        data_pathes, times, events, test_size=0.3, random_state=369)
    data_pathes_train, data_pathes_val, times_train, times_val, events_train, events_val = train_test_split(
        data_pathes_train,
        times_train,
        events_train,
        test_size=0.2,
        random_state=369)

    labels_train = label_transfer(times_train, events_train)
    target_train = labtrans.fit_transform(*labels_train)
    dataset_train = VsDatasetBatch(data_pathes_train, *target_train)
    dl_train = tt.data.DataLoaderBatch(dataset_train,
                                       args.train_batch_size,
                                       shuffle=True)

    labels_val = label_transfer(times_val, events_val)
    target_val = labtrans.transform(*labels_val)
    dataset_val = VsDatasetBatch(data_pathes_val, *target_val)
    dl_val = tt.data.DataLoaderBatch(dataset_val,
                                     args.train_batch_size,
                                     shuffle=True)

    labels_test = label_transfer(times_test, events_test)
    dataset_test_x = VsTestInput(data_pathes_test)
    dl_test_x = DataLoader(dataset_test_x, args.test_batch_size, shuffle=False)

    net = resnet18(args)
    model = DeepHitSingle(net,
                          tt.optim.Adam(lr=args.lr,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=5e-4,
                                        amsgrad=False),
                          duration_index=labtrans.cuts)
    # callbacks = [tt.cb.EarlyStopping(patience=15)]
    callbacks = [
        tt.cb.BestWeights(file_path=os.path.join(
            args.save_path, args.model_name, args.model_name + '_bestWeight'),
                          rm_file=False)
    ]
    verbose = True
    model_log = model.fit_dataloader(dl_train,
                                     args.epochs,
                                     callbacks,
                                     verbose,
                                     val_dataloader=dl_val)

    save_args(os.path.join(args.save_path, args.model_name), args)
    model_log.to_pandas().to_csv(os.path.join(args.save_path, args.model_name,
                                              'loss.csv'),
                                 index=False)
    model.save_net(
        path=os.path.join(args.save_path, args.model_name, args.model_name +
                          '_final'))
    surv = model.predict_surv_df(dl_test_x)
    surv.to_csv(os.path.join(args.save_path, args.model_name,
                             'test_sur_df.csv'),
                index=False)
    ev = EvalSurv(surv, *labels_test, 'km')
    print(ev.concordance_td())
    save_cindex(os.path.join(args.save_path, args.model_name),
                ev.concordance_td())
    print('done')
Esempio n. 10
0
def pycox_deep(filename, Y_train, Y_test, opt, choice):
    # choice = {'lr_rate': l, 'batch': b, 'decay': 0, 'weighted_decay': wd, 'net': net, 'index': index}
    X_train, X_test = enc_using_trained_ae(filename, TARGET=opt, ALPHA=0.01, N_ITER=100, L1R=-9999)
    path = './models/analysis/'
    check = 0
    savename = 'model_check_autoen_m5_test_batch+dropout+wd.csv'
    # r=root, d=directories, f = files
    for r, d, f in os.walk(path):
        for file in f:
            if savename in file:
                check = 1

    # X_train = X_train.drop('UR_SG3', axis=1)
    # X_test = X_test.drop('UR_SG3', axis=1)

    x_train = X_train
    x_test = X_test
    x_train['SVDTEPC_G'] = Y_train['SVDTEPC_G']
    x_train['PC_YN'] = Y_train['PC_YN']
    x_test['SVDTEPC_G'] = Y_test['SVDTEPC_G']
    x_test['PC_YN'] = Y_test['PC_YN']

    ## DataFrameMapper ##
    cols_standardize = list(X_train.columns)
    cols_standardize.remove('SVDTEPC_G')
    cols_standardize.remove('PC_YN')

    standardize = [(col, None) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    _ = x_mapper.fit_transform(X_train).astype('float32')
    X_train = x_mapper.transform(X_train).astype('float32')
    X_test = x_mapper.transform(X_test).astype('float32')

    get_target = lambda df: (df['SVDTEPC_G'].values, df['PC_YN'].values)
    y_train = get_target(x_train)
    durations_test, events_test = get_target(x_test)
    in_features = X_train.shape[1]
    print(in_features)
    num_nodes = choice['nodes']
    out_features = 1
    batch_norm = True  # False for batch_normalization
    dropout = 0.01
    output_bias = False
    # net = choice['net']
    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)


    print("training")
    model = CoxPH(net, tt.optim.Adam)
    # lrfinder = model.lr_finder(X_train, y_train, batch_size)

    # lr_best = lrfinder.get_best_lr()
    lr_best = 0.0001
    model.optimizer.set_lr(choice['lr_rate'])

    weighted_decay = choice['weighted_decay']
    verbose = True

    batch_size = choice['batch']
    epochs = 100

    if weighted_decay == 0:
        callbacks = [tt.callbacks.EarlyStopping(patience=epochs)]
        # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose=verbose)
    else:
        callbacks = [tt.callbacks.DecoupledWeightDecay(weight_decay=choice['decay'])]
        # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose)

    ''''''
    # dataloader = model.make_dataloader(tt.tuplefy(X_train, y_train),batch_size,True)
    datas = tt.tuplefy(X_train, y_train).to_tensor()
    print(datas)
    make_dataset = tt.data.DatasetTuple;
    DataLoader = tt.data.DataLoaderBatch
    dataset = make_dataset(*datas)
    dataloader = DataLoader(dataset, batch_size, False, sampler=StratifiedSampler(datas, batch_size))
    # dataloader = DataLoader(dataset,batch_size, True)
    model.fit_dataloader(dataloader, epochs, callbacks, verbose)
    # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose)
    # model.partial_log_likelihood(*val).mean()

    print("predicting")
    baseline_hazards = model.compute_baseline_hazards(datas[0], datas[1])
    baseline_hazards = df(baseline_hazards)

    surv = model.predict_surv_df(X_test)
    surv = 1 - surv
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    print("scoring")
    c_index = ev.concordance_td()
    print("c-index(", opt, "): ", c_index)

    if int(c_index * 10) == 0:
        hazardname = 'pycox_model_hazard_m5_v2_' + opt + '_0'
        netname = 'pycox_model_net_m5_v2_' + opt + '_0'
        weightname = 'pycox_model_weight_m5_v2_' + opt + '_0'
    else:
        hazardname = 'pycox_model_hazard_m5_' + opt + '_'
        netname = 'pycox_model_net_m5_' + opt + '_'
        weightname = 'pycox_model_weight_m5_' + opt + '_'

    baseline_hazards.to_csv('./test/'+hazardname + str(int(c_index * 100)) + '_' + str(index) + '.csv', index=False)
    netname = netname + str(int(c_index * 100)) + '_' + str(index) + '.sav'
    weightname = weightname + str(int(c_index * 100)) + '_' + str(index) + '.sav'
    model.save_net('./test/' + netname)
    model.save_model_weights('./test/' + weightname)

    pred = df(surv)
    pred = pred.transpose()
    surv_final = []
    pred_final = []

    for i in range(len(pred)):
        pred_final.append(float(1-pred[Y_test['SVDTEPC_G'][i]][i]))
        surv_final.append(float(pred[Y_test['SVDTEPC_G'][i]][i]))

    Y_test_cox = CoxformY(Y_test)
    #print(surv_final)
    c_cox, concordant, discordant,_,_ = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], surv_final)
    c_cox_pred = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], pred_final)[0]
    print("c-index(", opt, ") - sksurv: ", round(c_cox, 4))
    print("cox-concordant(", opt, ") - sksurv: ", concordant)
    print("cox-disconcordant(", opt, ") - sksurv: ", discordant)
    print("c-index_pred(", opt, ") - sksurv: ", round(c_cox_pred, 4))

    fpr, tpr, _ = metrics.roc_curve(Y_test['PC_YN'], pred_final)
    auc = metrics.auc(fpr, tpr)
    print("auc(", opt, "): ", round(auc, 4))

    if check == 1:
        model_check = pd.read_csv(path+savename)
    else:
        model_check = df(columns=['option', 'gender', 'c-td', 'c-index', 'auc'])
    line_append = {'option':str(choice), 'gender':opt, 'c-td':round(c_index,4), 'c-index':round(c_cox_pred,4), 'auc':round(auc,4)}
    model_check = model_check.append(line_append, ignore_index=True)
    model_check.to_csv(path+savename, index=False)

    del X_train
    del X_test

    return surv_final
Esempio n. 11
0
def train_deepsurv(data_df, r_splits):
  epochs = 100
  verbose = True

  num_nodes = [32]
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])
    
    xcols = list(df_train.columns)

    for col_name in ["subject_id", "event", "duration"]:
      if col_name in xcols:
        xcols.remove(col_name)

    cols_standardize = xcols

    standardize = [([col], StandardScaler()) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    x_test_30 =  x_mapper.transform(df_test_30).astype('float32')

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)

    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    in_features = x_train.shape[1]

    callbacks = [tt.callbacks.EarlyStopping()]

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256

    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
Esempio n. 12
0
def train_LSTMCox(data_df, r_splits):
  epochs = 100
  verbose = True

  in_features = 768
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])

    x_train = np.array(df_train["x0"].tolist()).astype("float32")
    x_val = np.array(df_val["x0"].tolist()).astype("float32")
    x_test = np.array(df_test["x0"].tolist()).astype("float32")
    x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32")

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)
    
    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    callbacks = [tt.callbacks.EarlyStopping()]

    net = LSTMCox(768, 32, 1, 1)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256
      
    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
Esempio n. 13
0
    y_pred_train_surv = np.cumprod((1 - ypred_surv_train_NN), axis=1)
    oneyr_surv_train = y_pred_train_surv[:, 50]
    oneyr_surv_valid = y_pred_valid_surv[:, 50]
    surv_valid = pd.DataFrame(np.transpose(y_pred_valid_surv))
    surv_valid.index = interval_l
    surv_train = pd.DataFrame(np.transpose(y_pred_train_surv))
    surv_train.index = interval_l
    dict_cv_cindex_train[key] = concordance_index(dataTrain.time,
                                                  oneyr_surv_train)
    dict_cv_cindex_valid[key] = concordance_index(dataValid.time,
                                                  oneyr_surv_valid)
    #scores_train += concordance_index(dataTrain.time,oneyr_surv_train)#,data_train.dead)
    #scores_test += concordance_index(dataValid.time,oneyr_surv_valid)
    #cta.append(concordance_index(dataTrain.time,oneyr_surv_train))
    #cte.append(concordance_index(dataValid.time,oneyr_surv_valid))
    ev_valid = EvalSurv(surv_valid, time_valid, dead_valid, censor_surv='km')
    scores_test += ev_valid.concordance_td()
    ev_train = EvalSurv(surv_train,
                        dataTrain['time'].values,
                        dataTrain['dead'].values,
                        censor_surv='km')
    scores_train += ev_train.concordance_td()
    cta.append(concordance_index(dataTrain.time, oneyr_surv_train))
    cte.append(concordance_index(dataValid.time, oneyr_surv_valid))
    ctda.append(ev_train.concordance_td())
    ctde.append(ev_valid.concordance_td())
    #scores_test += cindex_CV_score(Yvalid, ypred_test)
    #scores_train += cindex_CV_score(Ytrain, ypred_train)

save_loss_png = "output/loss_cv_" + str(h) + "_KIRC.png"
import seaborn as sns
Esempio n. 14
0
    epochs = args.epochs
    callbacks = [tt.callbacks.EarlyStopping(patience=patience)]
    verbose = True
    log = model.fit(x_train,
                    y_train_transformed,
                    batch_size,
                    epochs,
                    callbacks,
                    verbose,
                    val_data=val_transformed,
                    val_batch_size=batch_size)

    # Evaluation ===================================================================
    surv = get_surv(model, x_test)
    ev = EvalSurv(surv,
                  durations_test_transformed,
                  events_test,
                  censor_surv='km')
    # ctd = ev.concordance_td()
    ctd = concordance_index(event_times=durations_test_transformed,
                            predicted_scores=model.predict(x_test).reshape(-1),
                            event_observed=events_test)
    time_grid = np.linspace(durations_test_transformed.min(),
                            durations_test_transformed.max(), 100)
    ibs = ev.integrated_brier_score(time_grid)
    nbll = ev.integrated_nbll(time_grid)
    val_loss = min(log.monitors['val_'].scores['loss']['score'])
    wandb.log({'val_loss': val_loss, 'ctd': ctd, 'ibs': ibs, 'nbll': nbll})
    wandb.finish()
Esempio n. 15
0
    def _survLDA_train_variational_EM(self, word_count_matrix,
                                      survival_or_censoring_times,
                                      censoring_indicator_variables,
                                      val_word_count_matrix,
                                      val_survival_or_censoring_times,
                                      val_censoring_indicator_variables):
        """
        Input:
        - K: number of topics
        - alpha0: hyperparameter for Dirichlet prior

        Output:
        - tau: 2D numpy array; g-th row is for g-th topic and consists of word
            distribution for that topic
        - beta: 1D numpy array with length equal to the number of topics
            (Cox regression coefficients)
        - h0_reformatted: 2D numpy array; first row is sorted unique times and
            second row is the discretized version of log(h0)
        - gamma: 2D numpy array: i-th row is variational parameter for Dirichlet
            distribution for i-th subject (length is number of topics)
        - phi: list of length given by number of subjects; i-th element is a
            2D numpy array with number of rows given by the number of words in
            i-th subject's document and number of columns given by number of
            topics (variational parameter for how much each word of each subject
            belongs to different topics)
        - rmse: estimated median survival time RMSE computed using validation data
        - mae: estimated median survival time MAE computed using validation data
        - cindex: concordance index computed using validation data
        - stop_iter: interation number after stopping
        """
        # *********************************************************************
        # 1. SET UP TOPIC MODEL PARAMETERS BASED ON TRAINING DATA

        self.word_count_matrix = word_count_matrix
        num_subjects, num_words = self.word_count_matrix.shape  # train size and vocabulary size
        # the length of each patient's document (i.e. sum of all words including duplicates)
        doc_length = self.word_count_matrix.sum(axis=1).astype(np.int64)

        # tau tells us what the word distribution is per topic
        # - each row corresponds to a topic
        # - each row is a probability distribution, which we initialize to be a
        #   uniform distribution across the words (so each entry is 1 divided by `num_words`
        if self.random_init_tau:
            self.tau = np.random.rand(self.n_topics, num_words)
            self.tau /= self.tau.sum(axis=1)[:, np.newaxis]  # normalize tau
        else:
            self.tau = np.full((self.n_topics, num_words),
                               1. / num_words,
                               dtype=np.float64)

        # variational distribution parameter gamma tells us the Dirichlet
        # distribution parameter for the topic distribution specific to each subject;
        # we can initialize this to be all ones corresponding to a uniform distribution prior
        self.gamma = np.ones((num_subjects, self.n_topics), dtype=np.float64)

        # variational distribution parameter phi tells us the probabilities of
        # each subject's words coming from each of the K different topics; we can
        # initialize these to be uniform over topics (1/K)
        self.W, self.phi = self._word_count_matrix_to_word_vectors_phi(
            word_count_matrix)
        val_W, _ = self._word_count_matrix_to_word_vectors_phi(
            val_word_count_matrix)

        # *********************************************************************
        # 2. SET UP COX BASELINE HAZARD FUNCTION

        # Cox baseline hazard function h0 can be represented as a finite 1D vector
        death_counter = {}
        for t in survival_or_censoring_times:
            if t not in death_counter:
                death_counter[t] = 1
            else:
                death_counter[t] += 1
        sorted_unique_times = np.sort(list(death_counter.keys()))
        self.sorted_unique_times = sorted_unique_times

        num_unique_times = len(sorted_unique_times)
        self.log_h0_discretized = np.zeros(num_unique_times)
        for r, t in enumerate(sorted_unique_times):
            self.log_h0_discretized[r] = np.log(death_counter[t])

        log_H0 = []
        for r in range(num_unique_times):
            log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)]))
        log_H0 = np.array(log_H0)

        time_map = {t: r for r, t in enumerate(sorted_unique_times)}
        time_order = np.array(
            [time_map[t] for t in survival_or_censoring_times], dtype=np.int64)

        # *********************************************************************
        # 3. EM MAIN LOOP

        pool = multiprocessing.Pool(4)
        stop_iter = 0
        val_censoring_indicator_variables = val_censoring_indicator_variables.astype(
            np.bool)
        for EM_iter_idx in range(self.max_iter):
            # ------------------------------------------------------------------
            # E-step (update gamma, phi; uses helper variables psi, xi)
            if self.verbose:
                print('[Variational EM iteration %d]' % (EM_iter_idx + 1))
                print('  Running E-step...', end='', flush=True)
            tic = time()

            # update gamma (dimensions: `num_subjects` by `K`)
            for i, phi_i in enumerate(self.phi):
                self.gamma[i] = self.alpha + phi_i.sum(axis=0)

            # compute psi (dimensions: `num_subjects` by `K`)
            psi = digamma(self.gamma) - digamma(
                self.gamma.sum(axis=1))[:, np.newaxis]
            # update phi, this normalizes already
            self.phi = pool.map(SurvLDA._compute_updated_phi_i_pmap_helper,
                                [(i, self.phi[i], psi[i], self.W[i], self.tau,
                                  self.beta, np.exp(log_H0[time_order[i]]),
                                  censoring_indicator_variables[i],
                                  doc_length[i], self.n_topics)
                                 for i in range(num_subjects)])

            toc = time()
            if self.verbose:
                print(' Done. Time elapsed: %f second(s).' % (toc - tic),
                      flush=True)

            # --------------------------------------------------------------------
            # M-step (update tau, beta, h0, H0; uses helper variable phi_bar)
            if self.verbose:
                print('  Running M-step...', end='', flush=True)
            tic = time()

            # update tau
            tau = np.zeros((self.n_topics, num_words), dtype=np.float64)
            for i in range(num_subjects):
                for j in range(doc_length[i]):
                    word = self.W[i][j]
                    for k in range(self.n_topics):
                        tau[k][word] += self.phi[i][j, k]
            # normalize tau
            self.tau /= tau.sum(axis=1)[:, np.newaxis]

            phi_bar = np.zeros((num_subjects, self.n_topics), dtype=np.float64)
            for i in range(num_subjects):
                phi_bar[i] = self.phi[i].sum(axis=0) / doc_length[i]

            y_r = np.vstack(
                (survival_or_censoring_times, censoring_indicator_variables)).T
            beta_phi_bar = np.dot(phi_bar, self.beta)
            log_h0_discretized_ = np.zeros(num_unique_times)
            log_H0_ = []
            for r, t in enumerate(sorted_unique_times):
                R_t = np.where(survival_or_censoring_times >= t, True, False)
                log_h0_discretized_[r] = np.log(death_counter[t]) - logsumexp(
                    beta_phi_bar[R_t])
                log_H0_.append(logsumexp(log_h0_discretized_[:(r + 1)]))

            # update beta
            def obj_fun(beta_):
                beta_phi_bar_ = np.dot(phi_bar, beta_)
                fun_val = np.dot(censoring_indicator_variables, beta_phi_bar_)
                # finally, we add in the third term
                for i in range(num_subjects):
                    product_terms = np.dot(self.phi[i],
                                           np.exp(beta_ / doc_length[i]))
                    fun_val -= np.exp(log_H0_[time_order[i]] +
                                      np.sum(np.log(product_terms)))
                return -fun_val

            self.beta = minimize(obj_fun, self.beta).x

            # compute dot product <beta, phi_bar[i]> for each patient i
            # (resulting in a 1D array of length `num_subjects`)
            beta_phi_bar = np.dot(phi_bar, self.beta)
            # print("phi bar is \n", phi_bar)

            # update h0
            for r, t in enumerate(sorted_unique_times):
                R_t = np.where(survival_or_censoring_times >= t, True, False)
                self.log_h0_discretized[r] = np.log(
                    death_counter[t]) - logsumexp(beta_phi_bar[R_t])
            log_H0 = []
            for r in range(num_unique_times):
                log_H0.append(logsumexp(self.log_h0_discretized[:(r + 1)]))
            log_H0 = np.array(log_H0)

            toc = time()
            if self.verbose:
                print(' Done. Time elapsed: %f second(s).' % (toc - tic),
                      flush=True)
                print('  Beta:', self.beta)

            # convergence criteria to decide whether to break out of the for loop early
            surv_estimates = []
            median_estimates = []
            for i in range(len(val_W)):
                preds = self._predict_survival(val_W[i])
                surv_estimates.append(preds[1])
                median_estimates.append(preds[0])
            surv_estimates = np.array(surv_estimates)
            median_estimates = np.array(median_estimates)

            val_ev = EvalSurv(pd.DataFrame(np.transpose(surv_estimates),
                                           index=sorted_unique_times),
                              val_survival_or_censoring_times,
                              val_censoring_indicator_variables,
                              censor_surv='km')
            val_cindex = val_ev.concordance_td('antolini')

            val_rmse = np.sqrt(np.mean((median_estimates[val_censoring_indicator_variables] - \
                            val_survival_or_censoring_times[val_censoring_indicator_variables])**2))
            val_mae = np.mean(np.abs(median_estimates[val_censoring_indicator_variables] - \
                            val_survival_or_censoring_times[val_censoring_indicator_variables]))

            if self.verbose:
                print('  Validation set survival time c-index: %f' %
                      val_cindex)
                print('  Validation set survival time rmse: %f' % val_rmse)
                print('  Validation set survival time mae: %f' % val_mae)

            stop_iter += 1

        pool.close()
        if self.verbose:
            print('  EM algorithm completes training at iteration ', stop_iter)

        return
Esempio n. 16
0
                        out_survival = utils.predict_survival_lognormal(
                            model, X.to(device), times)
                    elif arg.distribution == "exponient":
                        out_survival = utils.predict_survival_exponient(
                            model, X.to(device), times)
                    elif arg.distribution == "weibull":
                        out_survival = utils.predict_survival_weibull(
                            model, X.to(device), times)
                    elif arg.distribution == "combine":
                        out_survival = utils.predict_survival_multiple_distributions(
                            model, X.to(device), times)
                    surv = pd.DataFrame(out_survival, index=times)
                    durations_test, events_test = T.detach().numpy(), E.detach(
                    ).numpy()
                    ev = EvalSurv(surv,
                                  durations_test,
                                  events_test,
                                  censor_surv='km')
                    c_index_train = ev.concordance_td()

            ##### valid
            model.eval()
            batch_size = 8000
            with torch.no_grad():
                for X, T, E in dataloader.load_data(arg.dataname,
                                                    arg.path_clinical_val,
                                                    'test'):
                    T = T / ratio

                    #pdf = misc.cal_pdf(T)

                    if arg.distribution == "lognormal":
verbose = True

#%%time # Magic command for Jupyter Notebook only
log = model.fit(x_train,
                y_train,
                batch_size,
                epochs,
                callbacks,
                verbose,
                val_data=val)

_ = log.plot()

# Evaluation
surv = model.predict_surv_df(x_test)
ev = EvalSurv(surv, durations_test, events_test != 0, censor_surv='km')

ev.concordance_td()

ev.integrated_brier_score(np.linspace(0, durations_test.max(), 100))

cif = model.predict_cif(x_test)
cif1 = pd.DataFrame(cif[0], model.duration_index)
cif2 = pd.DataFrame(cif[1], model.duration_index)

ev1 = EvalSurv(1 - cif1, durations_test, events_test == 1, censor_surv='km')
ev2 = EvalSurv(1 - cif2, durations_test, events_test == 2, censor_surv='km')

ev1.concordance_td()

ev2.concordance_td()
Esempio n. 18
0
                       event_col='status',
                       show_progress=False,
                       step_size=.1)
        elapsed = time.time() - tic
        print('Time elapsed: %f second(s)' % elapsed)
        np.savetxt(time_elapsed_filename, np.array(elapsed).reshape(1, -1))

        # ---------------------------------------------------------------------
        # evaluation
        #

        sorted_y_test = np.unique(y_test[:, 0])
        surv_df = surv_model.predict_survival_function(X_test_std,
                                                       sorted_y_test)
        surv = surv_df.values.T
        ev = EvalSurv(surv_df, y_test[:, 0], y_test[:, 1], censor_surv='km')
        cindex_td = ev.concordance_td('antolini')
        print('c-index (td):', cindex_td)

        linear_predictors = \
            surv_model.predict_log_partial_hazard(X_test_std)
        cindex = concordance_index(y_test[:, 0], -linear_predictors, y_test[:,
                                                                            1])
        print('c-index:', cindex)

        time_grid = np.linspace(sorted_y_test[0], sorted_y_test[-1], 100)
        integrated_brier = ev.integrated_brier_score(time_grid)
        print('Integrated Brier score:', integrated_brier, flush=True)

        test_set_metrics = [cindex_td, integrated_brier]
Esempio n. 19
0
    ypred_train_NN = model.predict_proba(x_train_NN)
    ypred_test_NN = model.predict_proba(x_valid_NN)
    y_pred_valid_surv = np.cumprod((1 - ypred_test_NN), axis=1)
    y_pred_train_surv = np.cumprod((1 - ypred_train_NN), axis=1)
    oneyr_surv_train = y_pred_train_surv[:, 50]
    oneyr_surv_valid = y_pred_valid_surv[:, 50]
    surv_valid = pd.DataFrame(np.transpose(y_pred_valid_surv))
    surv_valid.index = interval_l
    surv_train = pd.DataFrame(np.transpose(y_pred_train_surv))
    surv_train.index = interval_l
    dict_cv_cindex_train[key] = concordance_index(dataTrain.time,
                                                  oneyr_surv_train)
    dict_cv_cindex_valid[key] = concordance_index(dataValid.time,
                                                  oneyr_surv_valid)
    ev_valid = EvalSurv(surv_valid,
                        dataValid['time'].values,
                        dataValid['dead'].values,
                        censor_surv='km')
    scores_test += ev_valid.concordance_td()
    ev_train = EvalSurv(surv_train,
                        dataTrain['time'].values,
                        dataTrain['dead'].values,
                        censor_surv='km')
    scores_train += ev_train.concordance_td()
    cta.append(concordance_index(dataTrain.time, oneyr_surv_train))
    cte.append(concordance_index(dataValid.time, oneyr_surv_valid))
    ctda.append(ev_train.concordance_td())
    ctde.append(ev_valid.concordance_td())
    #scores_train += concordance_index(dataTrain.time,oneyr_surv_train)
    #scores_test += concordance_index(dataValid.time,oneyr_surv_valid)

save_loss_png = "loss_cv_" + str(h) + "_KIRC_vK.png"
Esempio n. 20
0
    # Training ======================================================================

    epochs = args.epochs
    callbacks = [tt.callbacks.EarlyStopping()]
    verbose = True
    log = model.fit(x_train,
                    y_train,
                    batch_size,
                    epochs,
                    callbacks,
                    verbose,
                    val_data=val,
                    val_batch_size=batch_size)
    # log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size)

    # Evaluation ===================================================================

    _ = model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    # ctd = ev.concordance_td()
    ctd = ev.concordance_td()
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

    ibs = ev.integrated_brier_score(time_grid)
    nbll = ev.integrated_nbll(time_grid)
    val_loss = min(log.monitors['val_'].scores['loss']['score'])
    wandb.log({'val_loss': val_loss, 'ctd': ctd, 'ibs': ibs, 'nbll': nbll})
    wandb.finish()
         x_train, y_train = train
         for dim in hiddens:
             for lr in lrs:
                 outpath = "./survival/%s_%s_%s_%s_%s"%(mod,scale,dim,lr,seed)
                 if not os.path.exists(outpath):
                     os.mkdir(outpath)
                 
                 in_features = x_train.shape[1]
                 model = initialize_model(dim,labtrans,in_features)
                 model.optimizer.set_lr(0.001)
 
                 callbacks = [tt.callbacks.EarlyStopping()]
                 log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)
                 
                 surv = model.predict_surv_df(x_test)
                 ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
                 
                 result = pd.DataFrame([[0]*8],columns=["random","model","hiddens",
                                                "lr","scalers","c-index","brier","nll"])
                 
                 result["c-index"] = ev.concordance_td('antolini')  
                 print(ev.concordance_td('antolini') )
                 time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
                 result["brier"] = ev.integrated_brier_score(time_grid) 
                 print(ev.integrated_brier_score(time_grid) )
                 result["nll"] = ev.integrated_nbll(time_grid) 
                 result["lr"] = lr
                 result["model"] = mod
                 result["scaler"] = scale
                 result["random"] = seed
                 result["hiddens"] = dim
Esempio n. 22
0
    def _run_training_loop(self, num_epochs, scheduler, info_freq, log_dir):
        logger = SummaryWriter(log_dir)
        log_info = True

        if info_freq is not None:

            def print_header():
                sub_header = ' Epoch     Loss     Ctd     Loss     Ctd'
                print('-' * (len(sub_header) + 2))
                print('             Training        Validation')
                print('           ------------     ------------')
                print(sub_header)
                print('-' * (len(sub_header) + 2))

            print()

            print_header()

        for epoch in range(1, num_epochs + 1):
            if info_freq is None:
                print_info = False
            else:
                print_info = epoch == 1 or epoch % info_freq == 0

            for phase in ['train', 'val']:
                if phase == 'train':
                    self.model.train()
                else:
                    self.model.eval()

                running_losses = []

                if print_info or log_info:
                    running_durations = torch.FloatTensor().to(self.device)
                    running_censors = torch.LongTensor().to(self.device)
                    running_risks = torch.FloatTensor().to(self.device)

                # Iterate over data
                for data in self.dataloaders[phase]:
                    batch_result = self._process_data_batch(data, phase)
                    loss, risk, time, event = batch_result

                    # Stats
                    running_losses.append(loss.item())
                    running_durations = torch.cat(
                        (running_durations, time.data.float()))
                    running_censors = torch.cat(
                        (running_censors, event.long().data))
                    running_risks = torch.cat((running_risks, risk.detach()))

                epoch_loss = torch.mean(torch.tensor(running_losses))

                surv_probs = self._predictions_to_pycox(running_risks,
                                                        time_points=None)
                running_durations = running_durations.cpu().numpy()
                running_censors = running_censors.cpu().numpy()
                epoch_concord = EvalSurv(
                    surv_probs,
                    running_durations,
                    running_censors,
                    censor_surv='km').concordance_td('adj_antolini')

                if print_info:
                    if phase == 'train':
                        message = f' {epoch}/{num_epochs}'
                    space = 10 if phase == 'train' else 27
                    message += ' ' * (space - len(message))
                    message += f'{epoch_loss:.4f}'
                    space = 19 if phase == 'train' else 36
                    message += ' ' * (space - len(message))
                    message += f'{epoch_concord:.3f}'
                    if phase == 'val':
                        print(message)

                if log_info:
                    self._log_info(phase=phase,
                                   logger=logger,
                                   epoch=epoch,
                                   epoch_loss=epoch_loss,
                                   epoch_concord=epoch_concord)

                if phase == 'val':
                    if scheduler:
                        scheduler.step(epoch_concord)

                    # Record current performance
                    k = list(self.current_perf.keys())[0]
                    self.current_perf['epoch' +
                                      str(epoch)] = self.current_perf.pop(k)
                    self.current_perf['epoch' + str(epoch)] = epoch_concord
                    # Deep copy the model
                    for k, v in self.best_perf.items():
                        if epoch_concord >= v:
                            self.best_perf['epoch' +
                                           str(epoch)] = self.best_perf.pop(k)
                            self.best_perf['epoch' +
                                           str(epoch)] = epoch_concord
                            self.best_wts['epoch' +
                                          str(epoch)] = self.best_wts.pop(k)
                            self.best_wts['epoch' +
                                          str(epoch)] = copy.deepcopy(
                                              self.model.state_dict())
                            break
Esempio n. 23
0
_ = plt.xlabel('Time')

# Interpolating the survival estimates because the survival estimates so far are
# only defined at the 10 times in the discretization grid and the survival
# estimates are therefore a step function rather than a continuous one
surv = model.interpolate(10).predict_surv_df(x_test)

surv.iloc[:, :5].plot(drawstyle='steps-post')
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

# The EvalSurv class contains some useful evaluation criteria for time-to-event prediction.
# We set censor_surv = 'km' to state that we want to use Kaplan-Meier for estimating the
# censoring distribution.

ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

ev.concordance_td('antolini')

# Brier Score
# We can plot the the IPCW Brier score for a given set of times.
# Here we just use 100 time-points between the min and max duration in the test set.
# Note that the score becomes unstable for the highest times.
# It is therefore common to disregard the rightmost part of the graph.

time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
ev.brier_score(time_grid).plot()
plt.ylabel('Brier score')
_ = plt.xlabel('Time')

# Negative binomial log-likelihood
Esempio n. 24
0
                                           fold_y_train_discrete,
                                           batch_size, n_epochs, verbose=False)
                        elapsed = time.time() - tic
                        print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        print('Time elapsed (from previous fitting): '
                              + '%f second(s)' % elapsed)

                    surv_model.sub = 10
                    surv_df = surv_model.predict_surv_df(fold_X_val_std)
                    ev = EvalSurv(surv_df, fold_y_val[:, 0], fold_y_val[:, 1],
                                  censor_surv='km')

                    sorted_fold_y_val = np.sort(np.unique(fold_y_val[:, 0]))
                    time_grid = np.linspace(sorted_fold_y_val[0],
                                            sorted_fold_y_val[-1], 100)

                    surv = surv_df.to_numpy().T

                    cindex_scores.append(ev.concordance_td('antolini'))
                    integrated_brier_scores.append(
                            ev.integrated_brier_score(time_grid))
                    print('  c-index (td):', cindex_scores[-1])
                    print('  Integrated Brier score:',
                          integrated_brier_scores[-1])

                cross_val_cindex = np.mean(cindex_scores)
def main(data_root, cancer_type, anatomical_location, fold):

    # Import the RDF graph for PPI network
    f = open('seen.pkl', 'rb')
    seen = pickle.load(f)
    f.close()
    #####################

    f = open('ei.pkl', 'rb')
    ei = pickle.load(f)
    f.close()

    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    cancer_type_vector = np.zeros((33, ), dtype=np.float32)
    cancer_type_vector[cancer_type] = 1

    cancer_subtype_vector = np.zeros((25, ), dtype=np.float32)
    for i in CANCER_SUBTYPES[cancer_type]:
        cancer_subtype_vector[i] = 1

    anatomical_location_vector = np.zeros((52, ), dtype=np.float32)
    anatomical_location_vector[anatomical_location] = 1
    cell_type_vector = np.zeros((10, ), dtype=np.float32)
    cell_type_vector[CELL_TYPES[cancer_type]] = 1

    pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device)
    pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to(
        device)
    pt_tensor_anatomical_location = torch.FloatTensor(
        anatomical_location_vector).to(device)
    pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device)
    edge_index = torch.LongTensor(ei).to(device)

    # Import a dictionary that maps protiens to their coresponding genes by Ensembl database
    f = open('ens_dic.pkl', 'rb')
    dicty = pickle.load(f)
    f.close()
    dic = {}
    for d in dicty:
        key = dicty[d]
        if key not in dic:
            dic[key] = {}
        dic[key][d] = 1

    # Build a dictionary from ENSG -- ENST
    d = {}
    with open('data1/prot_names1.txt') as f:
        for line in f:
            tok = line.split()
            d[tok[1]] = tok[0]

    clin = [
    ]  # for clinical data (i.e. number of days to survive, days to death for dead patients and days to last followup for alive patients)
    feat_vecs = [
    ]  # list of lists ([[patient1],[patient2],.....[patientN]]) -- [patientX] = [gene_expression_value, diff_gene_expression_value, methylation_value, diff_methylation_value, VCF_value, CNV_value]
    suv_time = [
    ]  # list that include wheather a patient is alive or dead (i.e. 0 for dead and 1 for alive)
    can_types = ["BRCA_v2"]
    data_root = '/ibex/scratch/projects/c2014/sara/'
    for i in range(len(can_types)):
        # file that contain patients ID with their coressponding 6 differnt files names (i.e. files names for gene_expression, diff_gene_expression, methylation, diff_methylation, VCF and CNV)
        f = open(data_root + can_types[i] + '.txt')
        lines = f.read().splitlines()
        f.close()
        lines = lines[1:]
        count = 0
        feat_vecs = np.zeros((len(lines), 17186 * 6), dtype=np.float32)
        i = 0
        for l in tqdm(lines):
            l = l.split('\t')
            clinical_file = l[6]
            surv_file = l[2]
            myth_file = 'myth/' + l[3]
            diff_myth_file = 'diff_myth/' + l[1]
            exp_norm_file = 'exp_count/' + l[-1]
            diff_exp_norm_file = 'diff_exp/' + l[0]
            cnv_file = 'cnv/' + l[4] + '.txt'
            vcf_file = 'vcf/' + 'OutputAnnoFile_' + l[
                5] + '.hg38_multianno.txt.dat'
            # Check if all 6 files are exist for a patient (that's because for some patients, their survival time not reported)
            all_files = [
                myth_file, diff_exp_norm_file, diff_myth_file, exp_norm_file,
                cnv_file, vcf_file
            ]
            for fname in all_files:
                if not os.path.exists(fname):
                    print('File ' + fname + ' does not exist!')
                    sys.exit(1)
    #         f = open(clinical_file)
    #         content = f.read().strip()
    #         f.close()
            clin.append(clinical_file)
            #         f = open(surv_file)
            #         content = f.read().strip()
            #         f.close()
            suv_time.append(surv_file)
            temp_myth = myth_data(myth_file, seen, d, dic)
            vec = np.array(get_data(exp_norm_file, diff_exp_norm_file,
                                    diff_myth_file, cnv_file, vcf_file,
                                    temp_myth, seen, dic),
                           dtype=np.float32)
            vec = vec.flatten()
            #         vec = np.concatenate([
            #             vec, cancer_type_vector, cancer_subtype_vector,
            #             anatomical_location_vector, cell_type_vector])
            feat_vecs[i, :] = vec
            i += 1

    min_max_scaler = MinMaxScaler(clip=True)
    labels_days = []
    labels_surv = []
    for days, surv in zip(clin, suv_time):
        #     if days.replace("-", "") != "":
        #         days = float(days)
        #     else:
        #         days = 0.0
        labels_days.append(float(days))
        labels_surv.append(float(surv))

    # Train by batch
    dataset = feat_vecs
    #print(dataset.shape)
    labels_days = np.array(labels_days)
    labels_surv = np.array(labels_surv)

    censored_index = []
    uncensored_index = []
    for i in range(len(dataset)):
        if labels_surv[i] == 1:
            censored_index.append(i)
        else:
            uncensored_index.append(i)
    model = CoxPH(MyNet(edge_index).to(device), tt.optim.Adam(0.0001))

    censored_index = np.array(censored_index)
    uncensored_index = np.array(uncensored_index)

    # Each time test on a specific cancer type
    # total_cancers = ["TCGA-BRCA"]
    # for i in range(len(total_cancers)):
    # test_set = [d for t, d in zip(total_cancers, dataset) if t == total_cancers[i]]
    # train_set = [d for t, d in zip(total_cancers, dataset) if t != total_cancers[i]]

    # Censored split
    n = len(censored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        ctest_idx = index[fold * i:fold * i + i]
        ctrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        ctest_idx = index[fold * i:]
        ctrain_idx = index[:fold * i]
    ctrain_n = len(ctrain_idx)
    cvalid_n = ctrain_n // 10
    cvalid_idx = ctrain_idx[:cvalid_n]
    ctrain_idx = ctrain_idx[cvalid_n:]

    # Uncensored split
    n = len(uncensored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        utest_idx = index[fold * i:fold * i + i]
        utrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        utest_idx = index[fold * i:]
        utrain_idx = index[:fold * i]
    utrain_n = len(utrain_idx)
    uvalid_n = utrain_n // 10
    uvalid_idx = utrain_idx[:uvalid_n]
    utrain_idx = utrain_idx[uvalid_n:]

    train_idx = np.concatenate(censored_index[ctrain_idx],
                               uncensored_index[utrain_idx])
    np.random.seed(seed=0)
    np.random.shuffle(train_idx)
    valid_idx = np.concatenate(censored_index[cvalid_idx],
                               uncensored_index[uvalid_idx])
    np.random.seed(seed=0)
    np.random.shuffle(valid_idx)
    test_idx = np.concatenate(censored_index[ctest_idx],
                              uncensored_index[utest_idx])
    np.random.seed(seed=0)
    np.random.shuffle(test_idx)

    train_data = dataset[train_idx]
    train_data = min_max_scaler.fit_transform(train_data)
    train_labels_days = labels_days[train_idx]
    train_labels_surv = labels_surv[train_idx]
    train_labels = (train_labels_days, train_labels_surv)

    val_data = dataset[valid_idx]
    val_data = min_max_scaler.transform(val_data)
    val_labels_days = labels_days[valid_idx]
    val_labels_surv = labels_surv[valid_idx]
    test_data = dataset[test_idx]
    test_data = min_max_scaler.transform(test_data)
    test_labels_days = labels_days[test_idx]
    test_labels_surv = labels_surv[test_idx]
    val_labels = (val_labels_days, val_labels_surv)
    print(val_labels)

    callbacks = [tt.callbacks.EarlyStopping()]
    batch_size = 16
    epochs = 100
    val = (val_data, val_labels)
    log = model.fit(train_data,
                    train_labels,
                    batch_size,
                    epochs,
                    callbacks,
                    True,
                    val_data=val,
                    val_batch_size=batch_size)
    log.plot()
    plt.show()
    # print(model.partial_log_likelihood(*val).mean())
    train = train_data, train_labels
    # Compute the evaluation measurements
    model.compute_baseline_hazards(*train)
    surv = model.predict_surv_df(test_data)
    print(surv)
    ev = EvalSurv(surv, test_labels_days, test_labels_surv)
    print(ev.concordance_td())