Example #1
0
def train_deepsurv(data_df, r_splits):
  epochs = 100
  verbose = True

  num_nodes = [32]
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])
    
    xcols = list(df_train.columns)

    for col_name in ["subject_id", "event", "duration"]:
      if col_name in xcols:
        xcols.remove(col_name)

    cols_standardize = xcols

    standardize = [([col], StandardScaler()) for col in cols_standardize]

    x_mapper = DataFrameMapper(standardize)

    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')
    x_test_30 =  x_mapper.transform(df_test_30).astype('float32')

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)

    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    in_features = x_train.shape[1]

    callbacks = [tt.callbacks.EarlyStopping()]

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, output_bias=output_bias)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256

    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365
            lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
            best = lrfinder.get_best_lr()

            model.optimizer.set_lr(best)
            
            epochs = args.epochs
            callbacks = [tt.callbacks.EarlyStopping(patience=patience)]
            verbose = True
            log = model.fit(x_train, y_train_transformed, batch_size, epochs, callbacks, verbose, val_data = val_transformed, val_batch_size = batch_size)

            # Evaluation ===================================================================
            val_loss = min(log.monitors['val_'].scores['loss']['score'])
            
            # get Ctd
            ctd = concordance_index(event_times = durations_test_transformed,
                                    predicted_scores = model.predict(x_test).reshape(-1),
                                    event_observed = events_test)
            
            # set time grid for numerical integration to get IBS and IBLL
            if durations_test.min()>0:
                time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
            else:
                durations_test_copy = durations_test.copy()
                durations_test_copy.sort()
                time_grid = np.linspace(durations_test_copy[1], durations_test.max(), 100)
            # time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
            # transform time grid into DSAFT scale for fair comparison
            # pdb.set_trace()
            time_grid = np.exp(scaler_train.transform(np.log(time_grid.reshape(-1, 1)))).reshape(-1)
            # grid interval for numerical integration
            ds = np.array(time_grid - np.array([0.0] + time_grid[:-1].tolist()))
Example #3
0
def train_LSTMCox(data_df, r_splits):
  epochs = 100
  verbose = True

  in_features = 768
  out_features = 1
  batch_norm = True
  dropout = 0.6
  output_bias = False

  c_index_at = []
  c_index_30 = []

  time_auc_30 = []
  time_auc_60 = []
  time_auc_365 = []

  for i in range(len(r_splits)):
    print("\nIteration %s"%(i))
    
    #DATA PREP
    df_train, df_val, df_test, df_test_30 = prepare_datasets(data_df, r_splits[i][2], r_splits[i][1], r_splits[i][0])

    x_train = np.array(df_train["x0"].tolist()).astype("float32")
    x_val = np.array(df_val["x0"].tolist()).astype("float32")
    x_test = np.array(df_test["x0"].tolist()).astype("float32")
    x_test_30 = np.array(df_test_30["x0"].tolist()).astype("float32")

    labtrans = CoxTime.label_transform()
    get_target = lambda df: (df['duration'].values, df['event'].values)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    durations_test, events_test = get_target(df_test)
    durations_test_30, events_test_30 = get_target(df_test_30)
    val = tt.tuplefy(x_val, y_val)
    
    (train_x, train_y), (val_x, val_y), (test_x, test_y), _ = df2array(data_df, df_train, df_val, df_test, df_test_30)

    #MODEL
    callbacks = [tt.callbacks.EarlyStopping()]

    net = LSTMCox(768, 32, 1, 1)

    model = CoxPH(net, tt.optim.Adam)
    model.optimizer.set_lr(0.0001)

    if x_train.shape[0] % 2:
      batch_size = 255
    else:
      batch_size = 256
      
    log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

    model.compute_baseline_hazards()

    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index_at.append(ev.concordance_td())

    surv_30 = model.predict_surv_df(x_test_30)
    ev_30 = EvalSurv(surv_30, durations_test_30, events_test_30, censor_surv='km')
    c_index_30.append(ev_30.concordance_td())

    for time_x in [30, 60, 365]:
      va_auc, va_mean_auc = cumulative_dynamic_auc(train_y, test_y, model.predict(x_test).flatten(), time_x)

      eval("time_auc_" + str(time_x)).append(va_auc[0])

    print("C-index_30:", c_index_30[i])
    print("C-index_AT:", c_index_at[i])

    print("time_auc_30", time_auc_30[i])
    print("time_auc_60", time_auc_60[i])
    print("time_auc_365", time_auc_365[i])

  return c_index_at, c_index_30, time_auc_30, time_auc_60, time_auc_365