def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False): """ Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary, convert that to `torch.Tensor`, too. Remove any superfluous metadata. Args: im: Input image (`np.ndarray` or `torch.Tensor`) meta: Metadata dictionary. simple_keys: whether to keep only a simple subset of metadata keys. Returns: By default, a `MetaTensor` is returned. However, if `get_track_meta()` is `False`, a `torch.Tensor` is returned. """ img = convert_to_tensor(im) # potentially ascontiguousarray # if not tracking metadata, return `torch.Tensor` if not get_track_meta() or meta is None: return img # remove any superfluous metadata. if simple_keys: # ensure affine is of type `torch.Tensor` if "affine" in meta: meta["affine"] = convert_to_tensor( meta["affine"]) # bc-breaking remove_extra_metadata(meta) # bc-breaking # return the `MetaTensor` return MetaTensor(img, meta=meta)
def __call__(self, img: NdarrayOrTensor, randomize: bool = True, device: Optional[torch.device] = None) -> NdarrayOrTensor: img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img device = device if device is not None else self.device field = self.sfield() dgrid = self.grid + field.to(self.grid_dtype) dgrid = moveaxis(dgrid, 1, -1) # type: ignore img_t = convert_to_tensor(img[None], torch.float32, device) out = grid_sample( input=img_t, grid=dgrid, mode=look_up_option(self.grid_mode, GridSampleMode), align_corners=self.grid_align_corners, padding_mode=look_up_option(self.grid_padding_mode, GridSamplePadMode), ) out_t, *_ = convert_to_dst_type(out.squeeze(0), img) return out_t
def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]: """ Args: kspace: The input k-space data. The shape is (...,num_coils,H,W,2) for complex 2D inputs and (...,num_coils,H,W,D) for real 3D data. The last spatial dim is selected for sampling. For the fastMRI dataset, k-space has the form (...,num_slices,num_coils,H,W) and sampling is done along W. For a general 3D data with the shape (...,num_coils,H,W,D), sampling is done along D. Returns: A tuple containing (1) the under-sampled kspace (2) absolute value of the inverse fourier of the under-sampled kspace """ kspace_t = convert_to_tensor_complex(kspace) spatial_size = kspace_t.shape num_cols = spatial_size[-1] if self.is_complex: # for complex data num_cols = spatial_size[-2] center_fraction, acceleration = self.randomize_choose_acceleration() # Create the mask num_low_freqs = int(round(num_cols * center_fraction)) prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) mask = self.R.uniform(size=num_cols) < prob pad = (num_cols - num_low_freqs + 1) // 2 mask[pad : pad + num_low_freqs] = True # Reshape the mask mask_shape = [1 for _ in spatial_size] if self.is_complex: mask_shape[-2] = num_cols else: mask_shape[-1] = num_cols mask = convert_to_tensor(mask.reshape(*mask_shape).astype(np.float32)) # under-sample the ksapce masked = mask * kspace_t masked_kspace: Tensor = convert_to_tensor(masked) self.mask = mask # compute inverse fourier of the masked kspace masked_kspace_ifft: Tensor = convert_to_tensor( complex_abs(ifftn_centered(masked_kspace, spatial_dims=self.spatial_dims, is_complex=self.is_complex)) ) # combine coil images (it is assumed that the coil dimension is # the first dimension before spatial dimensions) masked_kspace_ifft_rss: Tensor = convert_to_tensor( root_sum_of_squares(masked_kspace_ifft, spatial_dim=-self.spatial_dims - 1) ) return masked_kspace, masked_kspace_ifft_rss
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img img_min = img.min() img_max = img.max() img_rng = img_max - img_min field = self.sfield() rfield, *_ = convert_to_dst_type(field, img) # everything below here is to be computed using the destination type (numpy, tensor, etc.) img = (img - img_min) / (img_rng + 1e-10) # rescale to unit values img = img**rfield # contrast is changed by raising image data to a power, in this case the field out = (img * img_rng ) + img_min # rescale back to the original image value range return out
def __call__(self, data: Mapping[Hashable, Tensor]) -> Dict[Hashable, Tensor]: """ This transform can support to crop ND spatial (channel-first) data. It also supports pseudo ND spatial data (e.g., (C,H,W) is a pseudo-3D data point where C is the number of slices) Args: data: is a dictionary containing (key,value) pairs from the loaded dataset Returns: the new data dictionary """ d = dict(data) # compute roi_size according to self.ref_key roi_size = d[self.ref_key].shape[1:] # first dimension is not spatial (could be channel) # crop keys for key in self.key_iterator(d): image = d[key] roi_center = tuple(i // 2 for i in image.shape[1:]) cropper = SpatialCrop(roi_center=roi_center, roi_size=roi_size) d[key] = convert_to_tensor(cropper(d[key])) return d
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ Apply the transform to `img`, if `randomize` randomizing the smooth field otherwise reusing the previous. """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: self.randomize() if not self._do_transform: return img field = self.sfield() rfield, *_ = convert_to_dst_type(field, img) # everything below here is to be computed using the destination type (numpy, tensor, etc.) out = img * rfield return out
def set_array(self, src, non_blocking=False, *_args, **_kwargs): """ Copies the elements from src into self tensor and returns self. The src tensor must be broadcastable with the self tensor. It may be of a different data type or reside on a different device. See also: `https://pytorch.org/docs/stable/generated/torch.Tensor.copy_.html` Args: src: the source tensor to copy from. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. _args: currently unused parameters. _kwargs: currently unused parameters. """ src: torch.Tensor = convert_to_tensor(src, track_meta=False, wrap_sequence=True) try: return self.copy_(src, non_blocking=non_blocking) except RuntimeError: # skip the shape checking self.data = src return self
def convert_to_tensor_complex( data, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, track_meta: bool = False, ) -> Tensor: """ Convert complex-valued data to a 2-channel PyTorch tensor. The real and imaginary parts are stacked along the last dimension. This function relies on 'monai.utils.type_conversion.convert_to_tensor' Args: data: input data can be PyTorch Tensor, numpy array, list, int, and float. will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original. for list, convert every item to a Tensor if applicable. dtype: target data type to when converting to Tensor. device: target device to put the converted Tensor data. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. default to `False`. Returns: PyTorch version of the data Example: .. code-block:: python import numpy as np data = np.array([ [1+1j, 1-1j], [2+2j, 2-2j] ]) # the following line prints (2,2) print(data.shape) # the following line prints torch.Size([2, 2, 2]) print(convert_to_tensor_complex(data).shape) """ # if data is not complex, just turn it into a tensor if isinstance(data, Tensor): if not torch.is_complex(data): converted_data: Tensor = convert_to_tensor( data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta) return converted_data else: if not np.iscomplexobj(data): converted_data = convert_to_tensor(data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta) return converted_data # if data is complex, turn its stacked version into a tensor if isinstance(data, torch.Tensor): data = torch.stack([data.real, data.imag], dim=-1) elif isinstance(data, np.ndarray): if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims if data.ndim > 0: data = np.ascontiguousarray(data) data = np.stack((data.real, data.imag), axis=-1) elif isinstance(data, (float, int)): data = [[data.real, data.imag]] elif isinstance(data, list): data = convert_to_numpy(data, wrap_sequence=True) data = np.stack((data.real, data.imag), axis=-1).tolist() converted_data = convert_to_tensor(data, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta) return converted_data
def compute_matches( self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute matches according to ATTS for a single image Adapted from (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss/loss.py#L180-L184) Args: boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode`` anchors: anchors to match Mx4 or Mx6, also assumed to be ``StandardMode``. num_anchors_per_level: number of anchors per feature pyramid level num_anchors_per_loc: number of anchors per position Returns: - matrix which contains the similarity from each boxes to each anchor [N, M] - vector which contains the matched box index for all anchors (if background `BELOW_LOW_THRESHOLD` is used and if it should be ignored `BETWEEN_THRESHOLDS` is used) [M] Note: ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`, also represented as "xyxy" ([xmin, ymin, xmax, ymax]) for 2D and "xyzxyz" ([xmin, ymin, zmin, xmax, ymax, zmax]) for 3D. """ num_gt = boxes.shape[0] num_anchors = anchors.shape[0] distances_, _, anchors_center = boxes_center_distance( boxes, anchors) # num_boxes x anchors distances = convert_to_tensor(distances_) # select candidates based on center distance candidate_idx_list = [] start_idx = 0 for _, apl in enumerate(num_anchors_per_level): end_idx = start_idx + apl * num_anchors_per_loc # topk: total number of candidates per position topk = min(self.num_candidates * num_anchors_per_loc, apl) # torch.topk() does not support float16 cpu, need conversion to float32 or float64 _, idx = distances[:, start_idx:end_idx].to(COMPUTE_DTYPE).topk( topk, dim=1, largest=False) # idx: shape [num_boxes x topk] candidate_idx_list.append(idx + start_idx) start_idx = end_idx # [num_boxes x num_candidates] (index of candidate anchors) candidate_idx = torch.cat(candidate_idx_list, dim=1) match_quality_matrix = self.similarity_fn( boxes, anchors) # [num_boxes x anchors] candidate_ious = match_quality_matrix.gather( 1, candidate_idx) # [num_boxes, n_candidates] # corner case, n_candidates<=1 will make iou_std_per_gt NaN if candidate_idx.shape[1] <= 1: matches = -1 * torch.ones( (num_anchors, ), dtype=torch.long, device=boxes.device) matches[candidate_idx] = 0 return match_quality_matrix, matches # compute adaptive iou threshold iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes] iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes] iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes] is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates] if self.debug: print(f"Anchor matcher threshold: {iou_thresh_per_gt}") if self.center_in_gt: # can discard all candidates in case of very small objects :/ # center point of selected anchors needs to lie within the ground truth boxes_idx = (torch.arange( num_gt, device=boxes.device, dtype=torch.long)[:, None].expand_as(candidate_idx).contiguous() ) # [num_boxes x n_candidates] is_in_gt_ = centers_in_boxes( anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist) is_in_gt = convert_to_tensor(is_in_gt_) is_pos = is_pos & is_in_gt.view_as( is_pos) # [num_boxes x n_candidates] # in case on anchor is assigned to multiple boxes, use box with highest IoU # TODO: think about a better way to do this for ng in range(num_gt): candidate_idx[ng, :] += ng * num_anchors ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1) index = candidate_idx.view(-1)[is_pos.view(-1)] ious_inf[index] = match_quality_matrix.view(-1)[index] ious_inf = ious_inf.view_as(match_quality_matrix) matched_vals, matches = ious_inf.to(COMPUTE_DTYPE).max(dim=0) matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD # print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}") return match_quality_matrix, matches