Ejemplo n.º 1
0
    def apply_transform(  # type: ignore
        self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor]  # type: ignore
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        height, width = input.size(2), input.size(3)
        num_mixes = params['mix_pairs'].size(0)
        batch_size = params['mix_pairs'].size(1)

        _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs')
        _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src')

        out_inputs = input.clone()
        out_labels = []
        for pair, crop in zip(params['mix_pairs'], params['crop_src']):
            input_permute = input.index_select(dim=0, index=pair.to(input.device))
            labels_permute = label.index_select(dim=0, index=pair.to(label.device))
            w, h = infer_box_shape(crop)
            lam = w.to(input.dtype) * h.to(input.dtype) / (width * height)  # width_beta * height_beta
            # compute mask to match input shape
            mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1)
            out_inputs[mask] = input_permute[mask]
            out_labels.append(
                torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1)
            )

        return out_inputs, torch.stack(out_labels, dim=0)
Ejemplo n.º 2
0
def apply_cutmix(
        input: torch.Tensor, labels: torch.Tensor,
        params: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Apply cutmix to images in a batch.

    CutMix augmentation strategy: patches are cut and pasted among training images where the ground
    truth labels are also mixed proportionally to the area of the patches.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        labels (torch.Tensor): Label tensor with shape (B,).
        params (Dict[str, torch.Tensor]):
            - params['mix_pairs']: Mixup indexes with shape (num_mixes, B).
            - params['crop_src']: Lambda for the mixup strength (num_mixes, B, 4, 2).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
        - Adjusted image, shape of :math:`(B, C, H, W)`.
        - Corresponding labels and lambdas for each mix, shape of :math:`(num_mixes, B, 2)`.

    Examples:
        >>> input = torch.stack([torch.zeros(1, 5, 5), torch.ones(1, 5, 5)], dim=0)
        >>> labels = torch.tensor([0, 1])
        >>> params = {'mix_pairs': torch.tensor([[1, 0]]), 'crop_src': torch.tensor([[[
        ...        [1., 1.],
        ...        [2., 1.],
        ...        [2., 2.],
        ...        [1., 2.]],
        ...       [[1., 1.],
        ...        [3., 1.],
        ...        [3., 2.],
        ...        [1., 2.]]]])}
        >>> apply_cutmix(input, labels, params)
        (tensor([[[[0., 0., 0., 0., 0.],
                  [0., 1., 1., 0., 0.],
                  [0., 1., 1., 0., 0.],
                  [0., 0., 0., 0., 0.],
                  [0., 0., 0., 0., 0.]]],
        <BLANKLINE>
        <BLANKLINE>
                [[[1., 1., 1., 1., 1.],
                  [1., 0., 0., 0., 1.],
                  [1., 0., 0., 0., 1.],
                  [1., 1., 1., 1., 1.],
                  [1., 1., 1., 1., 1.]]]]), tensor([[[0.0000, 1.0000, 0.1600],
                 [1.0000, 0.0000, 0.2400]]]))
    """
    height, width = input.size(2), input.size(3)
    num_mixes = params['mix_pairs'].size(0)
    batch_size = params['mix_pairs'].size(1)

    _shape_validation(params['mix_pairs'], [num_mixes, batch_size],
                      'mix_pairs')
    _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2],
                      'crop_src')

    out_inputs = input.clone()
    out_labels = []
    for pair, crop in zip(params['mix_pairs'], params['crop_src']):
        input_permute = input.index_select(dim=0, index=pair.to(input.device))
        labels_permute = labels.index_select(dim=0,
                                             index=pair.to(labels.device))
        w, h = infer_box_shape(crop)
        lam = w.to(input.dtype) * h.to(input.dtype) / (
            width * height)  # width_beta * height_beta
        # compute mask to match input shape
        mask = bbox_to_mask(crop, width,
                            height).bool().unsqueeze(dim=1).repeat(
                                1, input.size(1), 1, 1)
        out_inputs[mask] = input_permute[mask]
        out_labels.append(
            torch.stack([
                labels.to(input.dtype),
                labels_permute.to(input.dtype),
                lam.to(labels.device)
            ],
                        dim=1))

    return out_inputs, torch.stack(out_labels, dim=0)