Пример #1
0
def Coxnnet_pipeline(mod1, mod2, x1, x2, y1, y2, hyperparameters, save, path):
    dense_size = hyperparameters[
        'Dense size']  #number of nodes in dense layers
    dropout_p = hyperparameters['Dropout']

    in_features_one = x1.shape[1]
    if mod2 != 'None':
        pass
    else:
        net = Coxnnet(in_features_one, dense_size, dropout_p).to(device)
    net.train()
    model = CoxPH(net, tt.optim.SGD)

    batch_size = hyperparameters['batch_size']
    epochs = hyperparameters['Epoch']
    verbose = True
    lr = hyperparameters['Learning rate']
    model.optimizer.set_lr(lr)
    #model.optimizer.set('momentum', 0.9)

    if mod2 != 'None':
        log = model.fit((x1, x2), (y1, y2),
                        batch_size,
                        epochs,
                        verbose=verbose)
    else:
        log = model.fit(x1, (y1, y2), batch_size, epochs, verbose=verbose)
    net.eval()
    if save:
        PATH = SAVE_FOLDER + "Coxnnet_" + mod1 + "+" + str(mod2) + "_" + str(
            epochs) + "_" + str(batch_size) + "_" + str(lr) + "_" + str(
                dense_size) + "_" + str(dropout_p)
        torch.save(net.state_dict(), SAVE_FOLDER + path + ".pt")
    return model, log
Пример #2
0
def VAESurv_pipeline(mod1, mod2, x1, x2, y1, y2, hyperparameters, save, path):
    d_dims = hyperparameters['D dims']
    dense_size = hyperparameters['Dense size']#number of nodes in dense layers
    latent_size = hyperparameters['Latent size'] # number of nodes (dimensionality) of encoded data
    neuron_size = hyperparameters['Neuron size'] # Dimensions for survival network
    dropout_p = hyperparameters['Dropout']

    in_features_one = x1.shape[1]
    if mod2 != 'None':
        in_features_two = x2.shape[1]
        net = VAESurv(in_features_one, in_features_two, d_dims, dense_size, latent_size, dropout_p, neuron_size, device).to(device)
    else:
        pass
    #Load pre-trained VAE
    if mod2 != 'None':
        PATH=hyperparameters['State file path']+mod1+'+'+mod2+'.pt'
    else:
        PATH=hyperparameters['State file path']+mod1+'.pt'
    net.load_state_dict(torch.load(PATH), strict=False)
    net.eval()
    
    for name, param in net.named_parameters():
        if not('surv_net' in name):
            param.requires_grad = False
    """    
    print('Trainable parameters:')
    for name, param in net.named_parameters():
        if (param.requires_grad):
            print(name)
    """
    net.train()
    model = CoxPH(net, tt.optim.Adam) #loss = 
    
    batch_size = hyperparameters['batch_size'] 
    epochs = hyperparameters['Epoch']
    verbose = True
    model.optimizer.set_lr(hyperparameters['Learning rate'])
    model.optimizer.set('weight_decay', hyperparameters['L2 reg'])
    if mod2 != 'None':
        log = model.fit((x1,x2),(y1,y2), batch_size, epochs, verbose=verbose)
    else:
        log = model.fit(x1,(y1,y2), batch_size, epochs, verbose=verbose)
    net.eval()
    if save:
        PATH = SAVE_FOLDER + "VAESurv_"+ path 
        torch.save(net.state_dict(), PATH)
    return model, log
Пример #3
0
        f'L{args.num_layers}N{args.num_nodes}D{args.dropout}W{args.weight_decay}B{args.batch_size}',
        config=args)

    wandb.watch(net)

    # Loss configuration ============================================================

    # Training ======================================================================

    epochs = args.epochs
    callbacks = [tt.callbacks.EarlyStopping()]
    verbose = True
    log = model.fit(x_train,
                    y_train,
                    batch_size,
                    epochs,
                    callbacks,
                    verbose,
                    val_data=val,
                    val_batch_size=batch_size)
    # log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size)

    # Evaluation ===================================================================

    _ = model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')

    # ctd = ev.concordance_td()
    ctd = ev.concordance_td()
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
Пример #4
0
def train_deepsurv(data_df, r_splits):
  epochs = 100
  verbose = True

  num_nodes = [32]
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])
    
    xcols = list(df_train.columns)

    for col_name in ["subject_id", "event", "duration"]:
      if col_name in xcols:
        xcols.remove(col_name)

    cols_standardize = xcols

    standardize = [([col], StandardScaler()) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    x_test_30 =  x_mapper.transform(df_test_30).astype('float32')

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)

    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    in_features = x_train.shape[1]

    callbacks = [tt.callbacks.EarlyStopping()]

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256

    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
Пример #5
0
def train_LSTMCox(data_df, r_splits):
  epochs = 100
  verbose = True

  in_features = 768
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])

    x_train = np.array(df_train["x0"].tolist()).astype("float32")
    x_val = np.array(df_val["x0"].tolist()).astype("float32")
    x_test = np.array(df_test["x0"].tolist()).astype("float32")
    x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32")

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)
    
    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    callbacks = [tt.callbacks.EarlyStopping()]

    net = LSTMCox(768, 32, 1, 1)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256
      
    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
def _train_dcph(x, t, e, folds):

  """Helper Function to train a deep-cox model (DeepSurv, Faraggi-Simon).

  Args:
    x:
      a numpy array of input features (Training Data).
    t:
      a numpy vector of event times (Training Data).
    e:
      a numpy vector of event indicators (1 if event occured, 0 otherwise)
      (Training Data).
    folds:
       vector of the training cv folds.

  Returns:
    Trained pycox.CoxPH model.

  """

  in_features = x.shape[1]
  num_nodes = [100, 100]
  out_features = 1
  batch_norm = False
  dropout = 0.0
  output_bias = False

  fold_model = {}

  for f in set(folds):

    xf = x[folds != f]
    tf = t[folds != f]
    ef = e[folds != f]

    validx = sorted(
        np.random.choice(len(xf), size=(int(0.15 * len(xf))), replace=False))

    vidx = np.array([False] * len(xf))
    vidx[validx] = True

    net = ttup.practical.MLPVanilla(
        in_features,
        num_nodes,
        out_features,
        batch_norm,
        dropout,
        output_bias=output_bias).double()

    model = CoxPH(net, torch.optim.Adam)

    y_train = (tf[~vidx], ef[~vidx])
    y_val = (tf[vidx], ef[vidx])
    val = xf[vidx], y_val

    batch_size = 256
    model.optimizer.set_lr(0.001)
    epochs = 20
    callbacks = [ttup.callbacks.EarlyStopping()]

    model.fit(
        xf[~vidx],
        y_train,
        batch_size,
        epochs,
        callbacks,
        True,
        val_data=val,
        val_batch_size=batch_size)
    model.compute_baseline_hazards()

    fold_model[f] = model

  return fold_model
Пример #7
0
                    model_filename = \
                        os.path.join(output_dir, 'models',
                                     '%s_%s_exp%d_bs%d_nep%d_nla%d_nno%d_'
                                     % (survival_estimator_name, dataset,
                                        experiment_idx, batch_size, n_epochs,
                                        n_layers, n_nodes)
                                     +
                                     'lr%f_cv%d.pt'
                                     % (lr, cross_val_idx))
                    time_elapsed_filename = model_filename[:-3] + '_time.txt'
                    if not os.path.isfile(model_filename):
                        # print('*** Fitting with hyperparam:', hyperparam,
                        #       '-- cross val index:', cross_val_idx, flush=True)
                        surv_model.fit(
                            fold_X_train_std,
                            (fold_y_train[:, 0], fold_y_train[:, 1]),
                            batch_size,
                            n_epochs,
                            verbose=False)
                        surv_model.compute_baseline_hazards()
                        elapsed = time.time() - tic
                        print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        # print('*** Loading ***', flush=True)
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        print('Time elapsed (from previous fitting): ' +
                              '%f second(s)' % elapsed)
Пример #8
0
class DeepSurv_pycox():
    def __init__(self,
                 layers,
                 nodes_per_layer,
                 dropout,
                 weight_decay,
                 batch_size,
                 lr=0.01,
                 seed=47):
        # set seed
        np.random.seed(seed)
        _ = torch.manual_seed(seed)
        self.standardalizer = None
        self.standardize_data = True

        self._duration_col = "duration"
        self._event_col = "event"

        self.in_features = None
        self.out_features = 1
        self.batch_norm = True
        self.output_bias = False
        self.activation = torch.nn.ReLU
        self.epochs = 512
        self.num_workers = 2
        self.callbacks = [tt.callbacks.EarlyStopping()]

        # parameters tuned
        self.num_nodes = [int(nodes_per_layer) for _ in range(int(layers))]
        self.dropout = dropout
        self.weight_decay = weight_decay
        self.lr = lr
        self.batch_size = int(batch_size)

    def set_standardize(self, standardize_bool):
        self.standardize_data = standardize_bool

    def _format_to_pycox(self, X, Y, F):
        # from numpy to pandas df
        df = pd.DataFrame(data=X, columns=F)
        if Y is not None:
            df[self._duration_col] = Y[:, 0]
            df[self._event_col] = Y[:, 1]
        return df

    def _standardize_df(self, df, flag):
        # if flag = test, the df passed in does not contain Y labels
        if self.standardize_data:
            df_x = df if flag == 'test' else df.drop(
                columns=[self._duration_col, self._event_col])
            if flag == "train":
                cols_leave = []
                cols_standardize = []
                for column in df_x.columns:
                    if set(pd.unique(df[column])) == set([0, 1]):
                        cols_leave.append(column)
                    else:
                        cols_standardize.append(column)
                standardize = [([col], StandardScaler())
                               for col in cols_standardize]
                leave = [(col, None) for col in cols_leave]
                self.standardalizer = DataFrameMapper(standardize + leave)

                x = self.standardalizer.fit_transform(df_x).astype('float32')
                y = (df[self._duration_col].values.astype('float32'),
                     df[self._event_col].values.astype('float32'))

            elif flag == "val":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = (df[self._duration_col].values.astype('float32'),
                     df[self._event_col].values.astype('float32'))

            elif flag == "test":
                x = self.standardalizer.transform(df_x).astype('float32')
                y = None

            else:
                raise NotImplementedError

            return x, y
        else:
            raise NotImplementedError

    def fit(self, X, y, column_names):
        # format data
        self.column_names = column_names
        full_df = self._format_to_pycox(X, y, self.column_names)
        val_df = full_df.sample(frac=0.2)
        train_df = full_df.drop(val_df.index)
        train_x, train_y = self._standardize_df(train_df, "train")
        val_x, val_y = self._standardize_df(val_df, "val")
        # configure model
        self.in_features = train_x.shape[1]
        net = tt.practical.MLPVanilla(in_features=self.in_features,
                                      num_nodes=self.num_nodes,
                                      out_features=self.out_features,
                                      batch_norm=self.batch_norm,
                                      dropout=self.dropout,
                                      activation=self.activation,
                                      output_bias=self.output_bias)
        self.model = CoxPH(
            net, tt.optim.Adam(lr=self.lr, weight_decay=self.weight_decay))
        # self.model.optimizer.set_lr(self.lr)

        n_train = train_x.shape[0]
        while n_train % self.batch_size == 1:  # this will cause issues in batch norm
            self.batch_size += 1

        self.model.fit(train_x,
                       train_y,
                       self.batch_size,
                       self.epochs,
                       self.callbacks,
                       verbose=True,
                       val_data=(val_x, val_y),
                       val_batch_size=self.batch_size,
                       num_workers=self.num_workers)
        self.model.compute_baseline_hazards()

    def predict(self, test_x, time_list):
        # format data
        test_df = self._format_to_pycox(test_x, None, self.column_names)
        test_x, _ = self._standardize_df(test_df, "test")

        proba_matrix_ = self.model.predict_surv_df(test_x)
        proba_matrix = np.transpose(proba_matrix_.values)
        pred_medians = []
        median_time = max(time_list)
        # if the predicted proba never goes below 0.5, predict the largest seen value
        for test_idx, survival_proba in enumerate(proba_matrix):
            # the survival_proba is in descending order
            for col, proba in enumerate(survival_proba):
                if proba > 0.5:
                    continue
                if proba == 0.5 or col == 0:
                    median_time = time_list[col]
                else:
                    median_time = (time_list[col - 1] + time_list[col]) / 2
                break
            pred_medians.append(median_time)

        return np.array(pred_medians), proba_matrix_
Пример #9
0
 model_filename = \
     os.path.join(output_dir, 'models',
                  '%s_%s_exp%d_%s_bs%d_nep%d_nla%d_nno%d_'
                  % (survival_estimator_name, dataset,
                     experiment_idx, val_string, batch_size,
                     n_epochs, n_layers, n_nodes)
                  +
                  'lr%f_fold%d.pt'
                  % (lr, fold_idx))
 time_elapsed_filename = model_filename[:-3] + '_time.txt'
 if not os.path.isfile(model_filename):
     if use_early_stopping and not use_cross_val:
         surv_model.fit(
             fold_X_train_std,
             (fold_y_train[:, 0], fold_y_train[:, 1]),
             batch_size,
             n_epochs, [tt.callbacks.EarlyStopping()],
             val_data=(fold_X_val_std, (fold_y_val[:, 0],
                                        fold_y_val[:, 1])),
             verbose=False)
     else:
         surv_model.fit(
             fold_X_train_std,
             (fold_y_train[:, 0], fold_y_train[:, 1]),
             batch_size,
             n_epochs,
             verbose=False)
     surv_model.compute_baseline_hazards()
     elapsed = time.time() - tic
     print('Time elapsed: %f second(s)' % elapsed)
     np.savetxt(time_elapsed_filename,
                np.array(elapsed).reshape(1, -1))
def main(data_root, cancer_type, anatomical_location, fold):

    # Import the RDF graph for PPI network
    f = open('seen.pkl', 'rb')
    seen = pickle.load(f)
    f.close()
    #####################

    f = open('ei.pkl', 'rb')
    ei = pickle.load(f)
    f.close()

    global device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')

    cancer_type_vector = np.zeros((33, ), dtype=np.float32)
    cancer_type_vector[cancer_type] = 1

    cancer_subtype_vector = np.zeros((25, ), dtype=np.float32)
    for i in CANCER_SUBTYPES[cancer_type]:
        cancer_subtype_vector[i] = 1

    anatomical_location_vector = np.zeros((52, ), dtype=np.float32)
    anatomical_location_vector[anatomical_location] = 1
    cell_type_vector = np.zeros((10, ), dtype=np.float32)
    cell_type_vector[CELL_TYPES[cancer_type]] = 1

    pt_tensor_cancer_type = torch.FloatTensor(cancer_type_vector).to(device)
    pt_tensor_cancer_subtype = torch.FloatTensor(cancer_subtype_vector).to(
        device)
    pt_tensor_anatomical_location = torch.FloatTensor(
        anatomical_location_vector).to(device)
    pt_tensor_cell_type = torch.FloatTensor(cell_type_vector).to(device)
    edge_index = torch.LongTensor(ei).to(device)

    # Import a dictionary that maps protiens to their coresponding genes by Ensembl database
    f = open('ens_dic.pkl', 'rb')
    dicty = pickle.load(f)
    f.close()
    dic = {}
    for d in dicty:
        key = dicty[d]
        if key not in dic:
            dic[key] = {}
        dic[key][d] = 1

    # Build a dictionary from ENSG -- ENST
    d = {}
    with open('data1/prot_names1.txt') as f:
        for line in f:
            tok = line.split()
            d[tok[1]] = tok[0]

    clin = [
    ]  # for clinical data (i.e. number of days to survive, days to death for dead patients and days to last followup for alive patients)
    feat_vecs = [
    ]  # list of lists ([[patient1],[patient2],.....[patientN]]) -- [patientX] = [gene_expression_value, diff_gene_expression_value, methylation_value, diff_methylation_value, VCF_value, CNV_value]
    suv_time = [
    ]  # list that include wheather a patient is alive or dead (i.e. 0 for dead and 1 for alive)
    can_types = ["BRCA_v2"]
    data_root = '/ibex/scratch/projects/c2014/sara/'
    for i in range(len(can_types)):
        # file that contain patients ID with their coressponding 6 differnt files names (i.e. files names for gene_expression, diff_gene_expression, methylation, diff_methylation, VCF and CNV)
        f = open(data_root + can_types[i] + '.txt')
        lines = f.read().splitlines()
        f.close()
        lines = lines[1:]
        count = 0
        feat_vecs = np.zeros((len(lines), 17186 * 6), dtype=np.float32)
        i = 0
        for l in tqdm(lines):
            l = l.split('\t')
            clinical_file = l[6]
            surv_file = l[2]
            myth_file = 'myth/' + l[3]
            diff_myth_file = 'diff_myth/' + l[1]
            exp_norm_file = 'exp_count/' + l[-1]
            diff_exp_norm_file = 'diff_exp/' + l[0]
            cnv_file = 'cnv/' + l[4] + '.txt'
            vcf_file = 'vcf/' + 'OutputAnnoFile_' + l[
                5] + '.hg38_multianno.txt.dat'
            # Check if all 6 files are exist for a patient (that's because for some patients, their survival time not reported)
            all_files = [
                myth_file, diff_exp_norm_file, diff_myth_file, exp_norm_file,
                cnv_file, vcf_file
            ]
            for fname in all_files:
                if not os.path.exists(fname):
                    print('File ' + fname + ' does not exist!')
                    sys.exit(1)
    #         f = open(clinical_file)
    #         content = f.read().strip()
    #         f.close()
            clin.append(clinical_file)
            #         f = open(surv_file)
            #         content = f.read().strip()
            #         f.close()
            suv_time.append(surv_file)
            temp_myth = myth_data(myth_file, seen, d, dic)
            vec = np.array(get_data(exp_norm_file, diff_exp_norm_file,
                                    diff_myth_file, cnv_file, vcf_file,
                                    temp_myth, seen, dic),
                           dtype=np.float32)
            vec = vec.flatten()
            #         vec = np.concatenate([
            #             vec, cancer_type_vector, cancer_subtype_vector,
            #             anatomical_location_vector, cell_type_vector])
            feat_vecs[i, :] = vec
            i += 1

    min_max_scaler = MinMaxScaler(clip=True)
    labels_days = []
    labels_surv = []
    for days, surv in zip(clin, suv_time):
        #     if days.replace("-", "") != "":
        #         days = float(days)
        #     else:
        #         days = 0.0
        labels_days.append(float(days))
        labels_surv.append(float(surv))

    # Train by batch
    dataset = feat_vecs
    #print(dataset.shape)
    labels_days = np.array(labels_days)
    labels_surv = np.array(labels_surv)

    censored_index = []
    uncensored_index = []
    for i in range(len(dataset)):
        if labels_surv[i] == 1:
            censored_index.append(i)
        else:
            uncensored_index.append(i)
    model = CoxPH(MyNet(edge_index).to(device), tt.optim.Adam(0.0001))

    censored_index = np.array(censored_index)
    uncensored_index = np.array(uncensored_index)

    # Each time test on a specific cancer type
    # total_cancers = ["TCGA-BRCA"]
    # for i in range(len(total_cancers)):
    # test_set = [d for t, d in zip(total_cancers, dataset) if t == total_cancers[i]]
    # train_set = [d for t, d in zip(total_cancers, dataset) if t != total_cancers[i]]

    # Censored split
    n = len(censored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        ctest_idx = index[fold * i:fold * i + i]
        ctrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        ctest_idx = index[fold * i:]
        ctrain_idx = index[:fold * i]
    ctrain_n = len(ctrain_idx)
    cvalid_n = ctrain_n // 10
    cvalid_idx = ctrain_idx[:cvalid_n]
    ctrain_idx = ctrain_idx[cvalid_n:]

    # Uncensored split
    n = len(uncensored_index)
    index = np.arange(n)
    i = n // 5
    np.random.seed(seed=0)
    np.random.shuffle(index)
    if fold < 4:
        utest_idx = index[fold * i:fold * i + i]
        utrain_idx = index[:fold * i] + index[fold * i + i:]
    else:
        utest_idx = index[fold * i:]
        utrain_idx = index[:fold * i]
    utrain_n = len(utrain_idx)
    uvalid_n = utrain_n // 10
    uvalid_idx = utrain_idx[:uvalid_n]
    utrain_idx = utrain_idx[uvalid_n:]

    train_idx = np.concatenate(censored_index[ctrain_idx],
                               uncensored_index[utrain_idx])
    np.random.seed(seed=0)
    np.random.shuffle(train_idx)
    valid_idx = np.concatenate(censored_index[cvalid_idx],
                               uncensored_index[uvalid_idx])
    np.random.seed(seed=0)
    np.random.shuffle(valid_idx)
    test_idx = np.concatenate(censored_index[ctest_idx],
                              uncensored_index[utest_idx])
    np.random.seed(seed=0)
    np.random.shuffle(test_idx)

    train_data = dataset[train_idx]
    train_data = min_max_scaler.fit_transform(train_data)
    train_labels_days = labels_days[train_idx]
    train_labels_surv = labels_surv[train_idx]
    train_labels = (train_labels_days, train_labels_surv)

    val_data = dataset[valid_idx]
    val_data = min_max_scaler.transform(val_data)
    val_labels_days = labels_days[valid_idx]
    val_labels_surv = labels_surv[valid_idx]
    test_data = dataset[test_idx]
    test_data = min_max_scaler.transform(test_data)
    test_labels_days = labels_days[test_idx]
    test_labels_surv = labels_surv[test_idx]
    val_labels = (val_labels_days, val_labels_surv)
    print(val_labels)

    callbacks = [tt.callbacks.EarlyStopping()]
    batch_size = 16
    epochs = 100
    val = (val_data, val_labels)
    log = model.fit(train_data,
                    train_labels,
                    batch_size,
                    epochs,
                    callbacks,
                    True,
                    val_data=val,
                    val_batch_size=batch_size)
    log.plot()
    plt.show()
    # print(model.partial_log_likelihood(*val).mean())
    train = train_data, train_labels
    # Compute the evaluation measurements
    model.compute_baseline_hazards(*train)
    surv = model.predict_surv_df(test_data)
    print(surv)
    ev = EvalSurv(surv, test_labels_days, test_labels_surv)
    print(ev.concordance_td())