예제 #1
0
def bert_pad(coll, from_left, seq_length):
    '''
        Perform zero padding
    '''
    while len(coll) < seq_length:
        if from_left:
            coll = torch_pad(coll, [0, 0, 1, 0], mode='constant', value=0)
        else:
            coll = torch_pad(coll, [0, 0, 0, 1], mode='constant', value=0)
    return coll
예제 #2
0
def gaussian_blur(img: Tensor, kernel_size: List[int],
                  sigma: List[float]) -> Tensor:
    if not (isinstance(img, torch.Tensor)):
        raise TypeError(f"img should be Tensor. Got {type(img)}")

    _assert_image_tensor(img)

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size,
                                    sigma,
                                    dtype=dtype,
                                    device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img,
        [
            kernel.dtype,
        ],
    )

    # padding = (left, right, top, bottom)
    padding = [
        kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2,
        kernel_size[1] // 2
    ]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
예제 #3
0
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor:
    """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel.

    .. warning::

        Module ``transforms.functional_tensor`` is private and should not be used in user application.
        Please, consider instead using methods from `transforms.functional` module.

    Args:
        img (Tensor): Image to be blurred
        kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``.
        sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``.

    Returns:
        Tensor: An image that is blurred using gaussian kernel of given parameters
    """
    if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
        raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

    dtype = img.dtype if torch.is_floating_point(img) else torch.float32
    kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])

    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])

    # padding = (left, right, top, bottom)
    padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
    img = torch_pad(img, padding, mode="reflect")
    img = conv2d(img, kernel, groups=img.shape[-3])

    img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
    return img
예제 #4
0
def pad_tensor(inp, pad, value=0):
    '''Util function for padding inp tensor.
        inp: input data
        pad: padding size
        value: padding value
    '''
    return torch_pad(inp, [pad] * 4, 'constant', value)
예제 #5
0
def pad(img: Tensor,
        padding: List[int],
        fill: int = 0,
        padding_mode: str = "constant") -> Tensor:
    _assert_image_tensor(img)

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
        raise ValueError(
            f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
        )

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError(
            "Padding mode should be either constant, edge, reflect or symmetric"
        )

    p = _parse_pad_padding(padding)

    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32,
                                                          torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

    img = torch_pad(img, p, mode=padding_mode, value=float(fill))

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

    return img
예제 #6
0
    def forward(self, input):
        """
        This is the fully manual implementation of the forward and backward
        passes via the torch.autograd.Function.

        :param input: the input map (e.g., an image)
        :return: the result of 2D convolution
        """
        # ctx, input, filter, bias, padding = (0, 0), stride = (1, 1),
        # args = None, out_size = None, is_manual = tensor([0]),
        # conv_index = None
        filter = self.weight
        # N - number of input maps (or images in the batch).
        # C - number of input channels.
        # H - height of the input map (e.g., height of an image).
        # W - width of the input map (e.g. width of an image).
        N, C, H, W = input.size()

        # F - number of filters.
        # C - number of channels in each filter.
        # HH - the height of the filter.
        # WW - the width of the filter (its length).
        F, C, HH, WW = filter.size()

        pad_filter_H = H - HH
        pad_filter_W = W - WW

        filter = torch_pad(filter, (0, pad_filter_W, 0, pad_filter_H),
                           'constant', 0)

        input = dct(input)
        filter = dct(filter)
        # permute from N, C, H, W to H, W, N, C
        input = input.permute(2, 3, 0, 1)
        # permute from F, C, H, W to H, W, C, F
        filter = filter.permute(2, 3, 1, 0)
        result = torch.matmul(input, filter)
        # permute from H, W, N, F to N, F, H, W
        result = result.permute(2, 3, 0, 1)
        result = idct(result)
        out_H, out_W = self.out_HW(H, W, HH, WW)
        result = result[..., :out_H, :out_W]
        if self.bias is not None:
            # Add the bias term for each filter (it has to be unsqueezed to
            # the dimension of the out to properly sum up the values).
            unsqueezed_bias = self.bias.unsqueeze(-1).unsqueeze(-1)
            result += unsqueezed_bias
        if (self.stride_H != 1 or self.stride_W != 1) and (
                self.stride_type is StrideType.STANDARD):
            result = result[:, :, ::self.stride_H, ::self.stride_W]
        return result
예제 #7
0
def pad(img: Tensor,
        padding: List[int],
        fill: int = 0,
        padding_mode: str = "constant") -> Tensor:
    _assert_image_tensor(img)

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
        raise ValueError(
            f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
        )

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError(
            "Padding mode should be either constant, edge, reflect or symmetric"
        )

    if isinstance(padding, int):
        if torch.jit.is_scripting():
            # This maybe unreachable
            raise ValueError(
                "padding can't be an int while torchscripting, set it as a list [value, ]"
            )
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    p = [pad_left, pad_right, pad_top, pad_bottom]

    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32,
                                                          torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

    img = torch_pad(img, p, mode=padding_mode, value=float(fill))

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

    return img
예제 #8
0
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
    r"""PRIVATE METHOD. Pad the given Tensor Image on all sides with specified padding mode and fill value.

    .. warning::

        Module ``transforms.functional_tensor`` is private and should not be used in user application.
        Please, consider instead using methods from `transforms.functional` module.

    Args:
        img (Tensor): Image to be padded.
        padding (int or tuple or list): Padding on each border. If a single int is provided this
            is used to pad all borders. If a tuple or list of length 2 is provided this is the padding
            on left/right and top/bottom respectively. If a tuple or list of length 4 is provided
            this is the padding for the left, top, right and bottom borders
            respectively. In torchscript mode padding as single int is not supported, use a tuple or
            list of length 1: ``[padding, ]``.
        fill (int): Pixel fill value for constant fill. Default is 0.
            This value is only used when the padding_mode is constant
        padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant.
            Mode symmetric is not yet supported for Tensor inputs.

            - constant: pads with a constant value, this value is specified with fill

            - edge: pads with the last value on the edge of the image

            - reflect: pads with reflection of image (without repeating the last value on the edge)

                       padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
                       will result in [3, 2, 1, 2, 3, 4, 3, 2]

            - symmetric: pads with reflection of image (repeating the last value on the edge)

                         padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
                         will result in [2, 1, 1, 2, 3, 4, 4, 3]

    Returns:
        Tensor: Padded image.
    """
    if not _is_tensor_a_torch_image(img):
        raise TypeError("tensor is not a torch image.")

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
        raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
                         "{} element tuple".format(len(padding)))

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

    if isinstance(padding, int):
        if torch.jit.is_scripting():
            # This maybe unreachable
            raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    p = [pad_left, pad_right, pad_top, pad_bottom]

    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
        # Here we temporary cast input tensor to float
        # until pytorch issue is resolved :
        # https://github.com/pytorch/pytorch/issues/40763
        need_cast = True
        img = img.to(torch.float32)

    img = torch_pad(img, p, mode=padding_mode, value=float(fill))

    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        img = img.to(out_dtype)

    return img
예제 #9
0
    def forward(ctx,
                input,
                args,
                val=0,
                get_mask=get_hyper_mask,
                onesided=True):
        """
        In the forward pass we receive a Tensor containing the input
        and return a Tensor containing the output. ctx is a context
        object that can be used to stash information for backward
        computation. You can cache arbitrary objects for use in the
        backward pass using the ctx.save_for_backward method.

        :param input: the input image
        :param args: arguments that define: compress_rate - the compression 
        ratio, interpolate - the interpolation within mask: const, linear, exp,
        log, etc.
        :param val: the value (to change coefficients to) for the mask
        :onesided: should use the onesided FFT thanks to the conjugate symmetry
        or want to preserve all the coefficients
        """
        # ctx.save_for_backward(input)
        # print("round forward")
        FFTBandFunctionComplexMask2D.mark_dirty(input)

        N, C, H, W = input.size()

        if H != W:
            raise Exception("We support only squared input.")

        if args.next_power2:
            H_fft = next_power2(H)
            W_fft = next_power2(W)
            pad_H = H_fft - H
            pad_W = W_fft - W
            input = torch_pad(input, (0, pad_W, 0, pad_H), 'constant', 0)
        else:
            H_fft = H
            W_fft = W
        xfft = torch.rfft(input,
                          signal_ndim=FFTBandFunctionComplexMask2D.signal_ndim,
                          onesided=onesided)
        del input

        _, _, H_xfft, W_xfft, _ = xfft.size()
        # assert H_fft == W_xfft, "The input tensor has to be squared."

        mask, _ = get_mask(H=H_xfft,
                           W=W_xfft,
                           compress_rate=args.compress_fft_layer,
                           val=val,
                           interpolate=args.interpolate,
                           onesided=onesided)
        mask = mask[:, 0:W_xfft, :]
        # print(mask)
        mask = mask.to(xfft.dtype).to(xfft.device)
        xfft = xfft * mask

        if ctx is not None:
            ctx.xfft = xfft
            if args.is_DC_shift:
                ctx.xfft = shift_DC(xfft, onesided=onesided)

        # xfft = shift_DC(xfft, onesided=onesided, shift_to="center")
        # xfft = shift_DC(xfft, onesided=onesided, shift_to="corner")
        out = torch.irfft(input=xfft,
                          signal_ndim=FFTBandFunctionComplexMask2D.signal_ndim,
                          signal_sizes=(H_fft, W_fft),
                          onesided=onesided)
        out = out[..., :H, :W]
        return out
예제 #10
0
    def forward(ctx, input, args, onesided=True, is_test=False):
        """
        In the forward pass we receive a Tensor containing the input
        and return a Tensor containing the output. ctx is a context
        object that can be used to stash information for backward
        computation. You can cache arbitrary objects for use in the
        backward pass using the ctx.save_for_backward method.

        :param input: the input image
        :param args: for compress rate and next_power2.
        :param onesided: FFT convolution leverages the conjugate symmetry and
        returns only roughly half of the FFT map, otherwise the full map is
        returned
        :param is_test: test if the number of zero-ed out coefficients is
        correct
        """
        # ctx.save_for_backward(input)
        # print("round forward")
        FFTBandFunction2D.mark_dirty(input)

        N, C, H, W = input.size()

        if H != W:
            raise Exception(f"We support only squared input but the width: {W}"
                            f" is differnt from height: {H}")

        if args.next_power2:
            H_fft = next_power2(H)
            W_fft = next_power2(W)
            pad_H = H_fft - H
            pad_W = W_fft - W
            input = torch_pad(input, (0, pad_W, 0, pad_H), 'constant', 0)
        else:
            H_fft = H
            W_fft = W

        xfft = torch.rfft(input,
                          signal_ndim=FFTBandFunction2D.signal_ndim,
                          onesided=onesided)

        del input
        _, _, H_xfft, W_xfft, _ = xfft.size()

        # r - is the side of the retained square in one of the quadrants
        # 4 * r ** 2 / (H * W) = (1 - c)
        # r = np.sqrt((1 - c) * (H * W) / 4)

        compress_rate = args.compress_rate / 100

        if onesided:
            divisor = 2
        else:
            divisor = 4

        # r - is the length of the side that we retain after compression.
        r = np.sqrt((1 - compress_rate) * H_xfft * W_xfft / divisor)
        # r = np.floor(r)
        r = np.ceil(r)
        r = int(r)

        # zero out high energy coefficients
        if is_test:
            # We divide by 2 to not count zeros complex number twice.
            zero1 = torch.sum(xfft == 0.0).item() / 2
            # print(zero1)

        xfft[..., r:H_fft - r, :, :] = 0.0
        if onesided:
            xfft[..., :, r:, :] = 0.0
        else:
            xfft[..., :, r:W_fft - r, :] = 0.0

        if ctx is not None:
            ctx.xfft = xfft
            if args.is_DC_shift is True:
                ctx.xfft = shift_DC(xfft, onesided=onesided)

        if is_test:
            zero2 = torch.sum(xfft == 0.0).item() / 2
            # print(zero2)
            total_size = C * H_xfft * W_xfft
            # print("total size: ", total_size)
            fraction_zeroed = (zero2 - zero1) / total_size
            ctx.fraction_zeroed = fraction_zeroed
            # print("compress rate: ", compress_rate, " fraction of zeroed out: ", fraction_zeroed)
            error = 0.08
            if fraction_zeroed > compress_rate + error or (
                    fraction_zeroed < compress_rate - error):
                raise Exception(
                    f"The compression is wrong, for compression "
                    f"rate {compress_rate}, the number of fraction "
                    f"of zeroed out coefficients "
                    f"is: {fraction_zeroed}")

        # N, C, H_fft, W_fft = xfft
        out = torch.irfft(input=xfft,
                          signal_ndim=FFTBandFunction2D.signal_ndim,
                          signal_sizes=(H_fft, W_fft),
                          onesided=onesided)
        out = out[..., :H, :W]
        return out