示例#1
0
    def update_jacobians(self):
        Jt = idct(self.Jv, axis=2, norm='ortho')
        # compute gradient
        tmp = myifftshift(self.ifft2(self.fft2(self.Jx) * self.hth_fft))
        JDelta_freq = tmp.real - self.Hn
        for freq in range(self.nfreq):
            # compute iuwt adjoint
            Js_l = self.Recomp(self.Ju[freq], self.nbw_recomp)
            # compute xt
            Jxtt = self.Jx[:, :, freq] - self.tau * (
                JDelta_freq[:, :, freq] + self.mu_s * self.alpha_s[freq] * Js_l
                + self.mu_l * self.alpha_l * Jt[:, :, freq])
            self.Jxt[:, :, freq] = heavy(self.xtt[:, :, freq]) * Jxtt
            # update u
            tmp_spat_scal_J = self.Decomp(
                2 * self.Jxt[:, :, freq] - self.Jx[:, :, freq],
                self.nbw_decomp)
            for b in self.nbw_decomp:
                Jutt = self.Ju[freq][b] + self.sigma * self.mu_s * self.alpha_s[
                    freq] * tmp_spat_scal_J[b]
                self.Ju[freq][b] = rect(self.utt[freq][b]) * Jutt
        # update v
        Jvtt = self.Jv + self.sigma * self.mu_l * self.alpha_l[
            ..., None] * dct(2 * self.Jxt - self.Jx, axis=2, norm='ortho')
        self.Jv = rect(self.vtt) * Jvtt
        self.Jx = self.Jxt.copy(order='F')
        # wmsesure
        self.wmselistsure.append(self.wmsesure())
        # psnrsure
        if any(self.truesky):
            self.psnrlistsure.append(self.psnrsure())

        return self.wmselistsure[-1]
示例#2
0
    def dx2_mu(self):
        dt_s = idct(self.dv2_s, axis=2, norm='ortho')

        # compute gradient
        tmp = myifftshift(self.ifft2(self.fft2(self.dx2_s) * self.hth_fft))
        Delta_freq = tmp.real  #- self.fty
        for freq in range(self.nfreq):

            # compute iuwt adjoint
            wstu = self.alpha_s[freq] * self.Recomp(
                self.u2[freq], self.nbw_recomp) + self.mu_s * self.alpha_s[
                    freq] * self.Recomp(self.du2_s[freq], self.nbw_recomp)

            # compute xt
            dxtt_s = self.dx2_s[:, :, freq] - self.tau * (
                Delta_freq[:, :, freq] + wstu +
                self.mu_l * self.alpha_l * dt_s[:, :, freq])
            self.dxt2_s[:, :, freq] = heavy(self.xtt2[:, :, freq]) * dxtt_s

            # update u
            tmp_spat_scal = self.Decomp(
                self.alpha_s[freq] *
                (2 * self.xt2[:, :, freq] - self.x2[:, :, freq]) +
                self.mu_s * self.alpha_s[freq] *
                (2 * self.dxt2_s[:, :, freq] - self.dx2_s[:, :, freq]),
                self.nbw_decomp)

            for b in self.nbw_decomp:
                dutt_s = self.du2_s[freq][b] + self.sigma * tmp_spat_scal[b]
                self.du2_s[freq][b] = rect(self.utt2[freq][b]) * dutt_s

#            if freq==0:
#                print('wstu1:',np.linalg.norm(wstu))
#                print('xtt:',np.linalg.norm(self.xtt2[:,:,freq] ))
#                print('xt:',np.linalg.norm(self.dxt2_s[:,:,freq] ))
#                print('')

# update v
        dvtt2_s = self.dv2_s + self.sigma * self.mu_l * self.alpha_l[
            ..., None] * dct(
                2 * self.dxt2_s - self.dx2_s, axis=2, norm='ortho')
        self.dv2_s = rect(self.vtt2) * dvtt2_s

        self.dx2_s = self.dxt2_s.copy(order='F')

        #        print('x:',np.linalg.norm(self.dx2_s ))

        dt_l = np.asfortranarray(
            idct(self.dv2_l * self.mu_l * self.alpha_l[..., None] +
                 self.v2 * self.alpha_l[..., None],
                 axis=2,
                 norm='ortho'))
        #        print('1:',np.linalg.norm(self.dv2_l))
        #        print('2:',np.linalg.norm(self.v2))
        #        print('3:',np.linalg.norm(dt_l))

        # compute gradient
        tmp = myifftshift(self.ifft2(self.fft2(self.dx2_l) * self.hth_fft))
        Delta_freq = tmp.real  #- self.fty

        for freq in range(self.nfreq):

            # compute iuwt adjoint
            wstu = self.mu_s * self.alpha_s[freq] * self.Recomp(
                self.du2_l[freq], self.nbw_recomp)

            # compute xt
            dxtt_l = self.dx2_l[:, :, freq] - self.tau * (
                Delta_freq[:, :, freq] + wstu + dt_l[:, :, freq])
            self.dxt2_l[:, :, freq] = heavy(self.xtt2[:, :, freq]) * dxtt_l

            # update u
            tmp_spat_scal = self.Decomp(
                self.mu_s * self.alpha_s[freq] *
                (2 * self.dxt2_l[:, :, freq] - self.dx2_l[:, :, freq]),
                self.nbw_decomp)

            for b in self.nbw_decomp:
                dutt_l = self.du2_l[freq][b] + self.sigma * tmp_spat_scal[b]
                self.du2_l[freq][b] = rect(self.utt2[freq][b]) * dutt_l

#            if freq==0:
#                print('4:',np.linalg.norm(dt_l[:,:,freq]))
#                print('')

# update v
        dvtt2_l = self.dv2_l + self.sigma * self.mu_l * self.alpha_l[
            ..., None] * dct(
                2 * self.dxt2_l - self.dx2_l, axis=2,
                norm='ortho') + self.sigma * self.alpha_l[..., None] * dct(
                    2 * self.xt2 - self.x2, axis=2, norm='ortho')
        self.dv2_l = rect(self.vtt2) * dvtt2_l

        self.dx2_l = self.dxt2_l.copy(order='F')