def inverse(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should match the format returned by DWT1DForward. Returns: Reconstructed input of shape :math:`(N, C_{in}, L_{in})` Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ coeffs = list(coeffs) x0 = coeffs.pop(0) highs = coeffs assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)" mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel inverse transform for x1 in highs[::-1]: if x1 is None: x1 = torch.zeros_like(x0) # 'Unpad' added signal if x0.shape[-1] > x1.shape[-1]: x0 = x0[..., :-1] x0 = lowlevel.SFB1D.forward(x0, x1, self.h0, h1, mode) return x0
def _hsum_loss(w_transform): """ Calculate sum of highpass filter """ h0 = w_transform.h0 h1 = low_to_high(h0) loss = .5 * h1.sum()**2 return loss
def inverse(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ coeffs = list(coeffs) yl = coeffs.pop(0) yh = coeffs ll = yl mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) g0_col = self.h0.reshape((1, 1, -1, 1)) g1_col = h1.reshape((1, 1, -1, 1)) g0_row = self.h0.reshape((1, 1, 1, -1)) g1_row = h1.reshape((1, 1, 1, -1)) # Do a multilevel inverse transform for h in yh[::-1]: if h is None: h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device) # 'Unpad' added dimensions if ll.shape[-2] > h.shape[-2]: ll = ll[..., :-1, :] if ll.shape[-1] > h.shape[-1]: ll = ll[..., :-1] ll = lowlevel.SFB2D.forward(ll, h, g0_col, g1_col, g0_row, g1_row, mode) return ll
def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH coefficients. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ yh = () ll = x mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) h0_col = self.h0.reshape((1, 1, -1, 1)) h1_col = h1.reshape((1, 1, -1, 1)) h0_row = self.h0.reshape((1, 1, 1, -1)) h1_row = h1.reshape((1, 1, 1, -1)) # Do a multilevel transform for j in range(self.J): # Do 1 level of the transform ll, high = lowlevel.AFB2D.forward(ll, h0_col, h1_col, h0_row, h1_row, mode) yh += (high, ) return (ll, ) + yh
def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, L_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. """ assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)" highs = () x0 = x mode = lowlevel.mode_to_int(self.mode) h1 = low_to_high(self.h0) # Do a multilevel transform for j in range(self.J): x0, x1 = lowlevel.AFB1D.forward(x0, self.h0, h1, mode) highs += (x1,) return (x0,) + highs