def apply_transform( # type: ignore self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor] # type: ignore ) -> Tuple[torch.Tensor, torch.Tensor]: height, width = input.size(2), input.size(3) num_mixes = params['mix_pairs'].size(0) batch_size = params['mix_pairs'].size(1) _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs') _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src') out_inputs = input.clone() out_labels = [] for pair, crop in zip(params['mix_pairs'], params['crop_src']): input_permute = input.index_select(dim=0, index=pair.to(input.device)) labels_permute = label.index_select(dim=0, index=pair.to(label.device)) w, h = infer_box_shape(crop) lam = w.to(input.dtype) * h.to(input.dtype) / (width * height) # width_beta * height_beta # compute mask to match input shape mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1) out_inputs[mask] = input_permute[mask] out_labels.append( torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1) ) return out_inputs, torch.stack(out_labels, dim=0)
def apply_erase_rectangles(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor: r"""Apply rectangle erase by params. Generate a {0, 1} mask with drawed rectangle having parameters defined by params and size by input.size(). Args: input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, H, W)`. params (Dict[str, torch.Tensor]): - params['widths']: widths tensor - params['heights']: heights tensor - params['xs']: x positions tensor - params['ys']: y positions tensor - params['values']: the value to fill in Returns: torch.Tensor: Erased image with shape :math:`(B, C, H, W)`. """ if not (params['widths'].size() == params['heights'].size() == params['xs'].size() == params['ys'].size()): raise TypeError( "rectangle params components must have same shape. " f"Got ({params['widths'].size()}, {params['heights'].size()}) " f"and ({params['xs'].size()}, {params['ys'].size()})" ) values = params['values'].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, *input.shape[1:]).to(input) _, c, h, w = input.size() bboxes = bbox_generator(params['xs'], params['ys'], params['widths'], params['heights']) mask = bbox_to_mask(bboxes, w, h) # Returns B, H, W mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input) # Transform to B, c, H, W transformed = torch.where(mask == 1., values, input) return transformed
def apply_cutmix( input: torch.Tensor, labels: torch.Tensor, params: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply cutmix to images in a batch. CutMix augmentation strategy: patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches. Args: input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W). labels (torch.Tensor): Label tensor with shape (B,). params (Dict[str, torch.Tensor]): - params['mix_pairs']: Mixup indexes with shape (num_mixes, B). - params['crop_src']: Lambda for the mixup strength (num_mixes, B, 4, 2). Returns: Tuple[torch.Tensor, torch.Tensor]: - Adjusted image, shape of :math:`(B, C, H, W)`. - Corresponding labels and lambdas for each mix, shape of :math:`(num_mixes, B, 2)`. Examples: >>> input = torch.stack([torch.zeros(1, 5, 5), torch.ones(1, 5, 5)], dim=0) >>> labels = torch.tensor([0, 1]) >>> params = {'mix_pairs': torch.tensor([[1, 0]]), 'crop_src': torch.tensor([[[ ... [1., 1.], ... [2., 1.], ... [2., 2.], ... [1., 2.]], ... [[1., 1.], ... [3., 1.], ... [3., 2.], ... [1., 2.]]]])} >>> apply_cutmix(input, labels, params) (tensor([[[[0., 0., 0., 0., 0.], [0., 1., 1., 0., 0.], [0., 1., 1., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]]], <BLANKLINE> <BLANKLINE> [[[1., 1., 1., 1., 1.], [1., 0., 0., 0., 1.], [1., 0., 0., 0., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]]]), tensor([[[0.0000, 1.0000, 0.1600], [1.0000, 0.0000, 0.2400]]])) """ height, width = input.size(2), input.size(3) num_mixes = params['mix_pairs'].size(0) batch_size = params['mix_pairs'].size(1) _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs') _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src') out_inputs = input.clone() out_labels = [] for pair, crop in zip(params['mix_pairs'], params['crop_src']): input_permute = input.index_select(dim=0, index=pair.to(input.device)) labels_permute = labels.index_select(dim=0, index=pair.to(labels.device)) w, h = infer_box_shape(crop) lam = w.to(input.dtype) * h.to(input.dtype) / ( width * height) # width_beta * height_beta # compute mask to match input shape mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat( 1, input.size(1), 1, 1) out_inputs[mask] = input_permute[mask] out_labels.append( torch.stack([ labels.to(input.dtype), labels_permute.to(input.dtype), lam.to(labels.device) ], dim=1)) return out_inputs, torch.stack(out_labels, dim=0)