示例#1
0
文件: erasing.py 项目: kornia/kornia
    def apply_transform(self,
                        input: Tensor,
                        params: Dict[str, Tensor],
                        transform: Optional[Tensor] = None) -> Tensor:
        _, c, h, w = input.size()
        values = params["values"].unsqueeze(-1).unsqueeze(-1).unsqueeze(
            -1).repeat(1, *input.shape[1:]).to(input)

        bboxes = bbox_generator(params["xs"], params["ys"], params["widths"],
                                params["heights"])
        mask = bbox_to_mask(bboxes, w, h)  # Returns B, H, W
        mask = mask.unsqueeze(1).repeat(1, c, 1,
                                        1).to(input)  # Transform to B, c, H, W
        transformed = torch.where(mask == 1.0, values, input)
        return transformed
示例#2
0
文件: cutmix.py 项目: kornia/kornia
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]

        if not (type(height) is int and height > 0 and type(width) is int
                and width > 0):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )
        _device, _dtype = _extract_device_dtype([self.beta, self.cut_size])
        _common_param_check(batch_size, same_on_batch)

        if batch_size == 0:
            return dict(
                mix_pairs=torch.zeros([0, 3], device=_device,
                                      dtype=torch.long),
                crop_src=torch.zeros([0, 4, 2],
                                     device=_device,
                                     dtype=torch.long),
            )

        with torch.no_grad():
            batch_probs: torch.Tensor = _adapted_sampling(
                (batch_size * self.num_mix, ), self.prob_sampler,
                same_on_batch)
        mix_pairs: torch.Tensor = torch.rand(self.num_mix,
                                             batch_size,
                                             device=_device,
                                             dtype=_dtype).argsort(dim=1)
        cutmix_betas: torch.Tensor = _adapted_rsampling(
            (batch_size * self.num_mix, ), self.beta_sampler, same_on_batch)

        # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws:
        # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float"
        cutmix_betas = torch.min(torch.max(cutmix_betas, self._cut_size[0]),
                                 self._cut_size[1])
        cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs

        cut_height = (cutmix_rate * height).floor().to(device=_device,
                                                       dtype=_dtype)
        cut_width = (cutmix_rate * width).floor().to(device=_device,
                                                     dtype=_dtype)
        _gen_shape = (1, )

        if same_on_batch:
            _gen_shape = (cut_height.size(0), )
            cut_height = cut_height[0]
            cut_width = cut_width[0]

        # Reserve at least 1 pixel for cropping.
        x_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (width - cut_width - 1)
        y_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (height - cut_height - 1)
        x_start = x_start.floor().to(device=_device, dtype=_dtype)
        y_start = y_start.floor().to(device=_device, dtype=_dtype)

        crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(),
                                  cut_width, cut_height)

        # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2)
        crop_src = crop_src.view(self.num_mix, batch_size, 4, 2)

        return dict(
            mix_pairs=mix_pairs.to(device=_device, dtype=torch.long),
            crop_src=crop_src.floor().to(device=_device, dtype=_dtype),
        )
示例#3
0
文件: crop.py 项目: kornia/kornia
    def forward(
            self,
            batch_shape: torch.Size,
            same_on_batch: bool = False
    ) -> Dict[str, torch.Tensor]:  # type:ignore
        batch_size = batch_shape[0]
        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype(
            [self.size if isinstance(self.size, torch.Tensor) else None])

        if batch_size == 0:
            return dict(
                src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
                dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
            )

        input_size = (batch_shape[-2], batch_shape[-1])
        if not isinstance(self.size, torch.Tensor):
            size = torch.tensor(self.size, device=_device,
                                dtype=_dtype).repeat(batch_size, 1)
        else:
            size = self.size.to(device=_device, dtype=_dtype)
        if size.shape != torch.Size([batch_size, 2]):
            raise AssertionError(
                "If `size` is a tensor, it must be shaped as (B, 2). "
                f"Got {size.shape} while expecting {torch.Size([batch_size, 2])}."
            )
        if not (input_size[0] > 0 and input_size[1] > 0 and (size > 0).all()):
            raise AssertionError(
                f"Got non-positive input size or size. {input_size}, {size}.")
        size = size.floor()

        x_diff = input_size[1] - size[:, 1] + 1
        y_diff = input_size[0] - size[:, 0] + 1

        # Start point will be 0 if diff < 0
        x_diff = x_diff.clamp(0)
        y_diff = y_diff.clamp(0)

        if same_on_batch:
            # If same_on_batch, select the first then repeat.
            x_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) *
                       x_diff[0]).floor()
            y_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) *
                       y_diff[0]).floor()
        else:
            x_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) *
                       x_diff).floor()
            y_start = (_adapted_rsampling(
                (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) *
                       y_diff).floor()
        crop_src = bbox_generator(
            x_start.view(-1).to(device=_device, dtype=_dtype),
            y_start.view(-1).to(device=_device, dtype=_dtype),
            torch.where(
                size[:, 1] == 0,
                torch.tensor(input_size[1], device=_device, dtype=_dtype),
                size[:, 1]),
            torch.where(
                size[:, 0] == 0,
                torch.tensor(input_size[0], device=_device, dtype=_dtype),
                size[:, 0]),
        )

        if self.resize_to is None:
            crop_dst = bbox_generator(
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
                size[:, 1],
                size[:, 0],
            )
            _output_size = size.to(dtype=torch.long)
        else:
            if not (len(self.resize_to) == 2 and isinstance(
                    self.resize_to[0],
                (int, )) and isinstance(self.resize_to[1], (int, ))
                    and self.resize_to[0] > 0 and self.resize_to[1] > 0):
                raise AssertionError(
                    f"`resize_to` must be a tuple of 2 positive integers. Got {self.resize_to}."
                )
            crop_dst = torch.tensor(
                [[
                    [0, 0],
                    [self.resize_to[1] - 1, 0],
                    [self.resize_to[1] - 1, self.resize_to[0] - 1],
                    [0, self.resize_to[0] - 1],
                ]],
                device=_device,
                dtype=_dtype,
            ).repeat(batch_size, 1, 1)
            _output_size = torch.tensor(self.resize_to,
                                        device=_device,
                                        dtype=torch.long).expand(
                                            batch_size, -1)

        _input_size = torch.tensor(input_size,
                                   device=_device,
                                   dtype=torch.long).expand(batch_size, -1)

        return dict(src=crop_src,
                    dst=crop_dst,
                    input_size=_input_size,
                    output_size=_output_size)
示例#4
0
文件: cutmix.py 项目: kornia/kornia
def random_cutmix_generator(
    batch_size: int,
    width: int,
    height: int,
    p: float = 0.5,
    num_mix: int = 1,
    beta: Optional[torch.Tensor] = None,
    cut_size: Optional[torch.Tensor] = None,
    same_on_batch: bool = False,
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
    r"""Generate cutmix indexes and lambdas for a batch of inputs.

    Args:
        batch_size (int): the number of images. If batchsize == 1, the output will be as same as the input.
        width (int): image width.
        height (int): image height.
        p (float): probability of applying cutmix.
        num_mix (int): number of images to mix with. Default is 1.
        beta (torch.Tensor, optional): hyperparameter for generating cut size from beta distribution.
            If None, it will be set to 1.
        cut_size (torch.Tensor, optional): controlling the minimum and maximum cut ratio from [0, 1].
            If None, it will be set to [0, 1], which means no restriction.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.
        device (torch.device): the device on which the random numbers will be generated. Default: cpu.
        dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

    Returns:
        params Dict[str, torch.Tensor]: parameters to be passed for transformation.
            - mix_pairs (torch.Tensor): element-wise probabilities with a shape of (num_mix, B).
            - crop_src (torch.Tensor): element-wise probabilities with a shape of (num_mix, B, 4, 2).

    Note:
        The generated random numbers are not reproducible across different devices and dtypes.

    Examples:
        >>> rng = torch.manual_seed(0)
        >>> random_cutmix_generator(3, 224, 224, p=0.5, num_mix=2)
        {'mix_pairs': tensor([[2, 0, 1],
                [1, 2, 0]]), 'crop_src': tensor([[[[ 35.,  25.],
                  [208.,  25.],
                  [208., 198.],
                  [ 35., 198.]],
        <BLANKLINE>
                 [[156., 137.],
                  [155., 137.],
                  [155., 136.],
                  [156., 136.]],
        <BLANKLINE>
                 [[  3.,  12.],
                  [210.,  12.],
                  [210., 219.],
                  [  3., 219.]]],
        <BLANKLINE>
        <BLANKLINE>
                [[[ 83., 125.],
                  [177., 125.],
                  [177., 219.],
                  [ 83., 219.]],
        <BLANKLINE>
                 [[ 54.,   8.],
                  [205.,   8.],
                  [205., 159.],
                  [ 54., 159.]],
        <BLANKLINE>
                 [[ 97.,  70.],
                  [ 96.,  70.],
                  [ 96.,  69.],
                  [ 97.,  69.]]]])}
    """
    _device, _dtype = _extract_device_dtype([beta, cut_size])
    beta = torch.as_tensor(1.0 if beta is None else beta,
                           device=device,
                           dtype=dtype)
    cut_size = torch.as_tensor([0.0, 1.0] if cut_size is None else cut_size,
                               device=device,
                               dtype=dtype)
    if not (num_mix >= 1 and isinstance(num_mix, (int, ))):
        raise AssertionError(
            f"`num_mix` must be an integer greater than 1. Got {num_mix}.")
    if not (type(height) is int and height > 0 and type(width) is int
            and width > 0):
        raise AssertionError(
            f"'height' and 'width' must be integers. Got {height}, {width}.")
    _joint_range_check(cut_size, 'cut_size', bounds=(0, 1))
    _common_param_check(batch_size, same_on_batch)

    if batch_size == 0:
        return dict(
            mix_pairs=torch.zeros([0, 3], device=_device, dtype=torch.long),
            crop_src=torch.zeros([0, 4, 2], device=_device, dtype=torch.long),
        )

    batch_probs: torch.Tensor = random_prob_generator(batch_size * num_mix,
                                                      p,
                                                      same_on_batch,
                                                      device=device,
                                                      dtype=dtype)
    mix_pairs: torch.Tensor = torch.rand(num_mix,
                                         batch_size,
                                         device=device,
                                         dtype=dtype).argsort(dim=1)
    cutmix_betas: torch.Tensor = _adapted_beta((batch_size * num_mix, ),
                                               beta,
                                               beta,
                                               same_on_batch=same_on_batch)
    # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws:
    # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float"
    cutmix_betas = torch.min(torch.max(cutmix_betas, cut_size[0]), cut_size[1])
    cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs

    cut_height = (cutmix_rate * height).floor().to(device=device, dtype=_dtype)
    cut_width = (cutmix_rate * width).floor().to(device=device, dtype=_dtype)
    _gen_shape = (1, )

    if same_on_batch:
        _gen_shape = (cut_height.size(0), )
        cut_height = cut_height[0]
        cut_width = cut_width[0]

    # Reserve at least 1 pixel for cropping.
    x_start = (_adapted_uniform(
        _gen_shape,
        torch.zeros_like(cut_width, device=device, dtype=dtype),
        (width - cut_width - 1).to(device=device, dtype=dtype),
        same_on_batch,
    ).floor().to(device=device, dtype=_dtype))
    y_start = (_adapted_uniform(
        _gen_shape,
        torch.zeros_like(cut_height, device=device, dtype=dtype),
        (height - cut_height - 1).to(device=device, dtype=dtype),
        same_on_batch,
    ).floor().to(device=device, dtype=_dtype))

    crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(), cut_width,
                              cut_height)

    # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2)
    crop_src = crop_src.view(num_mix, batch_size, 4, 2)

    return dict(
        mix_pairs=mix_pairs.to(device=_device, dtype=torch.long),
        crop_src=crop_src.floor().to(device=_device, dtype=_dtype),
    )
示例#5
0
文件: crop.py 项目: kornia/kornia
def random_crop_generator(
    batch_size: int,
    input_size: Tuple[int, int],
    size: Union[Tuple[int, int], torch.Tensor],
    resize_to: Optional[Tuple[int, int]] = None,
    same_on_batch: bool = False,
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
    r"""Get parameters for ```crop``` transformation for crop transform.

    Args:
        batch_size (int): the tensor batch size.
        input_size (tuple): Input image shape, like (h, w).
        size (tuple): Desired size of the crop operation, like (h, w).
            If tensor, it must be (B, 2).
        resize_to (tuple): Desired output size of the crop, like (h, w). If None, no resize will be performed.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.
        device (torch.device): the device on which the random numbers will be generated. Default: cpu.
        dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

    Returns:
        params Dict[str, torch.Tensor]: parameters to be passed for transformation.
            - src (torch.Tensor): cropping bounding boxes with a shape of (B, 4, 2).
            - dst (torch.Tensor): output bounding boxes with a shape (B, 4, 2).

    Note:
        The generated random numbers are not reproducible across different devices and dtypes.

    Example:
        >>> _ = torch.manual_seed(0)
        >>> crop_size = torch.tensor([[25, 28], [27, 29], [26, 28]])
        >>> random_crop_generator(3, (30, 30), size=crop_size, same_on_batch=False)
        {'src': tensor([[[ 1.,  0.],
                 [28.,  0.],
                 [28., 24.],
                 [ 1., 24.]],
        <BLANKLINE>
                [[ 1.,  1.],
                 [29.,  1.],
                 [29., 27.],
                 [ 1., 27.]],
        <BLANKLINE>
                [[ 0.,  3.],
                 [27.,  3.],
                 [27., 28.],
                 [ 0., 28.]]]), 'dst': tensor([[[ 0.,  0.],
                 [27.,  0.],
                 [27., 24.],
                 [ 0., 24.]],
        <BLANKLINE>
                [[ 0.,  0.],
                 [28.,  0.],
                 [28., 26.],
                 [ 0., 26.]],
        <BLANKLINE>
                [[ 0.,  0.],
                 [27.,  0.],
                 [27., 25.],
                 [ 0., 25.]]]), 'input_size': tensor([[30, 30],
                [30, 30],
                [30, 30]])}
    """
    _common_param_check(batch_size, same_on_batch)
    _device, _dtype = _extract_device_dtype(
        [size if isinstance(size, torch.Tensor) else None])
    # Use float point instead
    _dtype = _dtype if _dtype in [torch.float16, torch.float32, torch.float64
                                  ] else dtype
    if not isinstance(size, torch.Tensor):
        size = torch.tensor(size, device=_device,
                            dtype=_dtype).repeat(batch_size, 1)
    else:
        size = size.to(device=_device, dtype=_dtype)
    if size.shape != torch.Size([batch_size, 2]):
        raise AssertionError(
            "If `size` is a tensor, it must be shaped as (B, 2). "
            f"Got {size.shape} while expecting {torch.Size([batch_size, 2])}.")
    if not (input_size[0] > 0 and input_size[1] > 0 and (size > 0).all()):
        raise AssertionError(
            f"Got non-positive input size or size. {input_size}, {size}.")
    size = size.floor()

    x_diff = input_size[1] - size[:, 1] + 1
    y_diff = input_size[0] - size[:, 0] + 1

    # Start point will be 0 if diff < 0
    x_diff = x_diff.clamp(0)
    y_diff = y_diff.clamp(0)

    if batch_size == 0:
        return dict(
            src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
            dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype),
        )

    if same_on_batch:
        # If same_on_batch, select the first then repeat.
        x_start = _adapted_uniform((batch_size, ), 0,
                                   x_diff[0].to(device=device, dtype=dtype),
                                   same_on_batch).floor()
        y_start = _adapted_uniform((batch_size, ), 0,
                                   y_diff[0].to(device=device, dtype=dtype),
                                   same_on_batch).floor()
    else:
        x_start = _adapted_uniform((1, ), 0,
                                   x_diff.to(device=device, dtype=dtype),
                                   same_on_batch).floor()
        y_start = _adapted_uniform((1, ), 0,
                                   y_diff.to(device=device, dtype=dtype),
                                   same_on_batch).floor()
    crop_src = bbox_generator(
        x_start.view(-1).to(device=_device, dtype=_dtype),
        y_start.view(-1).to(device=_device, dtype=_dtype),
        torch.where(size[:, 1] == 0,
                    torch.tensor(input_size[1], device=_device, dtype=_dtype),
                    size[:, 1]),
        torch.where(size[:, 0] == 0,
                    torch.tensor(input_size[0], device=_device, dtype=_dtype),
                    size[:, 0]),
    )

    if resize_to is None:
        crop_dst = bbox_generator(
            torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
            torch.tensor([0] * batch_size, device=_device, dtype=_dtype),
            size[:, 1],
            size[:, 0],
        )
    else:
        if not (len(resize_to) == 2 and isinstance(resize_to[0], (int, )) and
                isinstance(resize_to[1],
                           (int, )) and resize_to[0] > 0 and resize_to[1] > 0):
            raise AssertionError(
                f"`resize_to` must be a tuple of 2 positive integers. Got {resize_to}."
            )
        crop_dst = torch.tensor(
            [[[0, 0], [resize_to[1] - 1, 0],
              [resize_to[1] - 1, resize_to[0] - 1], [0, resize_to[0] - 1]]],
            device=_device,
            dtype=_dtype,
        ).repeat(batch_size, 1, 1)

    _input_size = torch.tensor(input_size, device=_device,
                               dtype=torch.long).expand(batch_size, -1)

    return dict(src=crop_src, dst=crop_dst, input_size=_input_size)