Exemplo n.º 1
0
    def __batch_prob_generator__(self, batch_shape: torch.Size, p: float,
                                 p_batch: float,
                                 same_on_batch: bool) -> Tensor:
        batch_prob: Tensor
        if p_batch == 1:
            batch_prob = torch.tensor([True])
        elif p_batch == 0:
            batch_prob = torch.tensor([False])
        else:
            batch_prob = _adapted_sampling((1, ), self._p_batch_gen,
                                           same_on_batch).bool()

        if batch_prob.sum().item() == 1:
            elem_prob: Tensor
            if p == 1:
                elem_prob = torch.tensor([True] * batch_shape[0])
            elif p == 0:
                elem_prob = torch.tensor([False] * batch_shape[0])
            else:
                elem_prob = _adapted_sampling((batch_shape[0], ), self._p_gen,
                                              same_on_batch).bool()
            batch_prob = batch_prob * elem_prob
        else:
            batch_prob = batch_prob.repeat(batch_shape[0])
        return batch_prob
Exemplo n.º 2
0
def random_prob_generator(
    batch_size: int,
    p: float = 0.5,
    same_on_batch: bool = False,
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    r"""Generate random probabilities for a batch of inputs.

    Args:
        batch_size (int): the number of images.
        p (float): probability to generate an 1-d binary mask. Default value is 0.5.
        same_on_batch (bool): apply the same transformation across the batch. Default: False.
        device (torch.device): the device on which the random numbers will be generated. Default: cpu.
        dtype (torch.dtype): the data type of the generated random numbers. Default: float32.

    Returns:
        torch.Tensor: parameters to be passed for transformation.
            - probs (torch.Tensor): element-wise probabilities with a shape of (B,).

    Note:
        The generated random numbers are not reproducible across different devices and dtypes.
    """
    _common_param_check(batch_size, same_on_batch)
    if not isinstance(p, (int, float)) or p > 1 or p < 0:
        raise TypeError(f"The probability should be a float number within [0, 1]. Got {type(p)}.")

    _bernoulli = Bernoulli(torch.tensor(float(p), device=device, dtype=dtype))
    probs_mask: torch.Tensor = _adapted_sampling((batch_size,), _bernoulli, same_on_batch).bool()

    return probs_mask
Exemplo n.º 3
0
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]

        _common_param_check(batch_size, same_on_batch)
        _device, _dtype = _extract_device_dtype([self.lambda_val])

        with torch.no_grad():
            batch_probs: torch.Tensor = _adapted_sampling(
                (batch_size, ), self.prob_sampler, same_on_batch)
        mixup_pairs: torch.Tensor = torch.randperm(batch_size,
                                                   device=_device,
                                                   dtype=_dtype).long()
        mixup_lambdas: torch.Tensor = _adapted_rsampling(
            (batch_size, ), self.lambda_sampler, same_on_batch)
        mixup_lambdas = mixup_lambdas * batch_probs

        return dict(
            mixup_pairs=mixup_pairs.to(device=_device, dtype=torch.long),
            mixup_lambdas=mixup_lambdas.to(device=_device, dtype=_dtype),
        )
Exemplo n.º 4
0
    def forward(self,
                batch_shape: torch.Size,
                same_on_batch: bool = False) -> Dict[str, torch.Tensor]:
        batch_size = batch_shape[0]
        height = batch_shape[-2]
        width = batch_shape[-1]

        if not (type(height) is int and height > 0 and type(width) is int
                and width > 0):
            raise AssertionError(
                f"'height' and 'width' must be integers. Got {height}, {width}."
            )
        _device, _dtype = _extract_device_dtype([self.beta, self.cut_size])
        _common_param_check(batch_size, same_on_batch)

        if batch_size == 0:
            return dict(
                mix_pairs=torch.zeros([0, 3], device=_device,
                                      dtype=torch.long),
                crop_src=torch.zeros([0, 4, 2],
                                     device=_device,
                                     dtype=torch.long),
            )

        with torch.no_grad():
            batch_probs: torch.Tensor = _adapted_sampling(
                (batch_size * self.num_mix, ), self.prob_sampler,
                same_on_batch)
        mix_pairs: torch.Tensor = torch.rand(self.num_mix,
                                             batch_size,
                                             device=_device,
                                             dtype=_dtype).argsort(dim=1)
        cutmix_betas: torch.Tensor = _adapted_rsampling(
            (batch_size * self.num_mix, ), self.beta_sampler, same_on_batch)

        # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws:
        # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float"
        cutmix_betas = torch.min(torch.max(cutmix_betas, self._cut_size[0]),
                                 self._cut_size[1])
        cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs

        cut_height = (cutmix_rate * height).floor().to(device=_device,
                                                       dtype=_dtype)
        cut_width = (cutmix_rate * width).floor().to(device=_device,
                                                     dtype=_dtype)
        _gen_shape = (1, )

        if same_on_batch:
            _gen_shape = (cut_height.size(0), )
            cut_height = cut_height[0]
            cut_width = cut_width[0]

        # Reserve at least 1 pixel for cropping.
        x_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (width - cut_width - 1)
        y_start: torch.Tensor = _adapted_rsampling(
            _gen_shape, self.rand_sampler,
            same_on_batch) * (height - cut_height - 1)
        x_start = x_start.floor().to(device=_device, dtype=_dtype)
        y_start = y_start.floor().to(device=_device, dtype=_dtype)

        crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(),
                                  cut_width, cut_height)

        # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2)
        crop_src = crop_src.view(self.num_mix, batch_size, 4, 2)

        return dict(
            mix_pairs=mix_pairs.to(device=_device, dtype=torch.long),
            crop_src=crop_src.floor().to(device=_device, dtype=_dtype),
        )
Exemplo n.º 5
0
 def forward(self, batch_shape: torch.Size, same_on_batch: bool = False) -> Dict[str, torch.Tensor]:  # type:ignore
     batch_size = batch_shape[0]
     probs_mask: torch.Tensor = _adapted_sampling((batch_size,), self.sampler, same_on_batch).bool()
     return dict(probs=probs_mask)