Esempio n. 1
0
    def test_inverse(self, _, data_name, acceptable_diff, *transforms):
        name = _

        data = self.all_data[data_name]

        forwards = [data.copy()]

        # Apply forwards
        for t in transforms:
            if isinstance(t, Randomizable):
                t.set_random_state(seed=get_seed())
            forwards.append(t(forwards[-1]))

        # Check that error is thrown when inverse are used out of order.
        t = SpatialPadd("image", [10, 5])
        with self.assertRaises(RuntimeError):
            t.inverse(forwards[-1])

        # Apply inverses
        fwd_bck = forwards[-1].copy()
        for i, t in enumerate(reversed(transforms)):
            if isinstance(t, InvertibleTransform):
                fwd_bck = t.inverse(fwd_bck)
                self.check_inverse(name, data.keys(), forwards[-i - 2],
                                   fwd_bck, forwards[-1], acceptable_diff)
Esempio n. 2
0
def get_data(keys):
    """Get the example data to be used.

    Use MarsAtlas as it only contains 1 image for quick download and
    that image is parcellated.
    """
    cache_dir = os.environ.get("MONAI_DATA_DIRECTORY") or tempfile.mkdtemp()
    fname = "MarsAtlas-MNI-Colin27.zip"
    url = "https://www.dropbox.com/s/ndz8qtqblkciole/" + fname + "?dl=1"
    out_path = os.path.join(cache_dir, "MarsAtlas-MNI-Colin27")
    zip_path = os.path.join(cache_dir, fname)

    download_and_extract(url, zip_path, out_path)

    image, label = sorted(glob(os.path.join(out_path, "*.nii")))

    data = {CommonKeys.IMAGE: image, CommonKeys.LABEL: label}

    transforms = Compose([
        LoadImaged(keys),
        AddChanneld(keys),
        ScaleIntensityd(CommonKeys.IMAGE),
        Rotate90d(keys, spatial_axes=[0, 2])
    ])
    data = transforms(data)
    max_size = max(data[keys[0]].shape)
    padder = SpatialPadd(keys, (max_size, max_size, max_size))
    return padder(data)
Esempio n. 3
0
    def test_decollation(self, batch_size=2, num_workers=2):

        im = create_test_image_2d(100, 101)[0]
        data = [{
            "image": make_nifti_image(im) if has_nib else im
        } for _ in range(6)]

        transforms = Compose([
            AddChanneld("image"),
            SpatialPadd("image", 150),
            RandFlipd("image", prob=1.0, spatial_axis=1),
            ToTensord("image"),
        ])
        # If nibabel present, read from disk
        if has_nib:
            transforms = Compose([LoadImaged("image"), transforms])

        dataset = CacheDataset(data, transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        for b, batch_data in enumerate(loader):
            decollated_1 = decollate_batch(batch_data)
            decollated_2 = Decollated()(batch_data)

            for decollated in [decollated_1, decollated_2]:
                for i, d in enumerate(decollated):
                    self.check_match(dataset[b * batch_size + i], d)
Esempio n. 4
0
    def test_fail(self):

        t1 = SpatialPadd("image", [10, 5])
        data = t1(self.all_data["2D"])

        # Check that error is thrown when inverse are used out of order.
        t2 = ResizeWithPadOrCropd("image", [10, 5])
        with self.assertRaises(RuntimeError):
            t2.inverse(data)
Esempio n. 5
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)), extra_transform
        ])

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2,
                     in_channels=1,
                     out_channels=1,
                     channels=(2, 4),
                     strides=(1, )).to(device)

        data = first(loader)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        self.assertIsInstance(labels, MetaTensor)
        segs = model(labels).detach().cpu()
        segs_decollated = decollate_batch(segs)
        self.assertIsInstance(segs_decollated[0], MetaTensor)
        # inverse of individual segmentation
        seg_metatensor = first(segs_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_applied_interp_mode(seg_metatensor.applied_operations,
                                    mode="nearest",
                                    align_corners=None)

        # manually invert the last crop samples
        xform = seg_metatensor.applied_operations.pop(-1)
        shape_before_extra_xform = xform["orig_size"]
        resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)
        with resizer.trace_transform(False):
            seg_metatensor = resizer(seg_metatensor)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse({"label": seg_metatensor})["label"]
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Esempio n. 6
0
    def test_inverse_inferred_seg(self):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({
                "image": image,
                "label": label.astype(np.float32)
            })

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform != "darwin" else 0
        transforms = Compose([
            AddChanneld(KEYS),
            SpatialPadd(KEYS, (150, 153)),
            CenterSpatialCropd(KEYS, (110, 99))
        ])
        num_invertible_transforms = sum(1 for i in transforms.transforms
                                        if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(
            dimensions=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4),
            strides=(2, ),
        ).to(device)

        data = first(loader)
        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX.value
        segs_dict = {
            "label": segs,
            label_transform_key: data[label_transform_key]
        }

        segs_dict_decollated = decollate_batch(segs_dict)

        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]),
                         num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
Esempio n. 7
0
    def get_transforms(self):
        self.logger.info("Getting transforms...")
        # Setup transforms of data sets
        train_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"],
                        spatial_size=self.pad_crop_shape),
            RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            RandSpatialCropd(keys=["image", "label"],
                             roi_size=self.pad_crop_shape,
                             random_center=True,
                             random_size=False),
            ToTensord(keys=["image", "label"]),
        ])

        val_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"],
                        spatial_size=self.pad_crop_shape),
            RandSpatialCropd(
                keys=["image", "label"],
                roi_size=self.pad_crop_shape,
                random_center=True,
                random_size=False,
            ),
            ToTensord(keys=["image", "label"]),
        ])

        test_transforms = Compose([
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            NormalizeIntensityd(keys=["image"]),
            ToTensord(keys=["image", "label"]),
        ])

        return train_transforms, val_transforms, test_transforms
def get_test_transforms(end_image_shape):
    test_transforms = Compose([
        LoadNiftid(keys=["image"]),
        AddChanneld(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Winsorized(keys=["image"]),
        NormalizeIntensityd(keys=["image"]),
        ScaleIntensityd(keys=["image"]),
        SpatialPadd(keys=["image"], spatial_size=end_image_shape),
        ToTensord(keys=["image"]),
    ])
    return test_transforms
 def test_multiple(self):
     orig_states = [True, False]
     ts = [
         SpatialPadd(["image", "label"], 10, allow_missing_keys=i)
         for i in orig_states
     ]
     with allow_missing_keys_mode(ts):
         for t in ts:
             self.assertTrue(t.allow_missing_keys)
             # and that transform works even though key is missing
             _ = t(self.data)
     for t, o_s in zip(ts, orig_states):
         self.assertEqual(t.allow_missing_keys, o_s)
Esempio n. 10
0
 def test_map_transform(self):
     for amk in [True, False]:
         t = SpatialPadd(["image", "label"], 10, allow_missing_keys=amk)
         with allow_missing_keys_mode(t):
             # check state is True
             self.assertTrue(t.allow_missing_keys)
             # and that transform works even though key is missing
             _ = t(self.data)
         # check it has returned to original state
         self.assertEqual(t.allow_missing_keys, amk)
         if not amk:
             # should fail because amks==False and key is missing
             with self.assertRaises(KeyError):
                 _ = t(self.data)
Esempio n. 11
0
def get_xforms(args, mode="train", keys=("image", "label")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadNiftid(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys,
                 pixdim=(1.25, 1.25, 5.0),
                 mode=("bilinear", "nearest")[:len(keys)]),
        ScaleIntensityRanged(keys[0],
                             a_min=-1000.0,
                             a_max=500.0,
                             b_min=0.0,
                             b_max=1.0,
                             clip=True),
    ]
    if mode == "train":
        xforms.extend([
            SpatialPadd(keys,
                        spatial_size=(args.patch_size, args.patch_size, -1),
                        mode="reflect"),  # ensure at least 192x192
            RandAffined(
                keys,
                prob=0.15,
                rotate_range=(-0.05, 0.05),
                scale_range=(-0.1, 0.1),
                mode=("bilinear", "nearest"),
                as_tensor_output=False,
            ),
            RandCropByPosNegLabeld(keys,
                                   label_key=keys[1],
                                   spatial_size=(args.patch_size,
                                                 args.patch_size,
                                                 args.n_slice),
                                   num_samples=3),
            RandGaussianNoised(keys[0], prob=0.15, std=0.01),
            RandFlipd(keys, spatial_axis=0, prob=0.5),
            RandFlipd(keys, spatial_axis=1, prob=0.5),
            RandFlipd(keys, spatial_axis=2, prob=0.5),
        ])
        dtype = (np.float32, np.uint8)
    if mode == "val":
        dtype = (np.float32, np.uint8)
    if mode == "infer":
        dtype = (np.float32, )
    xforms.extend([CastToTyped(keys, dtype=dtype), ToTensord(keys)])
    return monai.transforms.Compose(xforms)
Esempio n. 12
0
    def test_inverse_inferred_seg(self, extra_transform):

        test_data = []
        for _ in range(20):
            image, label = create_test_image_2d(100, 101)
            test_data.append({"image": image, "label": label.astype(np.float32)})

        batch_size = 10
        # num workers = 0 for mac
        num_workers = 2 if sys.platform == "linux" else 0
        transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform])
        num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform))

        dataset = CacheDataset(test_data, transform=transforms, progress=False)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device)

        data = first(loader)
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES)

        labels = data["label"].to(device)
        segs = model(labels).detach().cpu()
        label_transform_key = "label" + InverseKeys.KEY_SUFFIX
        segs_dict = {"label": segs, label_transform_key: data[label_transform_key]}

        segs_dict_decollated = decollate_batch(segs_dict)
        # inverse of individual segmentation
        seg_dict = first(segs_dict_decollated)
        # test to convert interpolation mode for 1 data of model output batch
        convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None)

        with allow_missing_keys_mode(transforms):
            inv_seg = transforms.inverse(seg_dict)["label"]
        self.assertEqual(len(data["label_transforms"]), num_invertible_transforms)
        self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms)
        self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)

        # Inverse of batch
        batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True)
        with allow_missing_keys_mode(transforms):
            inv_batch = batch_inverter(segs_dict)
        self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape)
Esempio n. 13
0
 def test_compose(self):
     amks = [True, False, True]
     t = Compose([
         SpatialPadd(["image", "label"], 10, allow_missing_keys=amk)
         for amk in amks
     ])
     with allow_missing_keys_mode(t):
         # check states are all True
         for _t in t.transforms:
             self.assertTrue(_t.allow_missing_keys)
         # and that transform works even though key is missing
         _ = t(self.data)
     # check they've returned to original state
     for _t, amk in zip(t.transforms, amks):
         self.assertEqual(_t.allow_missing_keys, amk)
     # should fail because not all amks==True and key is missing
     with self.assertRaises((KeyError, RuntimeError)):
         _ = t(self.data)
Esempio n. 14
0
def get_xforms_with_synthesis(mode="synthesis", keys=("image", "label"), keys2=("image", "label", "synthetic_lesion")):
    """returns a composed transform for train/val/infer."""

    xforms = [
        LoadImaged(keys),
        AddChanneld(keys),
        Orientationd(keys, axcodes="LPS"),
        Spacingd(keys, pixdim=(1.25, 1.25, 5.0), mode=("bilinear", "nearest")[: len(keys)]),
        ScaleIntensityRanged(keys[0], a_min=-1000.0, a_max=500.0, b_min=0.0, b_max=1.0, clip=True),
        CopyItemsd(keys,1, names=['image_1', 'label_1']),
    ]
    if mode == "synthesis":
        xforms.extend([
                  SpatialPadd(keys, spatial_size=(192, 192, -1), mode="reflect"),  # ensure at least 192x192
                  RandCropByPosNegLabeld(keys, label_key=keys[1], spatial_size=(192, 192, 16), num_samples=3),
                  TransCustom(keys, path_synthesis, read_cea_aug_slice2, 
                              pseudo_healthy_with_texture, scans_syns, decreasing_sequence, GEN=15,
                              POST_PROCESS=True, mask_outer_ring=True, new_value=.5),
                  RandAffined(
                      # keys,
                      keys2,
                      prob=0.15,
                      rotate_range=(0.05, 0.05, None),  # 3 parameters control the transform on 3 dimensions
                      scale_range=(0.1, 0.1, None), 
                      mode=("bilinear", "nearest", "bilinear"),
                      # mode=("bilinear", "nearest"),
                      as_tensor_output=False
                  ),
                  RandGaussianNoised((keys2[0],keys2[2]), prob=0.15, std=0.01),
                  # RandGaussianNoised(keys[0], prob=0.15, std=0.01),
                  RandFlipd(keys, spatial_axis=0, prob=0.5),
                  RandFlipd(keys, spatial_axis=1, prob=0.5),
                  RandFlipd(keys, spatial_axis=2, prob=0.5),
                  TransCustom2(0.333)
              ])
    dtype = (np.float32, np.uint8)
    # dtype = (np.float32, np.uint8, np.float32)
    xforms.extend([CastToTyped(keys, dtype=dtype)])
    return monai.transforms.Compose(xforms)
Esempio n. 15
0
 def train_pre_transforms(self, context: Context):
     return [
         LoadImaged(keys=("image", "label"), reader="ITKReader"),
         NormalizeLabelsInDatasetd(
             keys="label",
             label_names=self._labels),  # Specially for missing labels
         EnsureChannelFirstd(keys=("image", "label")),
         Spacingd(keys=("image", "label"),
                  pixdim=self.target_spacing,
                  mode=("bilinear", "nearest")),
         CropForegroundd(keys=("image", "label"), source_key="image"),
         SpatialPadd(keys=("image", "label"),
                     spatial_size=self.spatial_size),
         ScaleIntensityRanged(keys="image",
                              a_min=-175,
                              a_max=250,
                              b_min=0.0,
                              b_max=1.0,
                              clip=True),
         RandCropByPosNegLabeld(
             keys=("image", "label"),
             label_key="label",
             spatial_size=self.spatial_size,
             pos=1,
             neg=1,
             num_samples=self.num_samples,
             image_key="image",
             image_threshold=0,
         ),
         EnsureTyped(keys=("image", "label"), device=context.device),
         RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.10),
         RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.10),
         RandRotate90d(keys=("image", "label"), prob=0.10, max_k=3),
         RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
         SelectItemsd(keys=("image", "label")),
     ]
Esempio n. 16
0
def run_training(train_file_list, valid_file_list, config_info):
    """
    Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks:
        * Data Preparation: Extract the filenames and prepare the training/validation processing transforms
        * Load Data: Load training and validation data to PyTorch DataLoader
        * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler
        * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation
            during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach
            on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric.
        * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop.
        * Run training: The MONAI trainer is run, performing training and validation during training.
    Args:
        train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format.
        valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation:
            image filename in the first column and segmentation filename in the second column.
            The two columns should be separated by a comma.
            See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format.
        config_info: dict, contains configuration parameters for sampling, network and training.
            See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields.
    """

    """
    Read input and configuration parameters
    """
    # print MONAI config information
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # print to log the parameter setups
    print(yaml.dump(config_info))

    # extract network parameters, perform checks/set defaults if not present and print them to log
    if 'seg_labels' in config_info['training'].keys():
        seg_labels = config_info['training']['seg_labels']
    else:
        seg_labels = [1]
    nr_out_channels = len(seg_labels)
    print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels))
    patch_size = config_info["training"]["inplane_size"] + [1]
    print("Considering patch size = {}".format(patch_size))

    spacing = config_info["training"]["spacing"]
    print("Bringing all images to spacing = {}".format(spacing))

    if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None:
        model_to_load = config_info['training']['model_to_load']
        if not os.path.exists(model_to_load):
            raise FileNotFoundError("Cannot find model: {}".format(model_to_load))
        else:
            print("Loading model from {}".format(model_to_load))
    else:
        model_to_load = None

    # set up either GPU or CPU usage
    if torch.cuda.is_available():
        print("\n#### GPU INFORMATION ###")
        print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name()))
        current_device = torch.device("cuda:0")
    else:
        current_device = torch.device("cpu")
        print("Using device: {}".format(current_device))

    # set determinism if required
    if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None:
        seed = config_info['training']['manual_seed']
    else:
        seed = None
    if seed is not None:
        print("Using determinism with seed = {}\n".format(seed))
        set_determinism(seed=seed)

    """
    Setup data output directory
    """
    out_model_dir = os.path.join(config_info['output']['out_dir'],
                                 datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' +
                                 config_info['output']['out_postfix'])
    print("Saving to directory {}\n".format(out_model_dir))
    # create cache directory to store results for Persistent Dataset
    if 'cache_dir' in config_info['output'].keys():
        out_cache_dir = config_info['output']['cache_dir']
    else:
        out_cache_dir = os.path.join(out_model_dir, 'persistent_cache')
    persistent_cache: Path = Path(out_cache_dir)
    persistent_cache.mkdir(parents=True, exist_ok=True)

    """
    Data preparation
    """
    # Read the input files for training and validation
    print("*** Loading input data for training...")

    train_files = create_data_list_of_dictionaries(train_file_list)
    print("Number of inputs for training = {}".format(len(train_files)))

    val_files = create_data_list_of_dictionaries(valid_file_list)
    print("Number of inputs for validation = {}".format(len(val_files)))

    # Define MONAI processing transforms for the training data. This includes:
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1]
    # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M])
    # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd,
    #       RandFlipd)
    # - ToTensor: convert to pytorch tensor
    train_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size,
                        mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False),
            SqueezeDimd(keys=["image", "label"], dim=-1),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("bilinear", "nearest"),
                align_corners=(True, None),
                prob=0.16,
            ),
            RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2,
                        keep_size=True, mode=["bilinear", "nearest"],
                        padding_mode=["zeros", "border"]),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5),
            ToTensord(keys=["image", "label"]),
        ]
    )

    # Define MONAI processing transforms for the validation data
    # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3
    # - CropForegroundd: Reduce the background from the MR image
    # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the
    #       last direction (lowest resolution) to avoid introducing motion artefact resampling errors
    # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed
    # - NormalizeIntensityd: Apply whitening
    # - ToTensor: convert to pytorch tensor
    # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference
    val_transforms = Compose(
        [
            LoadNiftid(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            InPlaneSpacingd(
                keys=["image", "label"],
                pixdim=spacing,
                mode=("bilinear", "nearest"),
            ),
            SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]),
            NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True),
            ToTensord(keys=["image", "label"]),
        ]
    )

    """
    Load data 
    """
    # create training data loader
    train_ds = PersistentDataset(data=train_files, transform=train_transforms,
                                 cache_dir=persistent_cache)
    train_loader = DataLoader(train_ds,
                              batch_size=config_info['training']['batch_size_train'],
                              shuffle=True,
                              num_workers=config_info['device']['num_workers'])
    check_train_data = misc.first(train_loader)
    print("Training data tensor shapes:")
    print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape))

    # create validation data loader
    if config_info['training']['batch_size_valid'] != 1:
        raise Exception("Batch size different from 1 at validation ar currently not supported")
    val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=config_info['device']['num_workers'])
    check_valid_data = misc.first(val_loader)
    print("Validation data tensor shapes (Example):")
    print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape))

    """
    Network preparation
    """
    print("*** Preparing the network ...")
    # automatically extracts the strides and kernels based on nnU-Net empirical rules
    spacings = spacing[:2]
    sizes = patch_size[:2]
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])

    # initialise the network
    net = DynUNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=nr_out_channels,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
        res_block=False,
    ).to(current_device)
    print(net)

    # define the loss function
    loss_function = choose_loss_function(nr_out_channels, config_info)

    # define the optimiser and the learning rate scheduler
    opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9
    )

    """
    MONAI evaluator
    """
    print("*** Preparing the dynUNet evaluator engine...\n")
    # val_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )
    val_handlers = [
        StatsHandler(output_transform=lambda x: None),
        TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                output_transform=lambda x: None,
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True,
                        file_prefix='best_valid'),
    ]
    if config_info['output']['val_image_to_tensorboad']:
        val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"),
                                                    batch_transform=lambda x: (x["image"], x["label"]),
                                                    output_transform=lambda x: x["pred"], interval=2))

    # Define customized evaluator
    class DynUNetEvaluator(SupervisedEvaluator):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)
            flip_inputs_1 = torch.flip(inputs, dims=(2,))
            flip_inputs_2 = torch.flip(inputs, dims=(3,))
            flip_inputs_3 = torch.flip(inputs, dims=(2, 3))

            def _compute_pred():
                pred = self.inferer(inputs, self.network)
                # use random flipping as data augmentation at inference
                flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,))
                flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,))
                flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3))
                return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4

            # execute forward computation
            self.network.eval()
            with torch.no_grad():
                if self.amp:
                    with torch.cuda.amp.autocast():
                        predictions = _compute_pred()
                else:
                    predictions = _compute_pred()
            return {"image": inputs, "label": targets, "pred": predictions}

    evaluator = DynUNetEvaluator(
        device=current_device,
        val_data_loader=val_loader,
        network=net,
        inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0),
        post_transform=None,
        key_val_metric={
            "Mean_dice": MeanDice(
                include_background=False,
                to_onehot_y=True,
                mutually_exclusive=True,
                output_transform=lambda x: (x["pred"], x["label"]),
            )
        },
        val_handlers=val_handlers,
        amp=False,
    )

    """
    MONAI trainer
    """
    print("*** Preparing the dynUNet trainer engine...\n")
    # train_post_transforms = Compose(
    #     [
    #         Activationsd(keys="pred", sigmoid=True),
    #     ]
    # )

    validation_every_n_epochs = config_info['training']['validation_every_n_epochs']
    epoch_len = len(train_ds) // train_loader.batch_size
    validation_every_n_iters = validation_every_n_epochs * epoch_len

    # define event handlers for the trainer
    writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train"))
    train_handlers = [
        LrScheduleHandler(lr_scheduler=scheduler, print_lr=True),
        ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False),
        StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
        TensorBoardStatsHandler(summary_writer=writer_train,
                                log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss",
                                output_transform=lambda x: x["loss"],
                                global_epoch_transform=lambda x: trainer.state.iteration),
        CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt},
                        save_final=True,
                        save_interval=2, epoch_level=True,
                        n_saved=config_info['output']['max_nr_models_saved']),
    ]
    if model_to_load is not None:
        train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt}))

    # define customized trainer
    class DynUNetTrainer(SupervisedTrainer):
        def _iteration(self, engine, batchdata):
            inputs, targets = self.prepare_batch(batchdata)
            inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device)

            def _compute_loss(preds, label):
                labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]]
                return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))])

            self.network.train()
            self.optimizer.zero_grad()
            if self.amp and self.scaler is not None:
                with torch.cuda.amp.autocast():
                    predictions = self.inferer(inputs, self.network)
                    loss = _compute_loss(predictions, targets)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.inferer(inputs, self.network)
                loss = _compute_loss(predictions, targets).mean()
                loss.backward()
                self.optimizer.step()
            return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()}

    trainer = DynUNetTrainer(
        device=current_device,
        max_epochs=config_info['training']['nr_train_epochs'],
        train_data_loader=train_loader,
        network=net,
        optimizer=opt,
        loss_function=loss_function,
        inferer=SimpleInferer(),
        post_transform=None,
        key_train_metric=None,
        train_handlers=train_handlers,
        amp=False,
    )

    """
    Run training
    """
    print("*** Run training...")
    trainer.run()
    print("Done!")
Esempio n. 17
0
def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num,
                        num_samples):
    if mode != "test":
        keys = ["image", "label"]
    else:
        keys = ["image"]

    load_transforms = [
        LoadImaged(keys=keys),
        EnsureChannelFirstd(keys=keys),
    ]
    # 2. sampling
    sample_transforms = [
        PreprocessAnisotropic(
            keys=keys,
            clip_values=clip_values[task_id],
            pixdim=spacing[task_id],
            normalize_values=normalize_values[task_id],
            model_mode=mode,
        ),
    ]
    # 3. spatial transforms
    if mode == "train":
        other_transforms = [
            SpatialPadd(keys=["image", "label"],
                        spatial_size=patch_size[task_id]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=patch_size[task_id],
                pos=pos_sample_num,
                neg=neg_sample_num,
                num_samples=num_samples,
                image_key="image",
                image_threshold=0,
            ),
            RandZoomd(
                keys=["image", "label"],
                min_zoom=0.9,
                max_zoom=1.2,
                mode=("trilinear", "nearest"),
                align_corners=(True, None),
                prob=0.15,
            ),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandFlipd(["image", "label"], spatial_axis=[0], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[1], prob=0.5),
            RandFlipd(["image", "label"], spatial_axis=[2], prob=0.5),
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    elif mode == "validation":
        other_transforms = [
            CastToTyped(keys=["image", "label"], dtype=(np.float32, np.uint8)),
            EnsureTyped(keys=["image", "label"]),
        ]
    else:
        other_transforms = [
            CastToTyped(keys=["image"], dtype=(np.float32)),
            EnsureTyped(keys=["image"]),
        ]

    all_transforms = load_transforms + sample_transforms + other_transforms
    return Compose(all_transforms)
Esempio n. 18
0
        partial(CenterSpatialCropd, roi_size=-3),
        partial(RandSpatialCropd, roi_size=-3),
        partial(SpatialPadd, spatial_size=15),
        partial(BorderPadd, spatial_border=[15, 16]),
        partial(CenterSpatialCropd, roi_size=30),
        partial(SpatialCropd, roi_center=10, roi_size=100),
        partial(SpatialCropd, roi_start=3, roi_end=100),
):
    TESTS.append((t.func.__name__ + "bad 1D even", "1D even", 0,
                  t(KEYS)))  # type: ignore

TESTS.append((
    "SpatialPadd (x2) 2d",
    "2D",
    0,
    SpatialPadd(KEYS, spatial_size=[111, 113], method="end"),
    SpatialPadd(KEYS, spatial_size=[118, 117]),
))

TESTS.append((
    "SpatialPadd 3d",
    "3D",
    0,
    SpatialPadd(KEYS, spatial_size=[112, 113, 116]),
))

TESTS.append((
    "SpatialCropd 2d",
    "2D",
    0,
    SpatialCropd(KEYS, [49, 51], [90, 89]),
Esempio n. 19
0
    partial(CenterSpatialCropd, roi_size=-3),
    partial(RandSpatialCropd, roi_size=-3),
    partial(SpatialPadd, spatial_size=15),
    partial(BorderPadd, spatial_border=[15, 16]),
    partial(CenterSpatialCropd, roi_size=30),
    partial(SpatialCropd, roi_center=10, roi_size=100),
    partial(SpatialCropd, roi_start=3, roi_end=100),
):
    TESTS.append((t.func.__name__ + "bad 1D even", "1D even", 0, t(KEYS)))  # type: ignore

TESTS.append(
    (
        "SpatialPadd (x2) 2d",
        "2D",
        0,
        SpatialPadd(KEYS, spatial_size=[111, 113], method="end"),
        SpatialPadd(KEYS, spatial_size=[118, 117]),
    )
)

TESTS.append(("SpatialPadd 3d", "3D", 0, SpatialPadd(KEYS, spatial_size=[112, 113, 116])))

TESTS.append(("SpatialCropd 2d", "2D", 0, SpatialCropd(KEYS, [49, 51], [90, 89])))

TESTS.append(
    (
        "SpatialCropd 3d",
        "3D",
        0,
        SpatialCropd(KEYS, roi_slices=[slice(s, e) for s, e in zip([None, None, -99], [None, -2, None])]),
    )
Esempio n. 20
0
 def test_pad_shape(self, input_param, input_data, expected_val):
     padder = SpatialPadd(**input_param)
     result = padder(input_data)
     self.assertAlmostEqual(result['img'].shape, expected_val.shape)
Esempio n. 21
0
def main():

    #TODO Defining file paths & output directory path
    json_Path = os.path.normpath('/scratch/data_2021/tcia_covid19/dataset_split_debug.json')
    data_Root = os.path.normpath('/scratch/data_2021/tcia_covid19')
    logdir_path = os.path.normpath('/home/vishwesh/monai_tutorial_testing/issue_467')

    if os.path.exists(logdir_path)==False:
        os.mkdir(logdir_path)

    # Load Json & Append Root Path
    with open(json_Path, 'r') as json_f:
        json_Data = json.load(json_f)

    train_Data = json_Data['training']
    val_Data = json_Data['validation']

    for idx, each_d in enumerate(train_Data):
        train_Data[idx]['image'] = os.path.join(data_Root, train_Data[idx]['image'])

    for idx, each_d in enumerate(val_Data):
        val_Data[idx]['image'] = os.path.join(data_Root, val_Data[idx]['image'])

    print('Total Number of Training Data Samples: {}'.format(len(train_Data)))
    print(train_Data)
    print('#' * 10)
    print('Total Number of Validation Data Samples: {}'.format(len(val_Data)))
    print(val_Data)
    print('#' * 10)

    # Set Determinism
    set_determinism(seed=123)

    # Define Training Transforms
    train_Transforms = Compose(
        [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Spacingd(keys=["image"], pixdim=(
            2.0, 2.0, 2.0), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        SpatialPadd(keys=["image"], spatial_size=(96, 96, 96)),
        RandSpatialCropSamplesd(keys=["image"], roi_size=(96, 96, 96), random_size=False, num_samples=2),
        CopyItemsd(keys=["image"], times=2, names=["gt_image", "image_2"], allow_missing_keys=False),
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
            ]
        ),
        RandCoarseShuffled(keys=["image"], prob=0.8, holes=10, spatial_size=8),
        # Please note that that if image, image_2 are called via the same transform call because of the determinism
        # they will get augmented the exact same way which is not the required case here, hence two calls are made
        OneOf(transforms=[
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=5, dropout_holes=True,
                               max_spatial_size=32),
            RandCoarseDropoutd(keys=["image_2"], prob=1.0, holes=6, spatial_size=20, dropout_holes=False,
                               max_spatial_size=64),
        ]
        ),
        RandCoarseShuffled(keys=["image_2"], prob=0.8, holes=10, spatial_size=8)
        ]
    )

    check_ds = Dataset(data=train_Data, transform=train_Transforms)
    check_loader = DataLoader(check_ds, batch_size=1)
    check_data = first(check_loader)
    image = (check_data["image"][0][0])
    print(f"image shape: {image.shape}")

    # Define Network ViT backbone & Loss & Optimizer
    device = torch.device("cuda:0")
    model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed='conv',
                hidden_size=768,
                mlp_dim=3072,
    )

    model = model.to(device)

    # Define Hyper-paramters for training loop
    max_epochs = 500
    val_interval = 2
    batch_size = 4
    lr = 1e-4
    epoch_loss_values = []
    step_loss_values = []
    epoch_cl_loss_values = []
    epoch_recon_loss_values = []
    val_loss_values = []
    best_val_loss = 1000.0

    recon_loss = L1Loss()
    contrastive_loss = ContrastiveLoss(batch_size=batch_size*2, temperature=0.05)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define DataLoader using MONAI, CacheDataset needs to be used
    train_ds = Dataset(data=train_Data, transform=train_Transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    val_ds = Dataset(data=val_Data, transform=train_Transforms)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=4)

    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        epoch_cl_loss = 0
        epoch_recon_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            start_time = time.time()

            inputs, inputs_2, gt_input = (
                batch_data["image"].to(device),
                batch_data["image_2"].to(device),
                batch_data["gt_image"].to(device),
            )
            optimizer.zero_grad()
            outputs_v1, hidden_v1 = model(inputs)
            outputs_v2, hidden_v2 = model(inputs_2)

            flat_out_v1 = outputs_v1.flatten(start_dim=1, end_dim=4)
            flat_out_v2 = outputs_v2.flatten(start_dim=1, end_dim=4)

            r_loss = recon_loss(outputs_v1, gt_input)
            cl_loss = contrastive_loss(flat_out_v1, flat_out_v2)

            # Adjust the CL loss by Recon Loss
            total_loss = r_loss + cl_loss * r_loss

            total_loss.backward()
            optimizer.step()
            epoch_loss += total_loss.item()
            step_loss_values.append(total_loss.item())

            # CL & Recon Loss Storage of Value
            epoch_cl_loss += cl_loss.item()
            epoch_recon_loss += r_loss.item()

            end_time = time.time()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {total_loss.item():.4f}, "
                f"time taken: {end_time-start_time}s")

        epoch_loss /= step
        epoch_cl_loss /= step
        epoch_recon_loss /= step

        epoch_loss_values.append(epoch_loss)
        epoch_cl_loss_values.append(epoch_cl_loss)
        epoch_recon_loss_values.append(epoch_recon_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if epoch % val_interval == 0:
            print('Entering Validation for epoch: {}'.format(epoch+1))
            total_val_loss = 0
            val_step = 0
            model.eval()
            for val_batch in val_loader:
                val_step += 1
                start_time = time.time()
                inputs, gt_input = (
                    val_batch["image"].to(device),
                    val_batch["gt_image"].to(device),
                )
                print('Input shape: {}'.format(inputs.shape))
                outputs, outputs_v2 = model(inputs)
                val_loss = recon_loss(outputs, gt_input)
                total_val_loss += val_loss.item()
                end_time = time.time()

            total_val_loss /= val_step
            val_loss_values.append(total_val_loss)
            print(f"epoch {epoch + 1} Validation average loss: {total_val_loss:.4f}, " f"time taken: {end_time-start_time}s")

            if total_val_loss < best_val_loss:
                print(f"Saving new model based on validation loss {total_val_loss:.4f}")
                best_val_loss = total_val_loss
                checkpoint = {'epoch': max_epochs,
                              'state_dict': model.state_dict(),
                              'optimizer': optimizer.state_dict()
                              }
                torch.save(checkpoint, os.path.join(logdir_path, 'best_model.pt'))

            plt.figure(1, figsize=(8, 8))
            plt.subplot(2, 2, 1)
            plt.plot(epoch_loss_values)
            plt.grid()
            plt.title('Training Loss')

            plt.subplot(2, 2, 2)
            plt.plot(val_loss_values)
            plt.grid()
            plt.title('Validation Loss')

            plt.subplot(2, 2, 3)
            plt.plot(epoch_cl_loss_values)
            plt.grid()
            plt.title('Training Contrastive Loss')

            plt.subplot(2, 2, 4)
            plt.plot(epoch_recon_loss_values)
            plt.grid()
            plt.title('Training Recon Loss')

            plt.savefig(os.path.join(logdir_path, 'loss_plots.png'))
            plt.close(1)

    print('Done')
    return None
Esempio n. 22
0
    ToTensor,
    ToTensord,
)
from monai.transforms.inverse_batch_transform import Decollated
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS_DICT: List[Tuple] = []
TESTS_DICT.append(
    (SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1), ))
TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10), ))

TESTS_LIST: List[Tuple] = []
TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1), ))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10), ))

TEST_BASIC = [
    [("channel", "channel"), ["channel", "channel"]],
    [
        torch.Tensor([1, 2, 3]),
        [torch.tensor(1.0),
         torch.tensor(2.0),
         torch.tensor(3.0)]
Esempio n. 23
0
    SpatialPadd,
    ToTensor,
    ToTensord,
)
from monai.transforms.post.dictionary import Decollated
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS_DICT: List[Tuple] = []
TESTS_DICT.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0, spatial_axis=1)))
TESTS_DICT.append((RandRotate90d(KEYS, prob=0.0, max_k=1),))
TESTS_DICT.append((RandAffined(KEYS, prob=0.0, translate_range=10),))

TESTS_LIST: List[Tuple] = []
TESTS_LIST.append((SpatialPad(150), RandFlip(prob=1.0, spatial_axis=1)))
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))


TEST_BASIC = [
    [("channel", "channel"), ["channel", "channel"]],
    [torch.Tensor([1, 2, 3]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],
    [
        [[torch.Tensor((1.0, 2.0, 3.0)), torch.Tensor((2.0, 3.0, 1.0))]],
        [
Esempio n. 24
0
 def test_pad_shape(self, input_param, input_data, expected_val):
     padder = SpatialPadd(**input_param)
     result = padder(input_data)
     np.testing.assert_allclose(result["img"].shape, expected_val.shape)
Esempio n. 25
0
    def __init__(self, 
            data_dir: Path, 
            cache_dir: Path, 
            splits: Sequence[Sequence[Dict]],
            batch_size: int,
            spacing: Sequence[float] = (1.5, 1.5, 2.0),
            crop_size: Sequence[int] = [48, 48, 36], 
            roi_size: Sequence[int] = [192, 192, 144], 
            seed: int = 47, **kwargs):
        """Module that deals with preparation of the LIDC dataset for training segmentation models.

        Args:
            data_dir (Path): Folder where preprocessed data is stored. See `LIDCReader` docs for expected structure.
            cache_dir (Path): Folder where deterministic data transformations should be cached.
            splits (Sequence[Sequence[Dict]]): Data dictionaries for training
            and validation split.
            batch_size (int): Number of training examples in each batch.
            spacing (Sequence[float]): Pixel and slice spacing. Defaults to 1.5x1.5x2mm.
            crop_size (Sequence[int]): Size of crop that is used for training. Defaults to 48x48x36px.
            roi_size (Sequence[int]): Size of crop that is used for validation. Defaults to 192x192x144px.
            seed (int, optional): Random seed used for deterministic sampling and transformations. Defaults to 47.
        """
        super().__init__()
        self.data_dir = data_dir
        self.cache_dir = cache_dir
        self.splits = splits
        self.batch_size = batch_size
        self.val_split = val_split
        self.spacing = spacing
        self.crop_size = crop_size
        self.roi_size = roi_size
        self.seed = seed
        reader = LIDCReader(data_dir)
        self.train_transforms = Compose([
            LoadImaged(keys=["image", "label"], reader=reader),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=self.spacing,
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys=["image"]),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=self.crop_size,
                pos=1,
                neg=1,
                num_samples=2,
                image_key="image",
                image_threshold=0,
                ),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        self.val_transforms = Compose([
            LoadImaged(keys=["image", "label"], reader=reader),
            AddChanneld(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=self.spacing,
                     mode=("bilinear", "nearest")),
            ScaleIntensityd(keys=["image"]),
            SpatialPadd(keys=["image", "label"], spatial_size=self.roi_size,
                        mode="constant"),
            CenterSpatialCropd(keys=["image", "label"], roi_size=self.roi_size),
            ToTensord(keys=["image", "label"]),
            SelectItemsd(keys=["image", "label"]),
        ])
        self.hparams = {
            "batch_size": self.batch_size,
            "val_split": self.val_split,
            "spacing": self.spacing,
            "crop_size": self.crop_size,
            "roi_size": self.roi_size,
        }
        return
Esempio n. 26
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)
        LoadImaged(keys=['image', 'label']),
        AddChanneld(keys=['image', 'label']),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
        # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
        CropForegroundd(
            keys=['image', 'label'],
            source_key='image',
            start_coord_key='foreground_start_coord',
            end_coord_key='foreground_end_coord',
        ),  # crop CropForeground
        NormalizeIntensityd(keys=['image']),
        ScaleIntensityd(keys=['image']),
        # Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),
        SpatialPadd(keys=['image', 'label'],
                    spatial_size=opt.patch_size,
                    method='end'),
        RandSpatialCropd(keys=['image', 'label'],
                         roi_size=opt.patch_size,
                         random_size=False),
        ToTensord(keys=[
            'image', 'label', 'foreground_start_coord', 'foreground_end_coord'
        ], )
    ]

    transform = Compose(monai_transforms)
    check_ds = monai.data.Dataset(data=data_dicts, transform=transform)
    loader = DataLoader(check_ds,
                        batch_size=1,
                        shuffle=True,
                        num_workers=0,
Esempio n. 28
0
File: test.py Progetto: ckbr0/RIS
def main(train_output):
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    print_config()

    # Setup directories
    dirs = setup_directories()

    # Setup torch device
    device, using_gpu = create_device("cuda")

    # Load and randomize images

    # HACKATON image and segmentation data
    hackathon_dir = os.path.join(dirs["data"], 'HACKATHON')
    map_fn = lambda x: (x[0], int(x[1]))
    with open(os.path.join(hackathon_dir, "train.txt"), 'r') as fp:
        train_info_hackathon = [
            map_fn(entry.strip().split(',')) for entry in fp.readlines()
        ]
    image_dir = os.path.join(hackathon_dir, 'images', 'train')
    seg_dir = os.path.join(hackathon_dir, 'segmentations', 'train')
    _train_data_hackathon = get_data_from_info(image_dir,
                                               seg_dir,
                                               train_info_hackathon,
                                               dual_output=False)
    large_image_splitter(_train_data_hackathon, dirs["cache"])

    balance_training_data(_train_data_hackathon, seed=72)

    # PSUF data
    """psuf_dir = os.path.join(dirs["data"], 'psuf')
    with open(os.path.join(psuf_dir, "train.txt"), 'r') as fp:
        train_info = [entry.strip().split(',') for entry in fp.readlines()]
    image_dir = os.path.join(psuf_dir, 'images')
    train_data_psuf = get_data_from_info(image_dir, None, train_info)"""
    # Split data into train, validate and test
    train_split, test_data_hackathon = train_test_split(_train_data_hackathon,
                                                        test_size=0.2,
                                                        shuffle=True,
                                                        random_state=42)
    #train_data_hackathon, valid_data_hackathon = train_test_split(train_split, test_size=0.2, shuffle=True, random_state=43)
    # Setup transforms

    # Crop foreground
    crop_foreground = CropForegroundd(
        keys=["image"],
        source_key="image",
        margin=(5, 5, 0),
        #select_fn = lambda x: x != 0
    )
    # Crop Z
    crop_z = RelativeCropZd(keys=["image"], relative_z_roi=(0.07, 0.12))
    # Window width and level (window center)
    WW, WL = 1500, -600
    ct_window = CTWindowd(keys=["image"], width=WW, level=WL)
    spatial_pad = SpatialPadd(keys=["image"], spatial_size=(-1, -1, 30))
    resize = Resized(keys=["image"],
                     spatial_size=(int(512 * 0.50), int(512 * 0.50), -1),
                     mode="trilinear")

    # Create transforms
    common_transform = Compose([
        LoadImaged(keys=["image"]),
        ct_window,
        CTSegmentation(keys=["image"]),
        AddChanneld(keys=["image"]),
        resize,
        crop_foreground,
        crop_z,
        spatial_pad,
    ])
    hackathon_train_transfrom = Compose([
        common_transform,
        ToTensord(keys=["image"]),
    ]).flatten()
    psuf_transforms = Compose([
        LoadImaged(keys=["image"]),
        AddChanneld(keys=["image"]),
        ToTensord(keys=["image"]),
    ])

    # Setup data
    #set_determinism(seed=100)
    test_dataset = PersistentDataset(data=test_data_hackathon[:],
                                     transform=hackathon_train_transfrom,
                                     cache_dir=dirs["persistent"])
    test_loader = DataLoader(test_dataset,
                             batch_size=2,
                             shuffle=True,
                             pin_memory=using_gpu,
                             num_workers=1,
                             collate_fn=PadListDataCollate(
                                 Method.SYMMETRIC, NumpyPadMode.CONSTANT))

    # Setup network, loss function, optimizer and scheduler
    network = nets.DenseNet121(spatial_dims=3, in_channels=1,
                               out_channels=1).to(device)

    # Setup validator and trainer
    valid_post_transforms = Compose([
        Activationsd(keys="pred", sigmoid=True),
    ])

    # Setup tester
    tester = Tester(device=device,
                    test_data_loader=test_loader,
                    load_dir=train_output,
                    out_dir=dirs["out"],
                    network=network,
                    post_transform=valid_post_transforms,
                    non_blocking=using_gpu,
                    amp=using_gpu)

    # Run tester
    tester.run()
Esempio n. 29
0
from monai.data import CacheDataset, DataLoader, create_test_image_2d
from monai.data.utils import decollate_batch
from monai.transforms import AddChanneld, Compose, LoadImaged, RandFlipd, SpatialPadd, ToTensord
from monai.transforms.post.dictionary import Decollated
from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d
from monai.utils import optional_import, set_determinism
from monai.utils.enums import InverseKeys
from tests.utils import make_nifti_image

_, has_nib = optional_import("nibabel")

KEYS = ["image"]

TESTS: List[Tuple] = []
TESTS.append((SpatialPadd(KEYS, 150), RandFlipd(KEYS, prob=1.0,
                                                spatial_axis=1)))
TESTS.append((RandRotate90d(KEYS, prob=0.0, max_k=1), ))
TESTS.append((RandAffined(KEYS, prob=0.0, translate_range=10), ))


class TestDeCollate(unittest.TestCase):
    def setUp(self) -> None:
        set_determinism(seed=0)

        im = create_test_image_2d(100, 101)[0]
        self.data = [{
            "image": make_nifti_image(im) if has_nib else im
        } for _ in range(6)]

    def tearDown(self) -> None:
Esempio n. 30
0
def segment(image, label, result, weights, resolution, patch_size, network,
            gpu_ids):

    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    if label is not None:
        uniform_img_dimensions_internal(image, label, True)
        files = [{"image": image, "label": label}]
    else:
        files = [{"image": image}]

    # original size, size after crop_background, cropped roi coordinates, cropped resampled roi size
    original_shape, crop_shape, coord1, coord2, resampled_size, original_resolution = statistics_crop(
        image, resolution)

    # -------------------------------

    if label is not None:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image', 'label'],
                         pixdim=resolution,
                         mode=('bilinear', 'nearest')),  # resolution
                SpatialPadd(keys=['image', 'label'],
                            spatial_size=patch_size,
                            method='end'),
                ToTensord(keys=['image', 'label'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image', 'label']),
                AddChanneld(keys=['image', 'label']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image', 'label'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image', 'label'],
                    spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image', 'label'])
            ])

    else:
        if resolution is not None:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                Spacingd(keys=['image'], pixdim=resolution,
                         mode=('bilinear')),  # resolution
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])
        else:

            val_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # Threshold CT
                # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
                CropForegroundd(keys=['image'],
                                source_key='image'),  # crop CropForeground
                NormalizeIntensityd(keys=['image']),  # intensity
                ScaleIntensityd(keys=['image']),
                SpatialPadd(
                    keys=['image'], spatial_size=patch_size,
                    method='end'),  # pad if the image is smaller than patch
                ToTensord(keys=['image'])
            ])

    val_ds = monai.data.Dataset(data=files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=0,
                            collate_fn=list_data_collate,
                            pin_memory=False)

    dice_metric = DiceMetric(include_background=True,
                             reduction="mean",
                             get_not_nans=False)
    post_trans = Compose([
        EnsureType(),
        Activations(sigmoid=True),
        AsDiscrete(threshold_values=True)
    ])

    if gpu_ids != '-1':

        # try to use all the available GPUs
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ids
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        device = torch.device("cpu")

    # build the network
    if network == 'nnunet':
        net = build_net()  # nn build_net
    elif network == 'unetr':
        net = build_UNETR()  # UneTR

    net = net.to(device)

    if gpu_ids == '-1':

        net.load_state_dict(new_state_dict_cpu(weights))

    else:

        net.load_state_dict(new_state_dict(weights))

    # define sliding window size and batch size for windows inference
    roi_size = patch_size
    sw_batch_size = 4

    net.eval()
    with torch.no_grad():

        if label is None:
            for val_data in val_loader:
                val_images = val_data["image"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]

        else:
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].to(
                    device), val_data["label"].to(device)
                val_outputs = sliding_window_inference(val_images, roi_size,
                                                       sw_batch_size, net)
                val_outputs = [
                    post_trans(i) for i in decollate_batch(val_outputs)
                ]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            print("Evaluation Metric (Dice):", metric)

        result_array = val_outputs[0].squeeze().data.cpu().numpy()
        # Remove the pad if the image was smaller than the patch in some directions
        result_array = result_array[0:resampled_size[0], 0:resampled_size[1],
                                    0:resampled_size[2]]

        # resample back to the original resolution
        if resolution is not None:

            result_array_np = np.transpose(result_array, (2, 1, 0))
            result_array_temp = sitk.GetImageFromArray(result_array_np)
            result_array_temp.SetSpacing(resolution)

            # save temporary label
            writer = sitk.ImageFileWriter()
            writer.SetFileName('temp_seg.nii')
            writer.Execute(result_array_temp)

            files = [{"image": 'temp_seg.nii'}]

            files_transforms = Compose([
                LoadImaged(keys=['image']),
                AddChanneld(keys=['image']),
                Spacingd(keys=['image'],
                         pixdim=original_resolution,
                         mode=('nearest')),
                Resized(keys=['image'],
                        spatial_size=crop_shape,
                        mode=('nearest')),
            ])

            files_ds = Dataset(data=files, transform=files_transforms)
            files_loader = DataLoader(files_ds, batch_size=1, num_workers=0)

            for files_data in files_loader:
                files_images = files_data["image"]

                res = files_images.squeeze().data.numpy()

            result_array = np.rint(res)

            os.remove('./temp_seg.nii')

        # recover the cropped background before saving the image
        empty_array = np.zeros(original_shape)
        empty_array[coord1[0]:coord2[0], coord1[1]:coord2[1],
                    coord1[2]:coord2[2]] = result_array

        result_seg = from_numpy_to_itk(empty_array, image)

        # save label
        writer = sitk.ImageFileWriter()
        writer.SetFileName(result)
        writer.Execute(result_seg)
        print("Saved Result at:", str(result))