Exemple #1
0
    def dstep(self, W):
        """Compute dictionary update for training data of preceding
        :meth:`xstep`.
        """

        # Compute residual X D - S in frequency domain
        Ryf = sl.inner(self.Zf, self.Df, axis=self.cri.axisM) - self.Sf
        # Transform to spatial domain, apply mask, and transform back to
        # frequency domain
        Ryf[:] = sl.rfftn(W * sl.irfftn(Ryf, self.cri.Nv, self.cri.axisN),
                          None, self.cri.axisN)
        # Compute gradient
        gradf = sl.inner(np.conj(self.Zf), Ryf, axis=self.cri.axisK)

        # If multiple channel signal, single channel dictionary
        if self.cri.C > 1 and self.cri.Cd == 1:
            gradf = np.sum(gradf, axis=self.cri.axisC, keepdims=True)

        # Update gradient step
        self.eta = self.eta_a / (self.j + self.eta_b)

        # Compute gradient descent
        self.Gf[:] = self.Df - self.eta * gradf
        self.G = sl.irfftn(self.Gf, self.cri.Nv, self.cri.axisN)

        # Eval proximal operator
        self.Dprv[:] = self.D
        self.D[:] = self.Pcn(self.G)
Exemple #2
0
    def init_vars(self, S, dimK):
        """Initalise variables required for sparse coding and dictionary
        update for training data `S`."""

        Nv = S.shape[0:self.dimN]
        if self.cri is None or Nv != self.cri.Nv:
            self.cri = cr.CDU_ConvRepIndexing(self.dsz, S, dimK, self.dimN)
            if self.opt['CUDA_CBPDN']:
                if self.cri.Cd > 1 or self.cri.Cx > 1:
                    raise ValueError('CUDA CBPDN solver can only be used for '
                                     'single channel problems')
                if self.cri.K > 1:
                    raise ValueError('CUDA CBPDN solver can not be used with '
                                     'mini-batches')
            self.Df = sl.pyfftw_byte_aligned(sl.rfftn(self.D, self.cri.Nv,
                                                      self.cri.axisN))
            self.Gf = sl.pyfftw_empty_aligned(self.Df.shape, self.Df.dtype)
            self.Z = sl.pyfftw_empty_aligned(self.cri.shpX, self.dtype)
        else:
            self.Df[:] = sl.rfftn(self.D, self.cri.Nv, self.cri.axisN)
Exemple #3
0
    def xstep(self, S, lmbda, dimK):
        """Solve CSC problem for training data `S`."""

        if self.opt['CUDA_CBPDN']:
            Z = cuda.cbpdn(self.D.squeeze(), S[..., 0], lmbda,
                           self.opt['CBPDN'])
            Z = Z.reshape(self.cri.Nv + (1, 1, self.cri.M,))
            self.Z[:] = np.asarray(Z, dtype=self.dtype)
            self.Zf = sl.rfftn(self.Z, self.cri.Nv, self.cri.axisN)
            self.Sf = sl.rfftn(S.reshape(self.cri.shpS), self.cri.Nv,
                               self.cri.axisN)
            self.xstep_itstat = None
        else:
            # Create X update object (external representation is expected!)
            xstep = cbpdn.ConvBPDN(self.D.squeeze(), S, lmbda,
                                   self.opt['CBPDN'], dimK=dimK,
                                   dimN=self.cri.dimN)
            xstep.solve()
            self.Sf = xstep.Sf
            self.setcoef(xstep.getcoef())
            self.xstep_itstat = xstep.itstat[-1] if xstep.itstat else None
Exemple #4
0
    def setcoef(self, Z):
        """Set coefficient array."""

        # If the dictionary has a single channel but the input (and
        # therefore also the coefficient map array) has multiple
        # channels, the channel index and multiple image index have
        # the same behaviour in the dictionary update equation: the
        # simplest way to handle this is to just reshape so that the
        # channels also appear on the multiple image index.
        if self.cri.Cd == 1 and self.cri.C > 1:
            Z = Z.reshape(self.cri.Nv + (1,) + (self.cri.Cx*self.cri.K,) +
                          (self.cri.M,))
            if Z.shape != self.Z.shape:
                self.Z = sl.pyfftw_empty_aligned(Z.shape, self.dtype)
        self.Z[:] = np.asarray(Z, dtype=self.dtype)
        self.Zf = sl.rfftn(self.Z, self.cri.Nv, self.cri.axisN)