Exemple #1
0
    def forward(self, input):

        latent_map = self.encoder(input)

        contents = self.content_decoder(latent_map)
        spectral_contents = fft2(contents)

        attention = self.attention_decoder(latent_map)
        spectral_attention = self.spectral_attention_decoder(latent_map)

        # Unfold the tensor into a list of images
        content_imgs = []
        attention_imgs = []
        spectral_attention_imgs = []

        a_background = attention[:,
                                 self.content_dim:self.content_dim + 1, :, :]
        sa_background = spectral_attention[:,
                                           self.content_dim:self.content_dim +
                                           1, :, :]

        content_imgs.append(input)
        attention_imgs.append(a_background)
        spectral_attention_imgs.append(sa_background)

        img = input * a_background.repeat(1, 3, 1, 1)
        s_img = self.sf_ratio * torch.abs(
            ifft2(fft2(input) * sa_background.repeat(1, 3, 1, 1)))

        for i in range(self.content_dim):
            c = contents[:, i * 3:(i + 1) * 3, :, :]
            content_imgs.append(c)
            sc = spectral_contents[:, i * 3:(i + 1) * 3, :, :]

            a = attention[:, i:i + 1, :, :]
            attention_imgs.append(a)

            sa = spectral_attention[:, i:i + 1, :, :]
            spectral_attention_imgs.append(sa)

            img += c * a.repeat(1, 3, 1, 1)
            s_img += self.sf_ratio * torch.abs(
                ifft2(sc * sa.repeat(1, 3, 1, 1)))

            print(img.shape)
            print(s_img.shape)

        gen_img = self.fuser(torch.cat([img, s_img], 1))
        #if self.visualize_hidden:
        return gen_img, content_imgs[0], attention_imgs[
            0], spectral_attention_imgs[0]
Exemple #2
0
def sdsp(
    x: Tensor,
    filtr: Tensor,
    value_range: float = 1.,
    sigma_c: float = 0.001,
    sigma_d: float = 145.,
) -> Tensor:
    r"""Detects salient regions from :math:`x`.

    Args:
        x: An input tensor, :math:`(N, 3, H, W)`.
        filtr: The frequency domain filter, :math:`(H, W)`.
        value_range: The value range :math:`L` of the input (usually `1.` or `255`).

    Note:
        For the remaining arguments, refer to [Zhang2013]_.

    Returns:
        The visual saliency tensor, :math:`(N, H, W)`.

    Example:
        >>> x = torch.rand(5, 3, 256, 256)
        >>> filtr = sdsp_filter(x)
        >>> vs = sdsp(x, filtr)
        >>> vs.size()
        torch.Size([5, 256, 256])
    """

    x_lab = xyz_to_lab(rgb_to_xyz(x, value_range))

    # Frequency prior
    x_f = fft.ifft2(fft.fft2(x_lab) * filtr)
    x_f = cx.real(torch.view_as_real(x_f))

    s_f = l2_norm(x_f, dims=[1])

    # Color prior
    x_ab = x_lab[:, 1:]

    lo, _ = x_ab.flatten(-2).min(dim=-1)
    up, _ = x_ab.flatten(-2).max(dim=-1)

    lo = lo.view(lo.shape + (1, 1))
    up = up.view(lo.shape)
    span = torch.where(up > lo, up - lo, torch.tensor(1.).to(lo))

    x_ab = (x_ab - lo) / span

    s_c = 1. - torch.exp(-torch.sum(x_ab**2, dim=1) / sigma_c**2)

    # Location prior
    a, b = [torch.arange(n).to(x) - (n - 1) / 2 for n in x.shape[-2:]]

    s_d = torch.exp(-(a[None, :]**2 + b[:, None]**2) / sigma_d**2)

    # Visual saliency
    vs = s_f * s_c * s_d

    return vs
Exemple #3
0
    def reconstruct_from_pyramid(self, pyramid):
        # Work out pyramid parameters, and check they match current settings.
        n_orientations = len(pyramid[0]['b'])
        n_levels = len(pyramid)
        size = pyramid[0]['h'].size()
        if n_orientations != self.n_orientations or \
            n_levels != self.n_levels or \
            size != self.image_size:
            self.n_orientations = n_orientations
            self.image_size = size
            self.n_levels = n_levels
            self.calculate_filters()

        curr = fftshift(fft2(pyramid[-1]['l']))
        for l in range(len(pyramid) - 2, -1, -1):
            # Upsample the current reconstruction
            tmp = torch.zeros([
                curr.size(0),
                curr.size(1),
                curr.size(2) * 2,
                curr.size(3) * 2
            ],
                              dtype=torch.complex64)
            offsety = curr.size(-2) // 2
            offsetx = curr.size(-1) // 2
            tmp[:, :, offsety:3 * offsety, offsetx:3 * offsetx] = curr * 4
            curr = tmp

            curr = curr * self.L_FILT[l]

            for b in range(len(self.B_FILT[l])):
                curr += self.B_FILT[l][b] * fftshift(fft2(pyramid[l]['b'][b]))

        reconstruction = curr * self.L0_FILT + fftshift(fft2(
            pyramid[0]['h'])) * self.H0_FILT

        return torch.real(ifft2(ifftshift(reconstruction)))
Exemple #4
0
def fft(z):
    """
    Torch 2D FFT wrapper. No padding. The FFT is applied to the 2 last dimensions.

    Parameters
    ----------
    z : tensor
        Input.

    Returns
    -------
    tensor
        Output.

    """
    return fft2(z, s=(-1, -1))
Exemple #5
0
    def construct_pyramid(self, image, n_levels, n_orientations):
        if image.size() != self.image_size or \
            n_levels != self.n_levels or \
            n_orientations != self.n_orientations:
            # Need to recalculate the filters.
            self.image_size = image.size()
            self.n_levels = n_levels
            self.n_orientations = n_orientations
            self.calculate_filters()
        ft = fftshift(fft2(image))

        curr_level = {}
        pyramid = []
        h0 = self.H0_FILT * ft
        curr_level['h'] = torch.real(ifft2(ifftshift(h0)))

        l0 = self.L0_FILT * ft
        curr_level['l'] = torch.real(ifft2(ifftshift(l0)))

        # apply bandpass filter(B) and downsample iteratively. save pyramid
        _last = l0
        for i in range(self.n_levels):
            curr_level['b'] = []
            for j in range(len(self.B_FILT[i])):
                lb = _last * self.B_FILT[i][j]
                curr_level['b'].append(torch.real(ifft2(ifftshift(lb))))

            # apply lowpass filter(L) to image(Fourier Domain) downsampled.
            l1 = _last * self.L_FILT[i]

            ## Downsampling
            down_size = [l1.size(-2) // 4, l1.size(-1) // 4]

            # extract the central part of DFT
            down_image = l1[:, :, down_size[0]:3 * down_size[0],
                            down_size[1]:3 * down_size[1]] / 4
            #
            _last = down_image.clone()
            pyramid.append(curr_level)
            curr_level = {}

        # lowpass residual
        curr_level['l'] = torch.real(ifft2(ifftshift(_last)))
        pyramid.append(curr_level)
        return pyramid
    def forward(self, input):


        latent_map = self.encoder(input)

        contents = self.content_decoder(latent_map)
        spectral_contents = fft2(contents)

        attention = self.attention_decoder(latent_map)
        spectral_attention = self.spectral_attention_decoder(latent_map)

        # Unfold the tensor into a list of images
        content_imgs = {}
        attention_imgs = {}
        spectral_attention_imgs = {}

        a_background = attention[:, self.content_dim:self.content_dim+1, :, :]
        sa_background = spectral_attention[:, self.content_dim:self.content_dim+1, :, :]

        content_imgs['bg'] = input
        attention_imgs['bg'] = a_background
        spectral_attention_imgs['bg'] = sa_background

        gen_img =self.sf_ratio* input * a_background.repeat(1, 3, 1, 1) + self.sf_ratio*torch.abs(ifft2(fft2(input) * sa_background.repeat(1, 3, 1, 1)))

        for i in range(self.content_dim):
            c = contents[:, i * 3:(i + 1) * 3, :, :]
            content_imgs['layer' + str(i)] = c
            sc = spectral_contents[:, i * 3:(i + 1) * 3, :, :]

            a = attention[:, i:i+1, :, :]
            attention_imgs['layer' + str(i)] = a

            sa = spectral_attention[:, i:i+1, :, :]
            spectral_attention_imgs['layer' + str(i)] = sa

            gen_img += c * a.repeat(1, 3, 1, 1) + self.sf_ratio*torch.abs(ifft2(sc * sa.repeat(1, 3, 1, 1)))

        gen_img=F.tanh(gen_img)
        #if self.visualize_hidden:
        return gen_img,content_imgs['layer1'],attention_imgs['bg'],spectral_attention_imgs['bg']
Exemple #7
0
    def filters_tensor_morlet(self):

        J = self.J
        M = self.M
        N = self.N
        L = self.L

        if self.path is not None:
            hatpsi_ = torch.load(self.path + 'morlet_N' + str(N) + '_J' +
                                 str(J) + '_L' + str(L) + '.pt')  # (J,L,M,N,2)
            hatpsi = torch.cat((hatpsi_, torch.flip(hatpsi_, (2, 3))),
                               dim=1).numpy()  # (J,L2,M,N,2)
            fftpsi = hatpsi[..., 0] + hatpsi[..., 1] * 1.0j
            hatphi = torch.load(self.path + '/morlet_lp_N' + str(N) + '_J' +
                                str(J) + '_L' + str(L) +
                                '.pt').numpy()  # (M,N,2)
        else:
            Sihao_filters = FiltersSet(M=N, N=N, J=J, L=L).generate_morlet()
            hatpsi_ = Sihao_filters['psi']  # (J,L,M,N)
            hatpsi_ = torch.cat((hatpsi_[..., None], hatpsi_[..., None] * 0),
                                dim=-1)  # (J,L,M,N,2)
            hatpsi = torch.cat((hatpsi_, torch.flip(hatpsi_, (2, 3))),
                               dim=1).numpy()  # (J,L2,M,N,2)
            fftpsi = hatpsi[..., 0] + hatpsi[..., 1] * 1.0j
            hatphi = Sihao_filters['phi']  # (M,N)
            hatphi = torch.cat((hatphi[..., None], hatphi[..., None] * 0),
                               dim=-1).numpy()  # (M,N,2)

        A = self.A
        A_prime = self.A_prime

        alphas = np.arange(A, dtype=np.float32) / (max(A, 1)) * np.pi * 2
        alphas = np.exp(1j * alphas)

        alphas_prime = np.arange(A_prime, dtype=np.float32) / (max(
            A_prime, 1)) * np.pi * 2
        alphas_prime = np.exp(1j * alphas_prime)

        filt = np.zeros((J, 2 * L, A, self.M, self.N), dtype=np.complex_)
        filt_prime = np.zeros((J, 2 * L, A_prime, self.M, self.N),
                              dtype=np.complex_)

        for alpha in range(A):
            for j in range(J):
                for theta in range(L):
                    psi_signal = fftpsi[j, theta, ...]
                    filt[j, theta, alpha, :, :] = alphas[alpha] * psi_signal
                    filt[j, L+theta, alpha, :, :] = np.conj(alphas[alpha]) \
                        * psi_signal

        for alpha in range(A_prime):
            for j in range(J):
                for theta in range(L):
                    psi_signal = fftpsi[j, theta, ...]
                    filt_prime[j, theta,
                               alpha, :, :] = alphas_prime[alpha] * psi_signal
                    filt_prime[j, L + theta, alpha, :, :] = np.conj(
                        alphas_prime[alpha]) * psi_signal

        filters = np.stack((np.real(filt), np.imag(filt)), axis=-1)
        filters_prime = np.stack((np.real(filt_prime), np.imag(filt_prime)),
                                 axis=-1)

        self.hatphi = torch.view_as_complex(torch.FloatTensor(hatphi)).type(
            torch.cfloat)  # (M,N,2)
        self.hatpsi = torch.view_as_complex(torch.FloatTensor(filters)).type(
            torch.cfloat)
        self.hatpsi_prime = torch.view_as_complex(
            torch.FloatTensor(filters_prime)).type(torch.cfloat)

        # add haar
        self.hathaar2d = torch.view_as_complex(torch.zeros(3, M, N, 2))
        psi = torch.zeros(M, N, 2)
        psi[1, 1, 1] = 1 / 4
        psi[1, 2, 1] = -1 / 4
        psi[2, 1, 1] = 1 / 4
        psi[2, 2, 1] = -1 / 4
        self.hathaar2d[0, :, :] = fft.fft2(torch.view_as_complex(psi))

        psi[1, 1, 1] = 1 / 4
        psi[1, 2, 1] = 1 / 4
        psi[2, 1, 1] = -1 / 4
        psi[2, 2, 1] = -1 / 4
        self.hathaar2d[1, :, :] = fft.fft2(torch.view_as_complex(psi))

        psi[1, 1, 1] = 1 / 4
        psi[1, 2, 1] = -1 / 4
        psi[2, 1, 1] = -1 / 4
        psi[2, 2, 1] = 1 / 4
        self.hathaar2d[2, :, :] = fft.fft2(torch.view_as_complex(psi))

        # load masks for aperiodicity
        self.masks = maskns(J, M, N).unsqueeze(1).unsqueeze(1)  # (J, M, N)
Exemple #8
0
    def forward(self, input):

        J = self.J
        M = self.M
        N = self.N
        A = self.A
        L = self.L
        L2 = 2 * L
        phi = self.hatphi
        n = 0
        pad = self.pad
        wavelets = self.wavelets

        x_c = padc(input)  # add zeros to imag part -> (nb,M,N)
        hatx_c = fft.fft2(torch.view_as_complex(x_c)).type(
            torch.cfloat)  # fft2 -> (nb,M,N)

        if self.chunk_id < self.nb_chunks:
            nb = hatx_c.shape[0]
            hatpsi_la = self.hatpsi[:, :L, ...]  # (J,L,A,M,N)
            nb_channels = self.this_wph['la1'].shape[0]
            t = 3 if wavelets == 'morlet' else 1 if wavelets == 'steer' else 0
            if self.chunk_id < self.nb_chunks - 1:
                Sout = input.new(nb, nb_channels, M, N)
            else:
                Sout = input.new(nb, nb_channels + 1 + t, M, N)
            idxb = 0
            hatx_bc = hatx_c[idxb, :, :]  # (M,N)

            hatxpsi_bc = hatpsi_la * hatx_bc.view(1, 1, 1, M,
                                                  N)  # (J,L2,A,M,N)
            xpsi_bc = fft.ifft2(hatxpsi_bc)
            xpsi_bc_ = torch.real(xpsi_bc).relu()
            xpsi_bc_ = xpsi_bc_ * self.masks
            xpsi_bc0 = self.subinitmean1(xpsi_bc_)
            xpsi_bc0_n = self.divinitstd1(xpsi_bc0)
            xpsi_bc0_ = xpsi_bc0_n.view(1, J * L * A, M, N)

            xpsi_bc_la1 = xpsi_bc0_[:, self.this_wph['la1'],
                                    ...]  # (1,P_c,M,N)
            xpsi_bc_la2 = xpsi_bc0_[:, self.this_wph['la2'],
                                    ...]  # (1,P_c,M,N)

            x1 = torch.view_as_complex(padc(xpsi_bc_la1))
            x2 = torch.view_as_complex(padc(xpsi_bc_la2))
            hatconv_xpsi_bc = fft.fft2(x1) * torch.conj(fft.fft2(x2))
            conv_xpsi_bc = torch.real(fft.ifft2(hatconv_xpsi_bc))
            masks_shift = self.masks_shift[self.this_wph['shifted'],
                                           ...].view(1, -1, M, N)
            corr_bc = conv_xpsi_bc * masks_shift

            Sout[idxb, 0:nb_channels, ...] = corr_bc[0, ...]

            if self.chunk_id == self.nb_chunks - 1:
                # ADD 1 channel for spatial phiJ
                # add l2 phiJ to last channel
                hatxphi_c = hatx_c * self.hatphi.view(1, M, N)  # (nb,nc,M,N,2)
                xphi_c = fft.fft2(hatxphi_c)
                # haar filters
                if wavelets == 'morlet':
                    for hid in range(3):
                        hatxpsih_c = hatx_c * self.hathaar2d[hid, :, :].view(
                            1, M, N)  # (nb,nc,M,N)
                        xpsih_c = fft.ifft2(hatxpsih_c)
                        xpsih_c = self.divinitstdH[hid](xpsih_c)
                        xpsih_c = xpsih_c * self.masks[0, ...].view(1, M, N)
                        xpsih_mod = fft.fft2(
                            torch.view_as_complex(padc(xpsih_c.abs())))
                        xpsih_mod2 = fft.ifft2(xpsih_mod *
                                               torch.conj(xpsih_mod))
                        xpsih_mod2 = torch.real(
                            xpsih_mod2[0, ...]) * self.masks_shift[-1, ...]
                        Sout[idxb, -4 + hid, ...] = xpsih_mod2

                # submean from spatial M N
                xphi0_c = self.subinitmeanJ(xphi_c)
                xphi0_c = self.divinitstdJ(xphi0_c)
                xphi0_c = xphi0_c * self.masks[-1, ...].view(1, M, N)
                xphi0_mod = fft.fft2(torch.view_as_complex(padc(
                    xphi0_c.abs())))  # (nb,nc,M,N)
                xphi0_mod2 = fft.ifft2(xphi0_mod *
                                       torch.conj(xphi0_mod))  # (nb,nc,M,N)
                xphi0_mean = torch.real(xphi0_mod2) * self.masks_shift[
                    -1, ...].view(1, M, N)
                '''
                # low-high corr
                l_h_1 = padc(xpsi_bc0_).fft(2)
                l_h_2 = self.subinitmeanin(hatx_bc)
                l_h_2 = self.divinitstdin(l_h_2)
                l_h = mulcu(l_h_1, conjugate(l_h_2)).ifft(2)[...,0]
                l_h = l_h * self.masks_shift[1,...].view(1,1,M,N)
                Sout[idxb, idxc, nb_channels+t:-(1+t),...] = l_h[0,...]
                '''
                Sout[idxb, -1, ...] = xphi0_mean[idxb, ...]

            Sout = Sout.view(nb, -1, M * N)[..., self.indices]
            Sout = Sout.view(-1)
            Sout = torch.cat(
                (Sout, input.mean().view(1), input.std().view(1))) * 1e-4

        return Sout
Exemple #9
0
 def preprocess(self, X):
     Z = F.normalize(X)
     return fft2(Z, norm='ortho', dim=(2, 3))
Exemple #10
0
def phase_congruency(
    x: Tensor,
    filters: Tensor,
    value_range: float = 1.,
    k: float = 2.,
    rescale: float = 1.7,
    eps: float = 1e-8,
) -> Tensor:
    r"""Returns the Phase Congruency (PC) of :math:`x`.

    Args:
        x: An input tensor, :math:`(N, 1, H, W)`.
        filters: The frequency domain filters, :math:`(S_1, S_2, H, W)`.
        value_range: The value range :math:`L` of the input (usually `1.` or `255`).

    Note:
        For the remaining arguments, refer to [Kovesi1999]_.

    Returns:
        The PC tensor, :math:`(N, H, W)`.

    Example:
        >>> x = torch.rand(5, 1, 256, 256)
        >>> filters = pc_filters(x)
        >>> pc = phase_congruency(x, filters)
        >>> pc.size()
        torch.Size([5, 256, 256])
    """

    x = x * (255. / value_range)

    # Filters
    M_hat = filters
    M = fft.ifft2(M_hat)
    M = cx.real(torch.view_as_real(M))

    # Even & odd (real and imaginary) responses
    eo = fft.ifft2(fft.fft2(x[:, None]) * M_hat)
    eo = torch.view_as_real(eo)

    # Amplitude
    A = cx.mod(eo)

    # Expected E^2
    A2 = A[:, 0]**2
    median_A2, _ = A2.flatten(-2).median(dim=-1)
    expect_A2 = median_A2 / math.log(2)

    expect_M2_hat = (M_hat[0]**2).mean(dim=(-1, -2))
    expect_MiMj = (M[:, None] * M[None, :]).sum(dim=(0, 1, 3, 4))

    expect_E2 = expect_A2 * expect_MiMj / expect_M2_hat

    # Threshold
    sigma_G = expect_E2.sqrt()
    mu_R = sigma_G * (math.pi / 2)**0.5
    sigma_R = sigma_G * (2 - math.pi / 2)**0.5

    T = mu_R + k * sigma_R
    T = T / rescale  # emprirical rescaling
    T = T[..., None, None]

    # Phase deviation
    FH = eo.sum(dim=1, keepdim=True)
    phi_eo = FH / (cx.mod(FH)[..., None] + eps)

    E = cx.dot(eo, phi_eo) - cx.dot(eo, cx.turn(phi_eo)).abs()
    E = E.sum(dim=1)

    # Phase congruency
    pc = (E - T).relu().sum(dim=1) / (A.sum(dim=(1, 2)) + eps)

    return pc