Example #1
0
    def forward(self, x):
        """ Forward pass of the DWT.

        Args:
            x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Returns:
            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh)
                coefficients. yh is a list of length J with the first entry
                being the finest scale coefficients. yl has shape
                :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
                :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
                dimension in yh iterates over the LH, HL and HH coefficients.

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.
        """
        yh = []
        ll = x
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            # Do 1 level of the transform
            ll, high = lowlevel.AFB2D.apply(ll, self.h0_col.clone(),
                                            self.h1_col.clone(),
                                            self.h0_row.clone(),
                                            self.h1_row.clone(), mode)
            yh.append(high)

        return ll, yh
Example #2
0
    def forward(self, coeffs):
        """
        Args:
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should
              match the format returned by DWT1DForward.

        Returns:
            Reconstructed input of shape :math:`(N, C_{in}, L_{in})`

        Note:
            Can have None for any of the highpass scales and will treat the
            values as zeros (not in an efficient way though).
        """
        x0, highs = coeffs
        assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        mode = lowlevel.mode_to_int(self.mode)
        # Do a multilevel inverse transform
        for x1 in highs[::-1]:
            if x1 is None:
                x1 = torch.zeros_like(x0)

            # 'Unpad' added signal
            if x0.shape[-1] > x1.shape[-1]:
                x0 = x0[..., :-1]
            x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode)
        return x0
Example #3
0
    def forward(self, coeffs=None):
        """
        Do the 2D DWT inverse reconstruction for a set of coefficients

        Parameters
        ----------
        coeffs: tuple, torch.Tensor
            tuple of lowpass and bandpass coefficients, where yl is a lowpass tensor of shape
            :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape
            :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned
            by DWTForward.
            If this input is a torch tensor, then this will assume that `coeffs` is the overwritten
            results returned by :func:`DWTForwardOverwrite <hyde.dwt3d.DWTForwardOverwrite>`

        Returns
        -------
        torch.Tensor
            Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Notes
        -----
        - :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes
        of the DWT pyramid.
        - Can have None for any of the highpass scales and will treat the values as zeros (not
        in an efficient way though).
        """
        yl, yh = coeffs

        ll = yl
        padding_method = lowlevel.mode_to_int(self.padding_method)

        self.g0_col = self.g0_col.to(dtype=ll.dtype, device=ll.device)
        self.g1_col = self.g1_col.to(dtype=ll.dtype, device=ll.device)
        self.g0_row = self.g0_row.to(dtype=ll.dtype, device=ll.device)
        self.g1_row = self.g1_row.to(dtype=ll.dtype, device=ll.device)

        # Do a multilevel inverse transform
        for h in yh[::-1]:  # this is the reversed list
            if h is None:
                h = torch.zeros(
                    ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device
                )

            # 'Unpad' added dimensions
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[..., :-1, :]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[..., :-1]
            ll = lowlevel.SFB2D.apply(
                ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, padding_method
            )
        return ll
Example #4
0
    def forward(self, coeffs):
        """
        Args:
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
              yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
              W_{in}')` and yh is a list of bandpass tensors of shape
              :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
              the format returned by DWTForward

        Returns:
            Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.

        Note:
            Can have None for any of the highpass scales and will treat the
            values as zeros (not in an efficient way though).
        """
        yl, yh = coeffs
        ll_prev = yl
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel inverse transform
        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll_prev.shape[0],
                                ll_prev.shape[1],
                                3,
                                ll_prev.shape[-2],
                                ll_prev.shape[-1],
                                device=ll_prev.device)

            # 'Unpad' added dimensions
            if ll_prev.shape[-2] > h.shape[-2]:
                ll_prev = ll_prev[..., :-1, :].clone()
            if ll_prev.shape[-1] > h.shape[-1]:
                ll_prev = ll_prev[..., :-1].clone()
            ll_cur = lowlevel.SFB2D.apply(ll_prev, h, self.g0_col.clone(),
                                          self.g1_col.clone(),
                                          self.g0_row.clone(),
                                          self.g1_row.clone(), mode)
            ll_prev = ll_cur
        return ll_prev
Example #5
0
    def forward(self, x):
        """ Forward pass of the DWT.

        Args:
            x (tensor): Input of shape :math:`(N, C_{in}, L_{in})`

        Returns:
            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh) coefficients.
                yh is a list of length J with the first entry
                being the finest scale coefficients.
        """
        assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        highs = []
        x0 = x
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode)
            highs.append(x1)

        return x0, highs
Example #6
0
    def __init__(self, J=1, wave="db1", padding_method="zero", device="cpu"):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row, device=device)
        self.register_buffer("h0_col", filts[0])
        self.register_buffer("h1_col", filts[1])
        self.register_buffer("h0_row", filts[2])
        self.register_buffer("h1_row", filts[3])
        self.J = J
        self.mode_str = padding_method
        self.mode = lowlevel.mode_to_int(self.mode_str)
def setup():
    global mode, o_dim, ri_dim
    mode = mode_to_int('symmetric')
    o_dim = 2
    ri_dim = -1
    py3nvml.grab_gpus(1, gpu_fraction=0.5, env_set_ok=True)
Example #8
0
    def forward(self, coeffs):
        """
        Args:
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
                yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')`
                and yh is a list of  the complex bandpass coefficients of shape
                :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
                depending on o_dim and ri_dim

        Returns:
            Reconstructed output

        Note:
            Can accept Nones or an empty tensor (torch.tensor([])) for the
            lowpass or bandpass inputs. In this cases, an array of zeros
            replaces that input.

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
            DTCWT pyramid.

        Note:
            If include_scale was true for the forward pass, you should provide
            only the final lowpass output here, as normal for an inverse wavelet
            transform.
        """
        low, highs = coeffs
        J = len(highs)
        mode = mode_to_int(self.mode)
        _, _, h_dim, w_dim = get_dimensions6(self.o_dim, self.ri_dim)
        for j, s in zip(range(J - 1, 0, -1), highs[1:][::-1]):
            if s is not None and s.shape != torch.Size([]):
                assert s.shape[self.o_dim] == 6, "Inverse transform must " \
                    "have input with 6 orientations"
                assert len(s.shape) == 6, "Bandpass inputs must have " \
                    "6 dimensions"
                assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \
                    "with real and imaginary parts in the ri dimension"
                # Ensure the low and highpass are the right size
                r, c = low.shape[2:]
                r1, c1 = s.shape[h_dim], s.shape[w_dim]
                if r != r1 * 2:
                    low = low[:, :, 1:-1]
                if c != c1 * 2:
                    low = low[:, :, :, 1:-1]

            low = INV_J2PLUS.forward(low, s, self.g0a, self.g1a, self.g0b,
                                     self.g1b, self.o_dim, self.ri_dim, mode)

        # Ensure the low and highpass are the right size
        if highs[0] is not None and highs[0].shape != torch.Size([]):
            r, c = low.shape[2:]
            r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim]
            if r != r1 * 2:
                low = low[:, :, 1:-1]
            if c != c1 * 2:
                low = low[:, :, :, 1:-1]

        low = INV_J1.forward(low, highs[0], self.g0o, self.g1o, self.o_dim,
                             self.ri_dim, mode)
        return low
Example #9
0
    def forward(self, x):
        """ Forward Dual Tree Complex Wavelet Transform

        Args:
            x (tensor): Input to transform. Should be of shape
                :math:`(N, C_{in}, H_{in}, W_{in})`.

        Returns:
            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh) coefficients.
                If include_scale was true, yl will be a list of lowpass
                coefficients, otherwise will be just the final lowpass
                coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh
                will be a list of the complex bandpass coefficients of shape
                :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
                shape depending on o_dim and ri_dim

        Note:
            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
            DTCWT pyramid.
        """
        scales = [
            x.new_zeros([]),
        ] * self.J
        highs = [
            x.new_zeros([]),
        ] * self.J
        mode = mode_to_int(self.mode)
        if self.J == 0:
            return x, None

        # If the row/col count of X is not divisible by 2 then we need to
        # extend X
        r, c = x.shape[2:]
        if r % 2 != 0:
            x = torch.cat((x, x[:, :, -1:]), dim=2)
        if c % 2 != 0:
            x = torch.cat((x, x[:, :, :, -1:]), dim=3)

        # Do the level 1 transform
        low, h = FWD_J1.forward(x, self.h0o, self.h1o, self.skip_hps[0],
                                self.o_dim, self.ri_dim, mode)
        highs[0] = h
        if self.include_scale[0]:
            scales[0] = low

        for j in range(1, self.J):
            # Ensure the lowpass is divisible by 4
            r, c = low.shape[2:]
            if r % 4 != 0:
                low = torch.cat((low[:, :, 0:1], low, low[:, :, -1:]), dim=2)
            if c % 4 != 0:
                low = torch.cat((low[:, :, :, 0:1], low, low[:, :, :, -1:]),
                                dim=3)

            low, h = FWD_J2PLUS.forward(low, self.h0a, self.h1a, self.h0b,
                                        self.h1b, self.skip_hps[j], self.o_dim,
                                        self.ri_dim, mode)
            highs[j] = h
            if self.include_scale[j]:
                scales[j] = low

        if True in self.include_scale:
            return scales, highs
        else:
            return low, highs
Example #10
0
    def forward(self, x):
        """
        Forward pass of the DWT.

        Parameters
        ----------
        x : torch.Tensor
            Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        Returns
        -------
        overwritten_results : torch.Tensor
            the 2D torch Tensor which has all of the results in it
        yl : torch.Tensor
            the lowpass coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')`.
        yh : torch.Tensor
            the bandpass coefficients. yh is a list of length `self.decomp_level` with the first
            entry being the finest scale coefficients. it has shape :math:`list(N, C_{in}, 3,
            H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH
            coefficients.

        Notes
        -----
        :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of
        the DWT pyramid.
        """
        yh = []
        ll = x
        padding_method = lowlevel.mode_to_int(self.padding_method)
        full = None

        if x.dtype != self.h0_col.dtype:
            self.h0_col = self.h0_col.to(x.dtype)
            self.h1_col = self.h1_col.to(x.dtype)
            self.h0_row = self.h0_row.to(x.dtype)
            self.h1_row = self.h1_row.to(x.dtype)

        self.h0_col = self.h0_col.to(x.device)
        self.h1_col = self.h1_col.to(x.device)
        self.h0_row = self.h0_row.to(x.device)
        self.h1_row = self.h1_row.to(x.device)

        # prev_ll_d0 = None
        # Do a multilevel transform
        for lvl in range(self.decomp_level):
            # Do a level of the transform
            ll, high = lowlevel.AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, padding_method
            )

            s = ll.shape[-2:]
            if full is None:  # first iteration
                full_shape = list(ll.shape)
                full_shape[-1] *= 2
                full_shape[-2] *= 2
                full = torch.zeros(full_shape, device=ll.device, dtype=ll.dtype)
            full[:, :, : s[0], : s[1]] = ll
            full[:, :, : s[0], s[1] : s[1] * 2] = high[:, :, 0]
            full[:, :, s[0] : s[0] * 2, : s[1]] = high[:, :, 1]
            full[:, :, s[0] : s[0] * 2, s[1] : s[1] * 2] = high[:, :, 2]

            yh.append(high)

        return full, ll, yh