Exemplo n.º 1
0
                                        init_num_durations))
                    time_elapsed_filename = model_filename[:-3] + '_time.txt'
                    print('*** Pre-training...')
                    # assert os.path.isfile(model_filename)
                    if not os.path.isfile(model_filename):
                        # print('*** Fitting with hyperparam:', hyperparam, flush=True)
                        surv_model.fit(X_train_std,
                                       y_train_discrete,
                                       init_batch_size,
                                       init_n_epochs,
                                       verbose=False)
                        elapsed = time.time() - tic
                        # print('Time elapsed: %f second(s)' % elapsed)
                        np.savetxt(time_elapsed_filename,
                                   np.array(elapsed).reshape(1, -1))
                        surv_model.save_net(model_filename)
                    else:
                        # print('*** Loading ***', flush=True)
                        surv_model.load_net(model_filename)
                        elapsed = float(np.loadtxt(time_elapsed_filename))
                        # print('Time elapsed (from previous fitting): %f second(s)'
                        #       % elapsed)
                    surv_model.net.train()

                    pretrain_time = elapsed

                    print('*** Fine-tuning with DKSA...')
                    tic = time.time()
                    torch.manual_seed(fine_tune_random_seed)
                    np.random.seed(fine_tune_random_seed)
                    net = nn.Sequential(*surv_model.net.net, nn.Softmax(1))
Exemplo n.º 2
0
def main():
    parser = setup_parser()
    args = parser.parse_args()

    if args.which_gpu != 'none':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpu

    # save setting
    if not os.path.exists(os.path.join(args.save_path, args.model_name)):
        os.mkdir(os.path.join(args.save_path, args.model_name))

    # label transform
    labtrans = DeepHitSingle.label_transform(args.durations)

    # data reading seeting
    singnal_data_path = args.signal_dataset_path
    table_path = args.table_path
    time_col = 'SurvivalDays'
    event_col = 'Mortality'

    # dataset
    data_pathes, times, events = read_dataset(singnal_data_path, table_path,
                                              time_col, event_col,
                                              args.sample_ratio)

    data_pathes_train, data_pathes_test, times_train, times_test, events_train, events_test = train_test_split(
        data_pathes, times, events, test_size=0.3, random_state=369)
    data_pathes_train, data_pathes_val, times_train, times_val, events_train, events_val = train_test_split(
        data_pathes_train,
        times_train,
        events_train,
        test_size=0.2,
        random_state=369)

    labels_train = label_transfer(times_train, events_train)
    target_train = labtrans.fit_transform(*labels_train)
    dataset_train = VsDatasetBatch(data_pathes_train, *target_train)
    dl_train = tt.data.DataLoaderBatch(dataset_train,
                                       args.train_batch_size,
                                       shuffle=True)

    labels_val = label_transfer(times_val, events_val)
    target_val = labtrans.transform(*labels_val)
    dataset_val = VsDatasetBatch(data_pathes_val, *target_val)
    dl_val = tt.data.DataLoaderBatch(dataset_val,
                                     args.train_batch_size,
                                     shuffle=True)

    labels_test = label_transfer(times_test, events_test)
    dataset_test_x = VsTestInput(data_pathes_test)
    dl_test_x = DataLoader(dataset_test_x, args.test_batch_size, shuffle=False)

    net = resnet18(args)
    model = DeepHitSingle(net,
                          tt.optim.Adam(lr=args.lr,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=5e-4,
                                        amsgrad=False),
                          duration_index=labtrans.cuts)
    # callbacks = [tt.cb.EarlyStopping(patience=15)]
    callbacks = [
        tt.cb.BestWeights(file_path=os.path.join(
            args.save_path, args.model_name, args.model_name + '_bestWeight'),
                          rm_file=False)
    ]
    verbose = True
    model_log = model.fit_dataloader(dl_train,
                                     args.epochs,
                                     callbacks,
                                     verbose,
                                     val_dataloader=dl_val)

    save_args(os.path.join(args.save_path, args.model_name), args)
    model_log.to_pandas().to_csv(os.path.join(args.save_path, args.model_name,
                                              'loss.csv'),
                                 index=False)
    model.save_net(
        path=os.path.join(args.save_path, args.model_name, args.model_name +
                          '_final'))
    surv = model.predict_surv_df(dl_test_x)
    surv.to_csv(os.path.join(args.save_path, args.model_name,
                             'test_sur_df.csv'),
                index=False)
    ev = EvalSurv(surv, *labels_test, 'km')
    print(ev.concordance_td())
    save_cindex(os.path.join(args.save_path, args.model_name),
                ev.concordance_td())
    print('done')