예제 #1
0
 def test_10(self):
     N = 63
     M = 4
     Nd = 8
     D = np.random.randn(Nd, Nd, M)
     X0 = np.zeros((N, N, M))
     xr = np.random.randn(N, N, M)
     xp = np.abs(xr) > 3
     X0[xp] = np.random.randn(X0[xp].size)
     S = np.sum(ifftn(
         fftn(D, (N, N), (0, 1)) * fftn(X0, None, (0, 1)), None,
         (0, 1)).real,
                axis=2)
     lmbda = 1e-4
     alpha = 6
     rho = 3e-3
     opt = parcbpdn.ParConvBPDN.Options({
         'Verbose': False,
         'MaxMainIter': 1000,
         'RelStopTol': 1e-3,
         'rho': rho,
         'alpha': alpha,
         'AutoRho': {
             'Enabled': False
         }
     })
     b = parcbpdn.ParConvBPDN(D, S, lmbda, opt=opt)
     b.solve()
     X1 = b.Y.squeeze()
     assert sl.rrs(X0, X1) < 5e-5
     Sr = b.reconstruct().squeeze()
     assert sl.rrs(S, Sr) < 1e-4
예제 #2
0
    def xstep(self, S, lmbda, dimK):
        """Solve CSC problem for training data `S`."""

        if self.opt['CUDA_CBPDN']:
            Z = np.stack([
                cucbpdn.cbpdn(self.D.squeeze(), S[..., i], lmbda,
                              self.opt['CBPDN']) for i in range(S.shape[-1])
            ], axis=-2)
            Z = Z.reshape(self.cri.Nv + (1, self.cri.K, self.cri.M,))
            self.Z[:] = np.asarray(Z, dtype=self.dtype)
            self.Zf = sl.rfftn(self.Z, self.cri.Nv, self.cri.axisN)
            self.Sf = sl.rfftn(S.reshape(self.cri.shpS), self.cri.Nv,
                               self.cri.axisN)
            self.xstep_itstat = None
        elif self.opt['PAR_CBPDN']:
            popt = parcbpdn.ParConvBPDN.Options(dict(self.opt['CBPDN']))
            xstep = parcbpdn.ParConvBPDN(self.D.squeeze(), S, lmbda, opt=popt,
                                         dimK=dimK, dimN=self.cri.dimN)
            xstep.solve()
            self.Sf = xstep.Sf
            self.setcoef(xstep.getcoef())
            self.xstep_itstat = xstep.itstat[-1] if len(xstep.itstat) > 0 \
                                                 else None
        else:
            # Create X update object (external representation is expected!)
            xstep = cbpdn.ConvBPDN(self.D.squeeze(), S, lmbda,
                                   self.opt['CBPDN'], dimK=dimK,
                                   dimN=self.cri.dimN)
            xstep.solve()
            self.Sf = xstep.Sf
            self.setcoef(xstep.getcoef())
            self.xstep_itstat = xstep.itstat[-1] if len(xstep.itstat) > 0 \
                                                 else None
예제 #3
0
 def test_05(self):
     N = 16
     Nd = 5
     K = 2
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N, K)
     lmbda = 1e-1
     b = parcbpdn.ParConvBPDN(D, s, lmbda)
     assert b.cri.dimC == 0
     assert b.cri.dimK == 1
예제 #4
0
 def test_03(self):
     N = 16
     Nd = 5
     Cd = 3
     M = 4
     D = np.random.randn(Nd, Nd, Cd, M)
     s = np.random.randn(N, N, Cd)
     lmbda = 1e-1
     b = parcbpdn.ParConvBPDN(D, s, lmbda)
     assert b.cri.dimC == 1
     assert b.cri.dimK == 0
예제 #5
0
 def test_02(self):
     N = 16
     Nd = 5
     Cs = 3
     K = 5
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N, Cs, K)
     lmbda = 1e-1
     b = parcbpdn.ParConvBPDN(D, s, lmbda)
     assert b.cri.dimC == 1
     assert b.cri.dimK == 1
예제 #6
0
 def test_08(self):
     N = 16
     Nd = 5
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N)
     try:
         b = parcbpdn.ParConvBPDN(D, s)
         b.solve()
     except Exception as e:
         print(e)
         assert 0
예제 #7
0
 def test_14(self):
     N = 16
     Nd = 5
     Cd = 3
     M = 4
     D = np.random.randn(Nd, Nd, Cd, M)
     s = np.random.randn(N, N, Cd)
     lmbda = 1e-1
     try:
         b = parcbpdn.ParConvBPDN(D, s, lmbda)
         b.solve()
     except Exception as e:
         print(e)
         assert 0
예제 #8
0
 def test_17(self):
     N = 16
     Nd = 5
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N)
     w = np.ones(s.shape)
     lmbda = 1e-1
     try:
         b = parcbpdn.ParConvBPDN(D, s, lmbda, W=w)
         b.solve()
         b.reconstruct()
     except Exception as e:
         print(e)
         assert 0
예제 #9
0
 def test_19(self):
     N = 16
     Nd = 5
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N)
     lmbda = 1e-1
     opt = parcbpdn.ParConvBPDN.Options({
         'Verbose': False,
         'MaxMainIter': 10
     })
     b = parcbpdn.ParConvBPDN(D, s, lmbda, opt=opt)
     bp = pickle.dumps(b)
     c = pickle.loads(bp)
     Xb = b.solve()
     Xc = c.solve()
     assert np.linalg.norm(Xb - Xc) == 0.0
예제 #10
0
 def test_06(self):
     N = 16
     Nd = 5
     K = 2
     M = 4
     D = np.random.randn(Nd, Nd, M)
     s = np.random.randn(N, N, K)
     dt = np.float32
     opt = parcbpdn.ParConvBPDN.Options({
         'Verbose': False,
         'MaxMainIter': 20,
         'AutoRho': {
             'Enabled': True
         },
         'DataType': dt
     })
     lmbda = 1e-1
     b = parcbpdn.ParConvBPDN(D, s, lmbda, opt=opt)
     b.solve()
     assert b.X.dtype == dt
     assert b.Y.dtype == dt
     assert b.U.dtype == dt
예제 #11
0
X = b.solve()
"""
Initialise and run parallel CSC solver using ADMM dictionary partition method :cite:`skau-2018-fast`.
"""

opt_par = parcbpdn.ParConvBPDN.Options({
    'Verbose': True,
    'MaxMainIter': 200,
    'RelStopTol': 1e-2,
    'AuxVarObj': False,
    'AutoRho': {
        'Enabled': False
    },
    'alpha': 2.5
})
b_par = parcbpdn.ParConvBPDN(D, sh, lmbda, opt=opt_par, dimK=0)
X_par = b_par.solve()
"""
Report runtimes of different methods of solving the same problem.
"""

print("ConvBPDN solve time: %.2fs" % b.timer.elapsed('solve_wo_rsdl'))
print("ParConvBPDN solve time: %.2fs" % b_par.timer.elapsed('solve_wo_rsdl'))
print(
    "ParConvBPDN was %.2f times faster than ConvBPDN\n" %
    (b.timer.elapsed('solve_wo_rsdl') / b_par.timer.elapsed('solve_wo_rsdl')))
"""
Reconstruct images from sparse representations.
"""

shr = b.reconstruct().squeeze()
예제 #12
0
                            {'Enabled': False, 'StdResiduals': False}})
b = cbpdn.ConvBPDNMaskDcpl(D, sh, lmbda, mskp, opt=opt)
X = b.solve()


"""
Initialise and run parallel CSC solver using an ADMM dictionary partition :cite:`skau-2018-fast`.
"""

opt_par = parcbpdn.ParConvBPDN.Options({'Verbose': True, 'MaxMainIter': 200,
                            'HighMemSolve': True, 'RelStopTol': 1e-2,
                            'AuxVarObj': False, 'RelaxParam': 1.8,
                            'rho': 5e1*lmbda + 1e-1, 'alpha': 1.5,
                            'AutoRho': {'Enabled': False,
                            'StdResiduals': False}})
b_par = parcbpdn.ParConvBPDN(D, sh, lmbda, mskp, opt=opt_par)
X_par = b_par.solve()


"""
Report runtimes of different methods of solving the same problem.
"""

print("ConvBPDNMaskDcpl solve time: %.2fs" % b.timer.elapsed('solve_wo_rsdl'))
print("ParConvBPDN solve time: %.2fs" % b_par.timer.elapsed('solve_wo_rsdl'))
print("ParConvBPDN was %.2f times faster than ConvBPDNMaskDcpl\n" %
      (b.timer.elapsed('solve_wo_rsdl')/b_par.timer.elapsed('solve_wo_rsdl')))


"""
Reconstruct images from sparse representations.
예제 #13
0
    def solve(self, S, W=None):
        """Solve for given signal S, optionally with mask W."""
        self.cri = cr.CSC_ConvRepIndexing(self.D.squeeze()[:, :, None, None,
                                                           ...],
                                          S[:, :, None, None, ...],
                                          dimK=None,
                                          dimN=4)

        self.timer.start(['solve', 'solve_wo_eval'])

        # Initialize with CBPDN
        self.timer.start('xstep')
        copt = copy.deepcopy(self.opt['CBPDN'])
        if self.opt['OCDL', 'CUCBPDN']:
            X = cucbpdn.cbpdn(self.getdict(),
                              S.squeeze(),
                              self.lmbda,
                              opt=copt)
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        elif self.opt['OCDL', 'PARCBPDN']:
            popt = parcbpdn.ParConvBPDN.Options(dict(self.opt['CBPDN']))
            xstep = parcbpdn.ParConvBPDN(self.getdict(),
                                         S,
                                         self.lmbda,
                                         opt=popt,
                                         nproc=self.opt['OCDL', 'nproc'])
            X = xstep.solve()
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        else:
            if W is None:
                xstep = cbpdn.ConvBPDN(self.getdict(), S, self.lmbda, opt=copt)
                xstep.solve()
                X = np.asarray(xstep.getcoef().reshape(self.cri.shpX),
                               dtype=self.dtype)
            else:
                xstep = cbpdn.AddMaskSim(cbpdn.ConvBPDN,
                                         self.getdict(),
                                         S,
                                         W,
                                         self.lmbda,
                                         opt=copt)
                X = xstep.solve()
                X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
                # The additive component is removed from masked signal
                add_cpnt = reconstruct_additive_component(xstep)
                S -= add_cpnt.reshape(S.shape)

        self.timer.stop('xstep')

        # update At and Bt
        self.timer.start('hessian')
        patches = self.im2slices(S)
        self.update_At(X)
        self.update_Bt(X, patches)
        self.timer.stop('hessian')
        self.Lmbda = self.dtype.type(self.alpha * self.Lmbda + 1)

        # update dictionary with FISTA
        fopt = copy.deepcopy(self.opt['CCMOD'])
        fopt['X0'] = self.D
        if self.opt['OCDL', 'DiminishingTol']:
            fopt['RelStopTol'] = \
                self.dtype.type(self.opt['CCMOD', 'RelStopTol']/(1.+self.j))
        self.timer.start('dstep')
        dstep = StripeSliceFISTA(self.At, self.Bt, opt=fopt)
        dstep.solve()
        self.timer.stop('dstep')

        # set dictionary
        self.setdict(dstep.getmin())

        self.timer.stop('solve_wo_eval')
        evl = self.evaluate(S, X)
        self.timer.start('solve_wo_eval')

        t = self.timer.elapsed(self.opt['IterTimer'])
        if self.opt['OCDL', 'CUCBPDN']:
            # this requires a slight modification of dictlrn
            itst = self.isc.iterstats(self.j, t, None, dstep.itstat[-1], evl)
        else:
            itst = self.isc.iterstats(self.j, t, xstep.itstat[-1],
                                      dstep.itstat[-1], evl)
        self.itstat.append(itst)

        if self.opt['Verbose']:
            self.isc.printiterstats(itst)

        self.j += 1

        self.timer.stop(['solve', 'solve_wo_eval'])

        if 0:
            import matplotlib.pyplot as plt
            plt.imshow(su.tiledict(self.getdict().squeeze()))
            plt.show()

        return self.getdict()