示例#1
0
    def apply_transform(  # type: ignore
        self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor]  # type: ignore
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        height, width = input.size(2), input.size(3)
        num_mixes = params['mix_pairs'].size(0)
        batch_size = params['mix_pairs'].size(1)

        _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs')
        _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src')

        out_inputs = input.clone()
        out_labels = []
        for pair, crop in zip(params['mix_pairs'], params['crop_src']):
            input_permute = input.index_select(dim=0, index=pair.to(input.device))
            labels_permute = label.index_select(dim=0, index=pair.to(label.device))
            w, h = infer_box_shape(crop)
            lam = w.to(input.dtype) * h.to(input.dtype) / (width * height)  # width_beta * height_beta
            # compute mask to match input shape
            mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1)
            out_inputs[mask] = input_permute[mask]
            out_labels.append(
                torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1)
            )

        return out_inputs, torch.stack(out_labels, dim=0)
示例#2
0
def apply_erase_rectangles(input: torch.Tensor, params: Dict[str, torch.Tensor]) -> torch.Tensor:
    r"""Apply rectangle erase by params.

    Generate a {0, 1} mask with drawed rectangle having parameters defined by params and size by input.size().

    Args:
        input (torch.Tensor): Tensor to be transformed with shape :math:`(*, C, H, W)`.
        params (Dict[str, torch.Tensor]):
            - params['widths']: widths tensor
            - params['heights']: heights tensor
            - params['xs']: x positions tensor
            - params['ys']: y positions tensor
            - params['values']: the value to fill in

    Returns:
        torch.Tensor: Erased image with shape :math:`(B, C, H, W)`.
    """
    if not (params['widths'].size() == params['heights'].size() == params['xs'].size() == params['ys'].size()):
        raise TypeError(
            "rectangle params components must have same shape. "
            f"Got ({params['widths'].size()}, {params['heights'].size()}) "
            f"and ({params['xs'].size()}, {params['ys'].size()})"
        )

    values = params['values'].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, *input.shape[1:]).to(input)

    _, c, h, w = input.size()

    bboxes = bbox_generator(params['xs'], params['ys'], params['widths'], params['heights'])
    mask = bbox_to_mask(bboxes, w, h)  # Returns B, H, W
    mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input)  # Transform to B, c, H, W
    transformed = torch.where(mask == 1., values, input)
    return transformed
示例#3
0
def apply_cutmix(
        input: torch.Tensor, labels: torch.Tensor,
        params: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Apply cutmix to images in a batch.

    CutMix augmentation strategy: patches are cut and pasted among training images where the ground
    truth labels are also mixed proportionally to the area of the patches.

    Args:
        input (torch.Tensor): Tensor to be transformed with shape (H, W), (C, H, W), (B, C, H, W).
        labels (torch.Tensor): Label tensor with shape (B,).
        params (Dict[str, torch.Tensor]):
            - params['mix_pairs']: Mixup indexes with shape (num_mixes, B).
            - params['crop_src']: Lambda for the mixup strength (num_mixes, B, 4, 2).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
        - Adjusted image, shape of :math:`(B, C, H, W)`.
        - Corresponding labels and lambdas for each mix, shape of :math:`(num_mixes, B, 2)`.

    Examples:
        >>> input = torch.stack([torch.zeros(1, 5, 5), torch.ones(1, 5, 5)], dim=0)
        >>> labels = torch.tensor([0, 1])
        >>> params = {'mix_pairs': torch.tensor([[1, 0]]), 'crop_src': torch.tensor([[[
        ...        [1., 1.],
        ...        [2., 1.],
        ...        [2., 2.],
        ...        [1., 2.]],
        ...       [[1., 1.],
        ...        [3., 1.],
        ...        [3., 2.],
        ...        [1., 2.]]]])}
        >>> apply_cutmix(input, labels, params)
        (tensor([[[[0., 0., 0., 0., 0.],
                  [0., 1., 1., 0., 0.],
                  [0., 1., 1., 0., 0.],
                  [0., 0., 0., 0., 0.],
                  [0., 0., 0., 0., 0.]]],
        <BLANKLINE>
        <BLANKLINE>
                [[[1., 1., 1., 1., 1.],
                  [1., 0., 0., 0., 1.],
                  [1., 0., 0., 0., 1.],
                  [1., 1., 1., 1., 1.],
                  [1., 1., 1., 1., 1.]]]]), tensor([[[0.0000, 1.0000, 0.1600],
                 [1.0000, 0.0000, 0.2400]]]))
    """
    height, width = input.size(2), input.size(3)
    num_mixes = params['mix_pairs'].size(0)
    batch_size = params['mix_pairs'].size(1)

    _shape_validation(params['mix_pairs'], [num_mixes, batch_size],
                      'mix_pairs')
    _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2],
                      'crop_src')

    out_inputs = input.clone()
    out_labels = []
    for pair, crop in zip(params['mix_pairs'], params['crop_src']):
        input_permute = input.index_select(dim=0, index=pair.to(input.device))
        labels_permute = labels.index_select(dim=0,
                                             index=pair.to(labels.device))
        w, h = infer_box_shape(crop)
        lam = w.to(input.dtype) * h.to(input.dtype) / (
            width * height)  # width_beta * height_beta
        # compute mask to match input shape
        mask = bbox_to_mask(crop, width,
                            height).bool().unsqueeze(dim=1).repeat(
                                1, input.size(1), 1, 1)
        out_inputs[mask] = input_permute[mask]
        out_labels.append(
            torch.stack([
                labels.to(input.dtype),
                labels_permute.to(input.dtype),
                lam.to(labels.device)
            ],
                        dim=1))

    return out_inputs, torch.stack(out_labels, dim=0)