예제 #1
0
def engine_apply_transform(batch: Any, output: Any, transform: Callable[...,
                                                                        Dict]):
    """
    Apply transform on `batch` and `output`.
    If `batch` and `output` are dictionaries, temporarily combine them for the transform,
    otherwise, apply the transform for `output` data only.

    """
    if isinstance(batch, dict) and isinstance(output, dict):
        data = dict(batch)
        data.update(output)
        transformed_data = apply_transform(transform, data)

        if not isinstance(transformed_data, dict):
            raise AssertionError(
                "With a dict supplied to apply_transform a single dict return is expected."
            )

        for k, v in transformed_data.items():
            # split the output data of post transforms into `output` and `batch`,
            # `batch` should be read-only, so save the generated key-value into `output`
            if k in output or k not in batch:
                output[k] = v
            else:
                batch[k] = v
    else:
        output = apply_transform(transform, output)

    return batch, output
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        iter_start, iter_end = 0, 1
        try:
            iter_end = len(self.dataset)  # TODO: support iterable self.dataset
        except TypeError:
            raise NotImplementedError(
                "image dataset must implement `len()`.") from None

        if worker_info is not None:
            # split workload
            per_worker = int(
                np.ceil(
                    (iter_end - iter_start) / float(worker_info.num_workers)))
            iter_start += worker_info.id * per_worker
            iter_end = min(iter_start + per_worker, iter_end)

        for index in range(iter_start, iter_end):
            image = self.dataset[index]
            if not self.with_coordinates:
                for patch, *_ in self.patch_iter(
                        image):  # patch_iter to yield at least 1 item: patch
                    out_patch = (patch if self.transform is None else
                                 apply_transform(
                                     self.transform, patch, map_items=False))
                    yield out_patch
            else:
                for patch, slices, *_ in self.patch_iter(
                        image
                ):  # patch_iter to yield at least 2 items: patch, coords
                    out_patch = (patch if self.transform is None else
                                 apply_transform(
                                     self.transform, patch, map_items=False))
                    yield out_patch, slices
    def __getitem__(self, i):
        self.randomize()
        loadnifti = LoadNifti()
        X_img, compatible_meta = loadnifti(self.X_path[i])
        if int(len(self.X_path)) < 1000 and i < 5:  # only print the val x_path
            print(f"No. {i} file, path: {self.X_path[i]}")
        y_img, compatible_meta = loadnifti(self.y_path[i])

        if isinstance(self.X_transform, Randomizable):
            self.X_transform.set_random_state(seed=self._seed)
            self.y_transform.set_random_state(seed=self._seed)
        X_img = apply_transform(self.X_transform, X_img)
        y_img = apply_transform(self.y_transform, y_img)

        if self.using_flair:
            X_path_str = str(self.X_path[i])
            if "t1" in X_path_str:
                X_fair_path = X_path_str.replace("t1", "flair")
            else:
                X_fair_path = X_path_str.replace("t2", "flair")
            X_fair, compatible_meta = loadnifti(Path(X_fair_path))
            X_fair_img = apply_transform(self.X_transform, X_fair)
            X_img = torch.cat((X_img, X_fair_img), 0)

        return X_img, y_img
예제 #4
0
    def __getitem__(self, index: int):
        self.randomize()
        meta_data, seg_meta_data, seg, label = None, None, None, None

        # load data and optionally meta
        if self.image_only:
            img = self.loader(self.image_files[index])
            if self.seg_files is not None:
                seg = self.loader(self.seg_files[index])
        else:
            img, meta_data = self.loader(self.image_files[index])
            if self.seg_files is not None:
                seg, seg_meta_data = self.loader(self.seg_files[index])

        # apply the transforms
        if self.transform is not None:
            if isinstance(self.transform, Randomizable):
                self.transform.set_random_state(seed=self._seed)

            if self.transform_with_metadata:
                img, meta_data = apply_transform(self.transform,
                                                 (img, meta_data),
                                                 map_items=False,
                                                 unpack_items=True)
            else:
                img = apply_transform(self.transform, img, map_items=False)

        if self.seg_files is not None and self.seg_transform is not None:
            if isinstance(self.seg_transform, Randomizable):
                self.seg_transform.set_random_state(seed=self._seed)

            if self.transform_with_metadata:
                seg, seg_meta_data = apply_transform(self.seg_transform,
                                                     (seg, seg_meta_data),
                                                     map_items=False,
                                                     unpack_items=True)
            else:
                seg = apply_transform(self.seg_transform, seg, map_items=False)

        if self.labels is not None:
            label = self.labels[index]
            if self.label_transform is not None:
                label = apply_transform(self.label_transform,
                                        label,
                                        map_items=False)  # type: ignore

        # construct outputs
        data = [img]
        if seg is not None:
            data.append(seg)
        if label is not None:
            data.append(label)
        if not self.image_only and meta_data is not None:
            data.append(meta_data)
        if not self.image_only and seg_meta_data is not None:
            data.append(seg_meta_data)
        if len(data) == 1:
            return data[0]
        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
        return tuple(data)
예제 #5
0
    def __iter__(self):
        self.source = iter(self.data)
        for data in self.source:
            if self.transform is not None:
                data = apply_transform(self.transform, data)

            yield data
예제 #6
0
    def __getitem__(self, index):
        self.randomize()
        row = self.csv.iloc[index]
        jpg_lst = sorted(glob(os.path.join(data_dir, row.StudyInstanceUID, row.SeriesInstanceUID, '*.jpg')))
        img_lst = np.array([cv2.imread(jpg)[:,:,::-1] for jpg in jpg_lst])
#         print(img_lst.shape)
#         if np.random.randint(2)==0 and self.mode=='train':
#             img_lst = img_lst[::-1]
        if np.random.randint(2)==0 and self.mode=='train':
            z = np.random.randint(1,max(len(img_lst)//6,10))
            y = np.random.randint(3,30)
            x = np.random.randint(3,30)
            img_lst = img_lst[z:-z,x:-x,y:-y]
#             print("flip")
        img = np.stack([image.astype(np.float32) for image in img_lst], axis=2).transpose(3,0,1,2)
        
        if self.transform is not None:
            if isinstance(self.transform, Randomizable):
                self.transform.set_random_state(seed=self._seed)
            img = apply_transform(self.transform, img)
            
        if self.mode == 'test':
            return img
        else:
            return img, torch.tensor(row[target_cols]).float()
예제 #7
0
파일: dataset.py 프로젝트: yaritzabg/MONAI
 def _load_cache_item(self, item, transforms):
     for _transform in transforms:
         # execute all the deterministic transforms
         if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
             break
         item = apply_transform(_transform, item)
     return item
예제 #8
0
    def get(self, i):
        s = time.time()
        data = load_dicom_array(self.dicom_folders[i])
        image, files = data
        image_lung = np.expand_dims(window(image, WL=-600, WW=1500), axis=3)
        image_mediastinal = np.expand_dims(window(image, WL=40, WW=400),
                                           axis=3)
        image_pe_specific = np.expand_dims(window(image, WL=100, WW=700),
                                           axis=3)
        image = np.concatenate(
            [image_mediastinal, image_pe_specific, image_lung], axis=3)
        rat = MAX_LENGTH / np.max(image.shape[1:])
        names = [row.split(".dcm")[0].split("/")[-3:] for row in files]
        images = []
        for img in image:
            if self.transforms:
                img = self.transforms(image=img)['image']
            if self.preprocessing:
                img = self.preprocessing(image=img)['image']
            images.append(img)
        images = np.array(images)
        img = images[:, ::-1].transpose(1, 2, 3, 0)
        if self.transform3d is not None:
            if isinstance(self.transform3d, Randomizable):
                self.transform3d.set_random_state(seed=self._seed)
            img = apply_transform(self.transform3d, img)

        return torch.from_numpy(images), names, img
예제 #9
0
    def _transform(self, index: int):
        data = {k: v[index] for k, v in self.arrays.items()}

        if self.transform is not None:
            data = apply_transform(self.transform, data)

        return data
예제 #10
0
    def __getitem__(self, index: int):
        image_id = int(index / self.samples_per_image)
        image_paths = self.data[image_id]
        images = np.expand_dims(np.stack(
            [self.image_reader(x) for x in image_paths]),
                                axis=0)

        # Get mag level of file
        mag_level = get_mag_level(image_paths[0])

        patches = self.sampler(images)

        if len(patches) != self.samples_per_image:
            raise RuntimeWarning(
                f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
            )

        patch_id = (index - image_id * self.samples_per_image) * (
            -1 if index < 0 else 1)
        patch = patches[patch_id]
        if self.transform is not None:
            # Preprocessing - 1,10,256,256
            patch[0, 7, :, :] = preprocess(patch[0, 7, :, :], mag_level, "C01")
            patch[0, 8, :, :] = preprocess(patch[0, 8, :, :], mag_level, "C02")
            patch[0, 9, :, :] = preprocess(patch[0, 9, :, :], mag_level, "C03")
            patch[0, :7, :, :] = preprocess(patch[0, :7, :, :], mag_level,
                                            "C04")

            patch = apply_transform(self.transform, patch, map_items=False)
        return patch
예제 #11
0
 def _transform(self, index: int):
     """
     Fetch single data item from `self.data`.
     """
     data_i = self.data[index]
     return apply_transform(
         self.transform, data_i) if self.transform is not None else data_i
예제 #12
0
def monai_preprocess(imgs512):
    imgs = imgs512[:, :, 43:-55, 43:-55]
    img_monai = imgs[int(imgs.shape[0] * 0.25):int(imgs.shape[0] * 0.75)]
    img_monai = np.transpose(img_monai, (1, 2, 3, 0))
    img_monai = apply_transform(val_transforms, img_monai)
    img_monai = np.expand_dims(img_monai, axis=0)
    img_monai = torch.from_numpy(img_monai).cuda()
    return img_monai
예제 #13
0
 def __iter__(self):
     for image in super().__iter__():
         if not self.with_coordinates:
             for patch, *_ in self.patch_iter(
                     image):  # patch_iter to yield at least 1 item: patch
                 out_patch = (patch if self.transform is None else
                              apply_transform(
                                  self.transform, patch, map_items=False))
                 yield out_patch
         else:
             for patch, slices, *_ in self.patch_iter(
                     image
             ):  # patch_iter to yield at least 2 items: patch, coords
                 out_patch = (patch if self.transform is None else
                              apply_transform(
                                  self.transform, patch, map_items=False))
                 yield out_patch, slices
예제 #14
0
파일: dataset.py 프로젝트: yaritzabg/MONAI
    def __getitem__(self, index: int):
        def to_list(x):
            return list(x) if isinstance(x, (tuple, list)) else [x]

        data = list()
        for dataset in self.data:
            data.extend(to_list(dataset[index]))
        if self.transform is not None:
            data = apply_transform(self.transform, data, map_items=False)  # transform the list data
        return data
예제 #15
0
def pred_monai(imgs512):
    imgs = imgs512[:, :, 43:-55, 43:-55]
    img_monai = imgs[int(imgs.shape[0] * 0.25):int(imgs.shape[0] * 0.75)]
    img_monai = np.transpose(img_monai, (1, 2, 3, 0))
    img_monai = apply_transform(val_transforms, img_monai)
    img_monai = np.expand_dims(img_monai, axis=0)
    img_monai = torch.from_numpy(img_monai).cuda()
    monai_preds = torch.sigmoid(
        monai_model(img_monai)).cpu().detach().numpy().squeeze()

    return monai_preds
예제 #16
0
    def __iter__(self):
        info = get_worker_info()
        num_workers = info.num_workers if info is not None else 1
        id = info.id if info is not None else 0

        self.source = iter(self.data)
        for i, item in enumerate(self.source):
            if i % num_workers == id:
                if self.transform is not None:
                    item = apply_transform(self.transform, item)
                yield item
예제 #17
0
    def __getitem__(self, index: int):
        def to_list(x):
            return list(x) if isinstance(x, (tuple, list)) else [x]

        data = list()
        for dataset in self.data:
            data.extend(to_list(dataset[index]))
        if self.transform is not None:
            data = apply_transform(self.transform, data, map_items=False)  # transform the list data
        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
        return tuple(data)
예제 #18
0
    def _transform(self, index: int):
        # Get a single entry of data
        sample: Dict = self.data[index]
        # Extract patch image and associated metadata
        image, metadata = self._get_data(sample)
        # Get the label
        label = self._get_label(sample)

        # Create put all patch information together and apply transforms
        patch = {"image": image, "label": label, "metadata": metadata}
        return apply_transform(self.transform,
                               patch) if self.transform else patch
예제 #19
0
 def _load_cache_item(self, item: Any, transforms: Sequence[Callable]):
     """
     Args:
         item: input item to load and transform to generate dataset for model.
         transforms: transforms to execute operations on input item.
     """
     for _transform in transforms:
         # execute all the deterministic transforms
         if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
             break
         item = apply_transform(_transform, item)
     return item
예제 #20
0
    def __getitem__(self, index: int):
        self.randomize()
        meta_data = None
        img_loader = LoadNifti(as_closest_canonical=self.as_closest_canonical,
                               image_only=self.image_only,
                               dtype=self.dtype)
        if self.image_only:
            img = img_loader(self.image_files[index])
        else:
            img, meta_data = img_loader(self.image_files[index])
        seg = None
        if self.seg_files is not None:
            seg_loader = LoadNifti(image_only=True)
            seg = seg_loader(self.seg_files[index])
        label = None
        if self.labels is not None:
            label = self.labels[index]

        if self.transform is not None:
            if isinstance(self.transform, Randomizable):
                self.transform.set_random_state(seed=self._seed)
            img = apply_transform(self.transform, img)

        data = [img]

        if self.seg_transform is not None:
            if isinstance(self.seg_transform, Randomizable):
                self.seg_transform.set_random_state(seed=self._seed)
            seg = apply_transform(self.seg_transform, seg)

        if seg is not None:
            data.append(seg)
        if label is not None:
            data.append(label)
        if not self.image_only and meta_data is not None:
            data.append(meta_data)
        if len(data) == 1:
            return data[0]
        # use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
        return tuple(data)
예제 #21
0
파일: utils.py 프로젝트: gagandaroach/MONAI
def engine_apply_transform(batch: Any, output: Any, transform: Callable):
    """
    Apply transform for the engine.state.batch and engine.state.output.
    If `batch` and `output` are dictionaries, temporarily combine them for the transform,
    otherwise, apply the transform for `output` data only.

    """
    if isinstance(batch, dict) and isinstance(output, dict):
        data = dict(batch)
        data.update(output)
        data = apply_transform(transform, data)
        for k, v in data.items():
            # split the output data of post transforms into `output` and `batch`,
            # `batch` should be read-only, so save the generated key-value into `output`
            if k in output or k not in batch:
                output[k] = v
            else:
                batch[k] = v
    else:
        output = apply_transform(transform, output)

    return batch, output
예제 #22
0
파일: dataset.py 프로젝트: tuan-cs/MONAI
 def _load_cache_item(self, idx: int):
     """
     Args:
         idx: the index of the input data sequence.
     """
     item = self.data[idx]
     for _transform in self.transform.transforms:  # type:ignore
         # execute all the deterministic transforms
         if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
             break
         _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
         item = apply_transform(_xform, item)
     return item
예제 #23
0
 def _transform(self, index: int):
     image_id = int(index / self.samples_per_image)
     image = self.data[image_id]
     patches = self.patch_func(image)
     if len(patches) != self.samples_per_image:
         raise RuntimeWarning(
             f"`patch_func` must return a sequence of length: samples_per_image={self.samples_per_image}."
         )
     patch_id = (index - image_id * self.samples_per_image) * (-1 if index < 0 else 1)
     patch = patches[patch_id]
     if self.transform is not None:
         patch = apply_transform(self.transform, patch, map_items=False)
     return patch
 def _rand_pop(self, buffer: List, num_workers: int = 1, id: int = 0):
     length = len(buffer)
     for i in range(min(length, num_workers)):
         if self.shuffle:
             # randomly select an item for every worker and pop
             self.randomize(length)
         # switch random index data and the last index data
         item, buffer[self._idx] = buffer[self._idx], buffer[-1]
         buffer.pop()
         if i == id:
             if self.transform is not None:
                 item = apply_transform(self.transform, item)
             yield item
예제 #25
0
 def _load_cache_item(self, idx: int):
     """
     Args:
         idx: the index of the input data sequence.
     """
     item = self.data[idx]
     if not isinstance(self.transform, Compose):
         raise ValueError("transform must be an instance of monai.transforms.Compose.")
     for _transform in self.transform.transforms:
         # execute all the deterministic transforms
         if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
             break
         item = apply_transform(_transform, item)
     return item
예제 #26
0
 def __iter__(self):
     for image in super().__iter__():
         for patch, *others in self.patch_iter(image):
             out_patch = patch
             if self.patch_transform is not None:
                 out_patch = apply_transform(self.patch_transform,
                                             patch,
                                             map_items=False)
             if self.with_coordinates and len(
                     others
             ) > 0:  # patch_iter to yield at least 2 items: patch, coords
                 yield out_patch, others[0]
             else:
                 yield out_patch
예제 #27
0
 def __getitem__(self, index):
     self.randomize()
     row = self.csv.iloc[index]
     jpg_lst = sorted(glob(os.path.join(data_dir, row.StudyInstanceUID, row.SeriesInstanceUID, '*.jpg')))
     img_lst = np.array([cv2.imread(jpg)[:,:,::-1] for jpg in jpg_lst]) #z,y,x
     img = np.stack([image.astype(np.float32) for image in img_lst], axis=2).transpose(3,0,1,2)
     if self.transform is not None:
         if isinstance(self.transform, Randomizable):
             self.transform.set_random_state(seed=self._seed)
         img = apply_transform(self.transform, img)
         
     if self.mode == 'test':
         return img
     else:
         return img, torch.tensor(row[target_cols]).float()
    def __getitem__(self, i):
        self.randomize()
        loadnifti = LoadNifti()
        y_img, compatible_meta = loadnifti(self.y_path[i])
        y_img = apply_transform(self.y_transform, y_img)

        if isinstance(self.X_transform, Randomizable):
            self.X_transform.set_random_state(seed=self._seed)
            self.y_transform.set_random_state(seed=self._seed)

        X_img = []
        if self.num_scan_training > 1:
            for scan in self.X_path[i]:
                img = MGHImage.load(scan).get_fdata()
                img[img < 3] = 0.0
                img = apply_transform(self.X_transform, img)
                X_img.append(img)
            X_img = torch.cat(X_img, dim=0)
        else:
            img = MGHImage.load(self.X_path[i]).get_fdata()
            img[img < 3] = 0.0
            X_img = apply_transform(self.X_transform, img)

        return X_img, y_img
예제 #29
0
파일: dataset.py 프로젝트: yaritzabg/MONAI
 def __getitem__(self, index):
     if index < self.cache_num:
         # load data from cache and execute from the first random transform
         start_run = False
         data = self._cache[index]
         for _transform in self.transform.transforms:  # pytype: disable=attribute-error
             if not start_run and not isinstance(_transform, Randomizable) and isinstance(_transform, Transform):
                 continue
             else:
                 start_run = True
             data = apply_transform(_transform, data)
     else:
         # no cache for this data, execute all the transforms directly
         data = super(CacheDataset, self).__getitem__(index)
     return data
예제 #30
0
파일: dataset.py 프로젝트: staffik/MONAI
 def __getitem__(self, index):
     if index >= self.cache_num:
         # no cache for this index, execute all the transforms directly
         return super(CacheDataset, self).__getitem__(index)
     # load data from cache and execute from the first random transform
     start_run = False
     if self._cache is None:
         self._cache = self._fill_cache()
     data = self._cache[index]
     for _transform in self.transform.transforms:  # pytype: disable=attribute-error
         if start_run or isinstance(
                 _transform,
                 Randomizable) or not isinstance(_transform, Transform):
             start_run = True
             data = apply_transform(_transform, data)
     return data