def test_plot(self):
        set_determinism(0)

        testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")

        net = torch.nn.Conv2d(1, 1, 3, padding=1)

        opt = torch.optim.Adam(net.parameters())

        img = torch.rand(1, 16, 16)
        data = {CommonKeys.IMAGE: img, CommonKeys.LABEL: img}
        loader = DataLoader([data for _ in range(10)])

        trainer = SupervisedTrainer(
            device=torch.device("cpu"),
            max_epochs=1,
            train_data_loader=loader,
            network=net,
            optimizer=opt,
            loss_function=torch.nn.L1Loss(),
        )

        logger = MetricLogger()
        logger.attach(trainer)

        con = ThreadContainer(trainer)
        con.start()
        con.join()

        fig = con.plot_status(logger)

        with tempfile.TemporaryDirectory() as tempdir:
            tempimg = f"{tempdir}/threadcontainer_plot_test.png"
            fig.savefig(tempimg)
            comp = compare_images(f"{testing_dir}/threadcontainer_plot_test.png", tempimg, 1e-3)

            self.assertIsNone(comp, comp)  # None indicates test passed
    def test_training(self):
        """
        check that the quality AffineTransform backpropagation
        """
        atol = 1e-5
        set_determinism(seed=0)
        out_ref, loss_ref, init_loss_ref = compare_2d(True, self.device)
        print(out_ref.shape, loss_ref, init_loss_ref)

        set_determinism(seed=0)
        out, loss, init_loss = compare_2d(False, self.device)
        print(out.shape, loss, init_loss)
        np.testing.assert_allclose(out_ref, out, atol=atol)
        np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
        np.testing.assert_allclose(loss_ref, loss, atol=atol)

        set_determinism(seed=0)
        out, loss, init_loss = compare_2d(False, self.device, True)
        print(out.shape, loss, init_loss)
        np.testing.assert_allclose(out_ref, out, atol=atol)
        np.testing.assert_allclose(init_loss_ref, init_loss, atol=atol)
        np.testing.assert_allclose(loss_ref, loss, atol=atol)
Exemple #3
0
    def test_pickle(self):
        set_determinism(0)
        data1 = np.random.rand(10)
        data2 = np.random.rand(10)
        set_determinism(0)
        data3 = np.random.rand(10)
        data4 = np.random.rand(10)
        set_determinism(None)

        h1 = pickle_hashing(data1)
        h2 = pickle_hashing(data3)
        self.assertEqual(h1, h2)

        data_dict1 = {"b": data2, "a": data1}
        data_dict2 = {"a": data3, "b": data4}

        h1 = pickle_hashing(data_dict1)
        h2 = pickle_hashing(data_dict2)
        self.assertEqual(h1, h2)

        with self.assertRaises(TypeError):
            json_hashing(data_dict1)
Exemple #4
0
 def tearDown(self):
     set_determinism(seed=None)
Exemple #5
0
    def test_invert(self):
        set_determinism(seed=0)
        im_fname, seg_fname = (
            make_nifti_image(i)
            for i in create_test_image_3d(101, 100, 107, noise_max=100))
        transform = Compose([
            LoadImaged(KEYS),
            AddChanneld(KEYS),
            Orientationd(KEYS, "RPS"),
            Spacingd(KEYS,
                     pixdim=(1.2, 1.01, 0.9),
                     mode=["bilinear", "nearest"],
                     dtype=np.float32),
            ScaleIntensityd("image", minv=1, maxv=10),
            RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
            RandAxisFlipd(KEYS, prob=0.5),
            RandRotate90d(KEYS, spatial_axes=(1, 2)),
            RandZoomd(KEYS,
                      prob=0.5,
                      min_zoom=0.5,
                      max_zoom=1.1,
                      keep_size=True),
            RandRotated(KEYS,
                        prob=0.5,
                        range_x=np.pi,
                        mode="bilinear",
                        align_corners=True,
                        dtype=np.float64),
            RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"),
            ResizeWithPadOrCropd(KEYS, 100),
            # test EnsureTensor for complicated dict data and invert it
            CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"),
            # test to support Tensor, Numpy array and dictionary when inverting
            EnsureTyped(keys=["image", "test_dict"]),
            ToTensord("image"),
            CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]),
            CopyItemsd("label",
                       times=2,
                       names=["label_inverted", "label_inverted1"]),
            CopyItemsd("image",
                       times=2,
                       names=["image_inverted", "image_inverted1"]),
        ])
        data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available(
        ) else 2

        dataset = CacheDataset(data, transform=transform, progress=False)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)
        inverter = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted", "label_inverted", "test_dict"],
            transform=transform,
            orig_keys=["label", "label", "test_dict"],
            meta_keys=[
                PostFix.meta("image_inverted"),
                PostFix.meta("label_inverted"), None
            ],
            orig_meta_keys=[
                PostFix.meta("label"),
                PostFix.meta("label"), None
            ],
            nearest_interp=True,
            to_tensor=[True, False, False],
            device="cpu",
        )

        inverter_1 = Invertd(
            # `image` was not copied, invert the original value directly
            keys=["image_inverted1", "label_inverted1"],
            transform=transform,
            orig_keys=["image", "image"],
            meta_keys=[
                PostFix.meta("image_inverted1"),
                PostFix.meta("label_inverted1")
            ],
            orig_meta_keys=[PostFix.meta("image"),
                            PostFix.meta("image")],
            nearest_interp=[True, False],
            to_tensor=[True, True],
            device="cpu",
        )

        expected_keys = [
            "image",
            "image_inverted",
            "image_inverted1",
            PostFix.meta("image_inverted1"),
            PostFix.meta("image_inverted"),
            PostFix.meta("image"),
            "image_transforms",
            "label",
            "label_inverted",
            "label_inverted1",
            PostFix.meta("label_inverted1"),
            PostFix.meta("label_inverted"),
            PostFix.meta("label"),
            "label_transforms",
            "test_dict",
            "test_dict_transforms",
        ]
        # execute 1 epoch
        for d in loader:
            d = decollate_batch(d)
            for item in d:
                item = inverter(item)
                item = inverter_1(item)

                self.assertListEqual(sorted(item), expected_keys)
                self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100))
                self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                i = item["image_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                i = item["label_inverted"]
                torch.testing.assert_allclose(
                    i.to(torch.uint8).to(torch.float), i.to(torch.float))
                self.assertTupleEqual(i.shape[1:], (100, 101, 107))
                # test inverted test_dict
                self.assertTrue(
                    isinstance(item["test_dict"]["affine"], np.ndarray))
                self.assertTrue(
                    isinstance(item["test_dict"]["filename_or_obj"], str))

                # check the case that different items use different interpolation mode to invert transforms
                d = item["image_inverted1"]
                # if the interpolation mode is nearest, accumulated diff should be smaller than 1
                self.assertLess(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 1.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

                d = item["label_inverted1"]
                # if the interpolation mode is not nearest, accumulated diff should be greater than 10000
                self.assertGreater(
                    torch.sum(
                        d.to(torch.float) -
                        d.to(torch.uint8).to(torch.float)).item(), 10000.0)
                self.assertTupleEqual(d.shape, (1, 100, 101, 107))

        # check labels match
        reverted = item["label_inverted"].detach().cpu().numpy().astype(
            np.int32)
        original = LoadImaged(KEYS)(data[-1])["label"]
        n_good = np.sum(np.isclose(reverted, original, atol=1e-3))
        reverted_name = item[PostFix.meta("label_inverted")]["filename_or_obj"]
        original_name = data[-1]["label"]
        self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # 25300: 2 workers (cpu, non-macos)
        # 1812: 0 workers (gpu or macos)
        # 1821: windows torch 1.10.0
        self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821),
                        f"diff.  {reverted.size - n_good}")

        set_determinism(seed=None)
                                             "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr",
                                             "*.nii.gz")))
data_dicts = [{
    "image": image_name,
    "label": label_name
} for image_name, label_name in zip(train_images, train_labels)]
#n = len(data_dicts)
#train_files, val_files = data_dicts[:-3], data_dicts[-3:]
#train_files, val_files = data_dicts[:int(n*0.8)], data_dicts[int(n*0.2):]

val_files, train_files, test_files = data_dicts[0:8], data_dicts[
    8:40], data_dicts[40:50]
"""## Set deterministic training for reproducibility"""

set_determinism(seed=0)
"""## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.
1. `ToTensord` converts the numpy array to PyTorch Tensor for further steps.
"""
 def test_compute(self):
     set_determinism(123)
     self._compute()
 def setUp(self):
     set_determinism(seed=1234)
 def setUp(self) -> None:
     set_determinism(seed=0)
Exemple #10
0
def train(args):
    # load hyper parameters
    task_id = args.task_id
    fold = args.fold
    val_output_dir = "./runs_{}_fold{}_{}/".format(task_id, fold,
                                                   args.expr_name)
    log_filename = "nnunet_task{}_fold{}.log".format(task_id, fold)
    log_filename = os.path.join(val_output_dir, log_filename)
    interval = args.interval
    learning_rate = args.learning_rate
    max_epochs = args.max_epochs
    multi_gpu_flag = args.multi_gpu
    amp_flag = args.amp
    lr_decay_flag = args.lr_decay
    sw_batch_size = args.sw_batch_size
    tta_val = args.tta_val
    batch_dice = args.batch_dice
    window_mode = args.window_mode
    eval_overlap = args.eval_overlap
    local_rank = args.local_rank
    determinism_flag = args.determinism_flag
    determinism_seed = args.determinism_seed
    if determinism_flag:
        set_determinism(seed=determinism_seed)
        if local_rank == 0:
            print("Using deterministic training.")

    # transforms
    train_batch_size = data_loader_params[task_id]["batch_size"]
    if multi_gpu_flag:
        dist.init_process_group(backend="nccl", init_method="env://")

        device = torch.device(f"cuda:{local_rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cuda")

    properties, val_loader = get_data(args, mode="validation")
    _, train_loader = get_data(args, batch_size=train_batch_size, mode="train")

    # produce the network
    checkpoint = args.checkpoint
    net = get_network(properties, task_id, val_output_dir, checkpoint)
    net = net.to(device)

    if multi_gpu_flag:
        net = DistributedDataParallel(module=net,
                                      device_ids=[device],
                                      find_unused_parameters=True)

    optimizer = torch.optim.SGD(
        net.parameters(),
        lr=learning_rate,
        momentum=0.99,
        weight_decay=3e-5,
        nesterov=True,
    )

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs)**0.9)
    # produce evaluator
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        CheckpointSaver(save_dir=val_output_dir,
                        save_dict={"net": net},
                        save_key_metric=True),
    ]

    evaluator = DynUNetEvaluator(
        device=device,
        val_data_loader=val_loader,
        network=net,
        n_classes=len(properties["labels"]),
        inferer=SlidingWindowInferer(
            roi_size=patch_size[task_id],
            sw_batch_size=sw_batch_size,
            overlap=eval_overlap,
            mode=window_mode,
        ),
        post_transform=None,
        key_val_metric={
            "val_mean_dice":
            MeanDice(
                include_background=False,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=amp_flag,
        tta_val=tta_val,
    )
    # produce trainer
    loss = DiceCELoss(to_onehot_y=True, softmax=True, batch=batch_dice)
    train_handlers = []
    if lr_decay_flag:
        train_handlers += [
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True)
        ]

    train_handlers += [
        ValidationHandler(validator=evaluator,
                          interval=interval,
                          epoch_level=True),
        StatsHandler(tag_name="train_loss",
                     output_transform=lambda x: x["loss"]),
    ]

    trainer = DynUNetTrainer(
        device=device,
        max_epochs=max_epochs,
        train_data_loader=train_loader,
        network=net,
        optimizer=optimizer,
        loss_function=loss,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=amp_flag,
    )

    # run
    logger = logging.getLogger()

    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    # Setup file handler
    fhandler = logging.FileHandler(log_filename)
    fhandler.setLevel(logging.INFO)
    fhandler.setFormatter(formatter)

    # Configure stream handler for the cells
    chandler = logging.StreamHandler()
    chandler.setLevel(logging.INFO)
    chandler.setFormatter(formatter)

    # Add both handlers
    if local_rank == 0:
        logger.addHandler(fhandler)
        logger.addHandler(chandler)
        logger.setLevel(logging.INFO)

    trainer.run()
Exemple #11
0
    def setUp(self) -> None:
        set_determinism(seed=0)

        im = create_test_image_2d(100, 101)[0]
        self.data_dict = [{"image": make_nifti_image(im) if has_nib else im} for _ in range(6)]
        self.data_list = [make_nifti_image(im) if has_nib else im for _ in range(6)]
Exemple #12
0
    def test_training(self):
        repeated = []
        test_rounds = 3 if monai.config.get_torch_version_tuple() >= (1,
                                                                      6) else 2
        for i in range(test_rounds):
            set_determinism(seed=0)

            repeated.append([])
            best_metric = run_training_test(self.data_dir,
                                            device=self.device,
                                            amp=(i == 2))
            print("best metric", best_metric)
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="best_metric_2",
                                           data=best_metric,
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="best_metric",
                                           data=best_metric,
                                           rtol=1e-2))
            repeated[i].append(best_metric)

            model_file = sorted(
                glob(os.path.join(self.data_dir, "net_key_metric*.pt")))[-1]
            infer_metric = run_inference_test(self.data_dir,
                                              model_file,
                                              device=self.device,
                                              amp=(i == 2))
            print("infer metric", infer_metric)
            # check inference properties
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="infer_metric_2",
                                           data=infer_metric,
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="infer_metric",
                                           data=infer_metric,
                                           rtol=1e-2))
            repeated[i].append(infer_metric)

            output_files = sorted(
                glob(os.path.join(self.data_dir, "img*", "*.nii.gz")))
            for output in output_files:
                ave = np.mean(nib.load(output).get_fdata())
                repeated[i].append(ave)
            if i == 2:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="output_sums_2",
                                           data=repeated[i][2:],
                                           rtol=1e-2))
            else:
                self.assertTrue(
                    test_integration_value(TASK,
                                           key="output_sums",
                                           data=repeated[i][2:],
                                           rtol=1e-2))
        np.testing.assert_allclose(repeated[0], repeated[1])
Exemple #13
0
 def test_value(self, input_param, input_data, expected_value):
     set_determinism(seed=0)
     result = TorchVision(**input_param)(input_data)
     torch.testing.assert_allclose(result, expected_value)
Exemple #14
0
 def test_values(self):
     # check system default flags
     set_determinism(None)
     self.assertTrue(not torch.backends.cudnn.deterministic)
     self.assertTrue(get_seed() is None)
     # set default seed
     set_determinism()
     self.assertTrue(get_seed() is not None)
     self.assertTrue(torch.backends.cudnn.deterministic)
     self.assertTrue(not torch.backends.cudnn.benchmark)
     # resume default
     set_determinism(None)
     self.assertTrue(not torch.backends.cudnn.deterministic)
     self.assertTrue(not torch.backends.cudnn.benchmark)
     self.assertTrue(get_seed() is None)
     # test seeds
     seed = 255
     set_determinism(seed=seed)
     self.assertEqual(seed, get_seed())
     a = np.random.randint(seed)
     b = torch.randint(seed, (1,))
     set_determinism(seed=seed)
     c = np.random.randint(seed)
     d = torch.randint(seed, (1,))
     self.assertEqual(a, c)
     self.assertEqual(b, d)
     self.assertTrue(torch.backends.cudnn.deterministic)
     self.assertTrue(not torch.backends.cudnn.benchmark)
     set_determinism(seed=None)
Exemple #15
0
 def prepare_data(self):
     # set deterministic training for reproducibility
     set_determinism(seed=0)
Exemple #16
0
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from unittest import skipUnless

import numpy as np
from numpy.testing import assert_array_equal
from parameterized import parameterized

from monai.data import SlidingPatchWSIDataset
from monai.utils import WSIPatchKeys, optional_import, set_determinism
from tests.utils import download_url_or_skip_test, testing_data_config

set_determinism(0)

cucim, has_cucim = optional_import("cucim")
has_cucim = has_cucim and hasattr(cucim, "CuImage")
openslide, has_osl = optional_import("openslide")
imwrite, has_tiff = optional_import("tifffile", name="imwrite")
_, has_codec = optional_import("imagecodecs")
has_tiff = has_tiff and has_codec

FILE_KEY = "wsi_img"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data",
                         "temp_" + base_name + extension)

FILE_PATH_SMALL_0 = os.path.join(os.path.dirname(__file__), "testing_data",
Exemple #17
0
 def tearDown(self):
     set_determinism(seed=None)
     if os.path.exists(self.img_name):
         os.remove(self.img_name)
     if os.path.exists(self.seg_name):
         os.remove(self.seg_name)
Exemple #18
0
 def tearDown(self):
     set_determinism(seed=None)
     os.remove(os.path.join(self.data_dir, "best_metric_model.pth"))
Exemple #19
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-d",
                        "--dir",
                        default="./testdata",
                        type=str,
                        help="directory of Brain Tumor dataset.")
    # must parse the command-line argument: ``--local_rank=LOCAL_PROCESS_RANK``, which will be provided by DDP
    parser.add_argument("--local_rank",
                        type=int,
                        help="node rank for distributed training")
    parser.add_argument("-j",
                        "--workers",
                        default=1,
                        type=int,
                        metavar="N",
                        help="number of data loading workers (default: 1)")
    parser.add_argument("--epochs",
                        default=90,
                        type=int,
                        metavar="N",
                        help="number of total epochs to run")
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument(
        "-b",
        "--batch_size",
        default=4,
        type=int,
        metavar="N",
        help="mini-batch size (default: 256), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument("-p",
                        "--print_freq",
                        default=10,
                        type=int,
                        metavar="N",
                        help="print frequency (default: 10)")
    parser.add_argument("-e",
                        "--evaluate",
                        dest="evaluate",
                        action="store_true",
                        help="evaluate model on validation set")
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training.")
    parser.add_argument("--cache_rate", type=float, default=1.0)
    parser.add_argument("--val_interval", type=int, default=5)
    parser.add_argument("--network",
                        type=str,
                        default="UNet",
                        choices=["UNet", "SegResNet"])
    parser.add_argument("--log_dir", type=str, default=None)
    args = parser.parse_args()

    if args.seed is not None:
        set_determinism(seed=args.seed)
        warnings.warn("You have chosen to seed training. "
                      "This will turn on the CUDNN deterministic setting, "
                      "which can slow down your training considerably! "
                      "You may see unexpected behavior when restarting "
                      "from checkpoints.")

    main_worker(args=args)
Exemple #20
0
    def prepare_data(self):
        data_images = sorted([
            os.path.join(data_path, x) for x in os.listdir(data_path)
            if x.startswith("data")
        ])
        data_labels = sorted([
            os.path.join(data_path, x) for x in os.listdir(data_path)
            if x.startswith("label")
        ])
        data_dicts = [{
            "image":
            image_name,
            "label":
            label_name,
            "patient":
            image_name.split("/")[-1].replace("data",
                                              "").replace(".nii.gz", ""),
        } for image_name, label_name in zip(data_images, data_labels)]
        train_files, val_files = train_val_split(data_dicts)
        print(
            f"Training patients: {len(train_files)}, Validation patients: {len(val_files)}"
        )

        set_determinism(seed=0)

        train_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=PIXDIM,
                     mode=("bilinear", "nearest")),
            DataStatsdWithPatient(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-100,
                a_max=300,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=PATCH_SIZE,
                pos=1,
                neg=1,
                num_samples=16,
                image_key="image",
                image_threshold=0,
            ),
            RandFlipd(["image", "label"], spatial_axis=[0, 1, 2], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ])
        val_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"],
                     pixdim=PIXDIM,
                     mode=("bilinear", "nearest")),
            DataStatsdWithPatient(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-100,
                a_max=300,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            StoreShaped(keys=['image']),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            ToTensord(keys=["image", "label"]),
        ])

        self.train_ds = PersistentDataset(data=train_files,
                                          transform=train_transforms,
                                          cache_dir=cache_path)
        self.val_ds = PersistentDataset(data=val_files,
                                        transform=val_transforms,
                                        cache_dir=cache_path)
Exemple #21
0
 def tearDown(self) -> None:
     set_determinism(None)
Exemple #22
0
    def train(self,
              train_info,
              valid_info,
              hyperparameters,
              run_data_check=False):

        logging.basicConfig(stream=sys.stdout, level=logging.INFO)

        if not run_data_check:
            start_dt = datetime.datetime.now()
            start_dt_string = start_dt.strftime('%d/%m/%Y %H:%M:%S')
            print(f'Training started: {start_dt_string}')

            # 1. Create folders to save the model
            timedate_info = str(
                datetime.datetime.now()).split(' ')[0] + '_' + str(
                    datetime.datetime.now().strftime("%H:%M:%S")).replace(
                        ':', '-')
            path_to_model = os.path.join(
                self.out_dir, 'trained_models',
                self.unique_name + '_' + timedate_info)
            os.mkdir(path_to_model)

        # 2. Load hyperparameters
        learning_rate = hyperparameters['learning_rate']
        weight_decay = hyperparameters['weight_decay']
        total_epoch = hyperparameters['total_epoch']
        multiplicator = hyperparameters['multiplicator']
        batch_size = hyperparameters['batch_size']
        validation_epoch = hyperparameters['validation_epoch']
        validation_interval = hyperparameters['validation_interval']
        H = hyperparameters['H']
        L = hyperparameters['L']

        # 3. Consider class imbalance
        negative, positive = 0, 0
        for _, label in train_info:
            if int(label) == 0:
                negative += 1
            elif int(label) == 1:
                positive += 1

        pos_weight = torch.Tensor([(negative / positive)]).to(self.device)

        # 4. Create train and validation loaders, batch_size = 10 for validation loader (10 central slices)

        train_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        train_info)
        valid_data = get_data_from_info(self.image_data_dir, self.seg_data_dir,
                                        valid_info)
        large_image_splitter(train_data, self.cache_dir)

        set_determinism(seed=100)
        train_trans, valid_trans = self.transformations(H, L)
        train_dataset = PersistentDataset(
            data=train_data[:],
            transform=train_trans,
            cache_dir=self.persistent_dataset_dir)
        valid_dataset = PersistentDataset(
            data=valid_data[:],
            transform=valid_trans,
            cache_dir=self.persistent_dataset_dir)

        train_loader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))
        valid_loader = DataLoader(valid_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=self.pin_memory,
                                  num_workers=self.num_workers,
                                  collate_fn=PadListDataCollate(
                                      Method.SYMMETRIC, NumpyPadMode.CONSTANT))

        # Perform data checks
        if run_data_check:
            check_data = monai.utils.misc.first(train_loader)
            print(check_data["image"].shape, check_data["label"])
            for i in range(batch_size):
                multi_slice_viewer(
                    check_data["image"][i, 0, :, :, :],
                    check_data["image_meta_dict"]["filename_or_obj"][i])
            exit()
        """c = 1
        for d in train_loader:
            img = d["image"]
            seg = d["seg"][0]
            seg, _ = nrrd.read(seg)
            img_name = d["image_meta_dict"]["filename_or_obj"][0]
            print(c, "Name:", img_name, "Size:", img.nelement()*img.element_size()/1024/1024, "MB", "shape:", img.shape)
            multi_slice_viewer(img[0, 0, :, :, :], d["image_meta_dict"]["filename_or_obj"][0])
            #multi_slice_viewer(seg, d["image_meta_dict"]["filename_or_obj"][0])
            c += 1
        exit()"""

        # 5. Prepare model
        model = ModelCT().to(self.device)

        # 6. Define loss function, optimizer and scheduler
        loss_function = torch.nn.BCEWithLogitsLoss(
            pos_weight)  # pos_weight for class imbalance
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=learning_rate,
                                     weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           multiplicator,
                                                           last_epoch=-1)
        # 7. Create post validation transforms and handlers
        path_to_tensorboard = os.path.join(self.out_dir, 'tensorboard')
        writer = SummaryWriter(log_dir=path_to_tensorboard)
        valid_post_transforms = Compose([
            Activationsd(keys="pred", sigmoid=True),
        ])
        valid_handlers = [
            StatsHandler(output_transform=lambda x: None),
            TensorBoardStatsHandler(summary_writer=writer,
                                    output_transform=lambda x: None),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={"model": model},
                            save_key_metric=True),
            MetricsSaver(save_dir=path_to_model,
                         metrics=['Valid_AUC', 'Valid_ACC']),
        ]
        # 8. Create validatior
        discrete = AsDiscrete(threshold_values=True)
        evaluator = SupervisedEvaluator(
            device=self.device,
            val_data_loader=valid_loader,
            network=model,
            post_transform=valid_post_transforms,
            key_val_metric={
                "Valid_AUC":
                ROCAUC(output_transform=lambda x: (x["pred"], x["label"]))
            },
            additional_metrics={
                "Valid_Accuracy":
                Accuracy(output_transform=lambda x:
                         (discrete(x["pred"]), x["label"]))
            },
            val_handlers=valid_handlers,
            amp=self.amp,
        )
        # 9. Create trainer

        # Loss function does the last sigmoid, so we dont need it here.
        train_post_transforms = Compose([
            # Empty
        ])
        logger = MetricLogger(evaluator=evaluator)
        train_handlers = [
            logger,
            LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
            ValidationHandlerCT(validator=evaluator,
                                start=validation_epoch,
                                interval=validation_interval,
                                epoch_level=True),
            StatsHandler(tag_name="loss",
                         output_transform=lambda x: x["loss"]),
            TensorBoardStatsHandler(summary_writer=writer,
                                    tag_name="Train_Loss",
                                    output_transform=lambda x: x["loss"]),
            CheckpointSaver(save_dir=path_to_model,
                            save_dict={
                                "model": model,
                                "opt": optimizer
                            },
                            save_interval=1,
                            n_saved=1),
        ]

        trainer = SupervisedTrainer(
            device=self.device,
            max_epochs=total_epoch,
            train_data_loader=train_loader,
            network=model,
            optimizer=optimizer,
            loss_function=loss_function,
            post_transform=train_post_transforms,
            train_handlers=train_handlers,
            amp=self.amp,
        )
        # 10. Run trainer
        trainer.run()
        # 11. Save results
        np.save(path_to_model + '/AUCS.npy',
                np.array(logger.metrics['Valid_AUC']))
        np.save(path_to_model + '/ACCS.npy',
                np.array(logger.metrics['Valid_ACC']))
        np.save(path_to_model + '/LOSSES.npy', np.array(logger.loss))
        np.save(path_to_model + '/PARAMETERS.npy', np.array(hyperparameters))

        return path_to_model
Exemple #23
0
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils import optional_import, set_determinism
from tests.utils import skip_if_downloading_fails

if TYPE_CHECKING:
    import matplotlib.pyplot as plt

    has_matplotlib = True
    has_pil = True
else:
    plt, has_matplotlib = optional_import("matplotlib.pyplot")
    _, has_pil = optional_import("PIL.Image")

RAND_SEED = 42
random.seed(RAND_SEED)
set_determinism(seed=RAND_SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"


@unittest.skipUnless(sys.platform == "linux", "requires linux")
@unittest.skipUnless(has_pil, "requires PIL")
class TestLRFinder(unittest.TestCase):
    def setUp(self):

        self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY")
        if not self.root_dir:
            self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")

        self.transforms = Compose(
            [
 def setUp(self):
     set_determinism(0)
Exemple #25
0
 def tearDown(self):
     set_determinism(seed=None)
     shutil.rmtree(self.data_dir)
 def test_value(self, input_param, input_data, expected_value):
     set_determinism(seed=0)
     transform = RandTorchVisiond(**input_param)
     result = transform(input_data)
     self.assertTrue(isinstance(transform, Randomizable))
     torch.testing.assert_allclose(result["img"], expected_value)
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
Exemple #28
0
    def __init__(
            self,
            architecture: SegmentationArchitectures = SegmentationArchitectures.ResidualUNet2D,
            loss: SegmentationLosses = SegmentationLosses.GeneralizedDiceLoss,
            optimizer: Optimizers = Optimizers.Adam,
            mask_type: MaskType = MaskType.TIFF_LABELS,
            in_channels: int = 1,
            out_channels: int = 3,
            roi_size: Tuple[int, int] = (384, 384),
            num_filters_in_first_layer: int = 16,
            learning_rate: float = 0.001,
            weight_decay: float = 0.0001,
            momentum: float = 0.9,
            num_epochs: int = 400,
            batch_sizes: Tuple[int, int, int, int] = (8, 1, 1, 1),
            num_workers: Tuple[int, int, int, int] = (4, 4, 1, 1),
            validation_step: int = 2,
            sliding_window_batch_size: int = 4,
            class_names: Tuple[str, ...] = ("Background", "Object", "Border"),
            experiment_name: str = "Unet",
            model_name: str = "best_model",
            seed: int = 4294967295,
            working_dir: str = '.',
            stdout: TextIOWrapper = sys.stdout,
            stderr: TextIOWrapper = sys.stderr
    ):
        """Constructor.

        @param mask_type: MaskType
            Type of mask: defines file type, mask geometry and they way pixels
            are assigned to the various classes.

            @see qu.data.model.MaskType

        @param architecture: SegmentationArchitectures
            Core network architecture: one of (SegmentationArchitectures.ResidualUNet2D, SegmentationArchitectures.AttentionUNet2D)

        @param loss: SegmentationLosses
            Loss function: currently only SegmentationLosses.GeneralizedDiceLoss is supported

        @param optimizer: Optimizers
            Optimizer: one of (Optimizers.Adam, Optimizers.SGD)

        @param in_channels: int, optional: default = 1
            Number of channels in the input (e.g. 1 for gray-value images).

        @param out_channels: int, optional: default = 3
            Number of channels in the output (classes).

        @param roi_size: Tuple[int, int], optional: default = (384, 384)
            Crop area (and input size of the U-Net network) used for training and validation/prediction.

        @param num_filters_in_first_layer: int
            Number of filters in the first layer. Every subsequent layer doubles the number of filters.

        @param learning_rate: float, optional: default = 1e-3
            Initial learning rate for the optimizer.

        @param weight_decay: float, optional: default = 1e-4
            Weight decay of the learning rate for the optimizer.
            Used by the Adam optimizer.

        @param momentum: float, optional: default = 0.9
            Momentum of the accelerated gradient for the optimizer.
            Used by the SGD optimizer.

        @param num_epochs: int, optional: default = 400
            Number of epochs for training.

        @param batch_sizes: Tuple[int, int, int], optional: default = (8, 1, 1, 1)
            Batch sizes for training, validation, testing, and prediction, respectively.

        @param num_workers: Tuple[int, int, int], optional: default = (4, 4, 1, 1)
            Number of workers for training, validation, testing, and prediction, respectively.

        @param validation_step: int, optional: default = 2
            Number of training steps before the next validation is performed.

        @param sliding_window_batch_size: int, optional: default = 4
            Number of batches for sliding window inference during validation and prediction.

        @param class_names: Tuple[str, ...], optional: default = ("Background", "Object", "Border")
            Name of the classes for logging validation curves.

        @param experiment_name: str, optional: default = ""
            Name of the experiment that maps to the folder that contains training information (to
            be used by tensorboard). Please note, current datetime will be appended.

        @param model_name: str, optional: default = "best_model.ph"
            Name of the file that stores the best model. Please note, current datetime will be appended
            (before the extension).

        @param seed: int, optional; default = 4294967295
            Set random seed for modules to enable or disable deterministic training.

        @param working_dir: str, optional, default = "."
            Working folder where to save the model weights and the logs for tensorboard.

        """

        # Call base constructor
        super().__init__()

        # Standard pipe wrappers
        self._stdout = stdout
        self._stderr = stderr

        # Device (initialize as "cpu")
        self._device = "cpu"

        # Architecture, loss function and optimizer
        self._option_architecture = architecture
        self._option_loss = loss
        self._option_optimizer = optimizer
        self._learning_rate = learning_rate
        self._weight_decay = weight_decay
        self._momentum = momentum

        # Mask type
        self._mask_type = mask_type

        # Input and output channels
        self._in_channels = in_channels
        self._out_channels = out_channels

        # Define hyper parameters
        self._roi_size = roi_size
        self._num_filters_in_first_layer = num_filters_in_first_layer
        self._training_batch_size = batch_sizes[0]
        self._validation_batch_size = batch_sizes[1]
        self._test_batch_size = batch_sizes[2]
        self._prediction_batch_size = batch_sizes[3]
        self._training_num_workers = num_workers[0]
        self._validation_num_workers = num_workers[1]
        self._test_num_workers = num_workers[2]
        self._prediction_num_workers = num_workers[3]
        self._n_epochs = num_epochs
        self._validation_step = validation_step
        self._sliding_window_batch_size = sliding_window_batch_size

        # Other parameters
        self._class_names = out_channels * ["Unknown"]
        for i in range(min(out_channels, len(class_names))):
            self._class_names[i] = class_names[i]

        # Set monai seed
        set_determinism(seed=seed)

        # All file names
        self._train_image_names: list = []
        self._train_mask_names: list = []
        self._validation_image_names: list = []
        self._validation_mask_names: list = []
        self._test_image_names: list = []
        self._test_mask_names: list = []

        # Transforms
        self._train_image_transforms = None
        self._train_mask_transforms = None
        self._validation_image_transforms = None
        self._validation_mask_transforms = None
        self._test_image_transforms = None
        self._test_mask_transforms = None
        self._prediction_image_transforms = None
        self._validation_post_transforms = None
        self._test_post_transforms = None
        self._prediction_post_transforms = None

        # Datasets and data loaders
        self._train_dataset = None
        self._train_dataloader = None
        self._validation_dataset = None
        self._validation_dataloader = None
        self._test_dataset = None
        self._test_dataloader = None
        self._prediction_dataset = None
        self._prediction_dataloader = None

        # Set model architecture, loss function, metric and optimizer
        self._model = None
        self._training_loss_function = None
        self._optimizer = None
        self._validation_metric = None

        # Working directory, model file name and experiment name for Tensorboard logs.
        # The file names will be redefined at the beginning of the training.
        self._working_dir = Path(working_dir).resolve()
        self._raw_experiment_name = experiment_name
        self._raw_model_file_name = model_name

        # Keep track of the full path of the best model
        self._best_model = ''

        # Keep track of last error message
        self._message = ""
Exemple #29
0
 def test_training(self):
     set_determinism(seed=0)
     loss, step = run_test(device=self.device)
     print(f"Deterministic loss {loss} at training step {step}")
     np.testing.assert_allclose(step, 4)
     np.testing.assert_allclose(loss, 0.536134, rtol=1e-4)
Exemple #30
0
def main():

    #TODO Defining file paths & output directory path
    json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json')
    data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19')
    logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467')

    if os.path.exists(logdir_path)==False:
        os.mkdir(logdir_path)

    # Load Json & Append Root Path
    with open(json_Path, 'r') as json_f:
        json_Data = json.load(json_f)

    train_Data = json_Data['training']
    val_Data = json_Data['validation']

    for idx, each_d in enumerate(train_Data):
        train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image'])

    for idx, each_d in enumerate(val_Data):
        val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image'])

    print('Total Number of Training Data Samples: {}'.format(len(train_Data)))
    print(train_Data)
    print('#' * 10)
    print('Total Number of Validation Data Samples: {}'.format(len(val_Data)))
    print(val_Data)
    print('#' * 10)

    # Set Determinism
    set_determinism(seed=123)

    # Define Training Transforms
    train_Transforms = Compose(
        [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(
            2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
        ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8)
        ]
    )

    check_ds = Dataset(data=train_Data, transform=train_Transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader)
    image = (check_data["image"][0][0])
    print(f"image shape: {image.shape}")

    # Define Network ViT backbone & Loss & Optimizer
    device = torch.device("cuda:0")
    model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed='conv',
                hidden_size=768,
                mlp_dim=3072,
    )

    model = model.to(device)

    # Define Hyper-paramters for training loop
    max_epochs = 500
    val_interval = 2
    batch_size = 4
    lr = 1e-4
    epoch_loss_values = []
    step_loss_values = []
    epoch_cl_loss_values = []
    epoch_recon_loss_values = []
    val_loss_values = []
    best_val_loss = 1000.0

    recon_loss = L1Loss()
    contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define DataLoader using MONAI, CacheDataset needs to be used
    train_ds = Dataset(data=train_Data, transform=train_Transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    val_ds = Dataset(data=val_Data, transform=train_Transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        epoch_cl_loss = 0
        epoch_recon_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            start_time = time.time()

            inputs, inputs_2, gt_input = (
                batch_data["image"].to(device),
                batch_data["image_2"].to(device),
                batch_data["gt_image"].to(device),
            )
            optimizer.zero_grad()
            outputs_v1, hidden_v1 = model(inputs)
            outputs_v2, hidden_v2 = model(inputs_2)

            flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
            flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

            r_loss = recon_loss(outputs_v1, gt_input)
            cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

            # Adjust the CL loss by Recon Loss
            total_loss = r_loss + cl_loss * r_loss

            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()
            step_loss_values.append(total_loss.item())

            # CL & Recon Loss Storage of Value
            epoch_cl_loss += cl_loss.item()
            epoch_recon_loss += r_loss.item()

            end_time = time.time()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {total_loss.item():.4f}, "
                f"time taken: {end_time-start_time}s")

        epoch_loss /= step
        epoch_cl_loss /= step
        epoch_recon_loss /= step

        epoch_loss_values.append(epoch_loss)
        epoch_cl_loss_values.append(epoch_cl_loss)
        epoch_recon_loss_values.append(epoch_recon_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if epoch % val_interval == 0:
            print('Entering Validation for epoch: {}'.format(epoch+1))
            total_val_loss = 0
            val_step = 0
            model.eval()
            for val_batch in val_loader:
                val_step += 1
                start_time = time.time()
                inputs, gt_input = (
                    val_batch["image"].to(device),
                    val_batch["gt_image"].to(device),
                )
                print('Input shape: {}'.format(inputs.shape))
                outputs, outputs_v2 = model(inputs)
                val_loss = recon_loss(outputs, gt_input)
                total_val_loss += val_loss.item()
                end_time = time.time()

            total_val_loss /= val_step
            val_loss_values.append(total_val_loss)
            print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

            if total_val_loss < best_val_loss:
                print(f"Saving new model based on validation loss {total_val_loss:.4f}")
                best_val_loss = total_val_loss
                checkpoint = {'epoch': max_epochs,
                              'state_dict': model.state_dict(),
                              'optimizer': optimizer.state_dict()
                              }
                torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt'))

            plt.figure(1, figsize=(8, 8))
            plt.subplot(2, 2, 1)
            plt.plot(epoch_loss_values)
            plt.grid()
            plt.title('Training Loss')

            plt.subplot(2, 2, 2)
            plt.plot(val_loss_values)
            plt.grid()
            plt.title('Validation Loss')

            plt.subplot(2, 2, 3)
            plt.plot(epoch_cl_loss_values)
            plt.grid()
            plt.title('Training Contrastive Loss')

            plt.subplot(2, 2, 4)
            plt.plot(epoch_recon_loss_values)
            plt.grid()
            plt.title('Training Recon Loss')

            plt.savefig(os.path.join(logdir_path, 'loss_plots.png'))
            plt.close(1)

    print('Done')
    return None