Exemplo n.º 1
0
def load_model(model_file, data, clinical, surv_time, edge_index):
    """The function for loading a pytorch model
    """
    #############
    m = MyNet(edge_index).to(device)
    model = CoxPH(m, tt.optim.Adam(0.0001))
    #_, features = m(data)
    #print(features)
    model.load_net(model_file)
    prediction = model.predict_surv_df(data)
    #print(prediction)
    fs = features(model.net, torch.from_numpy(data).to(device))
    #print(fs)
    #ev = EvalSurv(prediction, clinical, surv_time)
    #prediction = ev.concordance_td()

    return prediction, fs
Exemplo n.º 2
0
    # Training ======================================================================

    epochs = args.epochs
    callbacks = [tt.callbacks.EarlyStopping()]
    verbose = True
    log = model.fit(x_train,
                    y_train,
                    batch_size,
                    epochs,
                    callbacks,
                    verbose,
                    val_data=val,
                    val_batch_size=batch_size)
    # log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size)

    # Evaluation ===================================================================

    _ = model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    # ctd = ev.concordance_td()
    ctd = ev.concordance_td()
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)

    ibs = ev.integrated_brier_score(time_grid)
    nbll = ev.integrated_nbll(time_grid)
    val_loss = min(log.monitors['val_'].scores['loss']['score'])
    wandb.log({'val_loss': val_loss, 'ctd': ctd, 'ibs': ibs, 'nbll': nbll})
    wandb.finish()
Exemplo n.º 3
0
def pycox_deep(filename, Y_train, Y_test, opt, choice):
    # choice = {'lr_rate': l, 'batch': b, 'decay': 0, 'weighted_decay': wd, 'net': net, 'index': index}
    X_train, X_test = enc_using_trained_ae(filename, TARGET=opt, ALPHA=0.01, N_ITER=100, L1R=-9999)
    path = './models/analysis/'
    check = 0
    savename = 'model_check_autoen_m5_test_batch+dropout+wd.csv'
    # r=root, d=directories, f = files
    for r, d, f in os.walk(path):
        for file in f:
            if savename in file:
                check = 1

    # X_train = X_train.drop('UR_SG3', axis=1)
    # X_test = X_test.drop('UR_SG3', axis=1)

    x_train = X_train
    x_test = X_test
    x_train['SVDTEPC_G'] = Y_train['SVDTEPC_G']
    x_train['PC_YN'] = Y_train['PC_YN']
    x_test['SVDTEPC_G'] = Y_test['SVDTEPC_G']
    x_test['PC_YN'] = Y_test['PC_YN']

    ## DataFrameMapper ##
    cols_standardize = list(X_train.columns)
    cols_standardize.remove('SVDTEPC_G')
    cols_standardize.remove('PC_YN')

    standardize = [(col, None) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    _ = x_mapper.fit_transform(X_train).astype('float32')
    X_train = x_mapper.transform(X_train).astype('float32')
    X_test = x_mapper.transform(X_test).astype('float32')

    get_target = lambda df: (df['SVDTEPC_G'].values, df['PC_YN'].values)
    y_train = get_target(x_train)
    durations_test, events_test = get_target(x_test)
    in_features = X_train.shape[1]
    print(in_features)
    num_nodes = choice['nodes']
    out_features = 1
    batch_norm = True  # False for batch_normalization
    dropout = 0.01
    output_bias = False
    # net = choice['net']
    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)


    print("training")
    model = CoxPH(net, tt.optim.Adam)
    # lrfinder = model.lr_finder(X_train, y_train, batch_size)

    # lr_best = lrfinder.get_best_lr()
    lr_best = 0.0001
    model.optimizer.set_lr(choice['lr_rate'])

    weighted_decay = choice['weighted_decay']
    verbose = True

    batch_size = choice['batch']
    epochs = 100

    if weighted_decay == 0:
        callbacks = [tt.callbacks.EarlyStopping(patience=epochs)]
        # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose=verbose)
    else:
        callbacks = [tt.callbacks.DecoupledWeightDecay(weight_decay=choice['decay'])]
        # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose)

    ''''''
    # dataloader = model.make_dataloader(tt.tuplefy(X_train, y_train),batch_size,True)
    datas = tt.tuplefy(X_train, y_train).to_tensor()
    print(datas)
    make_dataset = tt.data.DatasetTuple;
    DataLoader = tt.data.DataLoaderBatch
    dataset = make_dataset(*datas)
    dataloader = DataLoader(dataset, batch_size, False, sampler=StratifiedSampler(datas, batch_size))
    # dataloader = DataLoader(dataset,batch_size, True)
    model.fit_dataloader(dataloader, epochs, callbacks, verbose)
    # model.fit(X_train, y_train, batch_size, epochs, callbacks, verbose)
    # model.partial_log_likelihood(*val).mean()

    print("predicting")
    baseline_hazards = model.compute_baseline_hazards(datas[0], datas[1])
    baseline_hazards = df(baseline_hazards)

    surv = model.predict_surv_df(X_test)
    surv = 1 - surv
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    print("scoring")
    c_index = ev.concordance_td()
    print("c-index(", opt, "): ", c_index)

    if int(c_index * 10) == 0:
        hazardname = 'pycox_model_hazard_m5_v2_' + opt + '_0'
        netname = 'pycox_model_net_m5_v2_' + opt + '_0'
        weightname = 'pycox_model_weight_m5_v2_' + opt + '_0'
    else:
        hazardname = 'pycox_model_hazard_m5_' + opt + '_'
        netname = 'pycox_model_net_m5_' + opt + '_'
        weightname = 'pycox_model_weight_m5_' + opt + '_'

    baseline_hazards.to_csv('./test/'+hazardname + str(int(c_index * 100)) + '_' + str(index) + '.csv', index=False)
    netname = netname + str(int(c_index * 100)) + '_' + str(index) + '.sav'
    weightname = weightname + str(int(c_index * 100)) + '_' + str(index) + '.sav'
    model.save_net('./test/' + netname)
    model.save_model_weights('./test/' + weightname)

    pred = df(surv)
    pred = pred.transpose()
    surv_final = []
    pred_final = []

    for i in range(len(pred)):
        pred_final.append(float(1-pred[Y_test['SVDTEPC_G'][i]][i]))
        surv_final.append(float(pred[Y_test['SVDTEPC_G'][i]][i]))

    Y_test_cox = CoxformY(Y_test)
    #print(surv_final)
    c_cox, concordant, discordant,_,_ = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], surv_final)
    c_cox_pred = concordance_index_censored(Y_test_cox['PC_YN'], Y_test_cox['SVDTEPC_G'], pred_final)[0]
    print("c-index(", opt, ") - sksurv: ", round(c_cox, 4))
    print("cox-concordant(", opt, ") - sksurv: ", concordant)
    print("cox-disconcordant(", opt, ") - sksurv: ", discordant)
    print("c-index_pred(", opt, ") - sksurv: ", round(c_cox_pred, 4))

    fpr, tpr, _ = metrics.roc_curve(Y_test['PC_YN'], pred_final)
    auc = metrics.auc(fpr, tpr)
    print("auc(", opt, "): ", round(auc, 4))

    if check == 1:
        model_check = pd.read_csv(path+savename)
    else:
        model_check = df(columns=['option', 'gender', 'c-td', 'c-index', 'auc'])
    line_append = {'option':str(choice), 'gender':opt, 'c-td':round(c_index,4), 'c-index':round(c_cox_pred,4), 'auc':round(auc,4)}
    model_check = model_check.append(line_append, ignore_index=True)
    model_check.to_csv(path+savename, index=False)

    del X_train
    del X_test

    return surv_final
Exemplo n.º 4
0
def train_deepsurv(data_df, r_splits):
  epochs = 100
  verbose = True

  num_nodes = [32]
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])
    
    xcols = list(df_train.columns)

    for col_name in ["subject_id", "event", "duration"]:
      if col_name in xcols:
        xcols.remove(col_name)

    cols_standardize = xcols

    standardize = [([col], StandardScaler()) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    x_test_30 =  x_mapper.transform(df_test_30).astype('float32')

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)

    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    in_features = x_train.shape[1]

    callbacks = [tt.callbacks.EarlyStopping()]

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256

    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
Exemplo n.º 5
0
def train_LSTMCox(data_df, r_splits):
  epochs = 100
  verbose = True

  in_features = 768
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])

    x_train = np.array(df_train["x0"].tolist()).astype("float32")
    x_val = np.array(df_val["x0"].tolist()).astype("float32")
    x_test = np.array(df_test["x0"].tolist()).astype("float32")
    x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32")

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)
    
    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    callbacks = [tt.callbacks.EarlyStopping()]

    net = LSTMCox(768, 32, 1, 1)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256
      
    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
Exemplo n.º 6
0
                              output_bias=output_bias)

optimizer = tt.optim.Adam(lr=lr)

surv_model = CoxPH(net, optimizer)

model_filename = \
    os.path.join(output_dir, 'models',
                 '%s_%s_exp%d_%s_bs%d_nep%d_nla%d_nno%d_lr%f_test.pt'
                 % (survival_estimator_name, dataset, experiment_idx,
                    val_string, batch_size, n_epochs, n_layers, n_nodes, lr))
assert os.path.isfile(model_filename)
print('*** Loading ***', flush=True)
surv_model.load_net(model_filename)

surv_df = surv_model.predict_surv_df(X_test_std)
surv = surv_df.to_numpy().T

print()
print('[Test data statistics]')
sorted_y_test_times = np.sort(y_test[:, 0])
print('Quartiles:')
print('- Min observed time:', np.min(y_test[:, 0]))
print('- Q1 observed time:',
      sorted_y_test_times[int(0.25 * len(sorted_y_test_times))])
print('- Median observed time:', np.median(y_test[:, 0]))
print('- Q3 observed time:',
      sorted_y_test_times[int(0.75 * len(sorted_y_test_times))])
print('- Max observed time:', np.max(y_test[:, 0]))
print('Mean observed time:', np.mean(y_test[:, 0]))
print('Fraction censored:', 1. - np.mean(y_test[:, 1]))
Exemplo n.º 7
0
                            n_epochs,
                            verbose=False)
                        surv_model.compute_baseline_hazards()
                        elapsed = time.time() - tic
                        print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        # print('*** Loading ***', flush=True)
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        print('Time elapsed (from previous fitting): ' +
                              '%f second(s)' % elapsed)

                    surv_df = surv_model.predict_surv_df(fold_X_val_std)
                    ev = EvalSurv(surv_df,
                                  fold_y_val[:, 0],
                                  fold_y_val[:, 1],
                                  censor_surv='km')

                    sorted_fold_y_val = np.sort(np.unique(fold_y_val[:, 0]))
                    time_grid = np.linspace(sorted_fold_y_val[0],
                                            sorted_fold_y_val[-1], 100)

                    surv = surv_df.to_numpy().T

                    cindex_scores.append(ev.concordance_td('antolini'))
                    integrated_brier_scores.append(
                        ev.integrated_brier_score(time_grid))
                    print('  c-index (td):', cindex_scores[-1])
Exemplo n.º 8
0
def main():
    parser = setup_parser()
    args = parser.parse_args()

    if args.which_gpu != 'none':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpu

    # save setting
    if not os.path.exists(os.path.join(args.save_path, args.model_name)):
        os.mkdir(os.path.join(args.save_path, args.model_name))

    # data reading seeting
    singnal_data_path = args.signal_dataset_path
    table_path = args.table_path
    time_col = 'SurvivalDays'
    event_col = 'Mortality'

    # dataset
    data_pathes, times, events = read_dataset(singnal_data_path, table_path,
                                              time_col, event_col,
                                              args.sample_ratio)

    data_pathes_train, data_pathes_test, times_train, times_test, events_train, events_test = train_test_split(
        data_pathes, times, events, test_size=0.3, random_state=369)
    data_pathes_train, data_pathes_val, times_train, times_val, events_train, events_val = train_test_split(
        data_pathes_train,
        times_train,
        events_train,
        test_size=0.2,
        random_state=369)

    labels_train = label_transfer(times_train, events_train)
    dataset_train = VsDatasetBatch(data_pathes_train, *labels_train)
    dl_train = tt.data.DataLoaderBatch(dataset_train,
                                       args.train_batch_size,
                                       shuffle=True)

    labels_val = label_transfer(times_val, events_val)
    dataset_val = VsDatasetBatch(data_pathes_val, *labels_val)
    dl_val = tt.data.DataLoaderBatch(dataset_val,
                                     args.train_batch_size,
                                     shuffle=True)

    labels_test = label_transfer(times_test, events_test)
    dataset_test_x = VsTestInput(data_pathes_test)
    dl_test_x = DataLoader(dataset_test_x, args.test_batch_size, shuffle=False)

    net = resnet18(args)
    model = CoxPH(
        net,
        tt.optim.Adam(lr=args.lr,
                      betas=(0.9, 0.999),
                      eps=1e-08,
                      weight_decay=5e-4,
                      amsgrad=False))
    # callbacks = [tt.cb.EarlyStopping(patience=15)]
    callbacks = [
        tt.cb.BestWeights(file_path=os.path.join(
            args.save_path, args.model_name, args.model_name + '_bestWeight'),
                          rm_file=False)
    ]
    verbose = True
    model_log = model.fit_dataloader(dl_train,
                                     args.epochs,
                                     callbacks,
                                     verbose,
                                     val_dataloader=dl_val)

    save_args(os.path.join(args.save_path, args.model_name), args)
    model_log.to_pandas().to_csv(os.path.join(args.save_path, args.model_name,
                                              'loss.csv'),
                                 index=False)

    _ = model.compute_baseline_hazards(
        get_vs_data(dataset_train), (dataset_train.time, dataset_train.event))
    model.save_net(
        path=os.path.join(args.save_path, args.model_name, args.model_name +
                          '_final'))
    surv = model.predict_surv_df(dl_test_x)
    surv.to_csv(os.path.join(args.save_path, args.model_name,
                             'test_sur_df.csv'),
                index=False)
    ev = EvalSurv(surv, np.array(labels_test[0]), np.array(labels_test[1]),
                  'km')
    print(ev.concordance_td())
    save_cindex(os.path.join(args.save_path, args.model_name),
                ev.concordance_td())
    print('done')
Exemplo n.º 9
0
class DeepSurv_pycox():
    def __init__(self,
                 layers,
                 nodes_per_layer,
                 dropout,
                 weight_decay,
                 batch_size,
                 lr=0.01,
                 seed=47):
        # set seed
        np.random.seed(seed)
        _ = torch.manual_seed(seed)
        self.standardalizer = None
        self.standardize_data = True

        self._duration_col = "duration"
        self._event_col = "event"

        self.in_features = None
        self.out_features = 1
        self.batch_norm = True
        self.output_bias = False
        self.activation = torch.nn.ReLU
        self.epochs = 512
        self.num_workers = 2
        self.callbacks = [tt.callbacks.EarlyStopping()]

        # parameters tuned
        self.num_nodes = [int(nodes_per_layer) for _ in range(int(layers))]
        self.dropout = dropout
        self.weight_decay = weight_decay
        self.lr = lr
        self.batch_size = int(batch_size)

    def set_standardize(self, standardize_bool):
        self.standardize_data = standardize_bool

    def _format_to_pycox(self, X, Y, F):
        # from numpy to pandas df
        df = pd.DataFrame(data=X, columns=F)
        if Y is not None:
            df[self._duration_col] = Y[:, 0]
            df[self._event_col] = Y[:, 1]
        return df

    def _standardize_df(self, df, flag):
        # if flag = test, the df passed in does not contain Y labels
        if self.standardize_data:
            df_x = df if flag == 'test' else df.drop(
                columns=[self._duration_col, self._event_col])
            if flag == "train":
                cols_leave = []
                cols_standardize = []
                for column in df_x.columns:
                    if set(pd.unique(df[column])) == set([0, 1]):
                        cols_leave.append(column)
                    else:
                        cols_standardize.append(column)
                standardize = [([col], StandardScaler())
                               for col in cols_standardize]
                leave = [(col, None) for col in cols_leave]
                self.standardalizer = DataFrameMapper(standardize + leave)

                x = self.standardalizer.fit_transform(df_x).astype('float32')
                y = (df[self._duration_col].values.astype('float32'),
                     df[self._event_col].values.astype('float32'))

            elif flag == "val":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = (df[self._duration_col].values.astype('float32'),
                     df[self._event_col].values.astype('float32'))

            elif flag == "test":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = None

            else:
                raise NotImplementedError

            return x, y
        else:
            raise NotImplementedError

    def fit(self, X, y, column_names):
        # format data
        self.column_names = column_names
        full_df = self._format_to_pycox(X, y, self.column_names)
        val_df = full_df.sample(frac=0.2)
        train_df = full_df.drop(val_df.index)
        train_x, train_y = self._standardize_df(train_df, "train")
        val_x, val_y = self._standardize_df(val_df, "val")
        # configure model
        self.in_features = train_x.shape[1]
        net = tt.practical.MLPVanilla(in_features=self.in_features,
                                      num_nodes=self.num_nodes,
                                      out_features=self.out_features,
                                      batch_norm=self.batch_norm,
                                      dropout=self.dropout,
                                      activation=self.activation,
                                      output_bias=self.output_bias)
        self.model = CoxPH(
            net, tt.optim.Adam(lr=self.lr, weight_decay=self.weight_decay))
        # self.model.optimizer.set_lr(self.lr)

        n_train = train_x.shape[0]
        while n_train % self.batch_size == 1:  # this will cause issues in batch norm
            self.batch_size += 1

        self.model.fit(train_x,
                       train_y,
                       self.batch_size,
                       self.epochs,
                       self.callbacks,
                       verbose=True,
                       val_data=(val_x, val_y),
                       val_batch_size=self.batch_size,
                       num_workers=self.num_workers)
        self.model.compute_baseline_hazards()

    def predict(self, test_x, time_list):
        # format data
        test_df = self._format_to_pycox(test_x, None, self.column_names)
        test_x, _ = self._standardize_df(test_df, "test")

        proba_matrix_ = self.model.predict_surv_df(test_x)
        proba_matrix = np.transpose(proba_matrix_.values)
        pred_medians = []
        median_time = max(time_list)
        # if the predicted proba never goes below 0.5, predict the largest seen value
        for test_idx, survival_proba in enumerate(proba_matrix):
            # the survival_proba is in descending order
            for col, proba in enumerate(survival_proba):
                if proba > 0.5:
                    continue
                if proba == 0.5 or col == 0:
                    median_time = time_list[col]
                else:
                    median_time = (time_list[col - 1] + time_list[col]) / 2
                break
            pred_medians.append(median_time)

        return np.array(pred_medians), proba_matrix_
def main(data_root, cancer_type, anatomical_location, fold):

    # Import the RDF graph for PPI network
    f = open('seen.pkl', 'rb')
    seen = pickle.load(f)
    f.close()
    #####################

    f = open('ei.pkl', 'rb')
    ei = pickle.load(f)
    f.close()

    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    cancer_type_vector = np.zeros((33, ), dtype=np.float32)
    cancer_type_vector[cancer_type] = 1

    cancer_subtype_vector = np.zeros((25, ), dtype=np.float32)
    for i in CANCER_SUBTYPES[cancer_type]:
        cancer_subtype_vector[i] = 1

    anatomical_location_vector = np.zeros((52, ), dtype=np.float32)
    anatomical_location_vector[anatomical_location] = 1
    cell_type_vector = np.zeros((10, ), dtype=np.float32)
    cell_type_vector[CELL_TYPES[cancer_type]] = 1

    pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device)
    pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to(
        device)
    pt_tensor_anatomical_location = torch.FloatTensor(
        anatomical_location_vector).to(device)
    pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device)
    edge_index = torch.LongTensor(ei).to(device)

    # Import a dictionary that maps protiens to their coresponding genes by Ensembl database
    f = open('ens_dic.pkl', 'rb')
    dicty = pickle.load(f)
    f.close()
    dic = {}
    for d in dicty:
        key = dicty[d]
        if key not in dic:
            dic[key] = {}
        dic[key][d] = 1

    # Build a dictionary from ENSG -- ENST
    d = {}
    with open('data1/prot_names1.txt') as f:
        for line in f:
            tok = line.split()
            d[tok[1]] = tok[0]

    clin = [
    ]  # for clinical data (i.e. number of days to survive, days to death for dead patients and days to last followup for alive patients)
    feat_vecs = [
    ]  # list of lists ([[patient1],[patient2],.....[patientN]]) -- [patientX] = [gene_expression_value, diff_gene_expression_value, methylation_value, diff_methylation_value, VCF_value, CNV_value]
    suv_time = [
    ]  # list that include wheather a patient is alive or dead (i.e. 0 for dead and 1 for alive)
    can_types = ["BRCA_v2"]
    data_root = '/ibex/scratch/projects/c2014/sara/'
    for i in range(len(can_types)):
        # file that contain patients ID with their coressponding 6 differnt files names (i.e. files names for gene_expression, diff_gene_expression, methylation, diff_methylation, VCF and CNV)
        f = open(data_root + can_types[i] + '.txt')
        lines = f.read().splitlines()
        f.close()
        lines = lines[1:]
        count = 0
        feat_vecs = np.zeros((len(lines), 17186 * 6), dtype=np.float32)
        i = 0
        for l in tqdm(lines):
            l = l.split('\t')
            clinical_file = l[6]
            surv_file = l[2]
            myth_file = 'myth/' + l[3]
            diff_myth_file = 'diff_myth/' + l[1]
            exp_norm_file = 'exp_count/' + l[-1]
            diff_exp_norm_file = 'diff_exp/' + l[0]
            cnv_file = 'cnv/' + l[4] + '.txt'
            vcf_file = 'vcf/' + 'OutputAnnoFile_' + l[
                5] + '.hg38_multianno.txt.dat'
            # Check if all 6 files are exist for a patient (that's because for some patients, their survival time not reported)
            all_files = [
                myth_file, diff_exp_norm_file, diff_myth_file, exp_norm_file,
                cnv_file, vcf_file
            ]
            for fname in all_files:
                if not os.path.exists(fname):
                    print('File ' + fname + ' does not exist!')
                    sys.exit(1)
    #         f = open(clinical_file)
    #         content = f.read().strip()
    #         f.close()
            clin.append(clinical_file)
            #         f = open(surv_file)
            #         content = f.read().strip()
            #         f.close()
            suv_time.append(surv_file)
            temp_myth = myth_data(myth_file, seen, d, dic)
            vec = np.array(get_data(exp_norm_file, diff_exp_norm_file,
                                    diff_myth_file, cnv_file, vcf_file,
                                    temp_myth, seen, dic),
                           dtype=np.float32)
            vec = vec.flatten()
            #         vec = np.concatenate([
            #             vec, cancer_type_vector, cancer_subtype_vector,
            #             anatomical_location_vector, cell_type_vector])
            feat_vecs[i, :] = vec
            i += 1

    min_max_scaler = MinMaxScaler(clip=True)
    labels_days = []
    labels_surv = []
    for days, surv in zip(clin, suv_time):
        #     if days.replace("-", "") != "":
        #         days = float(days)
        #     else:
        #         days = 0.0
        labels_days.append(float(days))
        labels_surv.append(float(surv))

    # Train by batch
    dataset = feat_vecs
    #print(dataset.shape)
    labels_days = np.array(labels_days)
    labels_surv = np.array(labels_surv)

    censored_index = []
    uncensored_index = []
    for i in range(len(dataset)):
        if labels_surv[i] == 1:
            censored_index.append(i)
        else:
            uncensored_index.append(i)
    model = CoxPH(MyNet(edge_index).to(device), tt.optim.Adam(0.0001))

    censored_index = np.array(censored_index)
    uncensored_index = np.array(uncensored_index)

    # Each time test on a specific cancer type
    # total_cancers = ["TCGA-BRCA"]
    # for i in range(len(total_cancers)):
    # test_set = [d for t, d in zip(total_cancers, dataset) if t == total_cancers[i]]
    # train_set = [d for t, d in zip(total_cancers, dataset) if t != total_cancers[i]]

    # Censored split
    n = len(censored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        ctest_idx = index[fold * i:fold * i + i]
        ctrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        ctest_idx = index[fold * i:]
        ctrain_idx = index[:fold * i]
    ctrain_n = len(ctrain_idx)
    cvalid_n = ctrain_n // 10
    cvalid_idx = ctrain_idx[:cvalid_n]
    ctrain_idx = ctrain_idx[cvalid_n:]

    # Uncensored split
    n = len(uncensored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        utest_idx = index[fold * i:fold * i + i]
        utrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        utest_idx = index[fold * i:]
        utrain_idx = index[:fold * i]
    utrain_n = len(utrain_idx)
    uvalid_n = utrain_n // 10
    uvalid_idx = utrain_idx[:uvalid_n]
    utrain_idx = utrain_idx[uvalid_n:]

    train_idx = np.concatenate(censored_index[ctrain_idx],
                               uncensored_index[utrain_idx])
    np.random.seed(seed=0)
    np.random.shuffle(train_idx)
    valid_idx = np.concatenate(censored_index[cvalid_idx],
                               uncensored_index[uvalid_idx])
    np.random.seed(seed=0)
    np.random.shuffle(valid_idx)
    test_idx = np.concatenate(censored_index[ctest_idx],
                              uncensored_index[utest_idx])
    np.random.seed(seed=0)
    np.random.shuffle(test_idx)

    train_data = dataset[train_idx]
    train_data = min_max_scaler.fit_transform(train_data)
    train_labels_days = labels_days[train_idx]
    train_labels_surv = labels_surv[train_idx]
    train_labels = (train_labels_days, train_labels_surv)

    val_data = dataset[valid_idx]
    val_data = min_max_scaler.transform(val_data)
    val_labels_days = labels_days[valid_idx]
    val_labels_surv = labels_surv[valid_idx]
    test_data = dataset[test_idx]
    test_data = min_max_scaler.transform(test_data)
    test_labels_days = labels_days[test_idx]
    test_labels_surv = labels_surv[test_idx]
    val_labels = (val_labels_days, val_labels_surv)
    print(val_labels)

    callbacks = [tt.callbacks.EarlyStopping()]
    batch_size = 16
    epochs = 100
    val = (val_data, val_labels)
    log = model.fit(train_data,
                    train_labels,
                    batch_size,
                    epochs,
                    callbacks,
                    True,
                    val_data=val,
                    val_batch_size=batch_size)
    log.plot()
    plt.show()
    # print(model.partial_log_likelihood(*val).mean())
    train = train_data, train_labels
    # Compute the evaluation measurements
    model.compute_baseline_hazards(*train)
    surv = model.predict_surv_df(test_data)
    print(surv)
    ev = EvalSurv(surv, test_labels_days, test_labels_surv)
    print(ev.concordance_td())