def random_crop_generator3d( batch_size: int, input_size: Tuple[int, int, int], size: Union[Tuple[int, int, int], torch.Tensor], resize_to: Optional[Tuple[int, int, int]] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```crop``` transformation for crop transform. Args: batch_size (int): the tensor batch size. input_size (tuple): Input image shape, like (d, h, w). size (tuple): Desired size of the crop operation, like (d, h, w). If tensor, it must be (B, 3). resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3). - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3). Note: The generated random numbers are not reproducible across different devices and dtypes. """ _device, _dtype = _extract_device_dtype( [size if isinstance(size, torch.Tensor) else None]) if not isinstance(size, torch.Tensor): size = torch.tensor(size, device=device, dtype=dtype).repeat(batch_size, 1) else: size = size.to(device=device, dtype=dtype) if size.shape != torch.Size([batch_size, 3]): raise AssertionError( "If `size` is a tensor, it must be shaped as (B, 3). " f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}.") if not (len(input_size) == 3 and isinstance(input_size[0], (int, )) and isinstance(input_size[1], (int, )) and isinstance(input_size[2], (int, )) and input_size[0] > 0 and input_size[1] > 0 and input_size[2] > 0): raise AssertionError( f"`input_size` must be a tuple of 3 positive integers. Got {input_size}." ) x_diff = input_size[2] - size[:, 2] + 1 y_diff = input_size[1] - size[:, 1] + 1 z_diff = input_size[0] - size[:, 0] + 1 if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any(): raise ValueError( f"input_size {str(input_size)} cannot be smaller than crop size {str(size)} in any dimension." ) if batch_size == 0: return dict( src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), ) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = _adapted_uniform((batch_size, ), 0, x_diff[0], same_on_batch).floor() y_start = _adapted_uniform((batch_size, ), 0, y_diff[0], same_on_batch).floor() z_start = _adapted_uniform((batch_size, ), 0, z_diff[0], same_on_batch).floor() else: x_start = _adapted_uniform((1, ), 0, x_diff, same_on_batch).floor() y_start = _adapted_uniform((1, ), 0, y_diff, same_on_batch).floor() z_start = _adapted_uniform((1, ), 0, z_diff, same_on_batch).floor() crop_src = bbox_generator3d( x_start.to(device=_device, dtype=_dtype).view(-1), y_start.to(device=_device, dtype=_dtype).view(-1), z_start.to(device=_device, dtype=_dtype).view(-1), size[:, 2].to(device=_device, dtype=_dtype) - 1, size[:, 1].to(device=_device, dtype=_dtype) - 1, size[:, 0].to(device=_device, dtype=_dtype) - 1, ) if resize_to is None: crop_dst = bbox_generator3d( torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), size[:, 2].to(device=_device, dtype=_dtype) - 1, size[:, 1].to(device=_device, dtype=_dtype) - 1, size[:, 0].to(device=_device, dtype=_dtype) - 1, ) else: if not (len(resize_to) == 3 and isinstance(resize_to[0], (int, )) and isinstance(resize_to[1], (int, )) and isinstance(resize_to[2], (int, )) and resize_to[0] > 0 and resize_to[1] > 0 and resize_to[2] > 0): raise AssertionError( f"`resize_to` must be a tuple of 3 positive integers. Got {resize_to}." ) crop_dst = torch.tensor( [[ [0, 0, 0], [resize_to[-1] - 1, 0, 0], [resize_to[-1] - 1, resize_to[-2] - 1, 0], [0, resize_to[-2] - 1, 0], [0, 0, resize_to[-3] - 1], [resize_to[-1] - 1, 0, resize_to[-3] - 1], [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1], [0, resize_to[-2] - 1, resize_to[-3] - 1], ]], device=_device, dtype=_dtype, ).repeat(batch_size, 1, 1) return dict(src=crop_src.to(device=_device), dst=crop_dst.to(device=_device))
def forward( self, batch_shape: torch.Size, same_on_batch: bool = False ) -> Dict[str, torch.Tensor]: # type:ignore batch_size, _, depth, height, width = batch_shape _common_param_check(batch_size, same_on_batch) _device, _dtype = _extract_device_dtype( [self.size if isinstance(self.size, torch.Tensor) else None]) if not isinstance(self.size, torch.Tensor): size = torch.tensor(self.size, device=_device, dtype=_dtype).repeat(batch_size, 1) else: size = self.size.to(device=_device, dtype=_dtype) if size.shape != torch.Size([batch_size, 3]): raise AssertionError( "If `size` is a tensor, it must be shaped as (B, 3). " f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}." ) if not (isinstance(depth, (int, )) and isinstance(height, (int, )) and isinstance(width, (int, )) and depth > 0 and height > 0 and width > 0): raise AssertionError( f"`batch_shape` should not contain negative values. Got {(batch_shape)}." ) x_diff = width - size[:, 2] + 1 y_diff = height - size[:, 1] + 1 z_diff = depth - size[:, 0] + 1 if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any(): raise ValueError( f"input_size {(depth, height, width)} cannot be smaller than crop size {str(size)} in any dimension." ) if batch_size == 0: return dict( src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), ) x_start = _adapted_rsampling((batch_size, ), self.rand_sampler, same_on_batch).to(device=_device, dtype=_dtype) y_start = _adapted_rsampling((batch_size, ), self.rand_sampler, same_on_batch).to(device=_device, dtype=_dtype) z_start = _adapted_rsampling((batch_size, ), self.rand_sampler, same_on_batch).to(device=_device, dtype=_dtype) x_start = (x_start * x_diff).floor() y_start = (y_start * y_diff).floor() z_start = (z_start * z_diff).floor() crop_src = bbox_generator3d(x_start.view(-1), y_start.view(-1), z_start.view(-1), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1) if self.resize_to is None: crop_dst = bbox_generator3d( torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1, ) else: if not (len(self.resize_to) == 3 and isinstance( self.resize_to[0], (int, )) and isinstance(self.resize_to[1], (int, )) and isinstance(self.resize_to[2], (int, )) and self.resize_to[0] > 0 and self.resize_to[1] > 0 and self.resize_to[2] > 0): raise AssertionError( f"`resize_to` must be a tuple of 3 positive integers. Got {self.resize_to}." ) crop_dst = torch.tensor( [[ [0, 0, 0], [self.resize_to[-1] - 1, 0, 0], [self.resize_to[-1] - 1, self.resize_to[-2] - 1, 0], [0, self.resize_to[-2] - 1, 0], [0, 0, self.resize_to[-3] - 1], [self.resize_to[-1] - 1, 0, self.resize_to[-3] - 1], [ self.resize_to[-1] - 1, self.resize_to[-2] - 1, self.resize_to[-3] - 1 ], [0, self.resize_to[-2] - 1, self.resize_to[-3] - 1], ]], device=_device, dtype=_dtype, ).repeat(batch_size, 1, 1) return dict(src=crop_src.to(device=_device), dst=crop_dst.to(device=_device))