Пример #1
0
    def __init__(self, tranforms):
        self.tranform_list = []
        for tranform in tranforms:
            if 'LoadImaged' == tranform:
                self.tranform_list.append(LoadImaged(keys=["image", "label"]))
            elif 'AsChannelFirstd' == tranform:
                self.tranform_list.append(AsChannelFirstd(keys="image"))
            elif 'ConvertToMultiChannelBasedOnBratsClassesd' == tranform:
                self.tranform_list.append(ConvertToMultiChannelBasedOnBratsClassesd(keys="label"))
            elif 'Spacingd' == tranform:
                self.tranform_list.append(Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")))
            elif 'Orientationd' == tranform:
                self.tranform_list.append(Orientationd(keys=["image", "label"], axcodes="RAS"))
            elif 'CenterSpatialCropd' == tranform:
                self.tranform_list.append(CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]))
            elif 'NormalizeIntensityd' == tranform:
                self.tranform_list.append(NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True))
            elif 'ToTensord' == tranform:
                self.tranform_list.append(ToTensord(keys=["image", "label"]))
            elif 'Activations' == tranform:
                self.tranform_list.append(Activations(sigmoid=True))
            elif 'AsDiscrete' == tranform:
                self.tranform_list.append(AsDiscrete(threshold_values=True))
            else:
                raise ValueError(
                    f"Unsupported tranform: {tranform}. Please add it to support it."
                )

        super().__init__(self.tranform_list)
Пример #2
0
 def prepare_data(self):
     data_dir = self.hparams.data_dir
     
     # Train imgs/masks
     train_imgs = []
     train_masks = []
     with open(data_dir + 'train_imgs.txt', 'r') as f:
         train_imgs = [data_dir + image.rstrip() for image in f.readlines()]
     with open(data_dir + 'train_masks.txt', 'r') as f:
         train_masks = [data_dir + mask.rstrip() for mask in f.readlines()]
     train_dicts = [{'image': image, 'mask': mask} for (image, mask) in zip(train_imgs, train_masks)]
     train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)
     
     # Basic transforms
     data_keys = ["image", "mask"]
     data_transforms = Compose(
         [
             LoadNiftid(keys=data_keys),
             AddChanneld(keys=data_keys),
             ScaleIntensityRangePercentilesd(
                 keys='image',
                 lower=25,
                 upper=75,
                 b_min=-0.5,
                 b_max=0.5
             )
         ]
     )
     
     self.train_dataset = monai.data.CacheDataset(
         data=train_dicts,
         transform=Compose(
             [
                 data_transforms,
                 RandCropByPosNegLabeld(
                     keys=data_keys,
                     label_key="mask",
                     spatial_size=self.hparams.patch_size,
                     num_samples=4,
                     image_key="image",
                     pos=0.8,
                     neg=0.2
                 ),
                 ToTensord(keys=data_keys)
             ]
         ),
         cache_rate=1.0
     )
     
     self.val_dataset = monai.data.CacheDataset(
         data=val_dicts,
         transform=Compose(
             [
                 data_transforms,
                 CenterSpatialCropd(keys=data_keys, roi_size=self.hparams.patch_size),
                 ToTensord(keys=data_keys)
             ]
         ),
         cache_rate=1.0
     )
Пример #3
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)),
            CenterSpatialCropd(KEYS, (110, 99))
        ])
        num_invertible_transforms = sum(1 for i in transforms.transforms
                                        if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2, ),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value
        segs_dict = {
            "label": segs,
            label_transform_key: data[label_transform_key]
        }

        segs_dict_decollated = decollate_batch(segs_dict)

        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
    def prepare_data(self):
        data_df = pd.read_csv(
            '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv')
        train_imgs = data_df['IMAGE'][0:295].tolist()
        train_masks = data_df['SEGM'][0:295].tolist()

        train_dicts = [{
            'image': image,
            'mask': mask
        } for (image, mask) in zip(train_imgs, train_masks)]
        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.2)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            LoadNiftid(keys=data_keys),
            AddChanneld(keys=data_keys),
            ScaleIntensityRangePercentilesd(keys='image',
                                            lower=25,
                                            upper=75,
                                            b_min=-0.5,
                                            b_max=0.5)
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([
                data_transforms,
                RandCropByPosNegLabeld(keys=data_keys,
                                       label_key="mask",
                                       spatial_size=self.hparams.patch_size,
                                       num_samples=4,
                                       image_key="image",
                                       pos=0.8,
                                       neg=0.2),
                ToTensord(keys=data_keys)
            ]),
            cache_rate=1.0)

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([
                data_transforms,
                CenterSpatialCropd(keys=data_keys,
                                   roi_size=self.hparams.patch_size),
                ToTensord(keys=data_keys)
            ]),
            cache_rate=1.0)
Пример #5
0
    0,
    DivisiblePadd(KEYS, k=4),
))

TESTS.append((
    "DivisiblePadd 3d",
    "3D",
    0,
    DivisiblePadd(KEYS, k=[4, 8, 11]),
))

TESTS.append((
    "CenterSpatialCropd 2d",
    "2D",
    0,
    CenterSpatialCropd(KEYS, roi_size=95),
))

TESTS.append((
    "CenterSpatialCropd 3d",
    "3D",
    0,
    CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]),
))

TESTS.append(("CropForegroundd 2d", "2D", 0,
              CropForegroundd(KEYS, source_key="label", margin=2)))

TESTS.append(
    ("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label")))
Пример #6
0
def main_worker(args):
    # disable logging for processes except 0 on every node
    if args.local_rank != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    if not os.path.exists(args.dir):
        raise FileNotFoundError(f"Missing directory {args.dir}")

    # initialize the distributed training process, every GPU runs in a process
    dist.init_process_group(backend="nccl", init_method="env://")

    total_start = time.time()
    train_transforms = Compose([
        # load 4 Nifti images and stack them together
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(keys=["image", "label"],
                         roi_size=[128, 128, 64],
                         random_size=False),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ])

    # create a training data loader
    train_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=train_transforms,
        section="training",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=True,
    )
    train_loader = DataLoader(train_ds,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)

    # validation transforms and dataset
    val_transforms = Compose([
        LoadNiftid(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(keys=["image", "label"],
                 pixdim=(1.5, 1.5, 2.0),
                 mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ])
    val_ds = BratsCacheDataset(
        root_dir=args.dir,
        transform=val_transforms,
        section="validation",
        num_workers=4,
        cache_rate=args.cache_rate,
        shuffle=False,
    )
    val_loader = DataLoader(val_ds,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)

    if dist.get_rank() == 0:
        # Logging for TensorBoard
        writer = SummaryWriter(log_dir=args.log_dir)

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{args.local_rank}")
    if args.network == "UNet":
        model = UNet(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ).to(device)
    else:
        model = SegResNet(in_channels=4,
                          out_channels=3,
                          init_filters=16,
                          dropout_prob=0.2).to(device)
    loss_function = DiceLoss(to_onehot_y=False,
                             sigmoid=True,
                             squared_pred=True)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-5,
                                 amsgrad=True)
    # wrap the model with DistributedDataParallel module
    model = DistributedDataParallel(model, device_ids=[args.local_rank])

    # start a typical PyTorch training
    total_epoch = args.epochs
    best_metric = -1000000
    best_metric_epoch = -1
    epoch_time = AverageMeter("Time", ":6.3f")
    progress = ProgressMeter(total_epoch, [epoch_time], prefix="Epoch: ")
    end = time.time()
    print(f"Time elapsed before training: {end-total_start}")
    for epoch in range(total_epoch):

        train_loss = train(train_loader, model, loss_function, optimizer,
                           epoch, args, device)
        epoch_time.update(time.time() - end)

        if epoch % args.print_freq == 0:
            progress.display(epoch)

        if dist.get_rank() == 0:
            writer.add_scalar("Loss/train", train_loss, epoch)

        if (epoch + 1) % args.val_interval == 0:
            metric, metric_tc, metric_wt, metric_et = evaluate(
                model, val_loader, device)

            if dist.get_rank() == 0:
                writer.add_scalar("Mean Dice/val", metric, epoch)
                writer.add_scalar("Mean Dice TC/val", metric_tc, epoch)
                writer.add_scalar("Mean Dice WT/val", metric_wt, epoch)
                writer.add_scalar("Mean Dice ET/val", metric_et, epoch)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                print(
                    f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                    f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}"
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
        end = time.time()
        print(f"Time elapsed after epoch {epoch + 1} is {end - total_start}")

    if dist.get_rank() == 0:
        print(
            f"train completed, best_metric: {best_metric:.4f}  at epoch: {best_metric_epoch}"
        )
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
        writer.flush()
    dist.destroy_process_group()
Пример #7
0
    def __init__(self, 
            data_dir: Path, 
            cache_dir: Path, 
            splits: Sequence[Sequence[Dict]],
            batch_size: int,
            spacing: Sequence[float] = (1.5, 1.5, 2.0),
            crop_size: Sequence[int] = [48, 48, 36], 
            roi_size: Sequence[int] = [192, 192, 144], 
            seed: int = 47, **kwargs):
        """Module that deals with preparation of the LIDC dataset for training segmentation models.

        Args:
            data_dir (Path): Folder where preprocessed data is stored. See `LIDCReader` docs for expected structure.
            cache_dir (Path): Folder where deterministic data transformations should be cached.
            splits (Sequence[Sequence[Dict]]): Data dictionaries for training
            and validation split.
            batch_size (int): Number of training examples in each batch.
            spacing (Sequence[float]): Pixel and slice spacing. Defaults to 1.5x1.5x2mm.
            crop_size (Sequence[int]): Size of crop that is used for training. Defaults to 48x48x36px.
            roi_size (Sequence[int]): Size of crop that is used for validation. Defaults to 192x192x144px.
            seed (int, optional): Random seed used for deterministic sampling and transformations. Defaults to 47.
        """
        super().__init__()
        self.data_dir = data_dir
        self.cache_dir = cache_dir
        self.splits = splits
        self.batch_size = batch_size
        self.val_split = val_split
        self.spacing = spacing
        self.crop_size = crop_size
        self.roi_size = roi_size
        self.seed = seed
        reader = LIDCReader(data_dir)
        self.train_transforms = Compose([
            LoadImaged(keys=["image", "label"], reader=reader),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=self.spacing,
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys=["image"]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=self.crop_size,
                pos=1,
                neg=1,
                num_samples=2,
                image_key="image",
                image_threshold=0,
                ),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        self.val_transforms = Compose([
            LoadImaged(keys=["image", "label"], reader=reader),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=self.spacing,
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"], spatial_size=self.roi_size,
                        mode="constant"),
            CenterSpatialCropd(keys=["image", "label"], roi_size=self.roi_size),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        self.hparams = {
            "batch_size": self.batch_size,
            "val_split": self.val_split,
            "spacing": self.spacing,
            "crop_size": self.crop_size,
            "roi_size": self.roi_size,
        }
        return
Пример #8
0
TESTS.append(("RandSpatialCropd 2d", "2D", 0, RandSpatialCropd(KEYS, [96, 93], None, True, False)))

TESTS.append(("RandSpatialCropd 3d", "3D", 0, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False)))

TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7, 2, 5])))

TESTS.append(("BorderPadd 2d", "2D", 0, BorderPadd(KEYS, [3, 7])))

TESTS.append(("BorderPadd 3d", "3D", 0, BorderPadd(KEYS, [4])))

TESTS.append(("DivisiblePadd 2d", "2D", 0, DivisiblePadd(KEYS, k=4)))

TESTS.append(("DivisiblePadd 3d", "3D", 0, DivisiblePadd(KEYS, k=[4, 8, 11])))


TESTS.append(("CenterSpatialCropd 2d", "2D", 0, CenterSpatialCropd(KEYS, roi_size=95)))

TESTS.append(("CenterSpatialCropd 3d", "3D", 0, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98])))

TESTS.append(("CropForegroundd 2d", "2D", 0, CropForegroundd(KEYS, source_key="label", margin=2)))

TESTS.append(("CropForegroundd 3d", "3D", 0, CropForegroundd(KEYS, source_key="label", k_divisible=[5, 101, 2])))


TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, ResizeWithPadOrCropd(KEYS, [201, 150, 105])))

TESTS.append(("Flipd 3d", "3D", 0, Flipd(KEYS, [1, 2])))

TESTS.append(("RandFlipd 3d", "3D", 0, RandFlipd(KEYS, 1, [1, 2])))

TESTS.append(("RandAxisFlipd 3d", "3D", 0, RandAxisFlipd(KEYS, 1)))
Пример #9
0
 def test_value(self, input_param, input_data, expected_value):
     result = CenterSpatialCropd(**input_param)(input_data)
     assert_allclose(result["img"], expected_value, type_test=False)
Пример #10
0
 def test_shape(self, input_param, input_data, expected_shape):
     result = CenterSpatialCropd(**input_param)(input_data)
     self.assertTupleEqual(result["img"].shape, expected_shape)
Пример #11
0
    def __init__(self,
                 data_dir: Path,
                 cache_dir: Path,
                 splits: Sequence[Sequence[Dict]],
                 target: str = "malignancy",
                 min_anns: int = 3,
                 exclude_labels: Sequence[int] = [3],
                 label_mapping: Tuple[Sequence[int]] = ([1, 2, 4,
                                                         5], [0, 0, 1, 1]),
                 batch_size: int = 16,
                 spacing: Sequence[float] = (1.5, 1.5, 2.0),
                 roi_size: Sequence[int] = [40, 40, 30],
                 aug_prob: float = 0.0,
                 seed: int = 47):
        """Handles all things data related for classifying lung nodules from the LIDC-IDRI dataset. Adheres to the PyTorch Lightning DataModule interface.

        Args:
            data_dir (Path): Directory with preprocessed LIDC dataset, as outputted by `preprocess_data` script.
            cache_dir (Path): Directory where deterministic transformations of input samples will be cached.
            splits (Sequence[Dict]): Dictionaries containing metadata of training and validation sets. See `split_data` script for more information.
            target (str): Target variable, as denoted in splits dictionary. Defaults to malignancy.
            min_anns (int): Minimum number of annotations required for including nodule. Defaults to 0.
            exclude_labels (Sequence[int]): Label values to exclude in dataset.
            label_mapping (Tuple[Sequence[int]]): Label mapping for discretization.
            batch_size (int, optional): Batch size for training and validation. Defaults to 16.
            spacing (Sequence[float], optional): Pixel spacing (in mm) that inputs will be transformed into. Defaults to (1.5, 1.5, 2.0).
            roi_size (Sequence[int], optional): Shape that inputs will be transformed into. Defaults to [40, 40, 30].
            aug_prob (float): Probability of applying random data augmentation. Defaults to 0.0.
            seed (int, optional): Random seed for transformations etc. Defaults to 47.
        """
        super().__init__()
        self.data_dir = data_dir
        self.cache_dir = cache_dir
        self.splits = splits
        self.batch_size = batch_size
        self.spacing = spacing
        self.roi_size = roi_size
        self.seed = seed
        self.target = target
        self.min_anns = min_anns
        self.exclude_labels = exclude_labels
        self.label_mapping = label_mapping
        self.aug_prob = aug_prob
        self.hparams = {
            "batch_size": self.batch_size,
            "spacing": self.spacing,
            "roi_size": self.roi_size,
            "seed": self.seed,
            "target": self.target,
            "min_anns": self.min_anns,
            "exclude_labels": self.exclude_labels,
            "label_mapping": self.label_mapping,
        }
        reader = LIDCReader(self.data_dir, nodule_mode=True)
        self.train_transforms = Compose([
            LoadImaged(keys=["image"], reader=reader),
            AddChanneld(keys=["image"]),
            Spacingd(keys=["image"], pixdim=self.spacing, mode="bilinear"),
            ScaleIntensityd(keys=["image"]),
            SpatialPadd(keys=["image"],
                        spatial_size=self.roi_size,
                        mode="constant"),
            CenterSpatialCropd(keys=["image"], roi_size=self.roi_size),
            MapLabelValued(keys=["label"],
                           orig_labels=self.label_mapping[0],
                           target_labels=self.label_mapping[1]),
            RandAffined(
                keys=["image"],
                spatial_size=self.roi_size,
                prob=self.aug_prob,
                mode="bilinear",
                rotate_range=(np.pi / 18, np.pi / 18, np.pi / 4),
                scale_range=(0.1, 0.1, 0.1),
                padding_mode="border",
            ),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        self.val_transforms = Compose([
            LoadImaged(keys=["image"], reader=reader),
            AddChanneld(keys=["image"]),
            Spacingd(keys=["image"], pixdim=self.spacing, mode="bilinear"),
            ScaleIntensityd(keys=["image"]),
            SpatialPadd(keys=["image"],
                        spatial_size=self.roi_size,
                        mode="constant"),
            CenterSpatialCropd(keys=["image"], roi_size=self.roi_size),
            MapLabelValued(keys=["label"],
                           orig_labels=self.label_mapping[0],
                           target_labels=self.label_mapping[1]),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        return
    CastToTyped(keys="input", dtype=np.float32),
    NormalizeIntensityd(keys="input", nonzero=False),
    Lambdad(keys="input", func=lambda x: x.clip(-20, 20)),
    EnsureTyped(keys=("input", "mask")),
])

cfg.val_aug = Compose([
    Resized(
        keys=("input", "mask"),
        spatial_size=1120,
        size_mode="longest",
        mode="bilinear",
        align_corners=False,
    ),
    SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)),
    CenterSpatialCropd(keys=("input", "mask"),
                       roi_size=(cfg.img_size[0], cfg.img_size[1])),
    CastToTyped(keys="input", dtype=np.float32),
    NormalizeIntensityd(keys="input", nonzero=False),
    Lambdad(keys="input", func=lambda x: x.clip(-20, 20)),
    EnsureTyped(keys=("input", "mask")),
])

cfg.test_aug = Compose([
    Resized(
        keys=("input", "mask"),
        spatial_size=1120,
        size_mode="longest",
        mode="bilinear",
        align_corners=False,
    ),
    SpatialPadd(keys=("input", "mask"), spatial_size=(1120, 1120)),
Пример #13
0
 def test_value(self, input_param, input_data, expected_value):
     result = CenterSpatialCropd(**input_param)(input_data)
     np.testing.assert_allclose(result["img"], expected_value)
Пример #14
0
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
    RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
    ToTensord(keys=["image", "label"]),
])
val_transform = Compose([
    LoadImaged(keys=["image", "label"]),
    AsChannelFirstd(keys="image"),
    ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
    Spacingd(
        keys=["image", "label"],
        pixdim=(1.5, 1.5, 2.0),
        mode=("bilinear", "nearest"),
    ),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
    NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ToTensord(keys=["image", "label"]),
])
"""## Quickly load data with DecathlonDataset

Here we use `DecathlonDataset` to automatically download and extract the dataset.
It inherits MONAI `CacheDataset`, so we set `cache_num=100` to cache 100 items for training and use the defaut args to cache all the items for validation.
"""

train_ds = DecathlonDataset(
    root_dir=root_dir,
    task="Task01_BrainTumour",
    transform=train_transform,
    section="training",
    download=True,