def test_k(self):
     rotate = RandRotate90(max_k=2)
     for p in TEST_NDARRAYS:
         rotate.set_random_state(123)
         rotated = rotate(p(self.imt[0]))
         expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]]
         expected = np.stack(expected)
         assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8)
 def test_prob_k_spatial_axes(self):
     rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1))
     for p in TEST_NDARRAYS:
         rotate.set_random_state(234)
         rotated = rotate(p(self.imt[0]))
         expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]]
         expected = np.stack(expected)
         assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8)
def run_test(batch_size=64, train_steps=200, device=torch.device("cuda:0")):
    class _TestBatch(Dataset):
        def __init__(self, transforms):
            self.transforms = transforms

        def __getitem__(self, _unused_id):
            im, seg = create_test_image_2d(128,
                                           128,
                                           noise_max=1,
                                           num_objs=4,
                                           num_seg_classes=1)
            seed = np.random.randint(2147483647)
            self.transforms.set_random_state(seed=seed)
            im = self.transforms(im)
            self.transforms.set_random_state(seed=seed)
            seg = self.transforms(seg)
            return im, seg

        def __len__(self):
            return train_steps

    net = UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(4, 8, 16, 32),
        strides=(2, 2, 2),
        num_res_units=2,
    ).to(device)

    loss = DiceLoss(do_sigmoid=True)
    opt = torch.optim.Adam(net.parameters(), 1e-2)
    train_transforms = Compose([
        AddChannel(),
        ScaleIntensity(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(),
        ToTensor()
    ])

    src = DataLoader(_TestBatch(train_transforms),
                     batch_size=batch_size,
                     shuffle=True)

    net.train()
    epoch_loss = 0
    step = 0
    for img, seg in src:
        step += 1
        opt.zero_grad()
        output = net(img.to(device))
        step_loss = loss(output, seg.to(device))
        step_loss.backward()
        opt.step()
        epoch_loss += step_loss.item()
    epoch_loss /= step

    return epoch_loss, step
Ejemplo n.º 4
0
 def test_prob_k_spatial_axes(self):
     rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1))
     rotate.set_random_state(234)
     rotated = rotate(self.imt[0])
     expected = list()
     for channel in self.imt[0]:
         expected.append(np.rot90(channel, 1, (0, 1)))
     expected = np.stack(expected)
     self.assertTrue(np.allclose(rotated, expected))
Ejemplo n.º 5
0
 def test_spatial_axes(self):
     rotate = RandRotate90(spatial_axes=(0, 1))
     rotate.set_random_state(234)
     rotated = rotate(self.imt[0])
     expected = []
     for channel in self.imt[0]:
         expected.append(np.rot90(channel, 0, (0, 1)))
     expected = np.stack(expected)
     self.assertTrue(np.allclose(rotated, expected))
Ejemplo n.º 6
0
 def test_default(self):
     rotate = RandRotate90()
     rotate.set_random_state(123)
     rotated = rotate(self.imt[0])
     expected = []
     for channel in self.imt[0]:
         expected.append(np.rot90(channel, 0, (0, 1)))
     expected = np.stack(expected)
     self.assertTrue(np.allclose(rotated, expected))
Ejemplo n.º 7
0
 def test_k(self):
     rotate = RandRotate90(max_k=2)
     rotate.set_random_state(234)
     rotated = rotate(self.imt[0])
     expected = list()
     for channel in self.imt[0]:
         expected.append(np.rot90(channel, 0, (0, 1)))
     expected = np.stack(expected)
     self.assertTrue(np.allclose(rotated, expected))
Ejemplo n.º 8
0
 def test_default(self):
     rotate = RandRotate90()
     for p in TEST_NDARRAYS_ALL:
         rotate.set_random_state(123)
         im = p(self.imt[0])
         rotated = rotate(im)
         test_local_inversion(rotate, rotated, im)
         expected = [
             np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]
         ]
         expected = np.stack(expected)
         assert_allclose(rotated,
                         p(expected),
                         rtol=1.0e-5,
                         atol=1.0e-8,
                         type_test="tensor")
Ejemplo n.º 9
0
 def test_prob_k_spatial_axes(self):
     rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1))
     for p in TEST_NDARRAYS_ALL:
         rotate.set_random_state(234)
         im = p(self.imt[0])
         rotated = rotate(im)
         test_local_inversion(rotate, rotated, im)
         expected = [
             np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]
         ]
         expected = np.stack(expected)
         assert_allclose(rotated,
                         p(expected),
                         rtol=1.0e-5,
                         atol=1.0e-8,
                         type_test="tensor")
Ejemplo n.º 10
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete
        data = [im_fname for _ in range(12)]
        transform = Compose(
            [
                LoadImage(image_only=True),
                EnsureChannelFirst(),
                Orientation("RPS"),
                Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32),
                RandFlip(prob=0.5, spatial_axis=[1, 2]),
                RandAxisFlip(prob=0.5),
                RandRotate90(prob=0, spatial_axes=(1, 2)),
                RandZoom(prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
                RandRotate(prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64),
                RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
                ResizeWithPadOrCrop(100),
                CastToType(dtype=torch.uint8),
            ]
        )

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        dataset = Dataset(data, transform=transform)
        self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=transform, nearest_interp=True, device="cpu")

        for d in loader:
            d = decollate_batch(d)
            for item in d:
                orig = deepcopy(item)
                i = inverter(item)
                self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
        # check labels match
        reverted = i.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(data[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = i.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}")
        set_determinism(seed=None)
Ejemplo n.º 11
0
 def test_spatial_axes(self):
     rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0)
     for p in TEST_NDARRAYS_ALL:
         rotate.set_random_state(1234)
         im = p(self.imt[0])
         rotated = rotate(im)
         self.assertEqual(len(rotated.applied_operations), 1)
         expected = [
             np.rot90(channel, rotate._rand_k, (0, 1))
             for channel in self.imt[0]
         ]
         expected = np.stack(expected)
         assert_allclose(rotated,
                         p(expected),
                         rtol=1.0e-5,
                         atol=1.0e-8,
                         type_test="tensor")
         test_local_inversion(rotate, rotated, im)
Ejemplo n.º 12
0
    def test_k(self):
        rotate = RandRotate90(max_k=2)
        for p in TEST_NDARRAYS_ALL:
            im = p(self.imt[0])
            set_track_meta(False)
            rotated = rotate(im)
            self.assertNotIsInstance(rotated, MetaTensor)
            self.assertIsInstance(rotated, torch.Tensor)

            set_track_meta(True)
            rotate.set_random_state(123)
            rotated = rotate(im)
            test_local_inversion(rotate, rotated, im)
            expected = [
                np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]
            ]
            expected = np.stack(expected)
            assert_allclose(rotated,
                            p(expected),
                            rtol=1.0e-5,
                            atol=1.0e-8,
                            type_test="tensor")
Ejemplo n.º 13
0
                                prob=1,
                                range_x=np.pi,
                                keep_size=False)))
TESTS.append((dict,
              RandZoomd("image",
                        prob=1,
                        min_zoom=1.1,
                        max_zoom=2.0,
                        keep_size=False)))
TESTS.append((dict, RandRotate90d("image", prob=1, max_k=2)))

TESTS.append((list, RandSpatialCrop(roi_size=[8, 7], random_size=True)))
TESTS.append((list, RandRotate(prob=1, range_x=np.pi, keep_size=False)))
TESTS.append(
    (list, RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False)))
TESTS.append((list, RandRotate90(prob=1, max_k=2)))


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transforms):
        self.images = images
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.transforms(self.images[index]), self.labels[index]

Ejemplo n.º 14
0
    TESTS.append(
        (dict, pad_collate,
         Compose([RandRotate90d("image", prob=1, max_k=2),
                  ToTensord("image")])))

    TESTS.append(
        (list, pad_collate, RandSpatialCrop(roi_size=[8, 7],
                                            random_size=True)))
    TESTS.append(
        (list, pad_collate, RandRotate(prob=1, range_x=np.pi,
                                       keep_size=False)))
    TESTS.append((list, pad_collate,
                  RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0,
                           keep_size=False)))
    TESTS.append((list, pad_collate,
                  Compose([RandRotate90(prob=1, max_k=2),
                           ToTensor()])))


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transforms):
        self.images = images
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.transforms(self.images[index]), self.labels[index]
train_imtrans = Compose([
    ToTensor(),
    AddChannel(),
    RandSpatialCrop((96, 96), random_size=False),
])
#  RandRotate90(prob=0.5, spatial_axes=(0, 1)),
# AddChannel(),
# ToTensor(),
# ScaleIntensity(),
# AddChannel(),
# RandSpatialCrop((96, 96), random_size=False),
# LoadNifti(),
train_segtrans = Compose([
    LoadNifti(),
    AddChannel(),
    RandRotate90(prob=0.5, spatial_axes=(0, 1)),
    ToTensor(),
])
# For testing
# datadir1 = "/home1/quanquan/datasets/lsw/benign_65/fpAML_55/slices/"
# image_files = np.array([x.path for x in os.scandir(datadir1+"image") if x.name.endswith(".npy")])
# label_files = np.array([x.path for x in os.scandir(datadir1+"label") if x.name.endswith(".npy")])

###  Data Collection for Kits19
datadir_kits = "/home1/quanquan/datasets/kits19/resampled_data"
image_files = []
for subdir in os.scandir(datadir_kits):
    if subdir.name.startswith("case_"):
        image_name = os.path.join(subdir.path, "imaging.nii.gz")
        image_files.append(image_name)
image_files = np.array(image_files)
Ejemplo n.º 16
0
    def _define_training_transforms(self):
        """Define and initialize all training data transforms.

          * training set images transform
          * training set masks transform
          * validation set images transform
          * validation set masks transform
          * validation set images post-transform
          * test set images transform
          * test set masks transform
          * test set images post-transform
          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """

        if self._mask_type == MaskType.UNKNOWN:
            raise Exception("The mask type is unknown. Cannot continue!")

        # Depending on the mask type, we will need to adapt the Mask Loader
        # and Transform. We start by initializing the most common types.
        MaskLoader = LoadMask(self._mask_type)
        MaskTransform = Identity

        # Adapt the transform for the LABEL types
        if self._mask_type == MaskType.TIFF_LABELS or self._mask_type == MaskType.NUMPY_LABELS:
            MaskTransform = ToOneHot(num_classes=self._out_channels)

        # The H5_ONE_HOT type requires a different loader
        if self._mask_type == MaskType.H5_ONE_HOT:
            # MaskLoader: still missing
            raise Exception("HDF5 one-hot masks are not supported yet!")

        # Define transforms for training
        self._train_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                RandSpatialCrop(self._roi_size, random_size=False),
                RandRotate90(prob=0.5, spatial_axes=(0, 1)),
                ToTensor()
            ]
        )
        self._train_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                RandSpatialCrop(self._roi_size, random_size=False),
                RandRotate90(prob=0.5, spatial_axes=(0, 1)),
                ToTensor()
            ]
        )

        # Define transforms for validation
        self._validation_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                ToTensor()
            ]
        )
        self._validation_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                ToTensor()
            ]
        )

        # Define transforms for testing
        self._test_image_transforms = Compose(
            [
                LoadImage(image_only=True),
                ScaleIntensity(),
                AddChannel(),
                ToTensor()
            ]
        )
        self._test_mask_transforms = Compose(
            [
                MaskLoader,
                MaskTransform,
                ToTensor()
            ]
        )

        # Post transforms
        self._validation_post_transforms = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold_values=True)
            ]
        )

        self._test_post_transforms = Compose(
            [
                Activations(softmax=True),
                AsDiscrete(threshold_values=True)
            ]
        )
Ejemplo n.º 17
0
    '/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz',
    '/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz'
]
# 2 binary labels for gender classification: man and woman
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])

# Define transforms
train_transforms = Compose([
    ScaleIntensity(),
    AddChannel(),
    Resize((96, 96, 96)),
    RandRotate90(),
    ToTensor()
])
val_transforms = Compose(
    [ScaleIntensity(),
     AddChannel(),
     Resize((96, 96, 96)),
     ToTensor()])

# Define nifti dataset, data loader
check_ds = NiftiDataset(image_files=images,
                        labels=labels,
                        transform=train_transforms)
check_loader = DataLoader(check_ds,
                          batch_size=2,
                          num_workers=2,
Ejemplo n.º 18
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI314-IOP-0889-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI249-Guys-1072-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI609-HH-2600-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI173-HH-1590-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI020-Guys-0700-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI342-Guys-0909-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI134-Guys-0780-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI577-HH-2661-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI066-Guys-0731-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI130-HH-1528-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)

    # Define transforms
    train_transforms = Compose([
        ScaleIntensity(),
        AddChannel(),
        Resize((96, 96, 96)),
        RandRotate90(),
        ToTensor()
    ])
    val_transforms = Compose(
        [ScaleIntensity(),
         AddChannel(),
         Resize((96, 96, 96)),
         ToTensor()])

    # Define image dataset, data loader
    check_ds = ImageDataset(image_files=images,
                            labels=labels,
                            transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create a training data loader
    train_ds = ImageDataset(image_files=images[:10],
                            labels=labels[:10],
                            transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = ImageDataset(image_files=images[-10:],
                          labels=labels[-10:],
                          transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=2,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.DenseNet121(spatial_dims=3,
                                            in_channels=1,
                                            out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                num_correct = 0.0
                metric_count = 0
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    val_outputs = model(val_images)
                    value = torch.eq(val_outputs.argmax(dim=1), val_labels)
                    metric_count += len(value)
                    num_correct += value.sum().item()
                metric = num_correct / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_classification3d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_accuracy", metric, epoch + 1)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Ejemplo n.º 19
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        "/workspace/data/medical/ixi/IXI-T1/IXI314-IOP-0889-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI249-Guys-1072-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI609-HH-2600-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI173-HH-1590-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI020-Guys-0700-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI342-Guys-0909-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI134-Guys-0780-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI577-HH-2661-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI066-Guys-0731-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI130-HH-1528-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI607-Guys-1097-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI175-HH-1570-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI385-HH-2078-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI344-Guys-0905-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI409-Guys-0960-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI584-Guys-1129-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI253-HH-1694-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI092-HH-1436-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI574-IOP-1156-T1.nii.gz",
        "/workspace/data/medical/ixi/IXI-T1/IXI585-Guys-1130-T1.nii.gz",
    ]
    # 2 binary labels for gender classification: man and woman
    labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0])

    # define transforms
    train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), ToTensor()])
    val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), ToTensor()])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(image_files=images, labels=labels, transform=train_transforms)
    check_loader = DataLoader(check_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())
    im, label = monai.utils.misc.first(check_loader)
    print(type(im), im.shape, label)

    # create DenseNet121, CrossEntropyLoss and Adam optimizer
    net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2,)
    loss = torch.nn.CrossEntropyLoss()
    lr = 1e-5
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda:0")

    # ignite trainer expects batch=(img, label) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    trainer = create_supervised_trainer(net, opt, loss, device, False)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs/", "net", n_saved=10, require_empty=False)
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
    )

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    # set parameters for validation
    validation_every_n_epochs = 1

    metric_name = "Accuracy"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: Accuracy()}
    # ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net, val_metrics, device, True)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x: None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x: None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

    # create a validation data loader
    val_ds = NiftiDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

    @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
    def run_validation(engine):
        evaluator.run(val_loader)

    # create a training data loader
    train_ds = NiftiDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

    train_epochs = 30
    state = trainer.run(train_loader, train_epochs)
Ejemplo n.º 20
0
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS_DICT: List[Tuple] = []
TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),))

TESTS_LIST: List[Tuple] = []
TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))


TEST_BASIC = [
    [("channel", "channel"), ["channel", "channel"]],
    [torch.Tensor([1, 2, 3]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],
    [
        [[torch.Tensor((1.0, 2.0, 3.0)), torch.Tensor((2.0, 3.0, 1.0))]],
        [
            [[torch.tensor(1.0), torch.tensor(2.0)]],
            [[torch.tensor(2.0), torch.tensor(3.0)]],
            [[torch.tensor(3.0), torch.tensor(1.0)]],
        ],
    ],
    [torch.Tensor((True, True, False, False)), [1.0, 1.0, 0.0, 0.0]],
Ejemplo n.º 21
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print('generating synthetic data to {} (this may take a while)'.format(tempdir))
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

    images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
    segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        ToTensor()
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        ToTensor()
    ])
    val_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])
    val_segtrans = Compose([
        AddChannel(),
        ToTensor()
    ])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device('cuda:0')
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(do_sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print('-' * 10)
        print('epoch {}/{}'.format(epoch + 1, 5))
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print('{}/{}, train_loss: {:.4f}'.format(step, epoch_len, loss.item()))
            writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print('epoch {} average loss: {:.4f}'.format(epoch + 1, epoch_loss))

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=True,
                                             to_onehot_y=False, add_sigmoid=True)
                    metric_count += len(value)
                    metric_sum += value.sum().item()
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), 'best_metric_model.pth')
                    print('saved new best metric model')
                print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format(
                    epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar('val_mean_dice', metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image')
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label')
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output')
    shutil.rmtree(tempdir)
    print('train completed, best_metric: {:.4f} at epoch: {}'.format(best_metric, best_metric_epoch))
    writer.close()
Ejemplo n.º 22
0
class Loader():
    """Loader for different image datasets with built in split function and download if needed.
    
    Functions:
        load_IXIT1: Loads the IXIT1 3D brain MRI dataset.
        load_MedNIST: Loads the MedNIST 2D image dataset.
    """
    
    ixi_train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90()])
    ixi_test_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96))])
    
    mednist_train_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(),
                                        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), 
                                        RandFlip(spatial_axis=0, prob=0.5), 
                                        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5)])
    mednist_test_transforms = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity()])
    
    
    @staticmethod
    def load_IXIT1(download: bool = False, train_transforms: object = ixi_train_transforms, 
                   test_transforms: object = ixi_test_transforms, test_size: float = 0.2, 
                   val_size: float = 0.0, sample_size: float = 0.01, shuffle: bool = True):
        """Loads the IXIT1 3D Brain MRI dataset.
        
        Consists of ~566 images of 3D Brain MRI scans and labels (0) for male and (1) for female.
        
        Args:
            download (bool): If true, then data is downloaded before loading it as dataset.
            train_transforms (Compose): Specify the transformations to be applied to the training dataset.
            test_transforms (Compose): Specify the transformations to be applied to the test dataset.
            sample_size (float): Percentage of available images to be used.
            test_size (float): Precantage of sample to be used as test data.
            val_size (float): Percentage of sample to be used as validation data.
            shuffle (bool): Whether or not the data should be shuffled after loading.
        """
        # Download data if needed
        if download:
            data_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar'
            compressed_file = os.sep.join(['Data', 'IXI-T1.tar'])
            data_dir = os.sep.join(['Data', 'IXI-T1'])

            # Data download
            monai.apps.download_and_extract(data_url, compressed_file, './Data/IXI-T1')

            # Labels document download
            labels_url = 'http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI.xls'
            monai.apps.download_url(labels_url, './Data/IXI.xls')
            
        # Get all the images and corresponding Labels
        images = [impath for impath in os.listdir('./Data/IXI-T1')]

        df = pd.read_excel('./Data/IXI.xls')

        data = []
        labels = []
        for i in images:
            ixi_id = int(i[3:6])
            row = df.loc[df['IXI_ID'] == ixi_id]
            if not row.empty:
                data.append(os.sep.join(['Data', 'IXI-T1', i]))
                labels.append(int(row.iat[0, 1] - 1)) # Sex labels are 1/2 but need to be 0/1

        data, labels = data[:int(len(data) * sample_size)], labels[:int(len(data) * sample_size)]
        
        # Make train test validation split
        train_data, train_labels, test_data, test_labels, val_data, val_labels = _split(data, labels, 
                                                                                        test_size, val_size)
        
        # Construct and return Datasets
        train_ds = IXIT1Dataset(train_data, train_labels, train_transforms, shuffle)
        test_ds = IXIT1Dataset(test_data, test_labels, test_transforms, shuffle)
        
        if val_size == 0:
            return train_ds, test_ds
        else:
            val_ds = IXIT1Dataset(val_data, val_labels, test_transforms, shuffle)
            return train_ds, test_ds, val_ds
        
    
    @staticmethod
    def load_MedNIST(download: bool = False, train_transforms: object = mednist_train_transforms, 
                   test_transforms: object = mednist_test_transforms, test_size: float = 0.2, 
                   val_size: float = 0.0, sample_size: float = 0.01, shuffle: bool = True):
        """Loads the MedNIST 2D image dataset.
        
        Consists of ~60.000 2D images from 6 classes: AbdomenCT, BreastMRI, ChestCT, CXR, Hand, HeadCT.
        
        Args:
            download (bool): If true, then data is downloaded before loading it as dataset.
            train_transforms (Compose): Specify the transformations to be applied to the training dataset.
            test_transforms (Compose): Specify the transformations to be applied to the test dataset.
            sample_size (float): Percentage of available images to be used.
            test_size (float): Precantage of sample to be used as test data.
            val_size (float): Percentage of sample to be used as validation data.
            shuffle (bool): Whether or not the data should be shuffled after loading.
        """
        
        root_dir = './Data'
        resource = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
        md5 = "0bc7306e7427e00ad1c5526a6677552d"

        compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
        data_dir = os.path.join(root_dir, "MedNIST")
            
        if download:
            monai.apps.download_and_extract(resource, compressed_file, root_dir, md5)

        # Reading image filenames from dataset folders and assigning labels
        class_names = sorted(x for x in os.listdir(data_dir)
                             if os.path.isdir(os.path.join(data_dir, x)))
        num_class = len(class_names)

        image_files = [
            [
                os.path.join(data_dir, class_names[i], x)
                for x in os.listdir(os.path.join(data_dir, class_names[i]))
            ]
            for i in range(num_class)
        ]
        
        image_files = [images[:int(len(images) * sample_size)] for images in image_files]
        
        # Constructing data and labels
        num_each = [len(image_files[i]) for i in range(num_class)]
        data = []
        labels = []

        for i in range(num_class):
            data.extend(image_files[i])
            labels.extend([int(i)] * num_each[i])
            
        if shuffle:
            np.random.seed(42)
            indicies = np.arange(len(data))
            np.random.shuffle(indicies)
            
            data = [data[i] for i in indicies]
            labels = [labels[i] for i in indicies]
        
        # Make train test validation split
        train_data, train_labels, test_data, test_labels, val_data, val_labels = _split(data, labels, 
                                                                                        test_size, val_size)
        
        # Construct and return datasets
        train_ds = MedNISTDataset(train_data, train_labels, train_transforms, shuffle)
        test_ds = MedNISTDataset(test_data, test_labels, test_transforms, shuffle)
        
        if val_size == 0:
            return train_ds, test_ds
        else:
            val_ds = MedNISTDataset(val_data, val_labels, test_transforms, shuffle)
            return train_ds, test_ds, val_ds
Ejemplo n.º 23
0
                                             N_valid_per_magn=4,
                                             is_val_split=is_val_split)

    # data preprocessing/augmentation
    trans_train = MozartTheComposer([

        #ScaleIntensity(),
        #             AddChannel(),
        #             RandSpatialCrop(roi_size=256, random_size=False),
        #CenterSpatialCrop(roi_size=2154),  # 2154
        #             RandScaleIntensity(factors=0.25, prob=aug_prob),
        RandRotate(range_x=15,
                   prob=aug_prob,
                   keep_size=True,
                   padding_mode="reflection"),
        RandRotate90(prob=aug_prob, spatial_axes=(1, 2)),
        RandFlip(spatial_axis=(1, 2), prob=aug_prob),
        ToTensor()
    ])

    trans_val = MozartTheComposer([
        #         LoadImage(PILReader(), image_only=True),
        #ScaleIntensity(),
        #         AddChannel(),
        #         RandSpatialCrop(roi_size=256, random_size=False),
        #CenterSpatialCrop(roi_size=2154),
        ToTensor()
    ])

    # create dataset class
    train_dataset = OurDataset(data=train_split,
Ejemplo n.º 24
0
                            prob=1,
                            min_zoom=1.1,
                            max_zoom=2.0,
                            keep_size=False)))
    TESTS.append((dict, pad_collate, RandRotate90d("image", prob=1, max_k=2)))

    TESTS.append(
        (list, pad_collate, RandSpatialCrop(roi_size=[8, 7],
                                            random_size=True)))
    TESTS.append(
        (list, pad_collate, RandRotate(prob=1, range_x=np.pi,
                                       keep_size=False)))
    TESTS.append((list, pad_collate,
                  RandZoom(prob=1, min_zoom=1.1, max_zoom=2.0,
                           keep_size=False)))
    TESTS.append((list, pad_collate, RandRotate90(prob=1, max_k=2)))


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transforms):
        self.images = images
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.transforms(self.images[index]), self.labels[index]

Ejemplo n.º 25
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        Image.fromarray(im.astype("uint8")).save(
            os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray(seg.astype("uint8")).save(
            os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    train_segtrans = Compose([
        LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ToTensor(),
    ])
    val_imtrans = Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        ToTensor()
    ])
    val_segtrans = Compose(
        [LoadImage(image_only=True),
         AddChannel(), ToTensor()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20],
                            train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean")
    post_trans = Compose(
        [Activations(sigmoid=True),
         AsDiscrete(threshold_values=True)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    val_outputs = post_trans(val_outputs)
                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")

    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Ejemplo n.º 26
0
    def _define_transforms(self):
        """Define and initialize all data transforms.

          * training set images transform
          * training set targets transform
          * validation set images transform
          * validation set targets transform
          * validation set images post-transform
          * test set images transform
          * test set targets transform
          * test set images post-transform
          * prediction set images transform
          * prediction set images post-transform

        @return True if data transforms could be instantiated, False otherwise.
        """
        # Define transforms for training
        self._train_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            ToTensor()
        ])
        self._train_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            RandSpatialCrop(self._roi_size, random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            ToTensor()
        ])

        # Define transforms for validation
        self._validation_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])
        self._validation_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])

        # Define transforms for testing
        self._test_image_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])
        self._test_target_transforms = Compose([
            LoadImage(image_only=True),
            ScaleIntensityRange(0, 65535, 0.0, 1.0, clip=False),
            AddChannel(),
            ToTensor()
        ])

        # Define transforms for prediction
        self._prediction_image_transforms = Compose(
            [LoadImage(image_only=True),
             AddChannel(),
             ToTensor()])

        # Post transforms
        self._validation_post_transforms = Compose([Identity()])

        self._test_post_transforms = Compose(
            [ToNumpy(), ScaleIntensity(0, 65535)])

        self._prediction_post_transforms = Compose(
            [ToNumpy(), ScaleIntensity(0, 65535)])