def test_single_numpy(self): input_data = {"img": np.array([[0, 1], [1, 2]])} result = ConcatItemsd(keys="img", name="cat_img")(input_data) result["cat_img"] += 1 np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]])) np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2, 3]]))
def test_single_tensor(self): input_data = {"img": torch.tensor([[0, 1], [1, 2]])} result = ConcatItemsd(keys="img", name="cat_img")(input_data) result["cat_img"] += 1 torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]])) torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
def test_numpy_values(self): input_data = { "img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]]) } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) self.assertTrue("cat_img" in result) result["cat_img"] += 1 np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]])) np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2, 3], [1, 2], [2, 3]]))
def test_tensor_values(self): device = torch.device( "cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") input_data = { "img1": torch.tensor([[0, 1], [1, 2]], device=device), "img2": torch.tensor([[0, 1], [1, 2]], device=device), } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) self.assertTrue("cat_img" in result) result["cat_img"] += 1 torch.testing.assert_allclose( result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) torch.testing.assert_allclose( result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
def evaluta_model(test_files, model_name): test_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # create a validation data loader device = torch.device("cpu") print(len(test_files)) test_ds = monai.data.Dataset(data=test_files, transform=test_transforms) test_loader = DataLoader(test_ds, batch_size=len(test_files), num_workers=2, pin_memory=torch.device) # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) # Evaluate the model on test dataset # # print(os.path.basename(model_name).split('.')[0]) checkpoint = torch.load(model_name) model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) # epochs = checkpoint['epoch'] # model.load_state_dict(torch.load(log_dir)) model.eval() with torch.no_grad(): saver = CSVSaver(output_dir="../result/GLeason/2d_output/", filename=os.path.basename(model_name).split('.')[0] + '.csv') for test_data in test_loader: test_images, test_labels = test_data["inputs"].to(device), test_data["label"].to(device) pred = model(test_images) # Gleason Classification # y_soft_label = (test_labels / 0.25).long() # y_soft_pred = (pred / 0.25).round().squeeze_().long() # print(test_data) probabilities = torch.sigmoid(pred) # pred2 = model(test_images).argmax(dim=1) # print(test_data) # saver.save_batch(probabilities.argmax(dim=1), test_data["t2Img_meta_dict"]) # zero = torch.zeros_like(probabilities) # one = torch.ones_like(probabilities) # y_pred_ordinal = torch.where(probabilities > 0.5, one, zero) # y_pred_acc = (y_pred_ordinal.sum(1)).to(torch.long) saver.save_batch(probabilities.argmax(dim=1), test_data["dwiImg_meta_dict"]) # print(test_labels) # print(probabilities[:, 1]) # for x in np.nditer(probabilities[:, 1]): # print(x) # prob_list.append(x) # falseList = [] # trueList = [] # for pre, label in zip(pred2.tolist(), test_labels.tolist() ): # if pre == 0 and label == 0: # falseList.append(0) # elif pre == 1 and label == 1: # trueList.append(1) # specificity = (falseList.count(0) / test_labels.tolist().count(0)) # sensitivity = (trueList.count(1) / test_labels.tolist().count(1)) # print('specificity:' + '%.4f' % specificity + ',', # 'sensitivity:' + '%.4f' % sensitivity + ',', # 'accuracy:' + '%.4f' % ((specificity + sensitivity) / 2)) # print(type(test_labels), type(pred)) # fpr, tpr, thresholds = roc_curve(test_labels, probabilities[:, 1]) # roc_auc = auc(fpr, tpr) # print('AUC = ' + str(roc_auc)) # AUC_list.append(roc_auc) # # print(accuracy_score(test_labels, pred2)) # accuracy_list.append(accuracy_score(test_labels, pred2)) # plt.plot(fpr, tpr, linewidth=2, label="ROC") # plt.xlabel("false presitive rate") # plt.ylabel("true presitive rate") # # plt.ylim(0, 1.05) # # plt.xlim(0, 1.05) # plt.legend(loc=4) # 图例的位置 # plt.show() saver.finalize() # cm = confusion_matrix(test_labels, y_pred_acc) cm = confusion_matrix(test_labels, probabilities.argmax(dim=1)) # cm = confusion_matrix(y_soft_label, y_soft_pred) # kappa_value = cohen_kappa_score(test_labels, y_pred_acc, weights='quadratic') kappa_value = cohen_kappa_score(test_labels, probabilities.argmax(dim=1), weights='quadratic') print('quadratic weighted kappa=' + str(kappa_value)) kappa_list.append(kappa_value) plot_confusion_matrix(cm, 'confusion_matrix.png', title='confusion matrix') from sklearn.metrics import classification_report print(classification_report(test_labels, probabilities.argmax(dim=1), digits=4)) accuracy_list.append( classification_report(test_labels, probabilities.argmax(dim=1), digits=4, output_dict=True)["accuracy"])
def training(train_files, val_files, log_dir): # Define transforms for image print(log_dir) train_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]), RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]), RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5), # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1), ToTensord(keys=["inputs"]), ] ) val_transforms = Compose( [ LoadNiftid(keys=modalDataKey), AddChanneld(keys=modalDataKey), NormalizeIntensityd(keys=modalDataKey), # ScaleIntensityd(keys=modalDataKey), ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)), # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'), ConcatItemsd(keys=modalDataKey, name="inputs"), ToTensord(keys=["inputs"]), ] ) # data_size = len(full_files) # split = data_size // 2 # indices = list(range(data_size)) # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True) # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn) # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn) # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn) # randomBatch_sizeList = [8, 16, 32, 64, 128] # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3] # batch_size = random.choice(randomBatch_sizeList) # lr = random.choice(randomLRList) lr = 0.01 batch_size = 256 # print(batch_size) # print(lr) # Define dataset, data loader check_ds = monai.data.Dataset(data=train_files, transform=train_transforms) check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) check_data = monai.utils.misc.first(check_loader) # print(check_data) # create a training data loader train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device) # train_data = monai.utils.misc.first(train_loader) # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device) # Create Net, CrossEntropyLoss and Adam optimizer # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device) # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device) # im_size = (2,) + tuple(train_ds[0]["inputs"].shape) model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device) classes = np.array([0, 1, 2, 3, 4]) # print(check_data["label"].numpy()) class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy()) class_weights_tensor = torch.Tensor(class_weights).to(device) # print(class_weights_tensor) # loss_function = nn.BCEWithLogitsLoss() loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor) # loss_function = torch.nn.MSELoss() # m = torch.nn.LogSoftmax(dim=1) optimizer = torch.optim.Adam(model.parameters(), lr) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1) # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(log_dir): checkpoint = torch.load(log_dir) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') # start a typical PyTorch training epoch_num = 300 val_interval = 2 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = list() metric_values = list() writer = SummaryWriter() # checkpoint_interval = 100 for epoch in range(start_epoch + 1, epoch_num): print("-" * 10) print(f"epoch {epoch + 1}/{epoch_num}") # print(scheduler.get_last_lr()) model.train() epoch_loss = 0 step = 0 # for i, (inputs, labels, imgName) in enumerate(train_loader): for batch_data in train_loader: step += 1 inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device) # batch_arr = [] # for j in range(len(inputs)): # batch_arr.append(inputs[i]) # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # batch_img = batch_img.type(torch.FloatTensor).to(device) outputs = model(inputs) # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5) # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device)) loss = loss_function(outputs, labels.long()) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}") epoch_len = len(train_loader) // train_loader.batch_size writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) epoch_loss /= step scheduler.step() print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0])) epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # if (epoch + 1) % checkpoint_interval == 0: # 每隔checkpoint_interval保存一次 # checkpoint = {'model': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'epoch': epoch # } # path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch) # torch.save(checkpoint, path_checkpoint) if (epoch + 1) % val_interval == 0: model.eval() with torch.no_grad(): y_pred = torch.tensor([], dtype=torch.float32, device=device) y = torch.tensor([], dtype=torch.long, device=device) # for i, (inputs, labels, imgName) in enumerate(val_loader): for val_data in val_loader: val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device) # val_batch_arr = [] # for j in range(len(inputs)): # val_batch_arr.append(inputs[i]) # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device)) # labels = Variable(torch.from_numpy(np.array(labels)).to(device)) # val_img = val_img.type(torch.FloatTensor).to(device) y_pred = torch.cat([y_pred, model(val_images)], dim=0) y = torch.cat([y, val_labels], dim=0) # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5) # y_pred = torch.sigmoid(y_pred) # y = (y / 0.25).long() # print(y) # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True) # zero = torch.zeros_like(y_pred) # one = torch.ones_like(y_pred) # y_pred_label = torch.where(y_pred > 0.5, one, zero) # print((y_pred_label.sum(1)).to(torch.long)) # y_pred_acc = (y_pred_label.sum(1)).to(torch.long) # print(y_pred.argmax(dim=1)) # kappa_value = kappa(cm) kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic') # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic') metric_values.append(kappa_value) acc_value = torch.eq(y_pred.argmax(dim=1), y) # print(acc_value) acc_metric = acc_value.sum().item() / len(acc_value) if kappa_value > best_metric: best_metric = kappa_value best_metric_epoch = epoch + 1 checkpoint = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(checkpoint, log_dir) print("saved new best metric model") print( "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format( epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch ) ) writer.add_scalar("val_accuracy", acc_metric, epoch + 1) print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") writer.close() plt.figure('train', (12, 6)) plt.subplot(1, 2, 1) plt.title("Epoch Average Loss") x = [i + 1 for i in range(len(epoch_loss_values))] y = epoch_loss_values plt.xlabel('epoch') plt.plot(x, y) plt.subplot(1, 2, 2) plt.title("Validation: Area under the ROC curve") x = [val_interval * (i + 1) for i in range(len(metric_values))] y = metric_values plt.xlabel('epoch') plt.plot(x, y) plt.show() evaluta_model(val_files, log_dir)