def apply_transform( # type: ignore self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor] # type: ignore ) -> Tuple[torch.Tensor, torch.Tensor]: height, width = input.size(2), input.size(3) num_mixes = params['mix_pairs'].size(0) batch_size = params['mix_pairs'].size(1) _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs') _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src') out_inputs = input.clone() out_labels = [] for pair, crop in zip(params['mix_pairs'], params['crop_src']): input_permute = input.index_select(dim=0, index=pair.to(input.device)) labels_permute = label.index_select(dim=0, index=pair.to(label.device)) w, h = infer_box_shape(crop) lam = w.to(input.dtype) * h.to(input.dtype) / (width * height) # width_beta * height_beta # compute mask to match input shape mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1) out_inputs[mask] = input_permute[mask] out_labels.append( torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1) ) return out_inputs, torch.stack(out_labels, dim=0)
def apply_cutmix( input: torch.Tensor, labels: torch.Tensor, params: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply cutmix to images in a batch. CutMix augmentation strategy: patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). labels (torch.Tensor): Label tensor with shape (B,). params (Dict[str, torch.Tensor]): - params['mix_pairs']: Mixup indexes with shape (num_mixes, B). - params['crop_src']: Lambda for the mixup strength (num_mixes, B, 4, 2). Returns: Tuple[torch.Tensor, torch.Tensor]: - Adjusted image, shape of :math:`(B, C, H, W)`. - Corresponding labels and lambdas for each mix, shape of :math:`(num_mixes, B, 2)`. Examples: >>> input = torch.stack([torch.zeros(1, 5, 5), torch.ones(1, 5, 5)], dim=0) >>> labels = torch.tensor([0, 1]) >>> params = {'mix_pairs': torch.tensor([[1, 0]]), 'crop_src': torch.tensor([[[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.]], ... [[1., 1.], ... [3., 1.], ... [3., 2.], ... [1., 2.]]]])} >>> apply_cutmix(input, labels, params) (tensor([[[[0., 0., 0., 0., 0.], [0., 1., 1., 0., 0.], [0., 1., 1., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]], <BLANKLINE> <BLANKLINE> [[[1., 1., 1., 1., 1.], [1., 0., 0., 0., 1.], [1., 0., 0., 0., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]]]), tensor([[[0.0000, 1.0000, 0.1600], [1.0000, 0.0000, 0.2400]]])) """ height, width = input.size(2), input.size(3) num_mixes = params['mix_pairs'].size(0) batch_size = params['mix_pairs'].size(1) _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs') _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src') out_inputs = input.clone() out_labels = [] for pair, crop in zip(params['mix_pairs'], params['crop_src']): input_permute = input.index_select(dim=0, index=pair.to(input.device)) labels_permute = labels.index_select(dim=0, index=pair.to(labels.device)) w, h = infer_box_shape(crop) lam = w.to(input.dtype) * h.to(input.dtype) / ( width * height) # width_beta * height_beta # compute mask to match input shape mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat( 1, input.size(1), 1, 1) out_inputs[mask] = input_permute[mask] out_labels.append( torch.stack([ labels.to(input.dtype), labels_permute.to(input.dtype), lam.to(labels.device) ], dim=1)) return out_inputs, torch.stack(out_labels, dim=0)