def bert_pad(coll, from_left, seq_length): ''' Perform zero padding ''' while len(coll) < seq_length: if from_left: coll = torch_pad(coll, [0, 0, 1, 0], mode='constant', value=0) else: coll = torch_pad(coll, [0, 0, 0, 1], mode='constant', value=0) return coll
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: if not (isinstance(img, torch.Tensor)): raise TypeError(f"img should be Tensor. Got {type(img)}") _assert_image_tensor(img) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( img, [ kernel.dtype, ], ) # padding = (left, right, top, bottom) padding = [ kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2 ] img = torch_pad(img, padding, mode="reflect") img = conv2d(img, kernel, groups=img.shape[-3]) img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: """PRIVATE METHOD. Performs Gaussian blurring on the img by given kernel. .. warning:: Module ``transforms.functional_tensor`` is private and should not be used in user application. Please, consider instead using methods from `transforms.functional` module. Args: img (Tensor): Image to be blurred kernel_size (sequence of int or int): Kernel size of the Gaussian kernel ``(kx, ky)``. sigma (sequence of float or float, optional): Standard deviation of the Gaussian kernel ``(sx, sy)``. Returns: Tensor: An image that is blurred using gaussian kernel of given parameters """ if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)): raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] img = torch_pad(img, padding, mode="reflect") img = conv2d(img, kernel, groups=img.shape[-3]) img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) return img
def pad_tensor(inp, pad, value=0): '''Util function for padding inp tensor. inp: input data pad: padding size value: padding value ''' return torch_pad(inp, [pad] * 4, 'constant', value)
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: _assert_image_tensor(img) if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") if isinstance(padding, tuple): padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: raise ValueError( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError( "Padding mode should be either constant, edge, reflect or symmetric" ) p = _parse_pad_padding(padding) if padding_mode == "edge": # remap padding_mode str padding_mode = "replicate" elif padding_mode == "symmetric": # route to another implementation return _pad_symmetric(img, p) need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True out_dtype = img.dtype need_cast = False if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): # Here we temporary cast input tensor to float # until pytorch issue is resolved : # https://github.com/pytorch/pytorch/issues/40763 need_cast = True img = img.to(torch.float32) img = torch_pad(img, p, mode=padding_mode, value=float(fill)) if need_squeeze: img = img.squeeze(dim=0) if need_cast: img = img.to(out_dtype) return img
def forward(self, input): """ This is the fully manual implementation of the forward and backward passes via the torch.autograd.Function. :param input: the input map (e.g., an image) :return: the result of 2D convolution """ # ctx, input, filter, bias, padding = (0, 0), stride = (1, 1), # args = None, out_size = None, is_manual = tensor([0]), # conv_index = None filter = self.weight # N - number of input maps (or images in the batch). # C - number of input channels. # H - height of the input map (e.g., height of an image). # W - width of the input map (e.g. width of an image). N, C, H, W = input.size() # F - number of filters. # C - number of channels in each filter. # HH - the height of the filter. # WW - the width of the filter (its length). F, C, HH, WW = filter.size() pad_filter_H = H - HH pad_filter_W = W - WW filter = torch_pad(filter, (0, pad_filter_W, 0, pad_filter_H), 'constant', 0) input = dct(input) filter = dct(filter) # permute from N, C, H, W to H, W, N, C input = input.permute(2, 3, 0, 1) # permute from F, C, H, W to H, W, C, F filter = filter.permute(2, 3, 1, 0) result = torch.matmul(input, filter) # permute from H, W, N, F to N, F, H, W result = result.permute(2, 3, 0, 1) result = idct(result) out_H, out_W = self.out_HW(H, W, HH, WW) result = result[..., :out_H, :out_W] if self.bias is not None: # Add the bias term for each filter (it has to be unsqueezed to # the dimension of the out to properly sum up the values). unsqueezed_bias = self.bias.unsqueeze(-1).unsqueeze(-1) result += unsqueezed_bias if (self.stride_H != 1 or self.stride_W != 1) and ( self.stride_type is StrideType.STANDARD): result = result[:, :, ::self.stride_H, ::self.stride_W] return result
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: _assert_image_tensor(img) if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") if isinstance(padding, tuple): padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: raise ValueError( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError( "Padding mode should be either constant, edge, reflect or symmetric" ) if isinstance(padding, int): if torch.jit.is_scripting(): # This maybe unreachable raise ValueError( "padding can't be an int while torchscripting, set it as a list [value, ]" ) pad_left = pad_right = pad_top = pad_bottom = padding elif len(padding) == 1: pad_left = pad_right = pad_top = pad_bottom = padding[0] elif len(padding) == 2: pad_left = pad_right = padding[0] pad_top = pad_bottom = padding[1] else: pad_left = padding[0] pad_top = padding[1] pad_right = padding[2] pad_bottom = padding[3] p = [pad_left, pad_right, pad_top, pad_bottom] if padding_mode == "edge": # remap padding_mode str padding_mode = "replicate" elif padding_mode == "symmetric": # route to another implementation return _pad_symmetric(img, p) need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True out_dtype = img.dtype need_cast = False if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): # Here we temporary cast input tensor to float # until pytorch issue is resolved : # https://github.com/pytorch/pytorch/issues/40763 need_cast = True img = img.to(torch.float32) img = torch_pad(img, p, mode=padding_mode, value=float(fill)) if need_squeeze: img = img.squeeze(dim=0) if need_cast: img = img.to(out_dtype) return img
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: r"""PRIVATE METHOD. Pad the given Tensor Image on all sides with specified padding mode and fill value. .. warning:: Module ``transforms.functional_tensor`` is private and should not be used in user application. Please, consider instead using methods from `transforms.functional` module. Args: img (Tensor): Image to be padded. padding (int or tuple or list): Padding on each border. If a single int is provided this is used to pad all borders. If a tuple or list of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple or list of length 4 is provided this is the padding for the left, top, right and bottom borders respectively. In torchscript mode padding as single int is not supported, use a tuple or list of length 1: ``[padding, ]``. fill (int): Pixel fill value for constant fill. Default is 0. This value is only used when the padding_mode is constant padding_mode (str): Type of padding. Should be: constant, edge or reflect. Default is constant. Mode symmetric is not yet supported for Tensor inputs. - constant: pads with a constant value, this value is specified with fill - edge: pads with the last value on the edge of the image - reflect: pads with reflection of image (without repeating the last value on the edge) padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode will result in [3, 2, 1, 2, 3, 4, 3, 2] - symmetric: pads with reflection of image (repeating the last value on the edge) padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode will result in [2, 1, 1, 2, 3, 4, 4, 3] Returns: Tensor: Padded image. """ if not _is_tensor_a_torch_image(img): raise TypeError("tensor is not a torch image.") if not isinstance(padding, (int, tuple, list)): raise TypeError("Got inappropriate padding arg") if not isinstance(fill, (int, float)): raise TypeError("Got inappropriate fill arg") if not isinstance(padding_mode, str): raise TypeError("Got inappropriate padding_mode arg") if isinstance(padding, tuple): padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, int): if torch.jit.is_scripting(): # This maybe unreachable raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") pad_left = pad_right = pad_top = pad_bottom = padding elif len(padding) == 1: pad_left = pad_right = pad_top = pad_bottom = padding[0] elif len(padding) == 2: pad_left = pad_right = padding[0] pad_top = pad_bottom = padding[1] else: pad_left = padding[0] pad_top = padding[1] pad_right = padding[2] pad_bottom = padding[3] p = [pad_left, pad_right, pad_top, pad_bottom] if padding_mode == "edge": # remap padding_mode str padding_mode = "replicate" elif padding_mode == "symmetric": # route to another implementation return _pad_symmetric(img, p) need_squeeze = False if img.ndim < 4: img = img.unsqueeze(dim=0) need_squeeze = True out_dtype = img.dtype need_cast = False if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64): # Here we temporary cast input tensor to float # until pytorch issue is resolved : # https://github.com/pytorch/pytorch/issues/40763 need_cast = True img = img.to(torch.float32) img = torch_pad(img, p, mode=padding_mode, value=float(fill)) if need_squeeze: img = img.squeeze(dim=0) if need_cast: img = img.to(out_dtype) return img
def forward(ctx, input, args, val=0, get_mask=get_hyper_mask, onesided=True): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. :param input: the input image :param args: arguments that define: compress_rate - the compression ratio, interpolate - the interpolation within mask: const, linear, exp, log, etc. :param val: the value (to change coefficients to) for the mask :onesided: should use the onesided FFT thanks to the conjugate symmetry or want to preserve all the coefficients """ # ctx.save_for_backward(input) # print("round forward") FFTBandFunctionComplexMask2D.mark_dirty(input) N, C, H, W = input.size() if H != W: raise Exception("We support only squared input.") if args.next_power2: H_fft = next_power2(H) W_fft = next_power2(W) pad_H = H_fft - H pad_W = W_fft - W input = torch_pad(input, (0, pad_W, 0, pad_H), 'constant', 0) else: H_fft = H W_fft = W xfft = torch.rfft(input, signal_ndim=FFTBandFunctionComplexMask2D.signal_ndim, onesided=onesided) del input _, _, H_xfft, W_xfft, _ = xfft.size() # assert H_fft == W_xfft, "The input tensor has to be squared." mask, _ = get_mask(H=H_xfft, W=W_xfft, compress_rate=args.compress_fft_layer, val=val, interpolate=args.interpolate, onesided=onesided) mask = mask[:, 0:W_xfft, :] # print(mask) mask = mask.to(xfft.dtype).to(xfft.device) xfft = xfft * mask if ctx is not None: ctx.xfft = xfft if args.is_DC_shift: ctx.xfft = shift_DC(xfft, onesided=onesided) # xfft = shift_DC(xfft, onesided=onesided, shift_to="center") # xfft = shift_DC(xfft, onesided=onesided, shift_to="corner") out = torch.irfft(input=xfft, signal_ndim=FFTBandFunctionComplexMask2D.signal_ndim, signal_sizes=(H_fft, W_fft), onesided=onesided) out = out[..., :H, :W] return out
def forward(ctx, input, args, onesided=True, is_test=False): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. :param input: the input image :param args: for compress rate and next_power2. :param onesided: FFT convolution leverages the conjugate symmetry and returns only roughly half of the FFT map, otherwise the full map is returned :param is_test: test if the number of zero-ed out coefficients is correct """ # ctx.save_for_backward(input) # print("round forward") FFTBandFunction2D.mark_dirty(input) N, C, H, W = input.size() if H != W: raise Exception(f"We support only squared input but the width: {W}" f" is differnt from height: {H}") if args.next_power2: H_fft = next_power2(H) W_fft = next_power2(W) pad_H = H_fft - H pad_W = W_fft - W input = torch_pad(input, (0, pad_W, 0, pad_H), 'constant', 0) else: H_fft = H W_fft = W xfft = torch.rfft(input, signal_ndim=FFTBandFunction2D.signal_ndim, onesided=onesided) del input _, _, H_xfft, W_xfft, _ = xfft.size() # r - is the side of the retained square in one of the quadrants # 4 * r ** 2 / (H * W) = (1 - c) # r = np.sqrt((1 - c) * (H * W) / 4) compress_rate = args.compress_rate / 100 if onesided: divisor = 2 else: divisor = 4 # r - is the length of the side that we retain after compression. r = np.sqrt((1 - compress_rate) * H_xfft * W_xfft / divisor) # r = np.floor(r) r = np.ceil(r) r = int(r) # zero out high energy coefficients if is_test: # We divide by 2 to not count zeros complex number twice. zero1 = torch.sum(xfft == 0.0).item() / 2 # print(zero1) xfft[..., r:H_fft - r, :, :] = 0.0 if onesided: xfft[..., :, r:, :] = 0.0 else: xfft[..., :, r:W_fft - r, :] = 0.0 if ctx is not None: ctx.xfft = xfft if args.is_DC_shift is True: ctx.xfft = shift_DC(xfft, onesided=onesided) if is_test: zero2 = torch.sum(xfft == 0.0).item() / 2 # print(zero2) total_size = C * H_xfft * W_xfft # print("total size: ", total_size) fraction_zeroed = (zero2 - zero1) / total_size ctx.fraction_zeroed = fraction_zeroed # print("compress rate: ", compress_rate, " fraction of zeroed out: ", fraction_zeroed) error = 0.08 if fraction_zeroed > compress_rate + error or ( fraction_zeroed < compress_rate - error): raise Exception( f"The compression is wrong, for compression " f"rate {compress_rate}, the number of fraction " f"of zeroed out coefficients " f"is: {fraction_zeroed}") # N, C, H_fft, W_fft = xfft out = torch.irfft(input=xfft, signal_ndim=FFTBandFunction2D.signal_ndim, signal_sizes=(H_fft, W_fft), onesided=onesided) out = out[..., :H, :W] return out