Example #1
0
 def test_single_numpy(self):
     input_data = {"img": np.array([[0, 1], [1, 2]])}
     result = ConcatItemsd(keys="img", name="cat_img")(input_data)
     result["cat_img"] += 1
     np.testing.assert_allclose(result["img"], np.array([[0, 1], [1, 2]]))
     np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2,
                                                                      3]]))
Example #2
0
 def test_single_tensor(self):
     input_data = {"img": torch.tensor([[0, 1], [1, 2]])}
     result = ConcatItemsd(keys="img", name="cat_img")(input_data)
     result["cat_img"] += 1
     torch.testing.assert_allclose(result["img"],
                                   torch.tensor([[0, 1], [1, 2]]))
     torch.testing.assert_allclose(result["cat_img"],
                                   torch.tensor([[1, 2], [2, 3]]))
Example #3
0
 def test_numpy_values(self):
     input_data = {
         "img1": np.array([[0, 1], [1, 2]]),
         "img2": np.array([[0, 1], [1, 2]])
     }
     result = ConcatItemsd(keys=["img1", "img2"],
                           name="cat_img")(input_data)
     self.assertTrue("cat_img" in result)
     result["cat_img"] += 1
     np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]]))
     np.testing.assert_allclose(result["cat_img"],
                                np.array([[1, 2], [2, 3], [1, 2], [2, 3]]))
Example #4
0
 def test_tensor_values(self):
     device = torch.device(
         "cuda:0") if torch.cuda.is_available() else torch.device("cpu:0")
     input_data = {
         "img1": torch.tensor([[0, 1], [1, 2]], device=device),
         "img2": torch.tensor([[0, 1], [1, 2]], device=device),
     }
     result = ConcatItemsd(keys=["img1", "img2"],
                           name="cat_img")(input_data)
     self.assertTrue("cat_img" in result)
     result["cat_img"] += 1
     torch.testing.assert_allclose(
         result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
     torch.testing.assert_allclose(
         result["cat_img"],
         torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
Example #5
0
def evaluta_model(test_files, model_name):
    test_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            ToTensord(keys=["inputs"]),
        ]
    )
    # create a validation data loader
    device = torch.device("cpu")
    print(len(test_files))
    test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
    test_loader = DataLoader(test_ds, batch_size=len(test_files), num_workers=2, pin_memory=torch.device)
    # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device)
    model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device)
    # Evaluate the model on test dataset #
    # print(os.path.basename(model_name).split('.')[0])
    checkpoint = torch.load(model_name)
    model.load_state_dict(checkpoint['model'])
    # optimizer.load_state_dict(checkpoint['optimizer'])
    # epochs = checkpoint['epoch']
    # model.load_state_dict(torch.load(log_dir))
    model.eval()
    with torch.no_grad():
        saver = CSVSaver(output_dir="../result/GLeason/2d_output/",
                         filename=os.path.basename(model_name).split('.')[0] + '.csv')
        for test_data in test_loader:
            test_images, test_labels = test_data["inputs"].to(device), test_data["label"].to(device)
            pred = model(test_images)  # Gleason Classification
            # y_soft_label = (test_labels / 0.25).long()
            # y_soft_pred = (pred / 0.25).round().squeeze_().long()
            # print(test_data)
            probabilities = torch.sigmoid(pred)
            # pred2 = model(test_images).argmax(dim=1)
            # print(test_data)
            # saver.save_batch(probabilities.argmax(dim=1), test_data["t2Img_meta_dict"])
            # zero = torch.zeros_like(probabilities)
            # one = torch.ones_like(probabilities)
            # y_pred_ordinal = torch.where(probabilities > 0.5, one, zero)
            # y_pred_acc = (y_pred_ordinal.sum(1)).to(torch.long)
            saver.save_batch(probabilities.argmax(dim=1), test_data["dwiImg_meta_dict"])
            # print(test_labels)
            # print(probabilities[:, 1])
            # for x in np.nditer(probabilities[:, 1]):
            #     print(x)
            #     prob_list.append(x)
        # falseList = []
        # trueList = []
        # for pre, label in zip(pred2.tolist(), test_labels.tolist() ):
        #     if pre == 0 and label == 0:
        #         falseList.append(0)
        #     elif pre == 1 and label == 1:
        #         trueList.append(1)
        # specificity = (falseList.count(0) / test_labels.tolist().count(0))
        # sensitivity = (trueList.count(1) / test_labels.tolist().count(1))
        # print('specificity:' + '%.4f' % specificity + ',',
        #       'sensitivity:' + '%.4f' % sensitivity + ',',
        #       'accuracy:' + '%.4f' % ((specificity + sensitivity) / 2))
        # print(type(test_labels), type(pred))
        # fpr, tpr, thresholds = roc_curve(test_labels, probabilities[:, 1])
        # roc_auc = auc(fpr, tpr)
        # print('AUC = ' + str(roc_auc))
        # AUC_list.append(roc_auc)
        # # print(accuracy_score(test_labels, pred2))
        # accuracy_list.append(accuracy_score(test_labels, pred2))
        # plt.plot(fpr, tpr, linewidth=2, label="ROC")
        # plt.xlabel("false presitive rate")
        # plt.ylabel("true presitive rate")
        # # plt.ylim(0, 1.05)
        # # plt.xlim(0, 1.05)
        # plt.legend(loc=4)  # 图例的位置
        # plt.show()
        saver.finalize()
        # cm = confusion_matrix(test_labels, y_pred_acc)
        cm = confusion_matrix(test_labels, probabilities.argmax(dim=1))
        # cm = confusion_matrix(y_soft_label, y_soft_pred)
        # kappa_value = cohen_kappa_score(test_labels, y_pred_acc, weights='quadratic')
        kappa_value = cohen_kappa_score(test_labels, probabilities.argmax(dim=1), weights='quadratic')
        print('quadratic weighted kappa=' + str(kappa_value))
        kappa_list.append(kappa_value)
        plot_confusion_matrix(cm, 'confusion_matrix.png', title='confusion matrix')
    from sklearn.metrics import classification_report
    print(classification_report(test_labels, probabilities.argmax(dim=1), digits=4))
    accuracy_list.append(
        classification_report(test_labels, probabilities.argmax(dim=1), digits=4, output_dict=True)["accuracy"])
Example #6
0
def training(train_files, val_files, log_dir):
    # Define transforms for image
    print(log_dir)
    train_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            RandRotate90d(keys=["inputs"], prob=0.8, spatial_axes=[0, 1]),
            RandAffined(keys=["inputs"], prob=0.8, scale_range=[0.1, 0.5]),
            RandZoomd(keys=["inputs"], prob=0.8, max_zoom=1.5, min_zoom=0.5),
            # RandFlipd(keys=["inputs"], prob=0.5, spatial_axis=1),
            ToTensord(keys=["inputs"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadNiftid(keys=modalDataKey),
            AddChanneld(keys=modalDataKey),
            NormalizeIntensityd(keys=modalDataKey),
            # ScaleIntensityd(keys=modalDataKey),
            ResizeWithPadOrCropd(keys=modalDataKey, spatial_size=(64, 64)),
            # Resized(keys=modalDataKey, spatial_size=(48, 48), mode='bilinear'),
            ConcatItemsd(keys=modalDataKey, name="inputs"),
            ToTensord(keys=["inputs"]),
        ]
    )
    # data_size = len(full_files)
    # split = data_size // 2
    # indices = list(range(data_size))
    # train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
    # valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])

    # full_loader = DataLoader(full_files, batch_size=64, sampler=sampler(full_files), pin_memory=True)
    # train_loader = DataLoader(full_files, batch_size=128, sampler=train_sampler, collate_fn=collate_fn)
    # val_loader = DataLoader(full_files, batch_size=split, sampler=valid_sampler, collate_fn=collate_fn)
    # DL = DataLoader(train_files, batch_size=64, shuffle=True, num_workers=0, drop_last=True, collate_fn=collate_fn)

    # randomBatch_sizeList = [8, 16, 32, 64, 128]
    # randomLRList = [1e-4, 1e-5, 5e-5, 5e-4, 1e-3]
    # batch_size = random.choice(randomBatch_sizeList)
    # lr = random.choice(randomLRList)
    lr = 0.01
    batch_size = 256
    # print(batch_size)
    # print(lr)
    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)
    check_data = monai.utils.misc.first(check_loader)
    # print(check_data)
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.device)
    # train_data = monai.utils.misc.first(train_loader)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=2, pin_memory=torch.device)

    # Create Net, CrossEntropyLoss and Adam optimizer
    # model = monai.networks.nets.se_resnet101(spatial_dims=2, in_ch=3, num_classes=6).to(device)
    # model = densenet121(spatial_dims=2, in_channels=3, out_channels=5).to(device)
    # im_size = (2,) + tuple(train_ds[0]["inputs"].shape)
    model = DenseNetASPP(spatial_dims=2, in_channels=2, out_channels=5).to(device)
    classes = np.array([0, 1, 2, 3, 4])
    # print(check_data["label"].numpy())
    class_weights = class_weight.compute_class_weight('balanced', classes, check_data["label"].numpy())
    class_weights_tensor = torch.Tensor(class_weights).to(device)
    # print(class_weights_tensor)
    # loss_function = nn.BCEWithLogitsLoss()
    loss_function = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)
    # loss_function = torch.nn.MSELoss()
    # m = torch.nn.LogSoftmax(dim=1)
    optimizer = torch.optim.Adam(model.parameters(), lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5, last_epoch=-1)
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')
    # start a typical PyTorch training
    epoch_num = 300
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    # checkpoint_interval = 100
    for epoch in range(start_epoch + 1, epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        # print(scheduler.get_last_lr())
        model.train()
        epoch_loss = 0
        step = 0
        # for i, (inputs, labels, imgName) in enumerate(train_loader):
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["inputs"].to(device), batch_data["label"].to(device)
            # batch_arr = []
            # for j in range(len(inputs)):
            #     batch_arr.append(inputs[i])
            # batch_img = Variable(torch.from_numpy(np.array(batch_arr)).to(device))
            # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
            # batch_img = batch_img.type(torch.FloatTensor).to(device)
            outputs = model(inputs)
            # y_ordinal_encoding = transformOrdinalEncoding(labels, labels.shape[0], 5)
            # loss = loss_function(outputs, torch.from_numpy(y_ordinal_encoding).to(device))
            loss = loss_function(outputs, labels.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(f"{step}/{len(train_loader)}, train_loss: {loss.item():.4f}")
            epoch_len = len(train_loader) // train_loader.batch_size
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        # if (epoch + 1) % checkpoint_interval == 0:  # 每隔checkpoint_interval保存一次
        #     checkpoint = {'model': model.state_dict(),
        #                   'optimizer': optimizer.state_dict(),
        #                   'epoch': epoch
        #                   }
        #     path_checkpoint = './model/checkpoint_{}_epoch.pth'.format(epoch)
        #     torch.save(checkpoint, path_checkpoint)
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                # for i, (inputs, labels, imgName) in enumerate(val_loader):
                for val_data in val_loader:
                    val_images, val_labels = val_data["inputs"].to(device), val_data["label"].to(device)
                    # val_batch_arr = []
                    # for j in range(len(inputs)):
                    #     val_batch_arr.append(inputs[i])
                    # val_img = Variable(torch.from_numpy(np.array(val_batch_arr)).to(device))
                    # labels = Variable(torch.from_numpy(np.array(labels)).to(device))
                    # val_img = val_img.type(torch.FloatTensor).to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                    # y_ordinal_encoding = transformOrdinalEncoding(y, y.shape[0], 5)
                    # y_pred = torch.sigmoid(y_pred)
                    # y = (y / 0.25).long()
                    # print(y)
                # auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, softmax=True)
                # zero = torch.zeros_like(y_pred)
                # one = torch.ones_like(y_pred)
                # y_pred_label = torch.where(y_pred > 0.5, one, zero)
                # print((y_pred_label.sum(1)).to(torch.long))
                # y_pred_acc = (y_pred_label.sum(1)).to(torch.long)
                # print(y_pred.argmax(dim=1))
                # kappa_value = kappa(cm)
                kappa_value = cohen_kappa_score(y.to("cpu"), y_pred.argmax(dim=1).to("cpu"), weights='quadratic')
                # kappa_value = cohen_kappa_score(y.to("cpu"), y_pred_acc.to("cpu"), weights='quadratic')
                metric_values.append(kappa_value)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                # print(acc_value)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if kappa_value > best_metric:
                    best_metric = kappa_value
                    best_metric_epoch = epoch + 1
                    checkpoint = {'model': model.state_dict(),
                                  'optimizer': optimizer.state_dict(),
                                  'epoch': epoch
                                  }
                    torch.save(checkpoint, log_dir)
                    print("saved new best metric model")
                print(
                    "current epoch: {} current Kappa: {:.4f} current accuracy: {:.4f} best Kappa: {:.4f} at epoch {}".format(
                        epoch + 1, kappa_value, acc_metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()
    plt.figure('train', (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.subplot(1, 2, 2)
    plt.title("Validation: Area under the ROC curve")
    x = [val_interval * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel('epoch')
    plt.plot(x, y)
    plt.show()
    evaluta_model(val_files, log_dir)