def load_model(model_file, data, clinical, surv_time, edge_index): """The function for loading a pytorch model """ ############# m = MyNet(edge_index).to(device) model = CoxPH(m, tt.optim.Adam(0.0001)) #_, features = m(data) #print(features) model.load_net(model_file) prediction = model.predict_surv_df(data) #print(prediction) fs = features(model.net, torch.from_numpy(data).to(device)) #print(fs) #ev = EvalSurv(prediction, clinical, surv_time) #prediction = ev.concordance_td() return prediction, fs
batch_norm, dropout, output_bias=output_bias) optimizer = tt.optim.Adam(lr=lr) surv_model = CoxPH(net, optimizer) model_filename = \ os.path.join(output_dir, 'models', '%s_%s_exp%d_%s_bs%d_nep%d_nla%d_nno%d_lr%f_test.pt' % (survival_estimator_name, dataset, experiment_idx, val_string, batch_size, n_epochs, n_layers, n_nodes, lr)) assert os.path.isfile(model_filename) print('*** Loading ***', flush=True) surv_model.load_net(model_filename) surv_df = surv_model.predict_surv_df(X_test_std) surv = surv_df.to_numpy().T print() print('[Test data statistics]') sorted_y_test_times = np.sort(y_test[:, 0]) print('Quartiles:') print('- Min observed time:', np.min(y_test[:, 0])) print('- Q1 observed time:', sorted_y_test_times[int(0.25 * len(sorted_y_test_times))]) print('- Median observed time:', np.median(y_test[:, 0])) print('- Q3 observed time:', sorted_y_test_times[int(0.75 * len(sorted_y_test_times))]) print('- Max observed time:', np.max(y_test[:, 0]))