def random_crop_generator3d( batch_size: int, input_size: Tuple[int, int, int], size: Union[Tuple[int, int, int], torch.Tensor], resize_to: Optional[Tuple[int, int, int]] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32 ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```crop``` transformation for crop transform. Args: batch_size (int): the tensor batch size. input_size (tuple): Input image shape, like (d, h, w). size (tuple): Desired size of the crop operation, like (d, h, w). If tensor, it must be (B, 3). resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3). - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3). Note: The generated random numbers are not reproducible across different devices and dtypes. """ _device, _dtype = _extract_device_dtype([size if isinstance(size, torch.Tensor) else None]) if not isinstance(size, torch.Tensor): size = torch.tensor(size, device=device, dtype=dtype).repeat(batch_size, 1) else: size = size.to(device=device, dtype=dtype) assert size.shape == torch.Size([batch_size, 3]), ( "If `size` is a tensor, it must be shaped as (B, 3). " f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}.") assert len(input_size) == 3 and isinstance(input_size[0], (int,)) and isinstance(input_size[1], (int,)) \ and isinstance(input_size[2], (int,)) and input_size[0] > 0 and input_size[1] > 0 and input_size[2] > 0, \ f"`input_size` must be a tuple of 3 positive integers. Got {input_size}." x_diff = input_size[2] - size[:, 2] + 1 y_diff = input_size[1] - size[:, 1] + 1 z_diff = input_size[0] - size[:, 0] + 1 if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any(): raise ValueError("input_size %s cannot be smaller than crop size %s in any dimension." % (str(input_size), str(size))) if batch_size == 0: return dict( src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype), ) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = _adapted_uniform((batch_size,), 0, x_diff[0], same_on_batch) y_start = _adapted_uniform((batch_size,), 0, y_diff[0], same_on_batch) z_start = _adapted_uniform((batch_size,), 0, z_diff[0], same_on_batch) else: x_start = _adapted_uniform((1,), 0, x_diff, same_on_batch) y_start = _adapted_uniform((1,), 0, y_diff, same_on_batch) z_start = _adapted_uniform((1,), 0, z_diff, same_on_batch) crop_src = bbox_generator3d( x_start.view(-1), y_start.view(-1), z_start.view(-1), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1).long() if resize_to is None: crop_dst = bbox_generator3d( torch.tensor([0] * batch_size, device=device, dtype=dtype), torch.tensor([0] * batch_size, device=device, dtype=dtype), torch.tensor([0] * batch_size, device=device, dtype=dtype), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1).long() else: assert len(resize_to) == 3 and isinstance(resize_to[0], (int,)) and isinstance(resize_to[1], (int,)) \ and isinstance(resize_to[2], (int,)) and resize_to[0] > 0 and resize_to[1] > 0 and resize_to[2] > 0, \ f"`resize_to` must be a tuple of 3 positive integers. Got {resize_to}." crop_dst = torch.tensor([[ [0, 0, 0], [resize_to[-1] - 1, 0, 0], [resize_to[-1] - 1, resize_to[-2] - 1, 0], [0, resize_to[-2] - 1, 0], [0, 0, resize_to[-3] - 1], [resize_to[-1] - 1, 0, resize_to[-3] - 1], [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1], [0, resize_to[-2] - 1, resize_to[-3] - 1], ]], device=device, dtype=torch.long).repeat(batch_size, 1, 1) return dict(src=crop_src.to(device=_device), dst=crop_dst.to(device=_device))
def random_crop_generator3d( batch_size: int, input_size: Tuple[int, int, int], size: Union[Tuple[int, int, int], torch.Tensor], resize_to: Optional[Tuple[int, int, int]] = None, same_on_batch: bool = False) -> Dict[str, torch.Tensor]: r"""Get parameters for ```crop``` transformation for crop transform. Args: batch_size (int): the tensor batch size. input_size (tuple): Input image shape, like (d, h, w). size (tuple): Desired size of the crop operation, like (d, h, w). If tensor, it must be (B, 3). resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed. same_on_batch (bool): apply the same transformation across the batch. Default: False. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. """ if not isinstance(size, torch.Tensor): size = torch.tensor(size).repeat(batch_size, 1) assert size.shape == torch.Size([batch_size, 3]), \ f"If `size` is a tensor, it must be shaped as (B, 3). Got {size.shape}." x_diff = input_size[2] - size[:, 2] + 1 y_diff = input_size[1] - size[:, 1] + 1 z_diff = input_size[0] - size[:, 0] + 1 if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any(): raise ValueError( "input_size %s cannot be smaller than crop size %s in any dimension." % (str(input_size), str(size))) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = _adapted_uniform((batch_size, ), 0, x_diff[0], same_on_batch).long() y_start = _adapted_uniform((batch_size, ), 0, y_diff[0], same_on_batch).long() z_start = _adapted_uniform((batch_size, ), 0, z_diff[0], same_on_batch).long() else: x_start = _adapted_uniform((1, ), 0, x_diff, same_on_batch).long() y_start = _adapted_uniform((1, ), 0, y_diff, same_on_batch).long() z_start = _adapted_uniform((1, ), 0, z_diff, same_on_batch).long() crop_src = bbox_generator3d(x_start.view(-1), y_start.view(-1), z_start.view(-1), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1) if resize_to is None: crop_dst = bbox_generator3d(torch.tensor([0] * batch_size), torch.tensor([0] * batch_size), torch.tensor([0] * batch_size), size[:, 2] - 1, size[:, 1] - 1, size[:, 0] - 1) else: crop_dst = torch.tensor([[ [0, 0, 0], [resize_to[-1] - 1, 0, 0], [resize_to[-1] - 1, resize_to[-2] - 1, 0], [0, resize_to[-2] - 1, 0], [0, 0, resize_to[-3] - 1], [resize_to[-1] - 1, 0, resize_to[-3] - 1], [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1], [0, resize_to[-2] - 1, resize_to[-3] - 1], ]]).repeat(batch_size, 1, 1) return dict(src=crop_src, dst=crop_dst)