def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH coefficients. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ yh = [] ll = x mode = lowlevel.mode_to_int(self.mode) # Do a multilevel transform for j in range(self.J): # Do 1 level of the transform ll, high = lowlevel.AFB2D.apply(ll, self.h0_col.clone(), self.h1_col.clone(), self.h0_row.clone(), self.h1_row.clone(), mode) yh.append(high) return ll, yh
def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should match the format returned by DWT1DForward. Returns: Reconstructed input of shape :math:`(N, C_{in}, L_{in})` Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ x0, highs = coeffs assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" mode = lowlevel.mode_to_int(self.mode) # Do a multilevel inverse transform for x1 in highs[::-1]: if x1 is None: x1 = torch.zeros_like(x0) # 'Unpad' added signal if x0.shape[-1] > x1.shape[-1]: x0 = x0[..., :-1] x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode) return x0
def forward(self, coeffs=None): """ Do the 2D DWT inverse reconstruction for a set of coefficients Parameters ---------- coeffs: tuple, torch.Tensor tuple of lowpass and bandpass coefficients, where yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward. If this input is a torch tensor, then this will assume that `coeffs` is the overwritten results returned by :func:`DWTForwardOverwrite <hyde.dwt3d.DWTForwardOverwrite>` Returns ------- torch.Tensor Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Notes ----- - :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. - Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ yl, yh = coeffs ll = yl padding_method = lowlevel.mode_to_int(self.padding_method) self.g0_col = self.g0_col.to(dtype=ll.dtype, device=ll.device) self.g1_col = self.g1_col.to(dtype=ll.dtype, device=ll.device) self.g0_row = self.g0_row.to(dtype=ll.dtype, device=ll.device) self.g1_row = self.g1_row.to(dtype=ll.dtype, device=ll.device) # Do a multilevel inverse transform for h in yh[::-1]: # this is the reversed list if h is None: h = torch.zeros( ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device ) # 'Unpad' added dimensions if ll.shape[-2] > h.shape[-2]: ll = ll[..., :-1, :] if ll.shape[-1] > h.shape[-1]: ll = ll[..., :-1] ll = lowlevel.SFB2D.apply( ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, padding_method ) return ll
def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ yl, yh = coeffs ll_prev = yl mode = lowlevel.mode_to_int(self.mode) # Do a multilevel inverse transform for h in yh[::-1]: if h is None: h = torch.zeros(ll_prev.shape[0], ll_prev.shape[1], 3, ll_prev.shape[-2], ll_prev.shape[-1], device=ll_prev.device) # 'Unpad' added dimensions if ll_prev.shape[-2] > h.shape[-2]: ll_prev = ll_prev[..., :-1, :].clone() if ll_prev.shape[-1] > h.shape[-1]: ll_prev = ll_prev[..., :-1].clone() ll_cur = lowlevel.SFB2D.apply(ll_prev, h, self.g0_col.clone(), self.g1_col.clone(), self.g0_row.clone(), self.g1_row.clone(), mode) ll_prev = ll_cur return ll_prev
def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. """ assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" highs = [] x0 = x mode = lowlevel.mode_to_int(self.mode) # Do a multilevel transform for j in range(self.J): x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode) highs.append(x1) return x0, highs
def __init__(self, J=1, wave="db1", padding_method="zero", device="cpu"): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): h0_col, h1_col = wave.dec_lo, wave.dec_hi h0_row, h1_row = h0_col, h1_col else: if len(wave) == 2: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = h0_col, h1_col elif len(wave) == 4: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row, device=device) self.register_buffer("h0_col", filts[0]) self.register_buffer("h1_col", filts[1]) self.register_buffer("h0_row", filts[2]) self.register_buffer("h1_row", filts[3]) self.J = J self.mode_str = padding_method self.mode = lowlevel.mode_to_int(self.mode_str)
def setup(): global mode, o_dim, ri_dim mode = mode_to_int('symmetric') o_dim = 2 ri_dim = -1 py3nvml.grab_gpus(1, gpu_fraction=0.5, env_set_ok=True)
def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of the complex bandpass coefficients of shape :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar depending on o_dim and ri_dim Returns: Reconstructed output Note: Can accept Nones or an empty tensor (torch.tensor([])) for the lowpass or bandpass inputs. In this cases, an array of zeros replaces that input. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a DTCWT pyramid. Note: If include_scale was true for the forward pass, you should provide only the final lowpass output here, as normal for an inverse wavelet transform. """ low, highs = coeffs J = len(highs) mode = mode_to_int(self.mode) _, _, h_dim, w_dim = get_dimensions6(self.o_dim, self.ri_dim) for j, s in zip(range(J - 1, 0, -1), highs[1:][::-1]): if s is not None and s.shape != torch.Size([]): assert s.shape[self.o_dim] == 6, "Inverse transform must " \ "have input with 6 orientations" assert len(s.shape) == 6, "Bandpass inputs must have " \ "6 dimensions" assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \ "with real and imaginary parts in the ri dimension" # Ensure the low and highpass are the right size r, c = low.shape[2:] r1, c1 = s.shape[h_dim], s.shape[w_dim] if r != r1 * 2: low = low[:, :, 1:-1] if c != c1 * 2: low = low[:, :, :, 1:-1] low = INV_J2PLUS.forward(low, s, self.g0a, self.g1a, self.g0b, self.g1b, self.o_dim, self.ri_dim, mode) # Ensure the low and highpass are the right size if highs[0] is not None and highs[0].shape != torch.Size([]): r, c = low.shape[2:] r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim] if r != r1 * 2: low = low[:, :, 1:-1] if c != c1 * 2: low = low[:, :, :, 1:-1] low = INV_J1.forward(low, highs[0], self.g0o, self.g1o, self.o_dim, self.ri_dim, mode) return low
def forward(self, x): """ Forward Dual Tree Complex Wavelet Transform Args: x (tensor): Input to transform. Should be of shape :math:`(N, C_{in}, H_{in}, W_{in})`. Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. If include_scale was true, yl will be a list of lowpass coefficients, otherwise will be just the final lowpass coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh will be a list of the complex bandpass coefficients of shape :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar shape depending on o_dim and ri_dim Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a DTCWT pyramid. """ scales = [ x.new_zeros([]), ] * self.J highs = [ x.new_zeros([]), ] * self.J mode = mode_to_int(self.mode) if self.J == 0: return x, None # If the row/col count of X is not divisible by 2 then we need to # extend X r, c = x.shape[2:] if r % 2 != 0: x = torch.cat((x, x[:, :, -1:]), dim=2) if c % 2 != 0: x = torch.cat((x, x[:, :, :, -1:]), dim=3) # Do the level 1 transform low, h = FWD_J1.forward(x, self.h0o, self.h1o, self.skip_hps[0], self.o_dim, self.ri_dim, mode) highs[0] = h if self.include_scale[0]: scales[0] = low for j in range(1, self.J): # Ensure the lowpass is divisible by 4 r, c = low.shape[2:] if r % 4 != 0: low = torch.cat((low[:, :, 0:1], low, low[:, :, -1:]), dim=2) if c % 4 != 0: low = torch.cat((low[:, :, :, 0:1], low, low[:, :, :, -1:]), dim=3) low, h = FWD_J2PLUS.forward(low, self.h0a, self.h1a, self.h0b, self.h1b, self.skip_hps[j], self.o_dim, self.ri_dim, mode) highs[j] = h if self.include_scale[j]: scales[j] = low if True in self.include_scale: return scales, highs else: return low, highs
def forward(self, x): """ Forward pass of the DWT. Parameters ---------- x : torch.Tensor Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns ------- overwritten_results : torch.Tensor the 2D torch Tensor which has all of the results in it yl : torch.Tensor the lowpass coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')`. yh : torch.Tensor the bandpass coefficients. yh is a list of length `self.decomp_level` with the first entry being the finest scale coefficients. it has shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH coefficients. Notes ----- :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ yh = [] ll = x padding_method = lowlevel.mode_to_int(self.padding_method) full = None if x.dtype != self.h0_col.dtype: self.h0_col = self.h0_col.to(x.dtype) self.h1_col = self.h1_col.to(x.dtype) self.h0_row = self.h0_row.to(x.dtype) self.h1_row = self.h1_row.to(x.dtype) self.h0_col = self.h0_col.to(x.device) self.h1_col = self.h1_col.to(x.device) self.h0_row = self.h0_row.to(x.device) self.h1_row = self.h1_row.to(x.device) # prev_ll_d0 = None # Do a multilevel transform for lvl in range(self.decomp_level): # Do a level of the transform ll, high = lowlevel.AFB2D.apply( ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, padding_method ) s = ll.shape[-2:] if full is None: # first iteration full_shape = list(ll.shape) full_shape[-1] *= 2 full_shape[-2] *= 2 full = torch.zeros(full_shape, device=ll.device, dtype=ll.dtype) full[:, :, : s[0], : s[1]] = ll full[:, :, : s[0], s[1] : s[1] * 2] = high[:, :, 0] full[:, :, s[0] : s[0] * 2, : s[1]] = high[:, :, 1] full[:, :, s[0] : s[0] * 2, s[1] : s[1] * 2] = high[:, :, 2] yh.append(high) return full, ll, yh