Exemplo n.º 1
0
def random_crop_generator3d(
    batch_size: int,
    input_size: Tuple[int, int, int],
    size: Union[Tuple[int, int, int], torch.Tensor],
    resize_to: Optional[Tuple[int, int, int]] = None,
    same_on_batch: bool = False,
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.float32
) -> Dict[str, torch.Tensor]:
    r"""Get parameters for ```crop``` transformation for crop transform.

    Args:
        batch_size (int): the tensor batch size.
        input_size (tuple): Input image shape, like (d, h, w).
        size (tuple): Desired size of the crop operation, like (d, h, w).
            If tensor, it must be (B, 3).
        resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.
        device (torch.device): the device on which the random numbers will be generated. Default: cpu.
        dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

    Returns:
        params Dict[str, torch.Tensor]: parameters to be passed for transformation.
            - src (torch.Tensor): cropping bounding boxes with a shape of (B, 8, 3).
            - dst (torch.Tensor): output bounding boxes with a shape (B, 8, 3).

    Note:
        The generated random numbers are not reproducible across different devices and dtypes.
    """
    _device, _dtype = _extract_device_dtype([size if isinstance(size, torch.Tensor) else None])
    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size, device=device, dtype=dtype).repeat(batch_size, 1)
    else:
        size = size.to(device=device, dtype=dtype)
    assert size.shape == torch.Size([batch_size, 3]), (
        "If `size` is a tensor, it must be shaped as (B, 3). "
        f"Got {size.shape} while expecting {torch.Size([batch_size, 3])}.")
    assert len(input_size) == 3 and isinstance(input_size[0], (int,)) and isinstance(input_size[1], (int,)) \
        and isinstance(input_size[2], (int,)) and input_size[0] > 0 and input_size[1] > 0 and input_size[2] > 0, \
        f"`input_size` must be a tuple of 3 positive integers. Got {input_size}."

    x_diff = input_size[2] - size[:, 2] + 1
    y_diff = input_size[1] - size[:, 1] + 1
    z_diff = input_size[0] - size[:, 0] + 1

    if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any():
        raise ValueError("input_size %s cannot be smaller than crop size %s in any dimension."
                         % (str(input_size), str(size)))

    if batch_size == 0:
        return dict(
            src=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
            dst=torch.zeros([0, 8, 3], device=_device, dtype=_dtype),
        )

    if same_on_batch:
        # If same_on_batch, select the first then repeat.
        x_start = _adapted_uniform((batch_size,), 0, x_diff[0], same_on_batch)
        y_start = _adapted_uniform((batch_size,), 0, y_diff[0], same_on_batch)
        z_start = _adapted_uniform((batch_size,), 0, z_diff[0], same_on_batch)
    else:
        x_start = _adapted_uniform((1,), 0, x_diff, same_on_batch)
        y_start = _adapted_uniform((1,), 0, y_diff, same_on_batch)
        z_start = _adapted_uniform((1,), 0, z_diff, same_on_batch)

    crop_src = bbox_generator3d(
        x_start.view(-1),
        y_start.view(-1),
        z_start.view(-1),
        size[:, 2] - 1,
        size[:, 1] - 1,
        size[:, 0] - 1).long()

    if resize_to is None:
        crop_dst = bbox_generator3d(
            torch.tensor([0] * batch_size, device=device, dtype=dtype),
            torch.tensor([0] * batch_size, device=device, dtype=dtype),
            torch.tensor([0] * batch_size, device=device, dtype=dtype),
            size[:, 2] - 1,
            size[:, 1] - 1,
            size[:, 0] - 1).long()
    else:
        assert len(resize_to) == 3 and isinstance(resize_to[0], (int,)) and isinstance(resize_to[1], (int,)) \
            and isinstance(resize_to[2], (int,)) and resize_to[0] > 0 and resize_to[1] > 0 and resize_to[2] > 0, \
            f"`resize_to` must be a tuple of 3 positive integers. Got {resize_to}."
        crop_dst = torch.tensor([[
            [0, 0, 0],
            [resize_to[-1] - 1, 0, 0],
            [resize_to[-1] - 1, resize_to[-2] - 1, 0],
            [0, resize_to[-2] - 1, 0],
            [0, 0, resize_to[-3] - 1],
            [resize_to[-1] - 1, 0, resize_to[-3] - 1],
            [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1],
            [0, resize_to[-2] - 1, resize_to[-3] - 1],
        ]], device=device, dtype=torch.long).repeat(batch_size, 1, 1)

    return dict(src=crop_src.to(device=_device),
                dst=crop_dst.to(device=_device))
Exemplo n.º 2
0
def random_crop_generator3d(
        batch_size: int,
        input_size: Tuple[int, int, int],
        size: Union[Tuple[int, int, int], torch.Tensor],
        resize_to: Optional[Tuple[int, int, int]] = None,
        same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
    r"""Get parameters for ```crop``` transformation for crop transform.

    Args:
        batch_size (int): the tensor batch size.
        input_size (tuple): Input image shape, like (d, h, w).
        size (tuple): Desired size of the crop operation, like (d, h, w).
            If tensor, it must be (B, 3).
        resize_to (tuple): Desired output size of the crop, like (d, h, w). If None, no resize will be performed.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.

    Returns:
        params Dict[str, torch.Tensor]: parameters to be passed for transformation.
    """
    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size).repeat(batch_size, 1)
    assert size.shape == torch.Size([batch_size, 3]), \
        f"If `size` is a tensor, it must be shaped as (B, 3). Got {size.shape}."

    x_diff = input_size[2] - size[:, 2] + 1
    y_diff = input_size[1] - size[:, 1] + 1
    z_diff = input_size[0] - size[:, 0] + 1

    if (x_diff < 0).any() or (y_diff < 0).any() or (z_diff < 0).any():
        raise ValueError(
            "input_size %s cannot be smaller than crop size %s in any dimension."
            % (str(input_size), str(size)))

    if same_on_batch:
        # If same_on_batch, select the first then repeat.
        x_start = _adapted_uniform((batch_size, ), 0, x_diff[0],
                                   same_on_batch).long()
        y_start = _adapted_uniform((batch_size, ), 0, y_diff[0],
                                   same_on_batch).long()
        z_start = _adapted_uniform((batch_size, ), 0, z_diff[0],
                                   same_on_batch).long()
    else:
        x_start = _adapted_uniform((1, ), 0, x_diff, same_on_batch).long()
        y_start = _adapted_uniform((1, ), 0, y_diff, same_on_batch).long()
        z_start = _adapted_uniform((1, ), 0, z_diff, same_on_batch).long()

    crop_src = bbox_generator3d(x_start.view(-1), y_start.view(-1),
                                z_start.view(-1), size[:, 2] - 1,
                                size[:, 1] - 1, size[:, 0] - 1)

    if resize_to is None:
        crop_dst = bbox_generator3d(torch.tensor([0] * batch_size),
                                    torch.tensor([0] * batch_size),
                                    torch.tensor([0] * batch_size),
                                    size[:, 2] - 1, size[:, 1] - 1,
                                    size[:, 0] - 1)
    else:
        crop_dst = torch.tensor([[
            [0, 0, 0],
            [resize_to[-1] - 1, 0, 0],
            [resize_to[-1] - 1, resize_to[-2] - 1, 0],
            [0, resize_to[-2] - 1, 0],
            [0, 0, resize_to[-3] - 1],
            [resize_to[-1] - 1, 0, resize_to[-3] - 1],
            [resize_to[-1] - 1, resize_to[-2] - 1, resize_to[-3] - 1],
            [0, resize_to[-2] - 1, resize_to[-3] - 1],
        ]]).repeat(batch_size, 1, 1)

    return dict(src=crop_src, dst=crop_dst)