def apply_transform(self, input: Tensor, params: Dict[str, Tensor], transform: Optional[Tensor] = None) -> Tensor: _, c, h, w = input.size() values = params["values"].unsqueeze(-1).unsqueeze(-1).unsqueeze( -1).repeat(1, *input.shape[1:]).to(input) bboxes = bbox_generator(params["xs"], params["ys"], params["widths"], params["heights"]) mask = bbox_to_mask(bboxes, w, h) # Returns B, H, W mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input) # Transform to B, c, H, W transformed = torch.where(mask == 1.0, values, input) return transformed
def forward(self, batch_shape: torch.Size, same_on_batch: bool = False) -> Dict[str, torch.Tensor]: batch_size = batch_shape[0] height = batch_shape[-2] width = batch_shape[-1] if not (type(height) is int and height > 0 and type(width) is int and width > 0): raise AssertionError( f"'height' and 'width' must be integers. Got {height}, {width}." ) _device, _dtype = _extract_device_dtype([self.beta, self.cut_size]) _common_param_check(batch_size, same_on_batch) if batch_size == 0: return dict( mix_pairs=torch.zeros([0, 3], device=_device, dtype=torch.long), crop_src=torch.zeros([0, 4, 2], device=_device, dtype=torch.long), ) with torch.no_grad(): batch_probs: torch.Tensor = _adapted_sampling( (batch_size * self.num_mix, ), self.prob_sampler, same_on_batch) mix_pairs: torch.Tensor = torch.rand(self.num_mix, batch_size, device=_device, dtype=_dtype).argsort(dim=1) cutmix_betas: torch.Tensor = _adapted_rsampling( (batch_size * self.num_mix, ), self.beta_sampler, same_on_batch) # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws: # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float" cutmix_betas = torch.min(torch.max(cutmix_betas, self._cut_size[0]), self._cut_size[1]) cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs cut_height = (cutmix_rate * height).floor().to(device=_device, dtype=_dtype) cut_width = (cutmix_rate * width).floor().to(device=_device, dtype=_dtype) _gen_shape = (1, ) if same_on_batch: _gen_shape = (cut_height.size(0), ) cut_height = cut_height[0] cut_width = cut_width[0] # Reserve at least 1 pixel for cropping. x_start: torch.Tensor = _adapted_rsampling( _gen_shape, self.rand_sampler, same_on_batch) * (width - cut_width - 1) y_start: torch.Tensor = _adapted_rsampling( _gen_shape, self.rand_sampler, same_on_batch) * (height - cut_height - 1) x_start = x_start.floor().to(device=_device, dtype=_dtype) y_start = y_start.floor().to(device=_device, dtype=_dtype) crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(), cut_width, cut_height) # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2) crop_src = crop_src.view(self.num_mix, batch_size, 4, 2) return dict( mix_pairs=mix_pairs.to(device=_device, dtype=torch.long), crop_src=crop_src.floor().to(device=_device, dtype=_dtype), )
def forward( self, batch_shape: torch.Size, same_on_batch: bool = False ) -> Dict[str, torch.Tensor]: # type:ignore batch_size = batch_shape[0] _common_param_check(batch_size, same_on_batch) _device, _dtype = _extract_device_dtype( [self.size if isinstance(self.size, torch.Tensor) else None]) if batch_size == 0: return dict( src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype), dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype), ) input_size = (batch_shape[-2], batch_shape[-1]) if not isinstance(self.size, torch.Tensor): size = torch.tensor(self.size, device=_device, dtype=_dtype).repeat(batch_size, 1) else: size = self.size.to(device=_device, dtype=_dtype) if size.shape != torch.Size([batch_size, 2]): raise AssertionError( "If `size` is a tensor, it must be shaped as (B, 2). " f"Got {size.shape} while expecting {torch.Size([batch_size, 2])}." ) if not (input_size[0] > 0 and input_size[1] > 0 and (size > 0).all()): raise AssertionError( f"Got non-positive input size or size. {input_size}, {size}.") size = size.floor() x_diff = input_size[1] - size[:, 1] + 1 y_diff = input_size[0] - size[:, 0] + 1 # Start point will be 0 if diff < 0 x_diff = x_diff.clamp(0) y_diff = y_diff.clamp(0) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = (_adapted_rsampling( (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) * x_diff[0]).floor() y_start = (_adapted_rsampling( (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) * y_diff[0]).floor() else: x_start = (_adapted_rsampling( (batch_size, ), self.rand_sampler, same_on_batch).to(x_diff) * x_diff).floor() y_start = (_adapted_rsampling( (batch_size, ), self.rand_sampler, same_on_batch).to(y_diff) * y_diff).floor() crop_src = bbox_generator( x_start.view(-1).to(device=_device, dtype=_dtype), y_start.view(-1).to(device=_device, dtype=_dtype), torch.where( size[:, 1] == 0, torch.tensor(input_size[1], device=_device, dtype=_dtype), size[:, 1]), torch.where( size[:, 0] == 0, torch.tensor(input_size[0], device=_device, dtype=_dtype), size[:, 0]), ) if self.resize_to is None: crop_dst = bbox_generator( torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), size[:, 1], size[:, 0], ) _output_size = size.to(dtype=torch.long) else: if not (len(self.resize_to) == 2 and isinstance( self.resize_to[0], (int, )) and isinstance(self.resize_to[1], (int, )) and self.resize_to[0] > 0 and self.resize_to[1] > 0): raise AssertionError( f"`resize_to` must be a tuple of 2 positive integers. Got {self.resize_to}." ) crop_dst = torch.tensor( [[ [0, 0], [self.resize_to[1] - 1, 0], [self.resize_to[1] - 1, self.resize_to[0] - 1], [0, self.resize_to[0] - 1], ]], device=_device, dtype=_dtype, ).repeat(batch_size, 1, 1) _output_size = torch.tensor(self.resize_to, device=_device, dtype=torch.long).expand( batch_size, -1) _input_size = torch.tensor(input_size, device=_device, dtype=torch.long).expand(batch_size, -1) return dict(src=crop_src, dst=crop_dst, input_size=_input_size, output_size=_output_size)
def random_cutmix_generator( batch_size: int, width: int, height: int, p: float = 0.5, num_mix: int = 1, beta: Optional[torch.Tensor] = None, cut_size: Optional[torch.Tensor] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Generate cutmix indexes and lambdas for a batch of inputs. Args: batch_size (int): the number of images. If batchsize == 1, the output will be as same as the input. width (int): image width. height (int): image height. p (float): probability of applying cutmix. num_mix (int): number of images to mix with. Default is 1. beta (torch.Tensor, optional): hyperparameter for generating cut size from beta distribution. If None, it will be set to 1. cut_size (torch.Tensor, optional): controlling the minimum and maximum cut ratio from [0, 1]. If None, it will be set to [0, 1], which means no restriction. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - mix_pairs (torch.Tensor): element-wise probabilities with a shape of (num_mix, B). - crop_src (torch.Tensor): element-wise probabilities with a shape of (num_mix, B, 4, 2). Note: The generated random numbers are not reproducible across different devices and dtypes. Examples: >>> rng = torch.manual_seed(0) >>> random_cutmix_generator(3, 224, 224, p=0.5, num_mix=2) {'mix_pairs': tensor([[2, 0, 1], [1, 2, 0]]), 'crop_src': tensor([[[[ 35., 25.], [208., 25.], [208., 198.], [ 35., 198.]], <BLANKLINE> [[156., 137.], [155., 137.], [155., 136.], [156., 136.]], <BLANKLINE> [[ 3., 12.], [210., 12.], [210., 219.], [ 3., 219.]]], <BLANKLINE> <BLANKLINE> [[[ 83., 125.], [177., 125.], [177., 219.], [ 83., 219.]], <BLANKLINE> [[ 54., 8.], [205., 8.], [205., 159.], [ 54., 159.]], <BLANKLINE> [[ 97., 70.], [ 96., 70.], [ 96., 69.], [ 97., 69.]]]])} """ _device, _dtype = _extract_device_dtype([beta, cut_size]) beta = torch.as_tensor(1.0 if beta is None else beta, device=device, dtype=dtype) cut_size = torch.as_tensor([0.0, 1.0] if cut_size is None else cut_size, device=device, dtype=dtype) if not (num_mix >= 1 and isinstance(num_mix, (int, ))): raise AssertionError( f"`num_mix` must be an integer greater than 1. Got {num_mix}.") if not (type(height) is int and height > 0 and type(width) is int and width > 0): raise AssertionError( f"'height' and 'width' must be integers. Got {height}, {width}.") _joint_range_check(cut_size, 'cut_size', bounds=(0, 1)) _common_param_check(batch_size, same_on_batch) if batch_size == 0: return dict( mix_pairs=torch.zeros([0, 3], device=_device, dtype=torch.long), crop_src=torch.zeros([0, 4, 2], device=_device, dtype=torch.long), ) batch_probs: torch.Tensor = random_prob_generator(batch_size * num_mix, p, same_on_batch, device=device, dtype=dtype) mix_pairs: torch.Tensor = torch.rand(num_mix, batch_size, device=device, dtype=dtype).argsort(dim=1) cutmix_betas: torch.Tensor = _adapted_beta((batch_size * num_mix, ), beta, beta, same_on_batch=same_on_batch) # Note: torch.clamp does not accept tensor, cutmix_betas.clamp(cut_size[0], cut_size[1]) throws: # Argument 1 to "clamp" of "_TensorBase" has incompatible type "Tensor"; expected "float" cutmix_betas = torch.min(torch.max(cutmix_betas, cut_size[0]), cut_size[1]) cutmix_rate = torch.sqrt(1.0 - cutmix_betas) * batch_probs cut_height = (cutmix_rate * height).floor().to(device=device, dtype=_dtype) cut_width = (cutmix_rate * width).floor().to(device=device, dtype=_dtype) _gen_shape = (1, ) if same_on_batch: _gen_shape = (cut_height.size(0), ) cut_height = cut_height[0] cut_width = cut_width[0] # Reserve at least 1 pixel for cropping. x_start = (_adapted_uniform( _gen_shape, torch.zeros_like(cut_width, device=device, dtype=dtype), (width - cut_width - 1).to(device=device, dtype=dtype), same_on_batch, ).floor().to(device=device, dtype=_dtype)) y_start = (_adapted_uniform( _gen_shape, torch.zeros_like(cut_height, device=device, dtype=dtype), (height - cut_height - 1).to(device=device, dtype=dtype), same_on_batch, ).floor().to(device=device, dtype=_dtype)) crop_src = bbox_generator(x_start.squeeze(), y_start.squeeze(), cut_width, cut_height) # (B * num_mix, 4, 2) => (num_mix, batch_size, 4, 2) crop_src = crop_src.view(num_mix, batch_size, 4, 2) return dict( mix_pairs=mix_pairs.to(device=_device, dtype=torch.long), crop_src=crop_src.floor().to(device=_device, dtype=_dtype), )
def random_crop_generator( batch_size: int, input_size: Tuple[int, int], size: Union[Tuple[int, int], torch.Tensor], resize_to: Optional[Tuple[int, int]] = None, same_on_batch: bool = False, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32, ) -> Dict[str, torch.Tensor]: r"""Get parameters for ```crop``` transformation for crop transform. Args: batch_size (int): the tensor batch size. input_size (tuple): Input image shape, like (h, w). size (tuple): Desired size of the crop operation, like (h, w). If tensor, it must be (B, 2). resize_to (tuple): Desired output size of the crop, like (h, w). If None, no resize will be performed. same_on_batch (bool): apply the same transformation across the batch. Default: False. device (torch.device): the device on which the random numbers will be generated. Default: cpu. dtype (torch.dtype): the data type of the generated random numbers. Default: float32. Returns: params Dict[str, torch.Tensor]: parameters to be passed for transformation. - src (torch.Tensor): cropping bounding boxes with a shape of (B, 4, 2). - dst (torch.Tensor): output bounding boxes with a shape (B, 4, 2). Note: The generated random numbers are not reproducible across different devices and dtypes. Example: >>> _ = torch.manual_seed(0) >>> crop_size = torch.tensor([[25, 28], [27, 29], [26, 28]]) >>> random_crop_generator(3, (30, 30), size=crop_size, same_on_batch=False) {'src': tensor([[[ 1., 0.], [28., 0.], [28., 24.], [ 1., 24.]], <BLANKLINE> [[ 1., 1.], [29., 1.], [29., 27.], [ 1., 27.]], <BLANKLINE> [[ 0., 3.], [27., 3.], [27., 28.], [ 0., 28.]]]), 'dst': tensor([[[ 0., 0.], [27., 0.], [27., 24.], [ 0., 24.]], <BLANKLINE> [[ 0., 0.], [28., 0.], [28., 26.], [ 0., 26.]], <BLANKLINE> [[ 0., 0.], [27., 0.], [27., 25.], [ 0., 25.]]]), 'input_size': tensor([[30, 30], [30, 30], [30, 30]])} """ _common_param_check(batch_size, same_on_batch) _device, _dtype = _extract_device_dtype( [size if isinstance(size, torch.Tensor) else None]) # Use float point instead _dtype = _dtype if _dtype in [torch.float16, torch.float32, torch.float64 ] else dtype if not isinstance(size, torch.Tensor): size = torch.tensor(size, device=_device, dtype=_dtype).repeat(batch_size, 1) else: size = size.to(device=_device, dtype=_dtype) if size.shape != torch.Size([batch_size, 2]): raise AssertionError( "If `size` is a tensor, it must be shaped as (B, 2). " f"Got {size.shape} while expecting {torch.Size([batch_size, 2])}.") if not (input_size[0] > 0 and input_size[1] > 0 and (size > 0).all()): raise AssertionError( f"Got non-positive input size or size. {input_size}, {size}.") size = size.floor() x_diff = input_size[1] - size[:, 1] + 1 y_diff = input_size[0] - size[:, 0] + 1 # Start point will be 0 if diff < 0 x_diff = x_diff.clamp(0) y_diff = y_diff.clamp(0) if batch_size == 0: return dict( src=torch.zeros([0, 4, 2], device=_device, dtype=_dtype), dst=torch.zeros([0, 4, 2], device=_device, dtype=_dtype), ) if same_on_batch: # If same_on_batch, select the first then repeat. x_start = _adapted_uniform((batch_size, ), 0, x_diff[0].to(device=device, dtype=dtype), same_on_batch).floor() y_start = _adapted_uniform((batch_size, ), 0, y_diff[0].to(device=device, dtype=dtype), same_on_batch).floor() else: x_start = _adapted_uniform((1, ), 0, x_diff.to(device=device, dtype=dtype), same_on_batch).floor() y_start = _adapted_uniform((1, ), 0, y_diff.to(device=device, dtype=dtype), same_on_batch).floor() crop_src = bbox_generator( x_start.view(-1).to(device=_device, dtype=_dtype), y_start.view(-1).to(device=_device, dtype=_dtype), torch.where(size[:, 1] == 0, torch.tensor(input_size[1], device=_device, dtype=_dtype), size[:, 1]), torch.where(size[:, 0] == 0, torch.tensor(input_size[0], device=_device, dtype=_dtype), size[:, 0]), ) if resize_to is None: crop_dst = bbox_generator( torch.tensor([0] * batch_size, device=_device, dtype=_dtype), torch.tensor([0] * batch_size, device=_device, dtype=_dtype), size[:, 1], size[:, 0], ) else: if not (len(resize_to) == 2 and isinstance(resize_to[0], (int, )) and isinstance(resize_to[1], (int, )) and resize_to[0] > 0 and resize_to[1] > 0): raise AssertionError( f"`resize_to` must be a tuple of 2 positive integers. Got {resize_to}." ) crop_dst = torch.tensor( [[[0, 0], [resize_to[1] - 1, 0], [resize_to[1] - 1, resize_to[0] - 1], [0, resize_to[0] - 1]]], device=_device, dtype=_dtype, ).repeat(batch_size, 1, 1) _input_size = torch.tensor(input_size, device=_device, dtype=torch.long).expand(batch_size, -1) return dict(src=crop_src, dst=crop_dst, input_size=_input_size)