示例#1
0
    def __init__(self, xstep, dstep, opt=None, isc=None):
        """
        Parameters
        ----------
        xstep : bpdn (or similar interface) object
          Object handling X update step
        dstep : cmod (or similar interface) object
          Object handling D update step
        opt : :class:`DictLearn.Options` object
          Algorithm options
        isc : :class:`IterStatsConfig` object
          Iteration statistics and header display configuration
        """

        if opt is None:
            opt = DictLearn.Options()
        self.opt = opt

        if isc is None:
            isc = IterStatsConfig(isfld=[
                'Iter', 'ObjFunX', 'XPrRsdl', 'XDlRsdl', 'XRho', 'ObjFunD',
                'DPrRsdl', 'DDlRsdl', 'DRho', 'Time'
            ],
                                  isxmap={
                                      'ObjFunX': 'ObjFun',
                                      'XPrRsdl': 'PrimalRsdl',
                                      'XDlRsdl': 'DualRsdl',
                                      'XRho': 'Rho'
                                  },
                                  isdmap={
                                      'ObjFunD': 'DFid',
                                      'DPrRsdl': 'PrimalRsdl',
                                      'DDlRsdl': 'DualRsdl',
                                      'DRho': 'Rho'
                                  },
                                  evlmap={},
                                  hdrtxt=[
                                      'Itn', 'FncX', 'r_X', 's_X',
                                      u('ρ_X'), 'FncD', 'r_D', 's_D',
                                      u('ρ_D')
                                  ],
                                  hdrmap={
                                      'Itn': 'Iter',
                                      'FncX': 'ObjFunX',
                                      'r_X': 'XPrRsdl',
                                      's_X': 'XDlRsdl',
                                      u('ρ_X'): 'XRho',
                                      'FncD': 'ObjFunD',
                                      'r_D': 'DPrRsdl',
                                      's_D': 'DDlRsdl',
                                      u('ρ_D'): 'DRho'
                                  })
        self.isc = isc

        self.xstep = xstep
        self.dstep = dstep

        self.itstat = []
        self.j = 0
示例#2
0
    def hdrval(cls):
        """Construct dictionary mapping display column title to
        IterationStats entries.
        """

        hdrmap = {'Itn': 'Iter'}
        hdrmap.update(cls.hdrval_objfun)
        hdrmap.update({'r': 'PrimalRsdl', 's': 'DualRsdl', u('ρ'): 'Rho'})
        return hdrmap
示例#3
0
def hdrtxt(xmethod, dmethod, opt):
    """Return ``hdrtxt`` argument for ``.IterStatsConfig`` initialiser.
    """

    txt = ['Itn', 'Fnc', 'DFid', u('ℓ1'), 'Cnstr']
    if xmethod == 'admm':
        txt.extend(['r_X', 's_X', u('ρ_X')])
    else:
        if opt['CBPDN', 'BackTrack', 'Enabled']:
            txt.extend(['F_X', 'Q_X', 'It_X', 'L_X'])
        else:
            txt.append('L_X')
    if dmethod != 'fista':
        txt.extend(['r_D', 's_D', u('ρ_D')])
    else:
        if opt['CCMOD', 'BackTrack', 'Enabled']:
            txt.extend(['F_D', 'Q_D', 'It_D', 'L_D'])
        else:
            txt.append('L_D')
    return txt
示例#4
0
def hdrmap(xmethod, dmethod, opt):
    """Return ``hdrmap`` argument for ``.IterStatsConfig`` initialiser.
    """

    hdr = {
        'Itn': 'Iter',
        'Fnc': 'ObjFun',
        'DFid': 'DFid',
        u('ℓ1'): 'RegL1',
        'Cnstr': 'Cnstr'
    }
    if xmethod == 'admm':
        hdr.update({'r_X': 'XPrRsdl', 's_X': 'XDlRsdl', u('ρ_X'): 'XRho'})
    else:
        if opt['CBPDN', 'BackTrack', 'Enabled']:
            hdr.update({
                'F_X': 'X_F_Btrack',
                'Q_X': 'X_Q_Btrack',
                'It_X': 'X_ItBt',
                'L_X': 'X_L'
            })
        else:
            hdr.update({'L_X': 'X_L'})
    if dmethod != 'fista':
        hdr.update({'r_D': 'DPrRsdl', 's_D': 'DDlRsdl', u('ρ_D'): 'DRho'})
    else:
        if opt['CCMOD', 'BackTrack', 'Enabled']:
            hdr.update({
                'F_D': 'D_F_Btrack',
                'Q_D': 'D_Q_Btrack',
                'It_D': 'D_ItBt',
                'L_D': 'D_L'
            })
        else:
            hdr.update({'L_D': 'D_L'})
    return hdr
示例#5
0
class ConvBPDN(fista.FISTADFT):
    r"""
    Base class for FISTA algorithm for the Convolutional BPDN (CBPDN)
    :cite:`garcia-2018-convolutional1` problem.

    |

    .. inheritance-diagram:: ConvBPDN
       :parts: 2

    |

    The generic problem form is

    .. math::
       \mathrm{argmin}_\mathbf{x} \;
        f( \{ \mathbf{x}_m \} ) + \lambda g( \{ \mathbf{x}_m \} )

    where :math:`f = (1/2) \left\| \sum_m \mathbf{d}_m * \mathbf{x}_m -
    \mathbf{s} \right\|_2^2`, and :math:`g(\cdot)` is a penalty
    term or the indicator function of a constraint; with input
    image :math:`\mathbf{s}`, dictionary filters :math:`\mathbf{d}_m`,
    and coefficient maps :math:`\mathbf{x}_m`. It is solved via the
    FISTA formulation

    Proximal step

    .. math::
       \mathbf{x}_k = \mathrm{prox}_{t_k}(g) (\mathbf{y}_k - 1/L \nabla
       f(\mathbf{y}_k) ) \;\;.

    Combination step

    .. math::
       \mathbf{y}_{k+1} = \mathbf{x}_k + \left( \frac{t_k - 1}{t_{k+1}}
       \right) (\mathbf{x}_k - \mathbf{x}_{k-1}) \;\;,

    with :math:`t_{k+1} = \frac{1 + \sqrt{1 + 4 t_k^2}}{2}`.


    After termination of the :meth:`solve` method, attribute
    :attr:`itstat` is a list of tuples representing statistics of each
    iteration. The fields of the named tuple ``IterationStats`` are:

       ``Iter`` : Iteration number

       ``ObjFun`` : Objective function value

       ``DFid`` : Value of data fidelity term :math:`(1/2) \| \sum_m
       \mathbf{d}_m * \mathbf{x}_m - \mathbf{s} \|_2^2`

       ``RegL1`` : Value of regularisation term :math:`\sum_m \|
       \mathbf{x}_m \|_1`

       ``Rsdl`` : Residual

       ``L`` : Inverse of gradient step parameter

       ``Time`` : Cumulative run time
    """
    class Options(fista.FISTADFT.Options):
        r"""ConvBPDN algorithm options

        Options include all of those defined in
        :class:`.fista.FISTADFT.Options`, together with
        additional options:

          ``NonNegCoef`` : Flag indicating whether to force solution to
          be non-negative.

          ``NoBndryCross`` : Flag indicating whether all solution
          coefficients corresponding to filters crossing the image
          boundary should be forced to zero.

          ``L1Weight`` : An array of weights for the :math:`\ell_1`
          norm. The array shape must be such that the array is
          compatible for multiplication with the X/Y variables. If this
          option is defined, the regularization term is :math:`\lambda
          \sum_m \| \mathbf{w}_m \odot \mathbf{x}_m \|_1` where
          :math:`\mathbf{w}_m` denotes slices of the weighting array on
          the filter index axis.

        """

        defaults = copy.deepcopy(fista.FISTADFT.Options.defaults)
        defaults.update({'NonNegCoef': False, 'NoBndryCross': False})
        defaults.update({'L1Weight': 1.0})
        defaults.update({'L': 500.0})

        def __init__(self, opt=None):
            """
            Parameters
            ----------
            opt : dict or None, optional (default None)
              ConvBPDN algorithm options
            """

            if opt is None:
                opt = {}
            fista.FISTADFT.Options.__init__(self, opt)

        def __setitem__(self, key, value):
            """Set options."""

            fista.FISTADFT.Options.__setitem__(self, key, value)

    itstat_fields_objfn = ('ObjFun', 'DFid', 'RegL1')
    hdrtxt_objfn = ('Fnc', 'DFid', u('Regℓ1'))
    hdrval_objfun = {'Fnc': 'ObjFun', 'DFid': 'DFid', u('Regℓ1'): 'RegL1'}

    def __init__(self, D, S, lmbda=None, opt=None, dimK=None, dimN=2):
        """
        This class supports an arbitrary number of spatial dimensions,
        `dimN`, with a default of 2. The input dictionary `D` is either
        `dimN` + 1 dimensional, in which case each spatial component
        (image in the default case) is assumed to consist of a single
        channel, or `dimN` + 2 dimensional, in which case the final
        dimension is assumed to contain the channels (e.g. colour
        channels in the case of images). The input signal set `S` is
        either `dimN` dimensional (no channels, only one signal),
        `dimN` + 1 dimensional (either multiple channels or multiple
        signals), or `dimN` + 2 dimensional (multiple channels and
        multiple signals). Determination of problem dimensions is
        handled by :class:`.cnvrep.CSC_ConvRepIndexing`.


        |

        **Call graph**

        .. image:: ../_static/jonga/fista_cbpdn_init.svg
           :width: 20%
           :target: ../_static/jonga/fista_cbpdn_init.svg

        |


        Parameters
        ----------
        D : array_like
          Dictionary array
        S : array_like
          Signal array
        lmbda : float
          Regularisation parameter
        opt : :class:`ConvBPDN.Options` object
          Algorithm options
        dimK : 0, 1, or None, optional (default None)
          Number of dimensions in input signal corresponding to multiple
          independent signals
        dimN : int, optional (default 2)
          Number of spatial/temporal dimensions
        """

        # Set default options if none specified
        if opt is None:
            opt = ConvBPDN.Options()

        # Infer problem dimensions and set relevant attributes of self
        if not hasattr(self, 'cri'):
            self.cri = cr.CSC_ConvRepIndexing(D, S, dimK=dimK, dimN=dimN)

        # Set dtype attribute based on S.dtype and opt['DataType']
        self.set_dtype(opt, S.dtype)

        # Set default lambda value if not specified
        if lmbda is None:
            cri = cr.CSC_ConvRepIndexing(D, S, dimK=dimK, dimN=dimN)
            Df = sl.rfftn(D.reshape(cri.shpD), cri.Nv, axes=cri.axisN)
            Sf = sl.rfftn(S.reshape(cri.shpS), axes=cri.axisN)
            b = np.conj(Df) * Sf
            lmbda = 0.1 * abs(b).max()

        # Set l1 term scaling and weight array
        self.lmbda = self.dtype.type(lmbda)
        self.wl1 = np.asarray(opt['L1Weight'], dtype=self.dtype)

        # Call parent class __init__
        self.Xf = None
        xshape = self.cri.shpX
        super(ConvBPDN, self).__init__(xshape, S.dtype, opt)

        # Reshape D and S to standard layout
        self.D = np.asarray(D.reshape(self.cri.shpD), dtype=self.dtype)
        self.S = np.asarray(S.reshape(self.cri.shpS), dtype=self.dtype)

        # Compute signal in DFT domain
        self.Sf = sl.rfftn(self.S, None, self.cri.axisN)

        # Create byte aligned arrays for FFT calls
        self.Y = self.X.copy()
        self.X = sl.pyfftw_empty_aligned(self.Y.shape, dtype=self.dtype)
        self.X[:] = self.Y

        # Initialise auxiliary variable Vf: Create byte aligned arrays
        # for FFT calls
        self.Vf = sl.pyfftw_rfftn_empty_aligned(self.X.shape, self.cri.axisN,
                                                self.dtype)

        self.Xf = sl.rfftn(self.X, None, self.cri.axisN)
        self.Yf = self.Xf.copy()
        self.store_prev()
        self.Yfprv = self.Yf.copy() + 1e5

        self.setdict()

        # Initialization needed for back tracking (if selected)
        self.postinitialization_backtracking_DFT()

    def setdict(self, D=None):
        """Set dictionary array."""

        if D is not None:
            self.D = np.asarray(D, dtype=self.dtype)
        self.Df = sl.rfftn(self.D, self.cri.Nv, self.cri.axisN)

    def getcoef(self):
        """Get final coefficient array."""

        return self.X

    def eval_grad(self):
        """Compute gradient in Fourier domain."""

        # Compute D X - S
        Ryf = self.eval_Rf(self.Yf)
        # Compute D^H Ryf
        gradf = np.conj(self.Df) * Ryf

        # Multiple channel signal, multiple channel dictionary
        if self.cri.Cd > 1:
            gradf = np.sum(gradf, axis=self.cri.axisC, keepdims=True)

        return gradf

    def eval_Rf(self, Vf):
        """Evaluate smooth term in Vf."""

        return sl.inner(self.Df, Vf, axis=self.cri.axisM) - self.Sf

    def eval_proxop(self, V):
        """Compute proximal operator of :math:`g`."""

        return sl.shrink1(V, (self.lmbda / self.L) * self.wl1)

    def rsdl(self):
        """Compute fixed point residual in Fourier domain."""

        diff = self.Xf - self.Yfprv
        return sl.rfl2norm2(diff, self.X.shape, axis=self.cri.axisN)

    def eval_objfn(self):
        """Compute components of objective function as well as total
        contribution to objective function.
        """

        dfd = self.obfn_dfd()
        reg = self.obfn_reg()
        obj = dfd + reg[0]
        return (obj, dfd) + reg[1:]

    def obfn_dfd(self):
        r"""Compute data fidelity term :math:`(1/2) \| \sum_m
        \mathbf{d}_m * \mathbf{x}_m - \mathbf{s} \|_2^2`.
        This function takes into account the unnormalised DFT scaling,
        i.e. given that the variables are the DFT of multi-dimensional
        arrays computed via :func:`rfftn`, this returns the data fidelity
        term in the original (spatial) domain.
        """

        Ef = self.eval_Rf(self.Xf)
        return sl.rfl2norm2(Ef, self.S.shape, axis=self.cri.axisN) / 2.0

    def obfn_reg(self):
        """Compute regularisation term and contribution to objective
        function.
        """

        rl1 = np.linalg.norm((self.wl1 * self.X).ravel(), 1)
        return (self.lmbda * rl1, rl1)

    def obfn_f(self, Xf=None):
        r"""Compute data fidelity term :math:`(1/2) \| \sum_m
        \mathbf{d}_m * \mathbf{x}_m - \mathbf{s} \|_2^2`
        This is used for backtracking. Since the backtracking is
        computed in the DFT, it is important to preserve the
        DFT scaling.
        """

        if Xf is None:
            Xf = self.Xf

        Rf = self.eval_Rf(Xf)
        return 0.5 * np.linalg.norm(Rf.flatten(), 2)**2

    def reconstruct(self, X=None):
        """Reconstruct representation."""

        if X is None:
            X = self.X
        Xf = sl.rfftn(X, None, self.cri.axisN)
        Sf = np.sum(self.Df * Xf, axis=self.cri.axisM)
        return sl.irfftn(Sf, self.cri.Nv, self.cri.axisN)
示例#6
0
    def hdrtxt(cls):
        """Construct tuple of status display column titles."""

        return ('Itn', ) + cls.hdrtxt_objfn + ('r', 's', u('ρ'))
示例#7
0
    def solve(self):
        """Start (or re-start) optimisation. This method implements the
        framework for the alternation between `X` and `D` updates in a
        dictionary learning algorithm.

        If option ``Verbose`` is ``True``, the progress of the
        optimisation is displayed at every iteration. At termination
        of this method, attribute :attr:`itstat` is a list of tuples
        representing statistics of each iteration.

        Attribute :attr:`timer` is an instance of :class:`.util.Timer`
        that provides the following labelled timers:

          ``init``: Time taken for object initialisation by
          :meth:`__init__`

          ``solve``: Total time taken by call(s) to :meth:`solve`

          ``solve_wo_func``: Total time taken by call(s) to
          :meth:`solve`, excluding time taken to compute functional
          value and related iteration statistics
        """

        # Construct tuple of status display column titles and set status
        # display strings
        hdrtxt = ['Itn', 'Fnc', 'DFid', u('Regℓ1')]
        hdrstr, fmtstr, nsep = common.solve_status_str(
            hdrtxt, fwdth0=type(self).fwiter, fprec=type(self).fpothr)

        # Print header and separator strings
        if self.opt['Verbose']:
            if self.opt['StatusHeader']:
                print(hdrstr)
                print("-" * nsep)

        pobjs = []
        X = np.transpose(self.xstep.S.squeeze(), (2, 1, 0))[None]
        n_trials, n_channels, *sig_support = X.shape

        d_hat = np.transpose(self.getdict().squeeze(), (3, 2, 1, 0))
        n_atoms, n_channels, *atom_support = d_hat.shape
        z_slice = tuple([None, Ellipsis] + [
            slice(size_ax - size_atom_ax + 1)
            for size_ax, size_atom_ax in zip(sig_support, atom_support)
        ])
        Z_hat = self.getcoef().squeeze().swapaxes(0, 2)[z_slice]
        pobjs.append(
            compute_X_and_objective(X, Z_hat, d_hat, reg=self.xstep.lmbda))

        # Reset timer
        self.timer.start(['solve', 'solve_wo_eval'])

        # Create process pool
        if self.nproc > 0:
            self.pool = mp.Pool(processes=self.nproc)

        for self.j in range(self.j, self.j + self.opt['MaxMainIter']):

            # Perform a set of update steps
            self.step()

            # Evaluate functional
            self.timer.stop('solve_wo_eval')
            fnev = self.evaluate()
            self.timer.start('solve_wo_eval')

            # Record iteration stats
            tk = self.timer.elapsed('solve')
            itst = self.IterationStats(*((self.j, ) + fnev + (tk, )))
            self.itstat.append(itst)

            self.timer.stop(['solve', 'solve_wo_eval'])
            d_hat = np.transpose(self.getdict().squeeze(), (3, 2, 1, 0))
            Z_hat = self.getcoef().squeeze().swapaxes(0, 2)[z_slice]
            pobjs.append(
                compute_X_and_objective(X, Z_hat, d_hat, reg=self.xstep.lmbda))
            tk = self.timer.elapsed('solve')
            print("[Wohlberg:PROGRESS] Iteration {} - {:.3e} ({:.0f}s)".format(
                self.j, pobjs[-1], tk))
            self.timer.start(['solve', 'solve_wo_eval'])

            # Display iteration stats if Verbose option enabled
            # if self.opt['Verbose']:
            #     print(fmtstr % itst[:-1])

            # Call callback function if defined
            if self.opt['Callback'] is not None:
                if self.opt['Callback'](self):
                    break

        # Clean up process pool
        if self.nproc > 0:
            self.pool.close()
            self.pool.join()

        # Increment iteration count
        self.j += 1

        # Record solve time
        self.timer.stop(['solve', 'solve_wo_eval'])

        # Print final separator string if Verbose option enabled
        if self.opt['Verbose'] and self.opt['StatusHeader']:
            print("-" * nsep)

        # Return final dictionary
        return self.getdict(), pobjs
示例#8
0
    def solve(self):
        """Start (or re-start) optimisation. This method implements the
        framework for the alternation between `X` and `D` updates in a
        dictionary learning algorithm.

        If option ``Verbose`` is ``True``, the progress of the
        optimisation is displayed at every iteration. At termination
        of this method, attribute :attr:`itstat` is a list of tuples
        representing statistics of each iteration.

        Attribute :attr:`timer` is an instance of :class:`.util.Timer`
        that provides the following labelled timers:

          ``init``: Time taken for object initialisation by
          :meth:`__init__`

          ``solve``: Total time taken by call(s) to :meth:`solve`

          ``solve_wo_func``: Total time taken by call(s) to
          :meth:`solve`, excluding time taken to compute functional
          value and related iteration statistics
        """

        # Construct tuple of status display column titles and set status
        # display strings
        hdrtxt = ['Itn', 'Fnc', 'DFid', u('Regℓ1')]
        hdrstr, fmtstr, nsep = common.solve_status_str(
            hdrtxt, fwdth0=type(self).fwiter, fprec=type(self).fpothr)

        # Print header and separator strings
        if self.opt['Verbose']:
            if self.opt['StatusHeader']:
                print(hdrstr)
                print("-" * nsep)

        # Reset timer
        self.timer.start(['solve', 'solve_wo_eval'])

        # Create process pool
        if self.nproc > 0:
            self.pool = mp.Pool(processes=self.nproc)

        for self.j in range(self.j, self.j + self.opt['MaxMainIter']):

            # Perform a set of update steps
            self.step()

            # Evaluate functional
            self.timer.stop('solve_wo_eval')
            fnev = self.evaluate()
            self.timer.start('solve_wo_eval')

            # Record iteration stats
            tk = self.timer.elapsed('solve')
            itst = self.IterationStats(*((self.j, ) + fnev + (tk, )))
            self.itstat.append(itst)

            # Display iteration stats if Verbose option enabled
            if self.opt['Verbose']:
                print(fmtstr % itst[:-1])

            # Call callback function if defined
            if self.opt['Callback'] is not None:
                if self.opt['Callback'](self):
                    break

        # Clean up process pool
        if self.nproc > 0:
            self.pool.close()
            self.pool.join()

        # Increment iteration count
        self.j += 1

        # Record solve time
        self.timer.stop(['solve', 'solve_wo_eval'])

        # Print final separator string if Verbose option enabled
        if self.opt['Verbose'] and self.opt['StatusHeader']:
            print("-" * nsep)

        # Return final dictionary
        return self.getdict()
示例#9
0
class ParConvBPDN(GenericConvBPDN):
    r"""
    Parallel ADMM algorithm for Convolutional BPDN (CBPDN) with or
    without a spatial mask :cite:`skau-2018-fast`.

    |

    .. inheritance-diagram:: ParConvBPDN
       :parts: 2

    |

    Solve the optimisation problem

    .. math::
       \mathrm{argmin}_\mathbf{x} \;
       (1/2) \left\|  W \left(\sum_m \mathbf{d}_m * \mathbf{x}_m -
       \mathbf{s}\right) \right\|_2^2 + \lambda \sum_m
       \| \mathbf{x}_m \|_1 \;\;,

    where :math:`W` is a mask array, via the ADMM problem

    .. math::
       \mathrm{argmin}_{\mathbf{x},\mathbf{y}_0,\mathbf{y}_1} \;
       (1/2) \| W \left( \sum_l \mathbf{y}_{0,l} - \mathbf{s} \right)
       \|_2^2 + \lambda \| \mathbf{y}_1 \|_1 \;\text{such that}\;
       \left( \begin{array}{c} D_{G_0} \\ \vdots \\ D_{G_{L-1}} \\
       \alpha I \end{array} \right) \mathbf{x} - \left( \begin{array}{c}
       \mathbf{y}_{0,0} \\ \vdots \\ \mathbf{y}_{0,L-1} \\ \alpha
       \mathbf{y}_1 \end{array} \right) = \left( \begin{array}{c}
       \mathbf{0} \\ \vdots \\ \mathbf{0} \\ \mathbf{0} \end{array}
       \right) \;\;,

    where the :math:`M` dictionary filters are partitioned into
    :math:`L` groups, :math:`\{G_l\}_{l \in \{0,\dots,L-1\}}` where

    .. math::
       G_i \cap G_j = \emptyset \text{ for } i \neq j \text{
       and } \bigcup_l G_l = \{0, \dots, M-1\} \;,

    and :math:`D_{G_l}` is a linear operator such that :math:`D_{G_l}
    \mathbf{x} = \sum_{g \in G_l} \mathbf{d}_g * \mathbf{x}_g`.

    Multi-image and multi-channel problems are also supported. The
    multi-image problem is

    .. math::
       \mathrm{argmin}_\mathbf{x} \;
       (1/2) \sum_k \left\| W_k \left( \sum_m \mathbf{d}_m *
       \mathbf{x}_{k,m} - \mathbf{s}_k \right) \right\|_2^2 + \lambda
       \sum_k \sum_m \| \mathbf{x}_{k,m} \|_1

    with input images :math:`\mathbf{s}_k`, masks :math:`W_k`, and
    coefficient maps :math:`\mathbf{x}_{k,m}`. The multi-channel
    problem with input image channels :math:`\mathbf{s}_c` and a
    multi-channel mask :math:`W_c` is either

    .. math::
       \mathrm{argmin}_\mathbf{x} \;
       (1/2) \sum_c \left\| W_c \left( \sum_m \mathbf{d}_m *
       \mathbf{x}_{c,m} - \mathbf{s}_c \right) \right\|_2^2 +
       \lambda \sum_c \sum_m \| \mathbf{x}_{c,m} \|_1

    with single-channel dictionary filters :math:`\mathbf{d}_m` and
    multi-channel coefficient maps :math:`\mathbf{x}_{c,m}`, or

    .. math::
       \mathrm{argmin}_\mathbf{x} \;
       (1/2) \sum_c \left\| W_c \left( \sum_m \mathbf{d}_{c,m} *
       \mathbf{x}_m - \mathbf{s}_c \right) \right\|_2^2 + \lambda
       \sum_m \| \mathbf{x}_m \|_1

    with multi-channel dictionary filters :math:`\mathbf{d}_{c,m}` and
    single-channel coefficient maps :math:`\mathbf{x}_m`.

    After termination of the :meth:`solve` method, AttributeError
    :attr:`itstat` is a list of tuples representing statistics of each
    iteration. The fields of the named tuple ``IterationStats`` are:

       ``Iter`` : Iteration number

       ``ObjFun`` : Objective function value

       ``DFid`` : Value of data fidelity term :math:`(1/2) \| W \left(
       \sum_m \mathbf{d}_m * \mathbf{x}_m - \mathbf{s} \right) \|_2^2`

       ``RegL1`` : Value of regularisation term :math:`\sum_m \|
       \mathbf{x}_m \|_1`

       ``PrimalRsdl`` : Norm of primal residual

       ``DualRsdl`` : Norm of dual residual

       ``EpsPrimal`` : Primal residual stopping tolerance
       :math:`\epsilon_{\mathrm{pri}}`

       ``EpsDual`` : Dual residual stopping tolerance
       :math:`\epsilon_{\mathrm{dua}}`

       ``Rho`` : Penalty parameter

       ``XSlvRelRes`` : Not Implemented (relative residual of X step solver)

       ``Time`` : Cumulative run time
    """
    class Options(GenericConvBPDN.Options):
        r"""ParConvBPDN algorithm options

        Options include all of those defined in
        :class:`.admm.ADMMEqual.Options`, together with additional options:

          ``alpha`` : A float indicating the relative weight between
          the constraint :math:`D_{G_l} \mathbf{x} = \mathbf{y}_{0,l}`
          and :math:`\alpha \mathbf{x} = \mathbf{y}_1`. None value
          effectively defaults to no weight or :math:`\alpha = 1`.

          ``Y0`` : Initial value for :math:`\mathbf{y}_0`.

          ``U0`` : Initial value for :math:`\mathbf{u}_0`.

          ``Y1`` : Initial value for :math:`\mathbf{y}_1`.

          ``U1`` : Initial value for :math:`\mathbf{u}_1`.


        and the exceptions:

          ``AutoRho`` : Not implemented.

          ``LinSolveCheck`` : Not implemented.

        """
        defaults = copy.deepcopy(GenericConvBPDN.Options.defaults)
        defaults.update({
            'L1Weight': 1.0,
            'alpha': None,
            'Y1': None,
            'U1': None
        })

        def __init__(self, opt=None):
            """
            Parameters
            ----------
            opt : dict or None, optional (default None)
               ParConvBPDN algorithm options
            """

            if opt is None:
                opt = {}
            GenericConvBPDN.Options.__init__(self, opt)

    itstat_fields_objfn = ('ObjFun', 'DFid', 'RegL1')
    hdrtxt_objfn = ('Fnc', 'DFid', u('Regl1'))
    hdrval_objfun = {'Fnc': 'ObjFun', 'DFid': 'DFid', u('Regl1'): 'RegL1'}

    def __init__(self,
                 D,
                 S,
                 lmbda=None,
                 W=None,
                 opt=None,
                 nproc=None,
                 ngrp=None,
                 dimK=None,
                 dimN=2):
        """
        Parameters
        ----------
        D : array_like
          Dictionary matrix
        S : array_like
          Signal vector or matrix
        lmbda : float
          Regularisation parameter
        W : array_like
          Mask array. The array shape must be such that the array is
          compatible for multiplication with input array S (see
          :func:`.cnvrep.mskWshape` for more details).
        opt : :class:`ParConvBPDN.Options` object
          Algorithm options
        nproc : int
          Number of processes
        ngrp : int
          Number of groups in partition of filter indices
        dimK : 0, 1, or None, optional (default None)
          Number of dimensions in input signal corresponding to multiple
          independent signals
        dimN : int, optional (default 2)
          Number of spatial dimensions
        """

        self.pool = None

        # Set default options if none specified
        if opt is None:
            opt = ParConvBPDN.Options()

        # Set dtype attribute based on S.dtype and opt['DataType']
        self.set_dtype(opt, S.dtype)

        # Set default lambda value if not specified
        if lmbda is None:
            cri = cr.CSC_ConvRepIndexing(D, S, dimK=dimK, dimN=dimN)
            Df = sl.rfftn(D.reshape(cri.shpD), cri.Nv, axes=cri.axisN)
            Sf = sl.rfftn(S.reshape(cri.shpS), axes=cri.axisN)
            b = np.conj(Df) * Sf
            lmbda = 0.1 * abs(b).max()

        # Set l1 term scaling and weight array
        self.lmbda = self.dtype.type(lmbda)

        # Set penalty parameter
        self.set_attr('rho',
                      opt['rho'],
                      dval=(50.0 * self.lmbda + 1.0),
                      dtype=self.dtype)
        self.set_attr('alpha', opt['alpha'], dval=1.0, dtype=self.dtype)

        # Set rho_xi attribute (see Sec. VI.C of wohlberg-2015-adaptive)
        # if self.lmbda != 0.0:
        #     rho_xi = (1.0 + (18.3)**(np.log10(self.lmbda) + 1.0))
        # else:
        #     rho_xi = 1.0
        # self.set_attr('rho_xi', opt['AutoRho', 'RsdlTarget'], dval=rho_xi,
        #               dtype=self.dtype)

        # Call parent class __init__
        super(ParConvBPDN, self).__init__(D, S, opt, dimK, dimN)

        if nproc is None:
            if ngrp is None:
                self.nproc = min(mp.cpu_count(), self.cri.M)
                self.ngrp = self.nproc
            else:
                self.nproc = min(mp.cpu_count(), ngrp, self.cri.M)
                self.ngrp = ngrp
        else:
            if ngrp is None:
                self.ngrp = nproc
                self.nproc = nproc
            else:
                self.ngrp = ngrp
                self.nproc = nproc

        if W is None:
            W = np.array([1.0], dtype=self.dtype)
        self.W = np.asarray(W.reshape(cr.mskWshape(W, self.cri)),
                            dtype=self.dtype)
        self.wl1 = np.asarray(opt['L1Weight'], dtype=self.dtype)
        self.wl1 = self.wl1.reshape(cr.l1Wshape(self.wl1, self.cri))

        self.xrrs = None

        # Initialise global variables
        # Conv Rep Indexing and parameter values for multiprocessing
        global mp_nproc
        mp_nproc = self.nproc
        global mp_ngrp
        mp_ngrp = self.ngrp
        global mp_Nv
        mp_Nv = self.cri.Nv
        global mp_axisN
        mp_axisN = tuple(i + 1 for i in self.cri.axisN)
        global mp_C
        mp_C = self.cri.C
        global mp_Cd
        mp_Cd = self.cri.Cd
        global mp_axisC
        mp_axisC = self.cri.axisC + 1
        global mp_axisM
        mp_axisM = 0
        global mp_NonNegCoef
        mp_NonNegCoef = self.opt['NonNegCoef']
        global mp_NoBndryCross
        mp_NoBndryCross = self.opt['NoBndryCross']
        global mp_Dshp
        mp_Dshp = self.D.shape

        # Parameters for optimization
        global mp_lmbda
        mp_lmbda = self.lmbda
        global mp_rho
        mp_rho = self.rho
        global mp_alpha
        mp_alpha = self.alpha
        global mp_rlx
        mp_rlx = self.rlx
        global mp_wl1
        init_mpraw('mp_wl1', np.moveaxis(self.wl1, self.cri.axisM, mp_axisM))

        # Matrices used in optimization
        global mp_S
        init_mpraw('mp_S',
                   np.moveaxis(self.S * self.W**2, self.cri.axisM, mp_axisM))
        global mp_Df
        init_mpraw('mp_Df', np.moveaxis(self.Df, self.cri.axisM, mp_axisM))
        global mp_X
        init_mpraw('mp_X', np.moveaxis(self.Y, self.cri.axisM, mp_axisM))
        shp_X = list(mp_X.shape)
        global mp_Xnr
        mp_Xnr = mpraw_as_np(mp_X.shape, mp_X.dtype)
        global mp_Y0
        shp_Y0 = shp_X[:]
        shp_Y0[0] = self.ngrp
        shp_Y0[mp_axisC] = mp_C
        if self.opt['Y0'] is not None:
            init_mpraw(
                'Y0',
                np.moveaxis(self.opt['Y0'].astype(self.dtype, copy=True),
                            self.cri.axisM, mp_axisM))
        else:
            mp_Y0 = mpraw_as_np(shp_Y0, mp_X.dtype)
        global mp_Y0old
        mp_Y0old = mpraw_as_np(shp_Y0, mp_X.dtype)
        global mp_Y1
        if self.opt['Y1'] is not None:
            init_mpraw(
                'Y1',
                np.moveaxis(self.opt['Y1'].astype(self.dtype, copy=True),
                            self.cri.axisM, mp_axisM))
        else:
            mp_Y1 = mpraw_as_np(shp_X, mp_X.dtype)
        global mp_Y1old
        mp_Y1old = mpraw_as_np(shp_X, mp_X.dtype)
        global mp_U0
        if self.opt['U0'] is not None:
            init_mpraw(
                'U0',
                np.moveaxis(self.opt['U0'].astype(self.dtype, copy=True),
                            self.cri.axisM, mp_axisM))
        else:
            mp_U0 = mpraw_as_np(shp_Y0, mp_X.dtype)
        global mp_U1
        if self.opt['U1'] is not None:
            init_mpraw(
                'U1',
                np.moveaxis(self.opt['U1'].astype(self.dtype, copy=True),
                            self.cri.axisM, mp_axisM))
        else:
            mp_U1 = mpraw_as_np(shp_X, mp_X.dtype)
        global mp_DX
        mp_DX = mpraw_as_np(shp_Y0, mp_X.dtype)
        global mp_DXnr
        mp_DXnr = mpraw_as_np(shp_Y0, mp_X.dtype)

        # Variables used to solve the optimization efficiently
        global mp_inv_off_diag
        if self.W.ndim is self.cri.axisM + 1:
            init_mpraw(
                'mp_inv_off_diag',
                np.moveaxis(
                    -self.W**2 / (mp_rho * (mp_rho + self.W**2 * mp_ngrp)),
                    self.cri.axisM, mp_axisM))
        else:
            init_mpraw('mp_inv_off_diag',
                       -self.W**2 / (mp_rho * (mp_rho + self.W**2 * mp_ngrp)))
        global mp_grp
        mp_grp = [
            np.min(i)
            for i in np.array_split(np.array(range(self.cri.M)), mp_ngrp)
        ] + [
            self.cri.M,
        ]
        global mp_cache
        if self.opt['HighMemSolve'] and self.cri.Cd == 1:
            mp_cache = [
                sl.solvedbi_sm_c(mp_Df[k], np.conj(mp_Df[k]), mp_alpha**2,
                                 mp_axisM)
                for k in np.array_split(np.array(range(self.cri.M)), self.ngrp)
            ]
        else:
            mp_cache = [None for k in mp_grp]
        global mp_b
        shp_b = shp_Y0[:]
        shp_b[0] = 1
        mp_b = mpraw_as_np(shp_b, mp_X.dtype)

        # Residual and stopping criteria variables
        global mp_ry0
        mp_ry0 = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_ry1
        mp_ry1 = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_sy0
        mp_sy0 = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_sy1
        mp_sy1 = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_nrmAx
        mp_nrmAx = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_nrmBy
        mp_nrmBy = mpraw_as_np((self.ngrp, ), mp_X.dtype)
        global mp_nrmu
        mp_nrmu = mpraw_as_np((self.ngrp, ), mp_X.dtype)

    def solve(self):
        """Start (or re-start) optimisation. This method implements the
        framework for the iterations of an ADMM algorithm.

        If option ``Verbose`` is ``True``, the progress of the
        optimisation is displayed at every iteration. At termination
        of this method, attribute :attr:`itstat` is a list of tuples
        representing statistics of each iteration, unless option
        ``FastSolve`` is ``True`` and option ``Verbose`` is ``False``.

        Attribute :attr:`timer` is an instance of :class:`.util.Timer`
        that provides the following labelled timers:

          ``init``: Time taken for object initialisation by
          :meth:`__init__`

          ``solve``: Total time taken by call(s) to :meth:`solve`

          ``solve_wo_func``: Total time taken by call(s) to
          :meth:`solve`, excluding time taken to compute functional
          value and related iteration statistics

          ``solve_wo_rsdl`` : Total time taken by call(s) to
          :meth:`solve`, excluding time taken to compute functional
          value and related iteration statistics as well as time take
          to compute residuals and implemented ``AutoRho`` mechanism
        """

        global mp_Y0old
        global mp_Y1old

        self.init_pool()

        fmtstr, nsep = self.display_start()

        # Start solve timer
        self.timer.start(['solve', 'solve_wo_func', 'solve_wo_rsdl'])

        first_iteration = self.k
        last_iteration = self.k + self.opt['MaxMainIter'] - 1
        # Main optimisation iterations
        for self.k in range(self.k, self.k + self.opt['MaxMainIter']):
            mp_Y0old[:] = np.copy(mp_Y0)
            mp_Y1old[:] = np.copy(mp_Y1)

            # Perform the variable updates.
            if self.k is first_iteration:
                self.distribute(par_initial_stepgrp, mp_ngrp)
            y0astep()
            if self.k is last_iteration:
                self.distribute(par_final_stepgrp, mp_ngrp)
            else:
                self.distribute(par_stepgrp, mp_ngrp)

            # Compute the residual variables
            self.timer.stop('solve_wo_rsdl')
            if self.opt['AutoRho', 'Enabled'] or not self.opt['FastSolve']:
                self.distribute(par_compute_residuals, mp_ngrp)
                r = np.sqrt(np.sum(mp_ry0) + np.sum(mp_ry1))
                s = np.sqrt(np.sum(mp_sy0) + np.sum(mp_sy1))

                epri = np.sqrt(self.Nc) * self.opt['AbsStopTol'] + \
                  np.max([np.sqrt(np.sum(mp_nrmAx)),
                          np.sqrt(np.sum(mp_nrmBy))]) * self.opt['RelStopTol']

                edua = np.sqrt(self.Nx) * self.opt['AbsStopTol'] + \
                  np.sqrt(np.sum(mp_nrmu)) * self.opt['RelStopTol']

            # Compute and record other iteration statistics and
            # display iteration stats if Verbose option enabled
            self.timer.stop(['solve_wo_func', 'solve_wo_rsdl'])
            if not self.opt['FastSolve']:
                itst = self.iteration_stats(self.k, r, s, epri, edua)
                self.itstat.append(itst)
                self.display_status(fmtstr, itst)
            self.timer.start(['solve_wo_func', 'solve_wo_rsdl'])

            # Automatic rho adjustment
            # self.timer.stop('solve_wo_rsdl')
            # if self.opt['AutoRho', 'Enabled'] or not self.opt['FastSolve']:
            #     self.update_rho(self.k, r, s)
            # self.timer.start('solve_wo_rsdl')

            # Call callback function if defined
            if self.opt['Callback'] is not None:
                if self.opt['Callback'](self):
                    break

            # Stop if residual-based stopping tolerances reached
            if self.opt['AutoRho', 'Enabled'] or not self.opt['FastSolve']:
                if r < epri and s < edua:
                    break

        # Increment iteration count
        self.k += 1

        # Record solve time
        self.timer.stop(['solve', 'solve_wo_func', 'solve_wo_rsdl'])

        # Print final separator string if Verbose option enabled
        self.display_end(nsep)

        self.Y = np.moveaxis(mp_Y1, mp_axisM, self.cri.axisM)
        self.X = np.moveaxis(mp_X, mp_axisM, self.cri.axisM)

        self.terminate_pool()

        return self.getmin()

    def init_pool(self):
        """Initialize multiprocessing pool if necessary."""

        # initialize the pool if needed
        if self.pool is None:
            if self.nproc > 1:
                self.pool = mp.Pool(processes=self.nproc)
            else:
                self.pool = None
        else:
            print('pool already initialized?')

    def distribute(self, f, n):
        """Distribute the computations amongst the multiprocessing pools

        Parameters
        ----------
        f : function
          Function to be distributed to the processors
        n : int
          The values in range(0,n) will be passed as arguments to the
          function f.
        """

        if self.pool is None:
            return [f(i) for i in range(n)]
        else:
            return self.pool.map(f, range(n))

    def terminate_pool(self):
        """Terminate and close the multiprocessing pool if necessary."""

        if self.pool is not None:
            self.pool.terminate()
            self.pool.join()
            del (self.pool)
            self.pool = None

    def obfn_gvar(self):
        """Variable to be evaluated in computing :meth:`ADMM.obfn_g`,
        depending on the ``gEvalY`` option value.
        """

        return mp_Y1 if self.opt['gEvalY'] else mp_X

    def obfn_fvar(self):
        """Variable to be evaluated in computing :meth:`ADMM.obfn_f`,
        depending on the ``fEvalX`` option value.
        """
        return mp_X if self.opt['fEvalX'] else mp_Y1

    def obfn_reg(self):
        r"""Compute regularisation term, :math:`\| x \|_1`, and
        contribution to objective function.
        """
        l1 = np.sum(mp_wl1 * np.abs(self.obfn_gvar()))
        return (self.lmbda * l1, l1)

    def obfn_dfd(self):
        r"""Compute data fidelity term :math:`(1/2) \| W \left( \sum_m
        \mathbf{d}_m * \mathbf{x}_m - \mathbf{s} \right) \|_2^2`.
        """
        XF = sl.rfftn(self.obfn_fvar(), mp_Nv, mp_axisN)
        DX = np.moveaxis(
            sl.irfftn(sl.inner(mp_Df, XF, mp_axisM), mp_Nv, mp_axisN),
            mp_axisM, self.cri.axisM)
        return np.sum((self.W * (DX - self.S))**2) / 2.0