Esempio n. 1
0
    def xstep(self, S, lmbda, dimK):
        """Solve CSC problem for training data `S`."""

        if self.opt['CUDA_CBPDN']:
            Z = np.stack([
                cucbpdn.cbpdn(self.D.squeeze(), S[..., i], lmbda,
                              self.opt['CBPDN']) for i in range(S.shape[-1])
            ], axis=-2)
            Z = Z.reshape(self.cri.Nv + (1, self.cri.K, self.cri.M,))
            self.Z[:] = np.asarray(Z, dtype=self.dtype)
            self.Zf = sl.rfftn(self.Z, self.cri.Nv, self.cri.axisN)
            self.Sf = sl.rfftn(S.reshape(self.cri.shpS), self.cri.Nv,
                               self.cri.axisN)
            self.xstep_itstat = None
        elif self.opt['PAR_CBPDN']:
            popt = parcbpdn.ParConvBPDN.Options(dict(self.opt['CBPDN']))
            xstep = parcbpdn.ParConvBPDN(self.D.squeeze(), S, lmbda, opt=popt,
                                         dimK=dimK, dimN=self.cri.dimN)
            xstep.solve()
            self.Sf = xstep.Sf
            self.setcoef(xstep.getcoef())
            self.xstep_itstat = xstep.itstat[-1] if len(xstep.itstat) > 0 \
                                                 else None
        else:
            # Create X update object (external representation is expected!)
            xstep = cbpdn.ConvBPDN(self.D.squeeze(), S, lmbda,
                                   self.opt['CBPDN'], dimK=dimK,
                                   dimN=self.cri.dimN)
            xstep.solve()
            self.Sf = xstep.Sf
            self.setcoef(xstep.getcoef())
            self.xstep_itstat = xstep.itstat[-1] if len(xstep.itstat) > 0 \
                                                 else None
Esempio n. 2
0
 def test_01(self):
     Nr = 32
     Nc = 31
     Nd = 5
     M = 4
     D = np.random.randn(Nd, Nd, M).astype(np.float32)
     s = np.random.randn(Nr, Nc).astype(np.float32)
     lmbda = 1e-1
     opt = cbpdn.ConvBPDN.Options({'Verbose': False, 'MaxMainIter': 50,
                              'AutoRho': {'Enabled': False}})
     b = cbpdn.ConvBPDN(D, s, lmbda, opt)
     X1 = b.solve()
     X2 = cucbpdn.cbpdn(D, s, lmbda, opt)
     assert(sm.mse(X1, X2) < 1e-10)
Esempio n. 3
0
def run_cbpdn_gpu(D, sl, sh, lmbda, test_blob=None, opt=None):
    """Run GPU version of CBPDN.  Only supports grayscale images."""
    assert _cfg.TEST.DATASET.GRAY, 'Only grayscale images are supported'
    assert cucbpdn is not None, 'GPU CBPDN is not supported'
    opt = cbpdn.ConvBPDN.Options(opt)
    if test_blob is None:
        test_blob = sl + sh
    fnc, psnr = 0., 0.
    for idx in range(test_blob.shape[-1]):
        X = cucbpdn.cbpdn(D, sh[..., idx].squeeze(), lmbda, opt=opt)
        shr = np.sum(spl.fftconv(D, X), axis=2)
        dfd = linalg.norm(shr.ravel() - sh[..., idx].ravel())
        rl1 = linalg.norm(X.ravel(), 1)
        obj = dfd + lmbda * rl1
        fnc += obj
        imgr = sl[..., idx] + shr
        psnr += sm.psnr(imgr, test_blob[..., idx].squeeze(), rng=1.)
    psnr /= test_blob.shape[-1]
    if _cfg.VERBOSE:
        print('.', end='', flush=True)
    return fnc, psnr
Esempio n. 4
0
# Load dictionary
D = util.convdicts()['G:12x12x72']


# Set up ConvBPDN options
lmbda = 1e-2
opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 200,
                    'HighMemSolve': True, 'RelStopTol': 5e-3,
                    'AuxVarObj': False, 'AutoRho': {'Enabled': False},
                    'rho': 5.0})


# Time CUDA cbpdn solve
t = util.Timer()
with util.ContextTimer(t):
    X = cucbpdn.cbpdn(D, sh, lmbda, opt)
print("Image size: %d x %d" % sh.shape)
print("GPU ConvBPDN solve time: %.2fs" % t.elapsed())


# Reconstruct the image from the sparse representation
shr = np.sum(spl.fftconv(D, X), axis=2)
imgr = sl + shr


#Display representation and reconstructed image.
fig = plot.figure(figsize=(14, 14))
plot.subplot(2, 2, 1)
plot.imview(sl, fig=fig, title='Lowpass component')
plot.subplot(2, 2, 2)
plot.imview(np.sum(abs(X), axis=2).squeeze(), fig=fig,
Esempio n. 5
0
    def solve(self, S):
        self.cri = cr.CSC_ConvRepIndexing(self.getdict(), S,
                                          dimK=self.cri.dimK,
                                          dimN=self.cri.dimN)

        self.timer.start(['solve', 'solve_wo_eval'])

        # Initialize with CBPDN
        self.timer.start('xstep')
        copt = copy.deepcopy(self.opt['CBPDN'])
        if self.opt['OCDL', 'CUCBPDN']:
            X = np.stack([
                cucbpdn.cbpdn(self.getdict(), S[..., i].squeeze(),
                              self.lmbda, opt=copt) for i in range(S.shape[-1])
            ], axis=-2)
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        elif self.opt['OCDL', 'PARCBPDN']:
            popt = parcbpdn.ParConvBPDN.Options(dict(self.opt['CBPDN']))
            xstep = parcbpdn.ParConvBPDN(self.getdict(), S, self.lmbda,
                                         opt=popt,
                                         nproc=self.opt['OCDL', 'nproc'])
            X = xstep.solve()
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        else:
            xstep = cbpdn.ConvBPDN(self.getdict(), S, self.lmbda, opt=copt)
            xstep.solve()
            X = np.asarray(xstep.getcoef().reshape(self.cri.shpX),
                           dtype=self.dtype)
        self.timer.stop('xstep')

        # X = np.asarray(xstep.getcoef().reshape(self.cri.shpX), dtype=self.dtype)
        S = np.asarray(S.reshape(self.cri.shpS), dtype=self.dtype)

        # update At and Bt
        # (H, W, 1, K, M) -> (H, W, Hc, Wc, 1, K, M)
        self.timer.start('hessian')
        Xe = self.extend_code(X)
        self.update_At(Xe)
        self.update_Bt(Xe, S)
        self.timer.stop('hessian')
        self.Lmbda = self.dtype.type(self.alpha*self.Lmbda+1)

        # update dictionary with FISTA
        fopt = copy.deepcopy(self.opt['CCMOD'])
        fopt['X0'] = self.D
        if self.opt['OCDL', 'DiminishingTol']:
            if self.opt['OCDL', 'MinTol'] is None:
                min_tol = 0.
            else:
                min_tol = self.opt['OCDL', 'MinTol']
            fopt['RelStopTol'] = max(
                self.dtype.type(self.opt['CCMOD', 'RelStopTol']/(1.+self.j)),
                min_tol
            )
        self.timer.start('dstep')
        dstep = SpatialFISTA(self.At, self.Bt, opt=fopt)
        dstep.solve()
        self.timer.stop('dstep')

        # set dictionary
        self.setdict(dstep.getmin())

        self.timer.stop('solve_wo_eval')
        evl = self.evaluate(S, X)
        self.timer.start('solve_wo_eval')

        t = self.timer.elapsed(self.opt['IterTimer'])
        if self.opt['OCDL', 'CUCBPDN']:
            # this requires a slight modification of dictlrn
            itst = self.isc.iterstats(self.j, t, None, dstep.itstat[-1], evl)
        else:
            itst = self.isc.iterstats(self.j, t, xstep.itstat[-1],
                                      dstep.itstat[-1], evl)
        self.itstat.append(itst)

        if self.opt['Verbose']:
            self.isc.printiterstats(itst)

        self.j += 1

        self.timer.stop(['solve', 'solve_wo_eval'])

        if 0:
            import matplotlib.pyplot as plt
            plt.imshow(su.tiledict(self.getdict().squeeze()))
            plt.show()

        return self.getdict()
Esempio n. 6
0
    def solve(self, S, W=None):
        """Solve for given signal S, optionally with mask W."""
        self.cri = cr.CSC_ConvRepIndexing(self.D.squeeze()[:, :, None, None,
                                                           ...],
                                          S[:, :, None, None, ...],
                                          dimK=None,
                                          dimN=4)

        self.timer.start(['solve', 'solve_wo_eval'])

        # Initialize with CBPDN
        self.timer.start('xstep')
        copt = copy.deepcopy(self.opt['CBPDN'])
        if self.opt['OCDL', 'CUCBPDN']:
            X = cucbpdn.cbpdn(self.getdict(),
                              S.squeeze(),
                              self.lmbda,
                              opt=copt)
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        elif self.opt['OCDL', 'PARCBPDN']:
            popt = parcbpdn.ParConvBPDN.Options(dict(self.opt['CBPDN']))
            xstep = parcbpdn.ParConvBPDN(self.getdict(),
                                         S,
                                         self.lmbda,
                                         opt=popt,
                                         nproc=self.opt['OCDL', 'nproc'])
            X = xstep.solve()
            X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
        else:
            if W is None:
                xstep = cbpdn.ConvBPDN(self.getdict(), S, self.lmbda, opt=copt)
                xstep.solve()
                X = np.asarray(xstep.getcoef().reshape(self.cri.shpX),
                               dtype=self.dtype)
            else:
                xstep = cbpdn.AddMaskSim(cbpdn.ConvBPDN,
                                         self.getdict(),
                                         S,
                                         W,
                                         self.lmbda,
                                         opt=copt)
                X = xstep.solve()
                X = np.asarray(X.reshape(self.cri.shpX), dtype=self.dtype)
                # The additive component is removed from masked signal
                add_cpnt = reconstruct_additive_component(xstep)
                S -= add_cpnt.reshape(S.shape)

        self.timer.stop('xstep')

        # update At and Bt
        self.timer.start('hessian')
        patches = self.im2slices(S)
        self.update_At(X)
        self.update_Bt(X, patches)
        self.timer.stop('hessian')
        self.Lmbda = self.dtype.type(self.alpha * self.Lmbda + 1)

        # update dictionary with FISTA
        fopt = copy.deepcopy(self.opt['CCMOD'])
        fopt['X0'] = self.D
        if self.opt['OCDL', 'DiminishingTol']:
            fopt['RelStopTol'] = \
                self.dtype.type(self.opt['CCMOD', 'RelStopTol']/(1.+self.j))
        self.timer.start('dstep')
        dstep = StripeSliceFISTA(self.At, self.Bt, opt=fopt)
        dstep.solve()
        self.timer.stop('dstep')

        # set dictionary
        self.setdict(dstep.getmin())

        self.timer.stop('solve_wo_eval')
        evl = self.evaluate(S, X)
        self.timer.start('solve_wo_eval')

        t = self.timer.elapsed(self.opt['IterTimer'])
        if self.opt['OCDL', 'CUCBPDN']:
            # this requires a slight modification of dictlrn
            itst = self.isc.iterstats(self.j, t, None, dstep.itstat[-1], evl)
        else:
            itst = self.isc.iterstats(self.j, t, xstep.itstat[-1],
                                      dstep.itstat[-1], evl)
        self.itstat.append(itst)

        if self.opt['Verbose']:
            self.isc.printiterstats(itst)

        self.j += 1

        self.timer.stop(['solve', 'solve_wo_eval'])

        if 0:
            import matplotlib.pyplot as plt
            plt.imshow(su.tiledict(self.getdict().squeeze()))
            plt.show()

        return self.getdict()