def test_shape(self, input_param, input_data, expected_shape,
                expected_last):
     xform = RandSpatialCropSamplesd(**input_param)
     xform.set_random_state(1234)
     result = xform(input_data)
     for item, expected in zip(result, expected_shape):
         self.assertTupleEqual(item["img"].shape, expected)
         self.assertTupleEqual(item["seg"].shape, expected)
     np.testing.assert_allclose(item["img"], expected_last["img"])
     np.testing.assert_allclose(item["seg"], expected_last["seg"])
Exemple #2
0
 def test_shape(self, input_param, input_data, expected_shape, expected_last):
     xform = RandSpatialCropSamplesd(**input_param)
     xform.set_random_state(1234)
     result = xform(input_data)
     for item, expected in zip(result, expected_shape):
         self.assertTupleEqual(item["img"].shape, expected)
         self.assertTupleEqual(item["seg"].shape, expected)
     for i, item in enumerate(result):
         self.assertEqual(item["img"].meta["patch_index"], i)
         self.assertEqual(item["seg"].meta["patch_index"], i)
     assert_allclose(item["img"], expected_last["img"], type_test=False)
     assert_allclose(item["seg"], expected_last["seg"], type_test=False)
    def test_samples(self):
        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
        keys = "image"
        xforms = Compose(
            [
                LoadImaged(keys=keys),
                AddChanneld(keys=keys),
                ScaleIntensityd(keys=keys),
                RandSpatialCropSamplesd(keys=keys, roi_size=(8, 8, 5), random_size=True, num_samples=10),
            ]
        )
        image_path = os.path.join(testing_dir, "anatomical.nii")
        xforms.set_random_state(0)
        ims = xforms({keys: image_path})
        fig, mat = matshow3d(
            [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False
        )
        self.assertTrue(mat.dtype == np.float32)

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/matshow3d_patch_test.png"
            fig.savefig(tempimg)
            comp = compare_images(f"{testing_dir}/matshow3d_patch_test.png", tempimg, 5e-2, in_decorator=True)
            if comp:
                print("not none comp: ", comp)  # matplotlib 3.2.2
                np.testing.assert_allclose(comp["rms"], 30.786983, atol=1e-3, rtol=1e-3)
            else:
                self.assertIsNone(comp, f"value of comp={comp}")  # None indicates test passed
Exemple #4
0
 def test_deep_copy(self):
     data = {"img": np.ones((1, 10, 11, 12))}
     num_samples = 3
     sampler = RandSpatialCropSamplesd(
         keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False
     )
     transform = Compose([DivisiblePadd(keys="img", k=5), sampler])
     samples = transform(data)
     self.assertEqual(len(samples), num_samples)
     for sample in samples:
         self.assertEqual(len(sample["img"].applied_operations), len(transform))
 def test_deep_copy(self):
     data = {"img": np.ones((1, 10, 11, 12))}
     num_samples = 3
     sampler = RandSpatialCropSamplesd(
         keys=["img"], roi_size=(3, 3, 3), num_samples=num_samples, random_center=True, random_size=False
     )
     transform = Compose([ToTensord(keys="img"), sampler])
     samples = transform(data)
     self.assertEqual(len(samples), num_samples)
     for sample in samples:
         self.assertEqual(len(sample["img_transforms"]), len(transform))
Exemple #6
0
def main():

    #TODO Defining file paths & output directory path
    json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json')
    data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19')
    logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467')

    if os.path.exists(logdir_path)==False:
        os.mkdir(logdir_path)

    # Load Json & Append Root Path
    with open(json_Path, 'r') as json_f:
        json_Data = json.load(json_f)

    train_Data = json_Data['training']
    val_Data = json_Data['validation']

    for idx, each_d in enumerate(train_Data):
        train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image'])

    for idx, each_d in enumerate(val_Data):
        val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image'])

    print('Total Number of Training Data Samples: {}'.format(len(train_Data)))
    print(train_Data)
    print('#' * 10)
    print('Total Number of Validation Data Samples: {}'.format(len(val_Data)))
    print(val_Data)
    print('#' * 10)

    # Set Determinism
    set_determinism(seed=123)

    # Define Training Transforms
    train_Transforms = Compose(
        [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(
            2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
        ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8)
        ]
    )

    check_ds = Dataset(data=train_Data, transform=train_Transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader)
    image = (check_data["image"][0][0])
    print(f"image shape: {image.shape}")

    # Define Network ViT backbone & Loss & Optimizer
    device = torch.device("cuda:0")
    model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed='conv',
                hidden_size=768,
                mlp_dim=3072,
    )

    model = model.to(device)

    # Define Hyper-paramters for training loop
    max_epochs = 500
    val_interval = 2
    batch_size = 4
    lr = 1e-4
    epoch_loss_values = []
    step_loss_values = []
    epoch_cl_loss_values = []
    epoch_recon_loss_values = []
    val_loss_values = []
    best_val_loss = 1000.0

    recon_loss = L1Loss()
    contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define DataLoader using MONAI, CacheDataset needs to be used
    train_ds = Dataset(data=train_Data, transform=train_Transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    val_ds = Dataset(data=val_Data, transform=train_Transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        epoch_cl_loss = 0
        epoch_recon_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            start_time = time.time()

            inputs, inputs_2, gt_input = (
                batch_data["image"].to(device),
                batch_data["image_2"].to(device),
                batch_data["gt_image"].to(device),
            )
            optimizer.zero_grad()
            outputs_v1, hidden_v1 = model(inputs)
            outputs_v2, hidden_v2 = model(inputs_2)

            flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
            flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

            r_loss = recon_loss(outputs_v1, gt_input)
            cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

            # Adjust the CL loss by Recon Loss
            total_loss = r_loss + cl_loss * r_loss

            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()
            step_loss_values.append(total_loss.item())

            # CL & Recon Loss Storage of Value
            epoch_cl_loss += cl_loss.item()
            epoch_recon_loss += r_loss.item()

            end_time = time.time()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {total_loss.item():.4f}, "
                f"time taken: {end_time-start_time}s")

        epoch_loss /= step
        epoch_cl_loss /= step
        epoch_recon_loss /= step

        epoch_loss_values.append(epoch_loss)
        epoch_cl_loss_values.append(epoch_cl_loss)
        epoch_recon_loss_values.append(epoch_recon_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if epoch % val_interval == 0:
            print('Entering Validation for epoch: {}'.format(epoch+1))
            total_val_loss = 0
            val_step = 0
            model.eval()
            for val_batch in val_loader:
                val_step += 1
                start_time = time.time()
                inputs, gt_input = (
                    val_batch["image"].to(device),
                    val_batch["gt_image"].to(device),
                )
                print('Input shape: {}'.format(inputs.shape))
                outputs, outputs_v2 = model(inputs)
                val_loss = recon_loss(outputs, gt_input)
                total_val_loss += val_loss.item()
                end_time = time.time()

            total_val_loss /= val_step
            val_loss_values.append(total_val_loss)
            print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

            if total_val_loss < best_val_loss:
                print(f"Saving new model based on validation loss {total_val_loss:.4f}")
                best_val_loss = total_val_loss
                checkpoint = {'epoch': max_epochs,
                              'state_dict': model.state_dict(),
                              'optimizer': optimizer.state_dict()
                              }
                torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt'))

            plt.figure(1, figsize=(8, 8))
            plt.subplot(2, 2, 1)
            plt.plot(epoch_loss_values)
            plt.grid()
            plt.title('Training Loss')

            plt.subplot(2, 2, 2)
            plt.plot(val_loss_values)
            plt.grid()
            plt.title('Validation Loss')

            plt.subplot(2, 2, 3)
            plt.plot(epoch_cl_loss_values)
            plt.grid()
            plt.title('Training Contrastive Loss')

            plt.subplot(2, 2, 4)
            plt.plot(epoch_recon_loss_values)
            plt.grid()
            plt.title('Training Recon Loss')

            plt.savefig(os.path.join(logdir_path, 'loss_plots.png'))
            plt.close(1)

    print('Done')
    return None
Exemple #7
0
    "RandCropByLabelClassesd 2d",
    "2D",
    1e-7,
    True,
    RandCropByLabelClassesd(KEYS,
                            "label", (99, 96),
                            ratios=[1, 2, 3, 4, 5],
                            num_classes=5,
                            num_samples=10),
))

TESTS.append(("RandCropByPosNegLabeld 2d", "2D", 1e-7, True,
              RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10)))

TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, True,
              RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10)))

TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, True,
              RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10)))

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], t[3],
                     Compose(Compose(t[4:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2  # type: ignore

NUM_SAMPLES = 5
N_SAMPLES_TESTS = [
    [
        RandCropByLabelClassesd(KEYS,
                                "label", (110, 99), [1, 2, 3, 4, 5],
                                num_classes=5,
 def test_shape(self, input_param, input_data, expected_shape):
     result = RandSpatialCropSamplesd(**input_param)(input_data)
     for item in result:
         self.assertTupleEqual(item["img"].shape, expected_shape)
         self.assertTupleEqual(item["seg"].shape, expected_shape)
Exemple #9
0
)

TESTS.append(("RandAffine 3d", "3D", 0, RandAffined(KEYS, spatial_size=None, prob=0)))

TESTS.append(
    (
        "RandCropByLabelClassesd 2d",
        "2D",
        1e-7,
        RandCropByLabelClassesd(KEYS, "label", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10),
    )
)

TESTS.append(("RandCropByPosNegLabeld 2d", "2D", 1e-7, RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10)))

TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10)))

TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10)))

TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], Compose(Compose(t[3:]))) for t in TESTS]

TESTS = TESTS + TESTS_COMPOSE_X2  # type: ignore

NUM_SAMPLES = 5
N_SAMPLES_TESTS = [
    [RandCropByLabelClassesd(KEYS, "label", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)],
    [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)],
    [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)],
    [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)],
]