def random_crop_generator3d(
    batch_size: int,
    input_size: Tuple[int, int, int],
    size: Union[Tuple[int, int, int], torch.Tensor],
    resize_to: Optional[Tuple[int, int, int]] = None,
    same_on_batch: bool = False,
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
    r"""Get parameters for ```crop``` transformation for crop transform.

    Args:
        batch_size (int): the tensor batch size.
        input_size (tuple): Input image shape, like (d, h, w).
        size (tuple): Desired size of the crop operation, like (d, h, w).
            If tensor, it must be (B, 3).
        resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.
        device (torch.device): the device on which the random numbers will be generated. Default: cpu.
        dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

    Returns:
        params Dict[str, torch.Tensor]: parameters to be passed for transformation.
            - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3).
            - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3).

    Note:
        The generated random numbers are not reproducible across different devices and dtypes.
    """
    _device, _dtype = _extract_device_dtype(
        [size if isinstance(size, torch.Tensor) else None])
    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size, device=device,
                            dtype=dtype).repeat(batch_size, 1)
    else:
        size = size.to(device=device, dtype=dtype)
    if size.shape != torch.Size([batch_size, 3]):
        raise AssertionError(
            "If `size` is a tensor, it must be shaped as (B, 3). "
            f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}.")
    if not (len(input_size) == 3 and isinstance(input_size[0], (int, ))
            and isinstance(input_size[1],
                           (int, )) and isinstance(input_size[2], (int, ))
            and input_size[0] > 0 and input_size[1] > 0 and input_size[2] > 0):
        raise AssertionError(
            f"`input_size` must be a tuple of 3 positive integers. Got {input_size}."
        )

    x_diff = input_size[2] - size[:, 2] + 1
    y_diff = input_size[1] - size[:, 1] + 1
    z_diff = input_size[0] - size[:, 0] + 1

    if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any():
        raise ValueError(
            f"input_size {str(input_size)} cannot be smaller than crop size {str(size)} in any dimension."
        )

    if batch_size == 0:
        return dict(
            src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
            dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
        )

    if same_on_batch:
        # If same_on_batch, select the first then repeat.
        x_start = _adapted_uniform((batch_size, ), 0, x_diff[0],
                                   same_on_batch).floor()
        y_start = _adapted_uniform((batch_size, ), 0, y_diff[0],
                                   same_on_batch).floor()
        z_start = _adapted_uniform((batch_size, ), 0, z_diff[0],
                                   same_on_batch).floor()
    else:
        x_start = _adapted_uniform((1, ), 0, x_diff, same_on_batch).floor()
        y_start = _adapted_uniform((1, ), 0, y_diff, same_on_batch).floor()
        z_start = _adapted_uniform((1, ), 0, z_diff, same_on_batch).floor()

    crop_src = bbox_generator3d(
        x_start.to(device=_device, dtype=_dtype).view(-1),
        y_start.to(device=_device, dtype=_dtype).view(-1),
        z_start.to(device=_device, dtype=_dtype).view(-1),
        size[:, 2].to(device=_device, dtype=_dtype) - 1,
        size[:, 1].to(device=_device, dtype=_dtype) - 1,
        size[:, 0].to(device=_device, dtype=_dtype) - 1,
    )

    if resize_to is None:
        crop_dst = bbox_generator3d(
            torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
            torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
            torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
            size[:, 2].to(device=_device, dtype=_dtype) - 1,
            size[:, 1].to(device=_device, dtype=_dtype) - 1,
            size[:, 0].to(device=_device, dtype=_dtype) - 1,
        )
    else:
        if not (len(resize_to) == 3 and isinstance(resize_to[0], (int, ))
                and isinstance(resize_to[1],
                               (int, )) and isinstance(resize_to[2],
                                                       (int, )) and
                resize_to[0] > 0 and resize_to[1] > 0 and resize_to[2] > 0):
            raise AssertionError(
                f"`resize_to` must be a tuple of 3 positive integers. Got {resize_to}."
            )
        crop_dst = torch.tensor(
            [[
                [0, 0, 0],
                [resize_to[-1] - 1, 0, 0],
                [resize_to[-1] - 1, resize_to[-2] - 1, 0],
                [0, resize_to[-2] - 1, 0],
                [0, 0, resize_to[-3] - 1],
                [resize_to[-1] - 1, 0, resize_to[-3] - 1],
                [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1],
                [0, resize_to[-2] - 1, resize_to[-3] - 1],
            ]],
            device=_device,
            dtype=_dtype,
        ).repeat(batch_size, 1, 1)

    return dict(src=crop_src.to(device=_device),
                dst=crop_dst.to(device=_device))
Esempio n. 2
0
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size, _, depth, height, width = batch_shape
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype(
            [self.size if isinstance(self.size, torch.Tensor) else None])

        if not isinstance(self.size, torch.Tensor):
            size = torch.tensor(self.size, device=_device,
                                dtype=_dtype).repeat(batch_size, 1)
        else:
            size = self.size.to(device=_device, dtype=_dtype)
        if size.shape != torch.Size([batch_size, 3]):
            raise AssertionError(
                "If `size` is a tensor, it must be shaped as (B, 3). "
                f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}."
            )
        if not (isinstance(depth, (int, )) and isinstance(height, (int, ))
                and isinstance(width, (int, )) and depth > 0 and height > 0
                and width > 0):
            raise AssertionError(
                f"`batch_shape` should not contain negative values. Got {(batch_shape)}."
            )

        x_diff = width - size[:, 2] + 1
        y_diff = height - size[:, 1] + 1
        z_diff = depth - size[:, 0] + 1

        if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any():
            raise ValueError(
                f"input_size {(depth, height, width)} cannot be smaller than crop size {str(size)} in any dimension."
            )

        if batch_size == 0:
            return dict(
                src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
                dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
            )

        x_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)
        y_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)
        z_start = _adapted_rsampling((batch_size, ), self.rand_sampler,
                                     same_on_batch).to(device=_device,
                                                       dtype=_dtype)

        x_start = (x_start * x_diff).floor()
        y_start = (y_start * y_diff).floor()
        z_start = (z_start * z_diff).floor()

        crop_src = bbox_generator3d(x_start.view(-1), y_start.view(-1),
                                    z_start.view(-1), size[:, 2] - 1,
                                    size[:, 1] - 1, size[:, 0] - 1)

        if self.resize_to is None:
            crop_dst = bbox_generator3d(
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                size[:, 2] - 1,
                size[:, 1] - 1,
                size[:, 0] - 1,
            )
        else:
            if not (len(self.resize_to) == 3 and isinstance(
                    self.resize_to[0],
                (int, )) and isinstance(self.resize_to[1], (int, ))
                    and isinstance(self.resize_to[2],
                                   (int, )) and self.resize_to[0] > 0
                    and self.resize_to[1] > 0 and self.resize_to[2] > 0):
                raise AssertionError(
                    f"`resize_to` must be a tuple of 3 positive integers. Got {self.resize_to}."
                )
            crop_dst = torch.tensor(
                [[
                    [0, 0, 0],
                    [self.resize_to[-1] - 1, 0, 0],
                    [self.resize_to[-1] - 1, self.resize_to[-2] - 1, 0],
                    [0, self.resize_to[-2] - 1, 0],
                    [0, 0, self.resize_to[-3] - 1],
                    [self.resize_to[-1] - 1, 0, self.resize_to[-3] - 1],
                    [
                        self.resize_to[-1] - 1, self.resize_to[-2] - 1,
                        self.resize_to[-3] - 1
                    ],
                    [0, self.resize_to[-2] - 1, self.resize_to[-3] - 1],
                ]],
                device=_device,
                dtype=_dtype,
            ).repeat(batch_size, 1, 1)

        return dict(src=crop_src.to(device=_device),
                    dst=crop_dst.to(device=_device))