Example #1
    def forward(self, x):
        """ Forward pass of the DWT.

            x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh)
                coefficients. yh is a list of length J with the first entry
                being the finest scale coefficients. yl has shape
                :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape
                :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new
                dimension in yh iterates over the LH, HL and HH coefficients.

            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.
        yh = []
        ll = x
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            # Do 1 level of the transform
            ll, high = lowlevel.AFB2D.apply(ll, self.h0_col.clone(),
                                            self.h1_row.clone(), mode)

        return ll, yh
Example #2
    def forward(self, coeffs):
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, should
              match the format returned by DWT1DForward.

            Reconstructed input of shape :math:`(N, C_{in}, L_{in})`

            Can have None for any of the highpass scales and will treat the
            values as zeros (not in an efficient way though).
        x0, highs = coeffs
        assert x0.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        mode = lowlevel.mode_to_int(self.mode)
        # Do a multilevel inverse transform
        for x1 in highs[::-1]:
            if x1 is None:
                x1 = torch.zeros_like(x0)

            # 'Unpad' added signal
            if x0.shape[-1] > x1.shape[-1]:
                x0 = x0[..., :-1]
            x0 = lowlevel.SFB1D.apply(x0, x1, self.g0, self.g1, mode)
        return x0
Example #3
    def forward(self, coeffs=None):
        Do the 2D DWT inverse reconstruction for a set of coefficients

        coeffs: tuple, torch.Tensor
            tuple of lowpass and bandpass coefficients, where yl is a lowpass tensor of shape
            :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape
            :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned
            by DWTForward.
            If this input is a torch tensor, then this will assume that `coeffs` is the overwritten
            results returned by :func:`DWTForwardOverwrite <hyde.dwt3d.DWTForwardOverwrite>`

            Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        - :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes
        of the DWT pyramid.
        - Can have None for any of the highpass scales and will treat the values as zeros (not
        in an efficient way though).
        yl, yh = coeffs

        ll = yl
        padding_method = lowlevel.mode_to_int(self.padding_method)

        self.g0_col = self.g0_col.to(dtype=ll.dtype, device=ll.device)
        self.g1_col = self.g1_col.to(dtype=ll.dtype, device=ll.device)
        self.g0_row = self.g0_row.to(dtype=ll.dtype, device=ll.device)
        self.g1_row = self.g1_row.to(dtype=ll.dtype, device=ll.device)

        # Do a multilevel inverse transform
        for h in yh[::-1]:  # this is the reversed list
            if h is None:
                h = torch.zeros(
                    ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device

            # 'Unpad' added dimensions
            if ll.shape[-2] > h.shape[-2]:
                ll = ll[..., :-1, :]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[..., :-1]
            ll = lowlevel.SFB2D.apply(
                ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, padding_method
        return ll
Example #4
    def forward(self, coeffs):
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
              yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
              W_{in}')` and yh is a list of bandpass tensors of shape
              :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
              the format returned by DWTForward

            Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
            downsampled shapes of the DWT pyramid.

            Can have None for any of the highpass scales and will treat the
            values as zeros (not in an efficient way though).
        yl, yh = coeffs
        ll_prev = yl
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel inverse transform
        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll_prev.shape[0],

            # 'Unpad' added dimensions
            if ll_prev.shape[-2] > h.shape[-2]:
                ll_prev = ll_prev[..., :-1, :].clone()
            if ll_prev.shape[-1] > h.shape[-1]:
                ll_prev = ll_prev[..., :-1].clone()
            ll_cur = lowlevel.SFB2D.apply(ll_prev, h, self.g0_col.clone(),
                                          self.g1_row.clone(), mode)
            ll_prev = ll_cur
        return ll_prev
Example #5
    def forward(self, x):
        """ Forward pass of the DWT.

            x (tensor): Input of shape :math:`(N, C_{in}, L_{in})`

            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh) coefficients.
                yh is a list of length J with the first entry
                being the finest scale coefficients.
        assert x.ndim == 3, "Can only handle 3d inputs (N, C, L)"
        highs = []
        x0 = x
        mode = lowlevel.mode_to_int(self.mode)

        # Do a multilevel transform
        for j in range(self.J):
            x0, x1 = lowlevel.AFB1D.apply(x0, self.h0, self.h1, mode)

        return x0, highs
Example #6
    def __init__(self, J=1, wave="db1", padding_method="zero", device="cpu"):
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row, device=device)
        self.register_buffer("h0_col", filts[0])
        self.register_buffer("h1_col", filts[1])
        self.register_buffer("h0_row", filts[2])
        self.register_buffer("h1_row", filts[3])
        self.J = J
        self.mode_str = padding_method
        self.mode = lowlevel.mode_to_int(self.mode_str)
def setup():
    global mode, o_dim, ri_dim
    mode = mode_to_int('symmetric')
    o_dim = 2
    ri_dim = -1
    py3nvml.grab_gpus(1, gpu_fraction=0.5, env_set_ok=True)
Example #8
    def forward(self, coeffs):
            coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
                yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')`
                and yh is a list of  the complex bandpass coefficients of shape
                :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
                depending on o_dim and ri_dim

            Reconstructed output

            Can accept Nones or an empty tensor (torch.tensor([])) for the
            lowpass or bandpass inputs. In this cases, an array of zeros
            replaces that input.

            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
            DTCWT pyramid.

            If include_scale was true for the forward pass, you should provide
            only the final lowpass output here, as normal for an inverse wavelet
        low, highs = coeffs
        J = len(highs)
        mode = mode_to_int(self.mode)
        _, _, h_dim, w_dim = get_dimensions6(self.o_dim, self.ri_dim)
        for j, s in zip(range(J - 1, 0, -1), highs[1:][::-1]):
            if s is not None and s.shape != torch.Size([]):
                assert s.shape[self.o_dim] == 6, "Inverse transform must " \
                    "have input with 6 orientations"
                assert len(s.shape) == 6, "Bandpass inputs must have " \
                    "6 dimensions"
                assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \
                    "with real and imaginary parts in the ri dimension"
                # Ensure the low and highpass are the right size
                r, c = low.shape[2:]
                r1, c1 = s.shape[h_dim], s.shape[w_dim]
                if r != r1 * 2:
                    low = low[:, :, 1:-1]
                if c != c1 * 2:
                    low = low[:, :, :, 1:-1]

            low = INV_J2PLUS.forward(low, s, self.g0a, self.g1a, self.g0b,
                                     self.g1b, self.o_dim, self.ri_dim, mode)

        # Ensure the low and highpass are the right size
        if highs[0] is not None and highs[0].shape != torch.Size([]):
            r, c = low.shape[2:]
            r1, c1 = highs[0].shape[h_dim], highs[0].shape[w_dim]
            if r != r1 * 2:
                low = low[:, :, 1:-1]
            if c != c1 * 2:
                low = low[:, :, :, 1:-1]

        low = INV_J1.forward(low, highs[0], self.g0o, self.g1o, self.o_dim,
                             self.ri_dim, mode)
        return low
Example #9
    def forward(self, x):
        """ Forward Dual Tree Complex Wavelet Transform

            x (tensor): Input to transform. Should be of shape
                :math:`(N, C_{in}, H_{in}, W_{in})`.

            (yl, yh)
                tuple of lowpass (yl) and bandpass (yh) coefficients.
                If include_scale was true, yl will be a list of lowpass
                coefficients, otherwise will be just the final lowpass
                coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh
                will be a list of the complex bandpass coefficients of shape
                :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar
                shape depending on o_dim and ri_dim

            :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a
            DTCWT pyramid.
        scales = [
        ] * self.J
        highs = [
        ] * self.J
        mode = mode_to_int(self.mode)
        if self.J == 0:
            return x, None

        # If the row/col count of X is not divisible by 2 then we need to
        # extend X
        r, c = x.shape[2:]
        if r % 2 != 0:
            x = torch.cat((x, x[:, :, -1:]), dim=2)
        if c % 2 != 0:
            x = torch.cat((x, x[:, :, :, -1:]), dim=3)

        # Do the level 1 transform
        low, h = FWD_J1.forward(x, self.h0o, self.h1o, self.skip_hps[0],
                                self.o_dim, self.ri_dim, mode)
        highs[0] = h
        if self.include_scale[0]:
            scales[0] = low

        for j in range(1, self.J):
            # Ensure the lowpass is divisible by 4
            r, c = low.shape[2:]
            if r % 4 != 0:
                low = torch.cat((low[:, :, 0:1], low, low[:, :, -1:]), dim=2)
            if c % 4 != 0:
                low = torch.cat((low[:, :, :, 0:1], low, low[:, :, :, -1:]),

            low, h = FWD_J2PLUS.forward(low, self.h0a, self.h1a, self.h0b,
                                        self.h1b, self.skip_hps[j], self.o_dim,
                                        self.ri_dim, mode)
            highs[j] = h
            if self.include_scale[j]:
                scales[j] = low

        if True in self.include_scale:
            return scales, highs
            return low, highs
Example #10
    def forward(self, x):
        Forward pass of the DWT.

        x : torch.Tensor
            Input of shape :math:`(N, C_{in}, H_{in}, W_{in})`

        overwritten_results : torch.Tensor
            the 2D torch Tensor which has all of the results in it
        yl : torch.Tensor
            the lowpass coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')`.
        yh : torch.Tensor
            the bandpass coefficients. yh is a list of length `self.decomp_level` with the first
            entry being the finest scale coefficients. it has shape :math:`list(N, C_{in}, 3,
            H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH

        :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of
        the DWT pyramid.
        yh = []
        ll = x
        padding_method = lowlevel.mode_to_int(self.padding_method)
        full = None

        if x.dtype != self.h0_col.dtype:
            self.h0_col = self.h0_col.to(x.dtype)
            self.h1_col = self.h1_col.to(x.dtype)
            self.h0_row = self.h0_row.to(x.dtype)
            self.h1_row = self.h1_row.to(x.dtype)

        self.h0_col = self.h0_col.to(x.device)
        self.h1_col = self.h1_col.to(x.device)
        self.h0_row = self.h0_row.to(x.device)
        self.h1_row = self.h1_row.to(x.device)

        # prev_ll_d0 = None
        # Do a multilevel transform
        for lvl in range(self.decomp_level):
            # Do a level of the transform
            ll, high = lowlevel.AFB2D.apply(
                ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, padding_method

            s = ll.shape[-2:]
            if full is None:  # first iteration
                full_shape = list(ll.shape)
                full_shape[-1] *= 2
                full_shape[-2] *= 2
                full = torch.zeros(full_shape, device=ll.device, dtype=ll.dtype)
            full[:, :, : s[0], : s[1]] = ll
            full[:, :, : s[0], s[1] : s[1] * 2] = high[:, :, 0]
            full[:, :, s[0] : s[0] * 2, : s[1]] = high[:, :, 1]
            full[:, :, s[0] : s[0] * 2, s[1] : s[1] * 2] = high[:, :, 2]


        return full, ll, yh