def train(rundir, diagnosis, epochs, learning_rate, use_gpu): train_loader, valid_loader, test_loader = load_data(diagnosis, use_gpu) model = MRNet() if use_gpu: model = model.cuda() optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=.01) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=.3, threshold=1e-4) best_val_loss = float('inf') start_time = datetime.now() for epoch in range(epochs): change = datetime.now() - start_time print('starting epoch {}. time passed: {}'.format(epoch+1, str(change))) train_loss, train_auc, _, _ = run_model(model, train_loader, train=True, optimizer=optimizer) print(f'train loss: {train_loss:0.4f}') print(f'train AUC: {train_auc:0.4f}') val_loss, val_auc, _, _ = run_model(model, valid_loader) print(f'valid loss: {val_loss:0.4f}') print(f'valid AUC: {val_auc:0.4f}') scheduler.step(val_loss) if val_loss < best_val_loss: best_val_loss = val_loss file_name = f'val{val_loss:0.4f}_train{train_loss:0.4f}_epoch{epoch+1}' save_path = Path(rundir) / file_name torch.save(model.state_dict(), save_path)
def evaluate(split, model_path, diagnosis, use_gpu): train_loader, valid_loader, test_loader = load_data(diagnosis, use_gpu) model = MRNet() state_dict = torch.load(model_path, map_location=(None if use_gpu else 'cpu')) model.load_state_dict(state_dict) if use_gpu: model = model.cuda() if split == 'train': loader = train_loader elif split == 'valid': loader = valid_loader elif split == 'test': loader = test_loader else: raise ValueError("split must be 'train', 'valid', or 'test'") loss, auc, preds, labels = run_model(model, loader) print(f'{split} loss: {loss:0.4f}') print(f'{split} AUC: {auc:0.4f}') return preds, labels
def evaluate(split, model_path, diagnosis, dataset, use_gpu, attention): preds = None labels = None if dataset == 0: train_loader, valid_loader, test_loader = external_load_data(diagnosis, use_gpu) #model = MRNet(useMultiHead = attention) #state_dict = torch.load(model_path, map_location=(None if use_gpu else 'cpu')) #model.load_state_dict(state_dict) model = torch.load(model_path) if use_gpu: model = model.cuda() if split == 'train': loader = train_loader elif split == 'valid': loader = valid_loader elif split == 'test': loader = test_loader else: raise ValueError("split must be 'train', 'valid', or 'test'") loss, auc, preds, labels = run_model(model, loader) print(f'{split} loss: {loss:0.4f}') print(f'{split} AUC: {auc:0.4f}') if dataset == 1: train_loaders, valid_loaders = mr_load_data(diagnosis, use_gpu, train_shuffle = True) path_s = os.listdir(model_path + '/sagittal') path_a = os.listdir(model_path + '/axial') path_c = os.listdir(model_path + '/coronal') ps = [int(x.split("epoch")[1]) for x in path_s] pa = [int(x.split("epoch")[1]) for x in path_a] pc = [int(x.split("epoch")[1]) for x in path_c] model_path_sag = path_s[ps.index(max(ps))] model_path_ax = path_a[pa.index(max(pa))] model_path_cor = path_c[pc.index(max(pc))] print("{} {} {}".format(model_path_sag, model_path_ax, model_path_cor)) state_dict_sag = torch.load(model_path + '/sagittal/' + model_path_sag, map_location=(None if use_gpu else 'cpu')) state_dict_ax = torch.load(model_path + '/axial/' + model_path_ax, map_location=(None if use_gpu else 'cpu')) state_dict_cor = torch.load(model_path + '/coronal/' + model_path_cor, map_location=(None if use_gpu else 'cpu')) model_sag = MRNet(useMultiHead=attention, max_layers=51) model_sag.load_state_dict(state_dict_sag) model_ax = MRNet(useMultiHead=attention, max_layers=61) model_ax.load_state_dict(state_dict_ax) model_cor = MRNet(useMultiHead=attention, max_layers=58) model_cor.load_state_dict(state_dict_cor) #model_sag = torch.load(model_path + '/sagittal/' + model_path_sag) #model_ax = torch.load(model_path + '/axial/' + model_path_ax) #model_cor = torch.load(model_path + '/coronal/' + model_path_cor) if use_gpu: model_sag = model_sag.cuda() model_ax = model_ax.cuda() model_cor = model_cor.cuda() loss_sag, auc_sag, t_preds_sag, labels_sag = run_model(model_sag, train_loaders[0]) _, _, preds_sag, _ = run_model(model_sag, valid_loaders[0]) print(f'sagittal {split} loss: {loss_sag:0.4f}') print(f'sagittal {split} AUC: {auc_sag:0.4f}') loss_ax, auc_ax, t_preds_ax, labels_ax = run_model(model_ax, train_loaders[1]) _, _, preds_ax, _ = run_model(model_ax, valid_loaders[1]) print(f'axial {split} loss: {loss_ax:0.4f}') print(f'axial {split} AUC: {auc_ax:0.4f}') loss_cor, auc_cor, t_preds_cor, labels_cor = run_model(model_cor, train_loaders[2]) _, _, preds_cor, valid_labels = run_model(model_cor, valid_loaders[2]) print(f'coronal {split} loss: {loss_cor:0.4f}') print(f'coronal {split} AUC: {auc_cor:0.4f}') X = np.zeros((len(t_preds_cor), 3)) X[:, 0] = t_preds_sag X[:, 1] = t_preds_ax X[:, 2] = t_preds_cor y = np.array(labels_cor) lgr = LogisticRegression(solver='lbfgs') lgr.fit(X,y) X_valid = np.zeros((len(preds_cor), 3)) X_valid[:, 0] = preds_sag X_valid[:, 1] = preds_ax X_valid[:, 2] = preds_cor y_preds = lgr.predict(X_valid) y_true = np.array(valid_labels) print(metrics.roc_auc_score(y_true, y_preds)) print(metrics.classification_report(y_true, y_preds, target_names=['class 0', 'class 1'])) return preds, labels