def __call__(self, data): d = dict(data) box_start, box_end = generate_spatial_bounding_box( d[self.source_key], self.select_fn, self.channel_indices, self.margin, self.allow_smaller) center = list( np.mean([box_start, box_end], axis=0).astype(int, copy=False)) current_size = list( np.subtract(box_end, box_start).astype(int, copy=False)) if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) box_start = np.array([s.start for s in cropper.slices]) box_end = np.array([s.stop for s in cropper.slices]) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key, meta_key, meta_key_postfix in self.key_iterator( d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end d[meta_key][self.original_shape_key] = d[key].shape image = cropper(d[key]) d[meta_key][self.cropped_shape_key] = image.shape d[key] = image return d
def __call__(self, data): d = dict(data) box_start, box_end = generate_spatial_bounding_box( d[self.source_key], self.select_fn, self.channel_indices, self.margin ) center = list(np.mean([box_start, box_end], axis=0).astype(int)) current_size = list(np.subtract(box_end, box_start).astype(int)) if np.all(np.less(current_size, self.spatial_size)): cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) box_start = cropper.roi_start box_end = cropper.roi_end else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) for key in self.key_iterator(d): meta_key = f"{key}_{self.meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end d[meta_key][self.original_shape_key] = d[key].shape image = cropper(d[key]) d[meta_key][self.cropped_shape_key] = image.shape d[key] = image return d
def __call__(self, data): d: Dict = dict(data) first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d guidance = d[self.guidance] original_spatial_shape = d[first_key].shape[1:] box_start, box_end = self.bounding_box( np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list( np.mean([box_start, box_end], axis=0).astype(int, copy=False)) spatial_size = self.spatial_size box_size = list( np.subtract(box_end, box_start).astype(int, copy=False)) spatial_size = spatial_size[-len(box_size):] if len(spatial_size) < len(box_size): # If the data is in 3D and spatial_size is specified as 2D [256,256] # Then we will get all slices in such case diff = len(box_size) - len(spatial_size) spatial_size = list( original_spatial_shape[1:(1 + diff)]) + spatial_size if np.all(np.less(box_size, spatial_size)): if len(center) == 3: # 3D Deepgrow: set center to be middle of the depth dimension (D) center[0] = spatial_size[0] // 2 cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) # update bounding box in case it was corrected by the SpatialCrop constructor box_start = np.array([s.start for s in cropper.slices]) box_end = np.array([s.stop for s in cropper.slices]) for key, meta_key, meta_key_postfix in self.key_iterator( d, self.meta_keys, self.meta_key_postfix): if not np.array_equal(d[key].shape[1:], original_spatial_shape): raise RuntimeError( "All the image specified in keys should have same spatial shape" ) meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end d[meta_key][self.original_shape_key] = d[key].shape image = cropper(d[key]) d[meta_key][self.cropped_shape_key] = image.shape d[key] = image pos_clicks, neg_clicks = guidance[0], guidance[1] pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d
def __call__(self, data): # load data d = dict(data) image = d["image"] image_spacings = d["image_meta_dict"]["pixdim"][1:4].tolist() if "label" in self.keys: label = d["label"] label[label < 0] = 0 if self.training: # only task 04 does not be impacted cropped_data = self.crop_foreg({"image": image, "label": label}) image, label = cropped_data["image"], cropped_data["label"] else: d["original_shape"] = np.array(image.shape[1:]) box_start, box_end = generate_spatial_bounding_box(image) image = SpatialCrop(roi_start=box_start, roi_end=box_end)(image) d["bbox"] = np.vstack([box_start, box_end]) d["crop_shape"] = np.array(image.shape[1:]) original_shape = image.shape[1:] # calculate shape resample_flag = False anisotrophy_flag = False if self.target_spacing != image_spacings: # resample resample_flag = True resample_shape = self.calculate_new_shape(image_spacings, original_shape) anisotrophy_flag = self.check_anisotrophy(image_spacings) image = resample_image(image, resample_shape, anisotrophy_flag) if self.training: label = resample_label(label, resample_shape, anisotrophy_flag) d["resample_flag"] = resample_flag d["anisotrophy_flag"] = anisotrophy_flag # clip image for CT dataset if self.low != 0 or self.high != 0: image = np.clip(image, self.low, self.high) image = (image - self.mean) / self.std else: image = self.normalize_intensity(image.copy()) d["image"] = image if "label" in self.keys: d["label"] = label return d
def __call__(self, data): d: Dict = dict(data) guidance = d[self.guidance] original_spatial_shape = d[self.keys[0]].shape[1:] box_start, box_end = self.bounding_box( np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size box_size = list(np.subtract(box_end, box_start).astype(int)) spatial_size = spatial_size[-len(box_size):] if len(spatial_size) < len(box_size): # If the data is in 3D and spatial_size is specified as 2D [256,256] # Then we will get all slices in such case diff = len(box_size) - len(spatial_size) spatial_size = list( original_spatial_shape[1:(1 + diff)]) + spatial_size if np.all(np.less(box_size, spatial_size)): if len(center) == 3: # 3D Deepgrow: set center to be middle of the depth dimension (D) center[0] = spatial_size[0] // 2 cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) box_start, box_end = cropper.roi_start, cropper.roi_end for key in self.keys: if not np.array_equal(d[key].shape[1:], original_spatial_shape): raise RuntimeError( "All the image specified in keys should have same spatial shape" ) meta_key = f"{key}_{self.meta_key_postfix}" d[meta_key][self.start_coord_key] = box_start d[meta_key][self.end_coord_key] = box_end d[meta_key][self.original_shape_key] = d[key].shape image = cropper(d[key]) d[meta_key][self.cropped_shape_key] = image.shape d[key] = image pos_clicks, neg_clicks = guidance[0], guidance[1] pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] d[self.guidance] = [pos, neg] return d
def test_tensor_shape(self, input_param, input_shape, expected_shape): input_data = torch.randint( 0, 2, size=input_shape, device="cuda" if torch.cuda.is_available() else "cpu") result = SpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape)
def __call__(self, data): guidance = data[self.guidance] box_start = None for key in self.keys: box_start, box_end = self.bounding_box( np.array(guidance[0] + guidance[1]), data[key].shape[1:]) center = np.mean([box_start, box_end], axis=0).astype(int).tolist() spatial_size = data.get(self.spatial_size_key, self.spatial_size) current_size = np.absolute(np.subtract( box_start, box_end)).astype(int).tolist() spatial_size = spatial_size[-len(current_size):] if len(spatial_size) < len( current_size ): # 3D spatial_size = [256,256] (include all slices in such case) diff = len(current_size) - len(spatial_size) spatial_size = list( data[key].shape[1:(1 + diff)]) + spatial_size if np.all(np.less(current_size, spatial_size)): if len(center) == 3: center[0] = center[0] + (spatial_size[0] // 2 - center[0]) cropper = SpatialCrop(roi_center=center, roi_size=spatial_size) else: cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) box_start, box_end = cropper.roi_start, cropper.roi_end meta_key = f"{key}_{self.meta_key_postfix}" data[meta_key][self.start_coord_key] = box_start data[meta_key][self.end_coord_key] = box_end data[meta_key][self.original_shape_key] = data[key].shape image = cropper(data[key]) data[meta_key][self.cropped_shape_key] = image.shape data[key] = image pos_clicks, neg_clicks = guidance[0], guidance[1] pos = np.subtract(pos_clicks, box_start).tolist() if len(pos_clicks) else [] neg = np.subtract(neg_clicks, box_start).tolist() if len(neg_clicks) else [] data[self.guidance] = [pos, neg] return data
def __call__( self, img: np.ndarray, label: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None, fg_indices: Optional[np.ndarray] = None, bg_indices: Optional[np.ndarray] = None, ) -> List[np.ndarray]: """ Args: img: input data to crop samples from based on the pos/neg ratio of `label` and `image`. Assumes `img` is a channel-first array. label: the label image that is used for finding foreground/background, if None, use `self.label`. image: optional image data to help select valid area, can be same as `img` or another image array. use ``label == 0 & image > image_threshold`` to select the negative sample(background) center. so the crop center will only exist on valid image area. if None, use `self.image`. fg_indices: foreground indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. bg_indices: background indices to randomly select crop centers, need to provide `fg_indices` and `bg_indices` together. """ if label is None: label = self.label if label is None: raise ValueError("label should be provided.") if image is None: image = self.image if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: fg_indices = self.fg_indices bg_indices = self.bg_indices else: fg_indices, bg_indices = map_binary_to_indices( label, image, self.image_threshold) if self.target_label is not None: label = (label == self.target_label).astype(np.uint8) self.randomize(label, fg_indices, bg_indices, image) results: List[np.ndarray] = [] if self.centers is not None: for center in self.centers: if np.any(np.greater(self.spatial_size, img.shape[1:])): cropper = ResizeWithPadOrCrop( spatial_size=self.spatial_size) else: cropper = SpatialCrop( roi_center=tuple(center), spatial_size=self.spatial_size) # type: ignore results.append(cropper(img)) return results
def __call__( self, img: np.ndarray, msk: Optional[np.ndarray] = None, center: Optional[tuple] = None, z_axis: Optional[int] = None, ): """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ if self.mask_data is None and msk is None: raise ValueError("Unknown mask_data.") mask_data_ = np.array([[1]]) if self.mask_data is not None and msk is None: mask_data_ = self.mask_data > 0 if msk is not None: mask_data_ = msk > 0 mask_data_ = np.asarray(mask_data_) if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]: raise ValueError( "When mask_data is not single channel, mask_data channels must match img, " f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}.") z_axis_ = z_axis if z_axis is not None else self.z_axis if center is None: center = self.get_center_pos(mask_data_, z_axis_) if self.crop_mode in ["single", "parallel"]: size_ = self.get_new_spatial_size(z_axis_) size_ = list(map(int, size_)) slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img) if np.any(slice_.shape[1:] != size_): slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_) return np.moveaxis(slice_.squeeze(0), z_axis_, 0) else: cross_slices = np.zeros(shape=(3, ) + self.roi_size) for k in range(3): size_ = np.insert(self.roi_size, k, 1) slice_ = SpatialCrop(roi_center=center, roi_size=size_)(img) if np.any(slice_.shape[1:] != size_): slice_ = ResizeWithPadOrCrop(spatial_size=size_)(slice_) cross_slices[k] = slice_.squeeze() return cross_slices
def __call__( self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: d = dict(data) for key in self.key_iterator(d): orig_size = d[key].shape[1:] z_size = orig_size[2] z_bottom = int(z_size * self.relative_z_roi[1]) z_top = z_size - int(z_size * self.relative_z_roi[0]) roi_start = np.array([0, 0, z_bottom]) roi_end = np.array([orig_size[0], orig_size[1], z_top]) cropper = SpatialCrop(roi_start=roi_start, roi_end=roi_end) d[key] = cropper(d[key]) return d
def __call__( self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None fg_indices = d.get( self.fg_indices_key) if self.fg_indices_key is not None else None bg_indices = d.get( self.bg_indices_key) if self.bg_indices_key is not None else None if self.target_label is not None: label = (label == self.target_label).astype(np.uint8) self.randomize(label, fg_indices, bg_indices, image) if not isinstance(self.spatial_size, tuple): raise TypeError( f"Expect spatial_size to be tuple, but got {type(self.spatial_size)}" ) if self.centers is None: raise AssertionError results: List[Dict[Hashable, np.ndarray]] = [{} for _ in range(self.num_samples)] for i, center in enumerate(self.centers): for key in self.key_iterator(d): img = d[key] if np.greater(self.spatial_size, img.shape[1:]).any(): cropper = ResizeWithPadOrCrop( spatial_size=self.spatial_size) else: cropper = SpatialCrop( roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore results[i][key] = cropper(img) # fill in the extra keys with unmodified data for key in set(data.keys()).difference(set(self.keys)): results[i][key] = data[key] # add `patch_index` to the meta data for key in self.key_iterator(d): meta_data_key = f"{key}_{self.meta_key_postfix}" if meta_data_key not in results[i]: results[i][meta_data_key] = {} # type: ignore results[i][meta_data_key][Key.PATCH_INDEX] = i return results
def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) results = [] for p in TEST_NDARRAYS: for q in TEST_NDARRAYS + (None, ): input_param_mod = { k: q(v) if k != "roi_slices" and q is not None else v for k, v in input_param.items() } im = p(input_data) result = SpatialCrop(**input_param_mod)(im) self.assertEqual(type(im), type(result)) if isinstance(result, torch.Tensor): self.assertEqual(result.device, im.device) self.assertTupleEqual(result.shape, expected_shape) results.append(result) if len(results) > 1: assert_allclose(results[0], results[-1], type_test=False)
def __call__( self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]: d = dict(data) label = d[self.label_key] image = d[self.image_key] if self.image_key else None fg_indices = d.get(self.fg_indices_key, None) if self.fg_indices_key is not None else None bg_indices = d.get(self.bg_indices_key, None) if self.bg_indices_key is not None else None self.randomize(label, fg_indices, bg_indices, image) assert isinstance(self.spatial_size, tuple) assert self.centers is not None results: List[Dict[Hashable, np.ndarray]] = [ dict() for _ in range(self.num_samples) ] for key in data.keys(): if key in self.keys: img = d[key] for i, center in enumerate(self.centers): if self.crop_mode in ["single", "parallel"]: size_ = self.get_new_spatial_size() slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) results[i][key] = np.moveaxis(slice_.squeeze(0), self.z_axis, 0) else: cross_slices = np.zeros(shape=(3, ) + self.spatial_size) for k in range(3): size_ = np.insert(self.spatial_size, k, 1) slice_ = SpatialCrop(roi_center=tuple(center), roi_size=size_)(img) cross_slices[k] = slice_.squeeze() results[i][key] = cross_slices else: for i in range(self.num_samples): results[i][key] = data[key] return results
def test_shape(self, input_param, input_shape, expected_shape): input_data = np.random.randint(0, 2, size=input_shape) result = SpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape)
def test_shape(self, input_param, input_data, expected_shape): result = SpatialCrop(**input_param)(input_data) self.assertTupleEqual(result.shape, expected_shape)
def test_error(self, input_param): with self.assertRaises(ValueError): SpatialCrop(**input_param)