% (survival_estimator_name, dataset,
                    experiment_idx, val_string)
                 +
                 'bs%d_nep%d_nla%d_nno%d_'
                 % (batch_size, n_epochs, n_layers, n_nodes)
                 +
                 'lr%f_nd%d_test.pt'
                 % (lr, num_durations))
assert os.path.isfile(model_filename)
print('*** Loading ***', flush=True)
surv_model.load_net(model_filename)

if num_durations > 0:
    surv_df = surv_model.interpolate(10).predict_surv_df(X_test_std)
else:
    surv_df = surv_model.predict_surv_df(X_test_std)
surv = surv_df.to_numpy().T

print()
print('[Test data statistics]')
sorted_y_test_times = np.sort(y_test[:, 0])
print('Quartiles:')
print('- Min observed time:', np.min(y_test[:, 0]))
print('- Q1 observed time:',
      sorted_y_test_times[int(0.25 * len(sorted_y_test_times))])
print('- Median observed time:', np.median(y_test[:, 0]))
print('- Q3 observed time:',
      sorted_y_test_times[int(0.75 * len(sorted_y_test_times))])
print('- Max observed time:', np.max(y_test[:, 0]))
print('Mean observed time:', np.mean(y_test[:, 0]))
print('Fraction censored:', 1. - np.mean(y_test[:, 1]))
Beispiel #2
0
def main():
    parser = setup_parser()
    args = parser.parse_args()

    if args.which_gpu != 'none':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpu

    # save setting
    if not os.path.exists(os.path.join(args.save_path, args.model_name)):
        os.mkdir(os.path.join(args.save_path, args.model_name))

    # label transform
    labtrans = DeepHitSingle.label_transform(args.durations)

    # data reading seeting
    singnal_data_path = args.signal_dataset_path
    table_path = args.table_path
    time_col = 'SurvivalDays'
    event_col = 'Mortality'

    # dataset
    data_pathes, times, events = read_dataset(singnal_data_path, table_path,
                                              time_col, event_col,
                                              args.sample_ratio)

    data_pathes_train, data_pathes_test, times_train, times_test, events_train, events_test = train_test_split(
        data_pathes, times, events, test_size=0.3, random_state=369)
    data_pathes_train, data_pathes_val, times_train, times_val, events_train, events_val = train_test_split(
        data_pathes_train,
        times_train,
        events_train,
        test_size=0.2,
        random_state=369)

    labels_train = label_transfer(times_train, events_train)
    target_train = labtrans.fit_transform(*labels_train)
    dataset_train = VsDatasetBatch(data_pathes_train, *target_train)
    dl_train = tt.data.DataLoaderBatch(dataset_train,
                                       args.train_batch_size,
                                       shuffle=True)

    labels_val = label_transfer(times_val, events_val)
    target_val = labtrans.transform(*labels_val)
    dataset_val = VsDatasetBatch(data_pathes_val, *target_val)
    dl_val = tt.data.DataLoaderBatch(dataset_val,
                                     args.train_batch_size,
                                     shuffle=True)

    labels_test = label_transfer(times_test, events_test)
    dataset_test_x = VsTestInput(data_pathes_test)
    dl_test_x = DataLoader(dataset_test_x, args.test_batch_size, shuffle=False)

    net = resnet18(args)
    model = DeepHitSingle(net,
                          tt.optim.Adam(lr=args.lr,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=5e-4,
                                        amsgrad=False),
                          duration_index=labtrans.cuts)
    # callbacks = [tt.cb.EarlyStopping(patience=15)]
    callbacks = [
        tt.cb.BestWeights(file_path=os.path.join(
            args.save_path, args.model_name, args.model_name + '_bestWeight'),
                          rm_file=False)
    ]
    verbose = True
    model_log = model.fit_dataloader(dl_train,
                                     args.epochs,
                                     callbacks,
                                     verbose,
                                     val_dataloader=dl_val)

    save_args(os.path.join(args.save_path, args.model_name), args)
    model_log.to_pandas().to_csv(os.path.join(args.save_path, args.model_name,
                                              'loss.csv'),
                                 index=False)
    model.save_net(
        path=os.path.join(args.save_path, args.model_name, args.model_name +
                          '_final'))
    surv = model.predict_surv_df(dl_test_x)
    surv.to_csv(os.path.join(args.save_path, args.model_name,
                             'test_sur_df.csv'),
                index=False)
    ev = EvalSurv(surv, *labels_test, 'km')
    print(ev.concordance_td())
    save_cindex(os.path.join(args.save_path, args.model_name),
                ev.concordance_td())
    print('done')
                        print('Time elapsed (from previous fitting): ' +
                              '%f second(s)' % elapsed)

                    fine_tune_time = elapsed

                    total_time = pretrain_time + fine_tune_time
                    print('Total time: %f second(s)' % total_time)
                    time_elapsed_filename = model_filename[:-3] + '_total_time.txt'
                    np.savetxt(time_elapsed_filename,
                               np.array(total_time).reshape(1, -1))

                    if num_durations > 0:
                        surv_df = \
                            surv_model.interpolate(10).predict_surv_df(fold_X_val_std)
                    else:
                        surv_df = surv_model.predict_surv_df(fold_X_val_std)
                    ev = EvalSurv(surv_df,
                                  fold_y_val[:, 0],
                                  fold_y_val[:, 1],
                                  censor_surv='km')

                    sorted_fold_y_val = np.sort(np.unique(fold_y_val[:, 0]))
                    time_grid = np.linspace(sorted_fold_y_val[0],
                                            sorted_fold_y_val[-1], 100)

                    surv = surv_df.to_numpy().T

                    cindex_scores.append(ev.concordance_td('antolini'))
                    integrated_brier_scores.append(
                        ev.integrated_brier_score(time_grid))
                    print('  c-index (td):', cindex_scores[-1])
lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=3)
_ = lr_finder.plot()

lr_finder.get_best_lr()

model.optimizer.set_lr(0.01)

# Training with best learning rate:
epochs = 100
callbacks = [tt.callbacks.EarlyStopping()]
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

_ = log.plot()

# Prediction:
surv = model.predict_surv_df(x_test)

surv.iloc[:, :5].plot(drawstyle='steps-post')
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

# Interpolating the survival estimates because the survival estimates so far are
# only defined at the 10 times in the discretization grid and the survival
# estimates are therefore a step function rather than a continuous one
surv = model.interpolate(10).predict_surv_df(x_test)

surv.iloc[:, :5].plot(drawstyle='steps-post')
plt.ylabel('S(t | x)')
_ = plt.xlabel('Time')

# The EvalSurv class contains some useful evaluation criteria for time-to-event prediction.