Exemple #1
0
    def __init__(self, est1, est2, msg_hdl=None, hist_list=[], nit=10,
                 comp_cost=False, prt_period=0, is_complex=False, map_est=False):
        Solver.__init__(self, hist_list)
        self.est1 = est1
        self.est2 = est2
        if msg_hdl is None:
            msg_hdl = estim.MsgHdlSimp(is_complex=is_complex,
                                       map_est=map_est,
                                       shape=self.est1.shape)
        self.msg_hdl = msg_hdl
        self.nit = nit
        self.comp_cost = comp_cost
        self.prt_period = prt_period

        # Check if dimensions match
        if self.est1.shape != self.est2.shape:
            err_str = '%s shape %s does not match %s shape %s' %\
                (self.est1.name, str(self.est1.shape),
                 self.est2.name, str(self.est2.shape))
            raise common.VpException(err_str)
        if self.est1.shape != self.msg_hdl.shape:
            err_str = '%s shape %s does not match msg_hdl shape %s' %\
                (self.est1.name, str(self.est1.shape),
                 str(self.msg_hdl.shape))
            raise common.VpException(err_str)

        if self.est1.var_axes != self.est2.var_axes:
            err_str = '%s var_axes %s does not match %s var_axes %s' %\
                (self.est1.name, str(self.est1.var_axes),
                 self.est2.name, str(self.est2.var_axes))
            raise common.VpException(err_str)
        if self.est1.var_axes != self.msg_hdl.var_axes:
            err_str = '%s var_axes %s does not match msg_hdl var_axes %s' %\
                (self.est1.name, str(self.est1.var_axes),
                 str(self.msg_hdl.var_axes))
Exemple #2
0
    def init_svd(self):
        """
        Initialization for the SVD method 
        """
        # Compute the SVD terms
        # Take an SVD A=USV'.  Then write p = SV'z + w,
        if not self.A.svd_avail:
            raise common.VpException("Transform must support an SVD")
        self.bt = self.A.UsvdH(self.b)
        srep_axes = self.A.srep_axes

        # Compute the norm of ||b-UU*(b)||^2/wvar
        if np.all(self.wvar > 0):
            bp = self.A.Usvd(self.bt)
            wvar_rep = common.repeat_axes(self.wvar,
                                          self.shape[1],
                                          self.wrep_axes,
                                          rep=False)
            err = np.abs(self.b - bp)**2
            self.bpnorm = np.sum(err / wvar_rep)
        else:
            self.bpnorm = 0

        # Check that all axes on which A operates are repeated
        ndim = len(self.shape[1])
        for i in range(ndim):
            if not (i in self.var_axes[1]) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over output axis")
            if not (i in self.wrep_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Noise variance must be constant over output axis")
            if not (i in self.var_axes[0]) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over input axis")
Exemple #3
0
    def __init__(self,A,y,wvar=0,\
                 wrep_axes='all', var_axes=(0,),name=None,map_est=False,\
                 is_complex=False,rvar_init=1e5,tune_wvar=False):

        BaseEst.__init__(self, shape=A.shape0, var_axes=var_axes,\
            dtype=A.dtype0, name=name,\
            type_name='LinEstim', nvars=1, cost_avail=True)
        self.A = A
        self.y = y
        self.wvar = wvar
        self.map_est = map_est
        self.is_complex = is_complex
        self.cost_avail = True
        self.rvar_init = rvar_init
        self.tune_wvar = tune_wvar

        # Get the input and output shape
        self.zshape = A.shape0
        self.yshape = A.shape1

        # Set the repetition axes
        ndim = len(self.yshape)
        if wrep_axes == 'all':
            wrep_axes = tuple(range(ndim))
        self.wrep_axes = wrep_axes

        # Compute the SVD terms
        # Take an SVD A=USV'.  Then write p = SV'z + w,
        if not A.svd_avail:
            raise common.VpException("Transform must support an SVD")
        self.p = A.UsvdH(y)
        srep_axes = A.get_svd_diag()[2]

        # Compute the norm of ||y-UU*(y)||^2/wvar
        if np.all(self.wvar > 0):
            yp = A.Usvd(self.p)
            wvar1 = common.repeat_axes(wvar,
                                       self.yshape,
                                       self.wrep_axes,
                                       rep=False)
            err = np.abs(y - yp)**2
            self.ypnorm = np.sum(err / wvar1)
        else:
            self.ypnorm = 0

        # Check that all axes on which A operates are repeated
        for i in range(ndim):
            if not (i in self.wrep_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over output axis")
            if not (i in self.var_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over input axis")
Exemple #4
0
    def est(self, r, rvar, return_cost=False, ind_out=None, avg_var_cost=True):
        """
        Estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """
        # Check parameters
        if ind_out is None:
            ind_out = [0, 1]
        if not avg_var_cost:
            raise ValueError(
                "disabling variance averaging not supported for MixEst")

        if self.est_meth == 'svd':
            return self.est_svd(r, rvar, return_cost, ind_out)
        elif self.est_meth == 'cg':
            return self.est_cg(r, rvar, return_cost, ind_out)
        else:
            raise common.VpException("Unknown estimation method {0:s}".format(
                self.est_meth))
Exemple #5
0
    def __init__(self,nrow=256,ncol=256,wavelet='db4',level=3,fwd_mode='recon',\
        dtype=np.float64,name=None):

        # Save parameters
        self.wavelet = wavelet
        self.level = level
        shape0 = (nrow,ncol)
        shape1 = (nrow,ncol)
        dtype0 = dtype
        dtype1 = dtype

        if pywt.Wavelet(wavelet).orthogonal:
            svd_avail = True #SVD calculation assumes an orthogonal wavelet
        else:
            svd_avail = False
        BaseLinTrans.__init__(self, shape0, shape1, dtype0, dtype1,\
           svd_avail=svd_avail,name=name)


        # Set the mode to periodic to make the wavelet orthogonal
        self.mode = 'periodization'

        # Send a zero image to get the coefficient slices
        im = np.zeros((nrow,ncol))
        coeffs = pywt.wavedec2(im, wavelet=self.wavelet, level=self.level, \
            mode=self.mode)
        _, self.coeff_slices = pywt.coeffs_to_array(coeffs)


        # Confirm that fwd_mode is valid
        if (fwd_mode != 'recon') and (fwd_mode != 'analysis'):
            raise common.VpException('fwd_mode must be recon or analysis')
        self.fwd_mode = fwd_mode
Exemple #6
0
    def __init__(self,
                 nrow=256,
                 ncol=256,
                 wavelet='db4',
                 level=3,
                 fwd_mode='recon'):

        # Initialize the base class
        LinTrans.__init__(self)

        # Save parameters
        self.wavelet = wavelet
        self.level = level
        self.shape0 = (nrow, ncol)
        self.shape1 = (nrow, ncol)

        # Set the mode to periodic to make the wavelet orthogonal
        self.mode = 'periodization'

        # Send a zero image to get the coefficient slices
        im = np.zeros((nrow, ncol))
        coeffs = pywt.wavedec2(im, wavelet=self.wavelet, level=self.level, \
            mode=self.mode)
        _, self.coeff_slices = pywt.coeffs_to_array(coeffs)

        # Confirm that fwd_mode is valid
        if (fwd_mode != 'recon') and (fwd_mode != 'analysis'):
            raise common.VpException('fwd_mode must be recon or analysis')
        self.fwd_mode = fwd_mode
Exemple #7
0
    def __init__(self, est_list, w, name=None):

        self.est_list = est_list
        self.w = w

        shape = est_list[0].shape
        var_axes = est_list[0].var_axes
        dtype = est_list[0].dtype

        # Check that all estimators have cost available
        for est in est_list:
            if est.shape != shape:
                raise common.VpException('Estimators must have the same shape')
            if est.var_axes != var_axes:
                raise common.VpException(
                    'Estimators must have the same var_axes')
            if not est.cost_avail:
                raise common.VpException(\
                    "Estimators in a mixture distribution"+\
                    "must have cost_avail==True")
        BaseEst.__init__(self,shape=shape,var_axes=var_axes,dtype=dtype,\
            name=name, type_name='Mixture', nvars=1, cost_avail=True)
Exemple #8
0
    def __init__(self, est_list, w):
        Estim.__init__(self)
        self.est_list = est_list
        self.w = w
        self.shape = est_list[0].shape
        self.var_axes = est_list[0].var_axes

        # Check that all estimators have cost available
        for est in est_list:
            if not est.cost_avail:
                raise common.VpException(\
                    "Estimators in a mixture distribution"+\
                    "must have cost_avail==True")
        self.cost_avail = True
Exemple #9
0
    def __init__(self,A,b,wvar=0,\
                 z1rep_axes=(0,), z0rep_axes=(0,),wrep_axes='all',\
                 map_est=False,is_complex=False,est_meth='svd',\
                 nit_cg=100, atol_cg=1e-3, save_stats=False):

        Estim.__init__(self)
        self.A = A
        self.b = b
        self.wvar = wvar
        self.map_est = map_est
        self.is_complex = is_complex
        self.cost_avail = True

        # Initial variance.  This is large value since the quantities
        # are underdetermined
        self.zvar0_init = np.Inf
        self.zvar1_init = np.Inf

        # Get the input and output shape
        self.shape0 = A.shape0
        self.shape1 = A.shape1

        # Set the repetition axes
        ndim = len(self.shape1)
        if z0rep_axes == 'all':
            z0rep_axes = tuple(range(ndim))
        if z1rep_axes == 'all':
            z1rep_axes = tuple(range(ndim))
        if wrep_axes == 'all':
            wrep_axes = tuple(range(ndim))
        self.z0rep_axes = z0rep_axes
        self.z1rep_axes = z1rep_axes
        self.wrep_axes = wrep_axes

        # Initialization depending on the estimation method
        self.est_meth = est_meth
        if self.est_meth == 'svd':
            self.init_svd()
        elif self.est_meth == 'cg':
            self.init_cg()
        else:
            raise common.VpException(
                "Unknown estimation method {0:s}".format(est_meth))

        # CG parameters
        self.nit_cg = nit_cg
        self.atol_cg = atol_cg
        self.save_stats = save_stats
        self.init_hist_dict()
Exemple #10
0
    def __init__(self,A,b,wvar=0,var_axes=[(0,), (0,)], wrep_axes='all',\
                 name=None,map_est=False,is_complex=False,est_meth='svd',\
                 nit_cg=100, atol_cg=1e-3, save_stats=False):

        self.A = A
        self.b = b
        self.wvar = wvar
        self.map_est = map_est
        self.is_complex = is_complex
        self.cost_avail = True

        # Initial variance.  This is large value since the quantities
        # are underdetermined
        self.zvar0_init = np.Inf
        self.zvar1_init = np.Inf

        # Set parameters of the base estimator
        shape = [A.shape0, A.shape1]
        dtype = [A.dtype0, A.dtype1]
        nvars = 2
        for i in range(nvars):
            if var_axes[i] == 'all':
                ndim = len(shape[i])
                var_axes[i] = tuple(range(ndim))
        if wrep_axes == 'all':
            ndim = len(shape[1])
            wrep_axes = tuple(range(ndim))
        self.wrep_axes = wrep_axes

        BaseEst.__init__(self,shape=shape, var_axes=var_axes, dtype=dtype, name=name,\
            type_name='LinEstTwo', nvars=nvars, cost_avail=True)

        # Initialization depending on the estimation method
        self.est_meth = est_meth
        if self.est_meth == 'svd':
            self.init_svd()
        elif self.est_meth == 'cg':
            self.init_cg()
        else:
            raise common.VpException(
                "Unknown estimation method {0:s}".format(est_meth))

        # CG parameters
        self.nit_cg = nit_cg
        self.atol_cg = atol_cg
        self.save_stats = save_stats
        self.init_hist_dict()
Exemple #11
0
    def __init__(self, est_list, msg_hdl_list=[], hist_list=[], nit=10,\
        comp_cost=False,prt_period=0):
        Solver.__init__(self, hist_list)
        self.est_list = est_list
        self.msg_hdl_list = msg_hdl_list
        self.nit = nit
        self.comp_cost = comp_cost
        self.prt_period = prt_period

        # Check if all estimators can compute the cost
        nlayers = len(self.est_list)
        for i in range(nlayers):
            esti = self.est_list[i]
            if self.comp_cost and not esti.cost_avail:
                errstr = "Requested cost computation, but cost_avail==False"\
                    + " for estimator " + str(i)
                raise common.VpException(errstr)
            self.comp_cost = self.comp_cost and esti.cost_avail
Exemple #12
0
 def est(self, r, rvar, return_cost=False):
     """
     Estimation function
     
     The proximal estimation function as 
     described in the base class :class:`vampyre.estim.base.Estim`
             
     :param r: Proximal mean
     :param rvar: Proximal variance
     :param Boolean return_cost:  Flag indicating if :code:`cost` is 
         to be returned
     
     :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
         mean, variance and optional cost.
     """
     if self.est_meth == 'svd':
         return self.est_svd(r, rvar, return_cost)
     elif self.est_meth == 'cg':
         return self.est_cg(r, rvar, return_cost)
     else:
         raise common.VpException("Unknown estimation method {0:s}".format(
             self.est_meth))
Exemple #13
0
    def est_cg(self, r, rvar, return_cost=False, ind_out=[0, 1]):
        """
        CG-based estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """

        # Unpack the inputs
        r0, r1 = r
        rvar0, rvar1 = rvar

        # Infinite variance case
        if np.any(rvar1 == np.Inf):
            zhat0 = r0
            zhatvar0 = rvar0
            zhat1 = self.A.dot(r0) + self.b

            # Compute variance numerically.
            yvar = np.abs(self.A.dot(self.dr0))**2
            zhatvar1 = np.mean(yvar, self.var_axes[1])*rvar0/\
                np.mean(self.dr0_norm_sq)

            zhat = []
            zhatvar = []
            if 0 in ind_out:
                zhat.append(zhat0)
                zhatvar.append(zhatvar0)
            if 1 in ind_out:
                zhat.append(zhat1)
                zhatvar.append(zhatvar1)
            cost = 0
            if return_cost:
                return zhat, zhatvar, cost
            else:
                return zhat, zhatvar

        elif np.any(rvar0 == np.Inf):
            raise common.VpException("Infinite variance case for rvar0 "+\
               "is not yet implemented")

        # Get dimensions
        self.n0 = np.prod(self.shape[0])
        self.n1 = np.prod(self.shape[1])
        """
        First-order terms
        """
        # Create the LSQR transform for the problem
        # The VAMP problem is equivalent to minimizing ||F(z)-g||^2
        F = LSQROp(self.A,self.b,rvar, self.wvar,\
            self.var_axes[0], self.var_axes[1],self.wrep_axes,\
            self.shape[0], self.shape[1], self.is_complex)
        g = F.get_tgt_vec(r)

        # Get the initial condition
        if self.zlast is None:
            zinit = F.pack(r0, r1)
        else:
            zinit = self.zlast
        g -= F.dot(zinit)

        # Run the LSQR optimization
        lsqr_out = scipy.sparse.linalg.lsqr(F,
                                            g,
                                            iter_lim=self.nit_cg,
                                            atol=self.atol_cg)
        zvec = lsqr_out[0] + zinit
        self.zlast = zvec
        zhat0, zhat1 = F.unpack(zvec)
        zhat = []
        if 0 in ind_out:
            zhat.append(zhat0)
        if 1 in ind_out:
            zhat.append(zhat1)

        # Save stats
        if 'zhat_nit' in self.hist_list:
            self.hist_dict['zhat_nit'].append(lsqr_out[2])
        """
        Cost
        """
        if return_cost:
            # Compute the cost
            cost = lsqr_out[3]**2

            # Add the cost for the second order terms.
            #
            # We only consider the MAP case, where the second-order cost is
            # (1/2)*nz1*log(2*pi*wvar)
            if self.is_complex:
                cscale = 1
            else:
                cscale = 2
            cost /= cscale
            if F.wvar_pos:
                if np.all(self.wvar > 0):
                    cost += (1 / cscale) * self.n1 * np.mean(
                        np.log(cscale * np.pi * self.wvar))
        """
        Second-order terms
        
        These are computed via the numerical gradient along a random direction
        """
        zhatvar = []
        if 0 in ind_out:
            # Perturb r0
            r0p = r0 + self.dr0
            g0 = F.get_tgt_vec([r0p, r1])

            # Get the initial condition
            if self.zvec0_last is None:
                zinit = F.pack(r0p, r1)
            else:
                zinit = self.zvec0_last
            g0 -= F.dot(zinit)

            # Run the LSQR optimization
            lsqr_out = scipy.sparse.linalg.lsqr(F,
                                                g0,
                                                iter_lim=self.nit_cg,
                                                atol=self.atol_cg)
            zvec0 = lsqr_out[0] + zinit
            self.zvec0_last = zvec0
            dzvec = zvec0 - zvec
            dz0, dz1 = F.unpack(dzvec)
            if 'zvar0_nit' in self.hist_list:
                self.hist_dict['zvar0_nit'].append(lsqr_out[2])

            # Compute the correlations
            alpha0 = np.mean(np.real(self.dr0.conj()*dz0),self.var_axes[0]) /\
                self.dr0_norm_sq
            zhatvar0 = alpha0 * rvar0
            zhatvar.append(zhatvar0)

        if 1 in ind_out:

            # Perturb r1
            r1p = r1 + self.dr1
            g1 = F.get_tgt_vec([r0, r1p])

            # Get the initial condition
            if self.zvec1_last is None:
                zinit = F.pack(r0, r1p)
            else:
                zinit = self.zvec1_last
            g1 -= F.dot(zinit)

            # Run the LSQR optimization
            lsqr_out = scipy.sparse.linalg.lsqr(F,
                                                g1,
                                                iter_lim=self.nit_cg,
                                                atol=self.atol_cg)
            zvec1 = lsqr_out[0] + zinit
            self.zvec1_last = zvec1
            dzvec = zvec1 - zvec
            dz0, dz1 = F.unpack(dzvec)
            if 'zvar1_nit' in self.hist_list:
                self.hist_dict['zvar1_nit'].append(lsqr_out[2])

            # Compute the correlations
            alpha1 = np.mean(np.real(self.dr1.conj()*dz1),self.var_axes[1]) /\
                self.dr1_norm_sq
            zhatvar1 = alpha1 * rvar1
            zhatvar.append(zhatvar1)

        if return_cost:
            return zhat, zhatvar, cost
        else:
            return zhat, zhatvar
Exemple #14
0
    def __init__(self, est_list, msg_hdl_list=[], hist_list=[], nit=10,\
        comp_cost=False,prt_period=0):
        Solver.__init__(self, hist_list)
        self.est_list = est_list
        self.msg_hdl_list = msg_hdl_list
        self.nit = nit
        self.comp_cost = comp_cost
        self.prt_period = prt_period
        self.time_iter = 0  # Computation time for last iteration

        # Check dimensions
        nlayers = len(self.est_list)
        if self.est_list[0].nvars != 1:
            raise ValueError('First estimator must take 1 variable')
        if self.est_list[-1].nvars != 1:
            raise ValueError('Last estimator must take 1 variable')
        for i in range(0, nlayers - 1):
            msg_hdl_shape = self.msg_hdl_list[i].shape
            msg_hdl_var_axes = self.msg_hdl_list[i].var_axes
            if i > 0:
                if (self.est_list[i].nvars != 2):
                    errstr = 'Estimator %s must take 2 variables'\
                        % self.est_list[i].name
                    raise ValueError(errstr)

            if i == 0:
                shape0 = self.est_list[i].shape
                var_axes0 = self.est_list[i].var_axes
            else:
                shape0 = self.est_list[i].shape[1]
                var_axes0 = self.est_list[i].var_axes[1]
            if i == nlayers - 2:
                shape1 = self.est_list[i + 1].shape
                var_axes1 = self.est_list[i + 1].var_axes
            else:
                shape1 = self.est_list[i + 1].shape[0]
                var_axes1 = self.est_list[i + 1].var_axes[0]

            if shape0 != shape1:
                errstr = 'Est %s shape %s does not match est %s shape %s'\
                    % (est_list[i].name, str(shape0), est_list[i+1].name, str(shape1))
                raise ValueError(errstr)
            if shape0 != msg_hdl_shape:
                errstr = 'Est %s shape %s does not match msg_hdl shape %s'\
                    % (est_list[i].name, str(shape0), str(msg_hdl_shape))
                raise ValueError(errstr)

            if var_axes0 != var_axes1:
                errstr = 'Est %s var_axes %s does not match est %s var_axes %s'\
                    % (est_list[i].name, str(var_axes0), est_list[i+1].name, str(var_axes1))
            if var_axes0 != msg_hdl_var_axes:
                errstr = 'Est %s var_axes %s does not match msg_hdl var_axes %s'\
                    % (est_list[i].name, str(var_axes0), str(msg_hdl_var_axes))
                raise ValueError(errstr)

        # Check if all estimators can compute the cost
        for i in range(nlayers):
            esti = self.est_list[i]
            if self.comp_cost and not esti.cost_avail:
                errstr = "Requested cost computation, but cost_avail==False"\
                    + " for estimator " + str(i)
                raise common.VpException(errstr)
            self.comp_cost = self.comp_cost and esti.cost_avail