Ejemplo n.º 1
0
    def est_map(self, r, rvar, return_cost, ind_out):
        """
        MAP Estimation
        In this case,  we wish to minimize
            cost = (z0-r0)^2/(2*rvar0) + (z1-r1)^2/(2*rvar1)
        
        where z1 = max(0,z0) 
        """
        # Unpack the terms
        r0, r1 = r
        rvar0, rvar1 = rvar

        # Clip variances
        rvar1 = np.minimum(1e8 * rvar0, rvar1)

        # Reshape the variances
        rvar0 = common.repeat_axes(rvar0, self.shape[0], self.var_axes[0])
        rvar1 = common.repeat_axes(rvar1, self.shape[1], self.var_axes[1])

        # Positive case:  z0 >= 0 and hence z1=z0
        z0p = np.maximum(0, (rvar0 * r1 + rvar1 * r0) / (rvar0 + rvar1))
        z1p = z0p
        zvar0p = rvar0 * rvar1 / (rvar0 + rvar1)
        zvar1p = zvar0p
        costp = 0.5 * ((z0p - r0)**2 / rvar0 + (z1p - r1)**2 / rvar1)

        # Negative case:  z0 <= 0 and hence z1 = 0
        z0n = np.minimum(0, r0)
        z1n = 0
        zvar0n = rvar0
        zvar1n = 0
        costn = 0.5 * ((z0n - r0)**2 / rvar0 + (z1n - r1)**2 / rvar1)

        # Find lower cost and select the correct choice for each element
        Ip = (costp < costn)
        zhat0 = z0p * Ip + z0n * (1 - Ip)
        zhat1 = z1p * Ip + z1n * (1 - Ip)
        zhatvar0 = zvar0p * Ip + zvar0n * (1 - Ip)
        zhatvar1 = zvar1p * Ip + zvar1n * (1 - Ip)
        cost = np.sum(costp * Ip + costn * (1 - Ip))

        # Average the variance over the specified axes
        zhatvar0 = np.mean(zhatvar0, axis=self.var_axes[0])
        zhatvar1 = np.mean(zhatvar1, axis=self.var_axes[1])
        zhatvar = [zhatvar0, zhatvar1]

        # Pack the items
        zhat = []
        zhatvar = []
        if 0 in ind_out:
            zhat.append(zhat0)
            zhatvar.append(zhatvar0)
        if 1 in ind_out:
            zhat.append(zhat1)
            zhatvar.append(zhatvar1)

        if not return_cost:
            return zhat, zhatvar
        else:
            return zhat, zhatvar, cost
Ejemplo n.º 2
0
    def est_init(self, return_cost=False, ind_out=None,\
        avg_var_cost=True):
        """
        Initial estimator.

        See the base class :class:`vampyre.estim.base.Estim` for
        a complete description.

        :param boolean return_cost:  Flag indicating if :code:`cost` is
            to be returned
        :returns: :code:`zmean, zvar, [cost]` which are the
            prior mean and variance
        """
        # Check parameters
        if (ind_out != [0]) and (ind_out != None):
            raise ValueError("ind_out must be either [0] or None")
        if not avg_var_cost:
            raise ValueError(
                "disabling variance averaging not supported for LinEst")

        # Get the diagonal parameters
        s, sshape, srep_axes = self.A.get_svd_diag()
        shape0 = self.A.shape0

        # Reshape the variances to the transformed space
        s1 = common.repeat_axes(s, sshape, srep_axes)
        wvar1 = common.repeat_axes(self.wvar,
                                   sshape,
                                   self.wrep_axes,
                                   rep=False)

        # Compute the estimate within the transformed space
        q = (1 / s1) * self.p
        qvar = wvar1 / (np.abs(s1)**2)
        qvar_mean = np.mean(qvar, axis=self.var_axes)

        rdim = np.product(sshape) / np.product(shape0)
        zmean = self.A.Vsvd(q)
        zvar = rdim * qvar_mean + (1 - rdim) * self.rvar_init

        # Exit if cost does not need to be computed
        if not return_cost:
            return zmean, zvar

        # Computes the MAP output cost
        if np.all(self.wvar > 0):
            cost = self.ypnorm
        else:
            cost = 0

        # Compute the output variance cost
        if np.all(self.wvar > 0) and self.map_est:
            clog = np.log(2 * np.pi * self.wvar)
            cost += common.repeat_sum(clog, self.zshape, self.wrep_axes)

        # Scale for real case
        if not self.is_complex:
            cost = 0.5 * cost
        return zmean, zvar, cost
Ejemplo n.º 3
0
    def init_svd(self):
        """
        Initialization for the SVD method 
        """
        # Compute the SVD terms
        # Take an SVD A=USV'.  Then write p = SV'z + w,
        if not self.A.svd_avail:
            raise common.VpException("Transform must support an SVD")
        self.bt = self.A.UsvdH(self.b)
        srep_axes = self.A.srep_axes

        # Compute the norm of ||b-UU*(b)||^2/wvar
        if np.all(self.wvar > 0):
            bp = self.A.Usvd(self.bt)
            wvar_rep = common.repeat_axes(self.wvar,
                                          self.shape[1],
                                          self.wrep_axes,
                                          rep=False)
            err = np.abs(self.b - bp)**2
            self.bpnorm = np.sum(err / wvar_rep)
        else:
            self.bpnorm = 0

        # Check that all axes on which A operates are repeated
        ndim = len(self.shape[1])
        for i in range(ndim):
            if not (i in self.var_axes[1]) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over output axis")
            if not (i in self.wrep_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Noise variance must be constant over output axis")
            if not (i in self.var_axes[0]) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over input axis")
Ejemplo n.º 4
0
    def est(self, r, rvar, return_cost=False, avg_var_cost=True):
        """
        Estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param boolean return_cost:  Flag indicating if :code:`cost` 
            is to be returned
        :param Boolean avg_var_cost: Average variance and cost.
            This should be disabled to obtain per element values.
            (Default=True)            
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior 
        mean, variance and optional cost.
        """
        # Infinite variance case
        if np.any(rvar == np.Inf):
            return self.est_init(return_cost, avg_var_cost)

        # Convert to 1D vectors
        r1 = r.ravel()
        rvar1 = common.repeat_axes(rvar, self.shape, self.var_axes)
        rvar1 = rvar1.ravel()

        # Compute the augmented penalty for each value
        faug = (np.abs(self.zval[None, :] - r1[:, None])**2) / rvar1[:, None]
        if not self.is_complex:
            faug *= 0.5
        faug = faug + self.fz[None, :]

        # Compute the conditional probability of each value
        fmin = np.min(faug, axis=1)
        pzr = np.exp(-faug + fmin[:, None])
        psum = np.sum(pzr, axis=1)
        pzr = pzr / psum[:, None]
        cost = -np.log(psum) + fmin

        zhat = pzr.dot(self.zval)
        zerr = np.abs(self.zval[None, :] - zhat[:, None])**2
        zhatvar = np.sum(pzr * zerr, axis=1)

        # Reshape values
        cost = np.reshape(cost, self.shape)
        zhat = np.reshape(zhat, self.shape)
        zhatvar = np.reshape(zhatvar, self.shape)

        self.pzr = pzr

        # Average values
        if avg_var_cost:
            cost = np.sum(cost)
            zhatvar = np.mean(zhatvar, axis=self.var_axes)

        if return_cost:
            return zhat, zhatvar, cost
        else:
            return zhat, zhatvar
Ejemplo n.º 5
0
 def compute_cost_terms(self,idir):
     """
     Computes the Gaussian cost in belief propagation.
     
     See base class :class:`MsgHdl` for more details.
     
     :param z:  Estimate :math:`z`
     :param r:  Estimate :math:`r`        
     :param zvar:  Variance :math:`\\tau_z`
     :param rvar:  Variance :math:`\\tau_r`
     :returns: :code:`cost` the cost :math:`c` defined above.
     """
     
     # Skip update if 
     if self.rvar_prev[idir] is None:
         return
         
     rvar_rep = common.repeat_axes(self.rvar_prev[idir],\
         self.shape,self.var_axes,rep=False) 
         
     # Computes the gradient
     z = self.zprev[idir]
     r = self.rprev[idir]        
     self.grad[idir] = (z-r)/rvar_rep
     self.cost[idir] = np.sum((1/rvar_rep)*(np.abs(z-r)**2))
     
     if not self.map_est:
         self.cost[idir] += np.prod(self.shape)*\
             np.mean(self.zvar_prev[idir]/self.rvar_prev[idir])            
         
     if not self.is_complex:
         self.cost[idir] *= 0.5
Ejemplo n.º 6
0
    def est(self, r, rvar, return_cost=False, ind_out=None, avg_var_cost=True):
        """
        Estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """
        # Check parameters
        if (ind_out != [0]) and (ind_out != None):
            raise ValueError("ind_out must be either [0] or None")
        if not avg_var_cost:
            raise ValueError(
                "disabling variance averaging not supported for HardThreshEst")

        # Repeat the variance and reshape r and rvar to 1D vectors
        rvar1 = common.repeat_axes(rvar, self.shape, self.var_axes)
        r1 = r.ravel()
        rvar1 = rvar1.ravel()
        y1 = self.y.ravel()

        # Compute the values for Ai
        #   A0i = \int_{-\intfy}^thresh z^i exp(-(z-r)^2/(2*rvar))
        #   A1i = \int_thresh^\infty z^i exp(-(z-r)^2/(2*rvar))
        rsig = np.sqrt(rvar1)
        A00, A01, A02 = gauss_integral(-np.Inf, self.thresh, r1, rvar1)
        A10 = rsig - A00
        A11 = rsig * r1 - A01
        A12 = rsig * (rvar1 + r1**2) - A02

        # Compute probability y==1 before flipping
        py1 = y1 * (1 - self.perr) + (1 - y1) * self.perr

        # Set Ai = A0i for y==0 and Ai=A1i for y==1
        A0 = A10 * py1 + A00 * (1 - py1)
        A1 = A11 * py1 + A01 * (1 - py1)
        A2 = A12 * py1 + A02 * (1 - py1)

        # Compute posterior mean and variance
        zhat = A1 / A0
        zhatvar = A2 / A0 - (zhat**2)
        cost = -np.sum(np.log(A0))

        # Reshape and average values
        zhat = np.reshape(zhat, self.shape)
        zhatvar = np.reshape(zhatvar, self.shape)
        zhatvar = np.mean(zhatvar, axis=self.var_axes)

        if return_cost:
            return zhat, zhatvar, cost
        else:
            return zhat, zhatvar
Ejemplo n.º 7
0
    def __init__(self,A,y,wvar=0,\
                 wrep_axes='all', var_axes=(0,),name=None,map_est=False,\
                 is_complex=False,rvar_init=1e5,tune_wvar=False):

        BaseEst.__init__(self, shape=A.shape0, var_axes=var_axes,\
            dtype=A.dtype0, name=name,\
            type_name='LinEstim', nvars=1, cost_avail=True)
        self.A = A
        self.y = y
        self.wvar = wvar
        self.map_est = map_est
        self.is_complex = is_complex
        self.cost_avail = True
        self.rvar_init = rvar_init
        self.tune_wvar = tune_wvar

        # Get the input and output shape
        self.zshape = A.shape0
        self.yshape = A.shape1

        # Set the repetition axes
        ndim = len(self.yshape)
        if wrep_axes == 'all':
            wrep_axes = tuple(range(ndim))
        self.wrep_axes = wrep_axes

        # Compute the SVD terms
        # Take an SVD A=USV'.  Then write p = SV'z + w,
        if not A.svd_avail:
            raise common.VpException("Transform must support an SVD")
        self.p = A.UsvdH(y)
        srep_axes = A.get_svd_diag()[2]

        # Compute the norm of ||y-UU*(y)||^2/wvar
        if np.all(self.wvar > 0):
            yp = A.Usvd(self.p)
            wvar1 = common.repeat_axes(wvar,
                                       self.yshape,
                                       self.wrep_axes,
                                       rep=False)
            err = np.abs(y - yp)**2
            self.ypnorm = np.sum(err / wvar1)
        else:
            self.ypnorm = 0

        # Check that all axes on which A operates are repeated
        for i in range(ndim):
            if not (i in self.wrep_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over output axis")
            if not (i in self.var_axes) and not (i in srep_axes):
                raise common.VpException(
                    "Variance must be constant over input axis")
Ejemplo n.º 8
0
    def min_primal_avg(self):
        """
        Minimizes the primal average variables, zhat
        """
        nlayers = len(self.est_list)
        self.fgrad = np.zeros(nlayers - 1)

        for i in range(nlayers - 1):

            msg_hdl = self.msg_hdl_list[i]
            rvarrev = common.repeat_axes(self.rvarrev[i],msg_hdl.shape,\
                msg_hdl.rep_axes,rep=False)
            rvarfwd = common.repeat_axes(self.rvarfwd[i],msg_hdl.shape,\
                msg_hdl.rep_axes,rep=False)
            self.zhat[i] = (rvarrev*(self.zhatrev[i]-self.sfwd[i]) +\
                rvarfwd*(self.zhatfwd[i]-self.srev[i]))/(rvarfwd + rvarrev)

            # Compute the gradients
            grad0 = (self.zhatrev[i] - self.rfwd[i]) / rvarfwd
            grad1 = (self.zhatfwd[i] - self.rrev[i]) / rvarrev
            self.fgrad[i] = np.mean((grad0 + grad1)**2)
Ejemplo n.º 9
0
    def svd_dotH(self, s1, q1):
        """
        Performs diagonal matrix multiplication conjugate

        Implements :math:`q_0 = \\mathrm{diag}(s_1)^* q_1`.

        :param s1: diagonal parameters
        :param q1: input to the diagonal multiplication
        :returns: :code:`q0` diagonal multiplication output
        """
        srep = common.repeat_axes(np.conj(s1), self.shape0, (0, 1), rep=False)
        q0 = srep * q1
        return q0
Ejemplo n.º 10
0
    def svd_dot(self, s1, q0):
        """
        Performs diagonal matrix multiplication.

        Implements :math:`q_1 = \\mathrm{diag}(s_1) q_0`.

        :param s1: diagonal parameters
        :param q0: input to the diagonal multiplication
        :returns: :code:`q1` diagonal multiplication output
        """
        srep = common.repeat_axes(s1, self.shape0, (0, 1), rep=False)
        q1 = srep * q0
        return q1
Ejemplo n.º 11
0
    def __init__(self,A,b,rvar,wvar,z0rep_axes,z1rep_axes,wrep_axes,\
        shape0,shape1,is_complex):
        self.A = A
        self.b = b
        self.shape0 = shape0
        self.shape1 = shape1
        self.n0 = np.prod(shape0)
        self.n1 = np.prod(shape1)

        # Compute scale factors
        rvar0, rvar1 = rvar
        self.rsqrt0 = common.repeat_axes(np.sqrt(rvar0),
                                         self.shape0,
                                         z0rep_axes,
                                         rep=False)
        self.rsqrt1 = common.repeat_axes(np.sqrt(rvar1),
                                         self.shape1,
                                         z1rep_axes,
                                         rep=False)
        self.wvar_pos = np.all(wvar > 0)
        if self.wvar_pos:
            self.wsqrt = common.repeat_axes(np.sqrt(wvar),
                                            self.shape1,
                                            wrep_axes,
                                            rep=False)

        # Compute dimensions of the transform F
        if self.wvar_pos:
            nin = self.n0 + self.n1
            nout = self.n0 + 2 * self.n1
        else:
            nin = self.n0
            nout = self.n0 + self.n1
        self.shape = (nout, nin)
        if is_complex:
            self.dtype = np.dtype(complex)
        else:
            self.dtype = np.dtype(float)
Ejemplo n.º 12
0
 def cost_adjust(self, r, z, rvar, zvar, shape, var_axes):
     """
     Computes the cost adjustment term for the
     Bethe Free Energy:
         
         J = beta*[log(2*pi*rvar) + ((z-r)**2 + xvar)/rvar]
         
     where beta = 1 for complex problems and 0 for real problems
     """
     J0 = np.mean(np.log(2 * np.pi * rvar)) * np.product(shape)
     rvar_rep = common.repeat_axes(rvar,shape,\
                                   var_axes,rep=False)
     J1 = np.sum(np.abs(r - z)**2 / rvar_rep)
     J2 = np.mean(zvar / rvar) * np.product(shape)
     J = J0 + J1 + J2
     if not self.is_complex:
         J = J / 2
     return J
Ejemplo n.º 13
0
    def est_mmse(self, r, rvar, return_cost, ind_out):
        """        
        In the MMSE estimation case, we wish to estimate
        z0 and z1 with priors zi = N(ri,rvari) and z1=f(z0)
        
        Substituting in z1 = f(z0), we have the density of z0:
          
           p(z0)  \propto qn(z0)1_{z0 < 0}  + qp(z0)1_{z0 > 0}
           
        where
           
           qp(z0)  = exp[-(z0-r0)^2/(2*rvar0) - (z0-r1)^2/(2*rvar1)]
           qn(z0)  = exp[-(z0-r0)^2/(2*rvar0) - r1^2/(2*rvar1)]
           
        First, we complete the squares and write:
        
           qp(z0) = exp(Amax)*Cp*exp(-(z0-rp)^2/(2*zvarp))/sqrt(2*pi)  
           qn(z0) = exp(Amax)*Cn*exp(-(z0-rn)^2/(2*zvarn))/sqrt(2*pi)        
           
        """

        # Unpack the terms
        r0, r1 = r
        rvar0, rvar1 = rvar

        # Reshape the variances
        rvar0 = common.repeat_axes(rvar0, self.shape[0], self.var_axes[0])
        rvar1 = common.repeat_axes(rvar1, self.shape[1], self.var_axes[1])

        if np.any(rvar1 == np.Inf):
            # Infinite variance case.
            zvarp = rvar0
            zvarn = rvar0
            rp = r0
            rn = r0
            Cp = 1
            Cn = 1
            Amax = 0

        else:

            # Compute the MAP estimate
            zhat_map, zvar_map = self.est_map(r,
                                              rvar,
                                              return_cost=False,
                                              ind_out=[0, 1])
            zhat0_map, zhat1_map = zhat_map
            zvar0_map, zvar1_map = zvar_map

            # Compute the conditional Gaussian terms for z > 0 and z < 0
            zvarp = rvar0 * rvar1 / (rvar0 + rvar1)
            zvarn = rvar0
            rp = (rvar1 * r0 + rvar0 * r1) / (rvar0 + rvar1)
            rn = r0

            # Compute scaling constants for each region
            Ap = 0.5 * ((rp**2) / zvarp - (r0**2) / rvar0 - (r1**2) / rvar1)
            An = 0.5 * (-(r1**2) / rvar1)
            Amax = np.maximum(Ap, An)
            Ap = Ap - Amax
            An = An - Amax
            Cp = np.exp(Ap)
            Cn = np.exp(An)

        # Compute moments for each region
        zp = Cp * gauss_integral(0, np.Inf, rp, zvarp)
        zn = Cn * gauss_integral(-np.Inf, 0, rn, zvarn)

        # Find poorly conditioned points
        Ibad = (zp[0] + zn[0] < 1e-6)
        zpsum = zp[0] + zn[0] + Ibad

        # Compute mean
        zhat0 = (zp[1] + zn[1]) / zpsum
        zhat1 = zp[1] / zpsum

        # Compute the variance
        zhatvar0 = (zp[2] + zn[2]) / zpsum - zhat0**2
        zhatvar1 = zp[2] / zpsum - zhat1**2

        # Replace bad points with MAP estimate
        if 1:
            zhat0 = zhat0 * (1 - Ibad) + zhat0_map * Ibad
            zhat1 = zhat1 * (1 - Ibad) + zhat1_map * Ibad
            zhatvar0 = zhatvar0 * (1 - Ibad) + zvar0_map * Ibad
            zhatvar1 = zhatvar1 * (1 - Ibad) + zvar1_map * Ibad

        # Average the variance over the specified axes
        zhatvar0 = np.mean(zhatvar0, axis=self.var_axes[0])
        zhatvar1 = np.mean(zhatvar1, axis=self.var_axes[1])

        # Pack the items
        zhat = []
        zhatvar = []
        if 0 in ind_out:
            zhat.append(zhat0)
            zhatvar.append(zhatvar0)
        if 1 in ind_out:
            zhat.append(zhat1)
            zhatvar.append(zhatvar1)

        if not return_cost:
            return zhat, zhatvar
        """
        Compute the 
            cost = -\log \int p(z_0) 
                 = -Amax - log(zp[0] + zn[0])        
        """
        nz = np.prod(self.shape[0])
        cost = -nz * np.mean(Amax - np.log(zpsum))
        return zhat, zhatvar, cost
Ejemplo n.º 14
0
    def msg_sub(self,z,zvar,idir=0):
        """
        Variance subtraction for message passing
        
        See base class :class:`MsgHdl` for more details.
        
        :param z: Mean from the factor node
        :param zvar: Variance from the factor node
        :param idir: Index of the factor node for incoming message
        """

        # If variance is fixed, then overwrite the variance to be consistent 
        # with the prior variances.            
        if self.damp_var == 0:            
            zvar = self.rvar_prev[0]*self.rvar_prev[1]/(self.rvar_prev[0] + self.rvar_prev[1])
        
        # Save the z value
        self.zprev[idir] = z
        self.zvar_prev[idir] = zvar
        
        # Get variance passed to the incoming node
        r0 = self.rprev[idir]
        rvar0 = self.rvar_prev[idir]
        
        
        if rvar0 is None:
            # Special case where there is no previous variance
            rvar1 = zvar
            r1 = z
        else:
            
            # Compute cost and gradient
            self.compute_cost_terms(idir)
                        
            # Threshold the decrease in variance
            alpha = zvar/rvar0
            alpha = np.maximum(self.alpha_min, alpha)
            alpha = np.minimum(self.alpha_max, alpha)
                
            # Compute output variance
            rvar1 = alpha/(1-alpha)*rvar0
                                                                                        
            # Compute the message
            alpha_rep = common.repeat_axes(alpha,self.shape,self.var_axes,rep=False)
            r1 = (z-alpha_rep*r0)/(1-alpha_rep)
        
        # Bound the variance
        rvar1 = np.maximum(rvar1, self.rvar_min)
        rvar1 = np.minimum(rvar1, self.rvar_max)
            
        # Get the last output mean and variance
        jdir = (idir+1)%2
        rvar1_prev = self.rvar_prev[jdir]
        r1_prev = self.rprev[jdir]
        if not (r1_prev is None):
            r1 = self.damp*r1 + (1-self.damp)*r1_prev
        if not (rvar1_prev is None):
            gam1_prev = 1/rvar1_prev
            gam1 = self.damp_var/rvar1 + (1-self.damp_var)*gam1_prev
            rvar1 = 1/gam1
            
        # Bound the variance
        rvar1 = np.maximum(rvar1, self.rvar_min)
        rvar1 = np.minimum(rvar1, self.rvar_max)
                        
        # Save outgoing message
        self.rprev[jdir] = r1
        self.rvar_prev[jdir] = rvar1
                                
        return r1, rvar1
Ejemplo n.º 15
0
    def est(self,r,rvar,return_cost=False, ind_out=None,\
        avg_var_cost=True):
        """
        Estimation function

        The proximal estimation function as
        described in the base class :class:`vampyre.estim.base.Estim`

        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is
            to be returned

        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """

        # Check parameters
        if (ind_out != [0]) and (ind_out != None):
            raise ValueError("ind_out must be either [0] or None")
        if not avg_var_cost:
            raise ValueError(
                "disabling variance averaging not supported for LinEst")

        # Get the diagonal parameters
        s, sshape, srep_axes = self.A.get_svd_diag()

        # Get dimensions
        nz = np.prod(self.zshape)
        ny = np.prod(self.yshape)
        ns = np.prod(sshape)

        # Reshape the variances to the transformed space
        s1 = common.repeat_axes(s, sshape, srep_axes, rep=False)
        rvar1 = common.repeat_axes(rvar, sshape, self.var_axes, rep=False)
        wvar1 = common.repeat_axes(self.wvar,
                                   sshape,
                                   self.wrep_axes,
                                   rep=False)

        # Compute the estimate within the transformed space
        qbar = self.A.VsvdH(r)
        d = 1 / (rvar1 * (np.abs(s1)**2) + wvar1)
        q = d * (rvar1 * s1.conj() * self.p + wvar1 * qbar)
        qvar = rvar1 * wvar1 * d
        qvar_mean = np.mean(qvar, axis=self.var_axes)

        zhat = self.A.Vsvd(q - qbar) + r
        zhatvar = ns / nz * qvar_mean + (1 - ns / nz) * rvar

        # Update the variance estimate if tuning is enabled
        if self.tune_wvar:
            yerr = np.abs(self.y - self.A.Usvd(s1 * q))**2
            self.wvar = np.mean(yerr, self.wrep_axes) + np.mean(
                qvar * (np.abs(s1)**2), self.wrep_axes)

        # Exit if cost does not need to be computed
        if not return_cost:
            return zhat, zhatvar

        # Computes the MAP output cost
        if np.all(self.wvar > 0):
            err = np.abs(self.p - s1 * q)**2
            cost = self.ypnorm + np.sum(err / wvar1)
        else:
            cost = 0

        # Add the MAP input cost
        err = np.abs(q - qbar)**2
        cost = cost + np.sum(err / rvar1)

        # Compute the variance cost.
        if self.map_est:
            # For the MAP case, this is log(2*pi*wvar)
            if np.all(self.wvar > 0):
                cost += ny * np.mean(np.log(2 * np.pi * self.wvar))
        else:
            # For the MMSE case, this is 1 + log(2*pi*wvar) - H(b)
            # where b is the Gaussian with variance wvar*rvar*d
            cost +=  -ns*np.mean(np.log(rvar1*d)) -\
                (nz-ns)*np.mean(np.log(2*np.pi*rvar1))
            if np.all(self.wvar > 0):
                cost += (ny - ns) * np.mean(np.log(2 * np.pi * self.wvar))

        # Scale for real case
        if not self.is_complex:
            cost = 0.5 * cost

        return zhat, zhatvar, cost
Ejemplo n.º 16
0
    def est_svd(self, r, rvar, return_cost=False, ind_out=[0, 1]):
        """
        SVD-based estimation function
        
        The proximal estimation function as 
        described in the base class :class:`vampyre.estim.base.Estim`
                
        :param r: Proximal mean
        :param rvar: Proximal variance
        :param Boolean return_cost:  Flag indicating if :code:`cost` is 
            to be returned
        
        :returns: :code:`zhat, zhatvar, [cost]` which are the posterior
            mean, variance and optional cost.
        """

        # Unpack the variables
        r0, r1 = r
        rvar0, rvar1 = rvar

        # Get the diagonal parameters
        s, sshape, srep_axes = self.A.get_svd_diag()

        # Get dimensions
        nz1 = np.prod(self.shape[1])
        nz0 = np.prod(self.shape[0])
        ns = np.prod(sshape)

        # Reshape the variances to match the dimensions of the first-order
        # terms
        s_rep = common.repeat_axes(s, sshape, srep_axes, rep=False)
        rvar0_rep = common.repeat_axes(rvar0,
                                       self.shape[0],
                                       self.var_axes[0],
                                       rep=False)
        rvar1_rep = common.repeat_axes(rvar1,
                                       self.shape[1],
                                       self.var_axes[1],
                                       rep=False)
        wvar_rep = common.repeat_axes(self.wvar,
                                      sshape,
                                      self.wrep_axes,
                                      rep=False)
        """
        To compute the estimates, we write:
        
            z0 = V*q0 + z0perp,  z1 = U*q1 + z1perp.
            
        We separately compute the estimates of q=(q0,q1), z0perp and z1perp.
        
        First, we compute the estimates for q=(q0,q1).  The joint density
        of q is given by p(q) \propto exp(-J(q)) where
        
            J(q) = ||q1-s*q0 - bt||^2/wvar + ||q1-q1bar||^2/rvar1 
                 + ||q0-q0bar||^2/rvar0
        where q0bar = V'(r0), q1bar = U'(z1), bt = U'(b).
        
        Now define c := [-s, 1]  P^{-1} = diag(rvar0,rvar1)
        Hence 
            J(q) = ||c'q-bt||^2/wvar + (q-qbar)'P^{-1}(q-qbar)
                 = q'Q^{-1}q -2q'*g
        where
            Q = (P^{-1} - cc'/wvar)^{-1}  = P - Pcc'P/(wvar + c'Pc)
            g = bt*c/wvar + P^{-1}qbar
        
        Now, we can verify the following:
           Q = cov(q)
           det(Q) = d = 1/(wvar + rvar1 + |s|^2*rvar0)
           Q*c/wvar = d*P^{-1}c = d*[rvar0*s, rvar1]
           Q*P^{-1}*qbar = qbar - Pc d*c'*qbar
               = qbar - [-rvar0*s,rvar1]*d*(qbar1-s*qbar0)
        Hence 
           qhat = E(q) = qbar - [-rvar0*s,rvar1]*d*(qbar1-s*qbar0-bt)        
        """

        # Infinite variance case
        if np.any(rvar1 == np.Inf):
            zhat0 = r0
            zhatvar0 = rvar0
            zhat1 = self.A.dot(r0) + self.b
            qvar1 = np.mean(
                np.abs(s_rep)**2 * rvar0_rep + wvar_rep, self.var_axes[1])
            zhatvar1 = ns / nz1 * qvar1 + (1 - ns / nz1) * self.wvar
            cost = 0  # FIX THIS
            zhat = []
            zhatvar = []
            if 0 in ind_out:
                zhat.append(zhat0)
                zhatvar.append(zhatvar0)
            if 1 in ind_out:
                zhat.append(zhat1)
                zhatvar.append(zhatvar1)
            if return_cost:
                return zhat, zhatvar, cost
            else:
                return zhat, zhatvar

        # Compute the offset terms
        qbar0 = self.A.VsvdH(r0)
        qbar1 = self.A.UsvdH(r1)

        # Compute E(q)
        d = 1 / ((np.abs(s_rep)**2) * rvar0_rep + rvar1_rep + wvar_rep)
        e = d * (qbar1 - s_rep * qbar0 - self.bt)
        qhat0 = qbar0 + rvar0_rep * s_rep * e
        qhat1 = qbar1 - rvar1_rep * e
        """
        Compute E(z)
        
        zhat0 = (I-VV')*r0 + V*qhat0 = r0 + V*(qhat0-qbar0)
        zhat1 = (I-UU')*(wvar*r1+rvvar1*b)/(wvar+rvar0) + U*qhat1
              = 
        """

        # Compute E(z)
        z1p = (wvar_rep * r1 + rvar1_rep * self.b) / (wvar_rep + rvar1_rep)
        qbar1p = (wvar_rep * qbar1 + rvar1_rep * self.bt) / (wvar_rep +
                                                             rvar1_rep)
        zhat0 = r0 + self.A.Vsvd(qhat0 - qbar0)
        zhat1 = z1p + self.A.Usvd(qhat1 - qbar1p)
        zhat = []
        if 0 in ind_out:
            zhat.append(zhat0)
        if 1 in ind_out:
            zhat.append(zhat1)
        """
        Compute the variance.
        From the above calcualtions, we have
            cov(q) = Q = P - Pcc'P/(wvar + c'Pc)
        
        var(q0) = rvar0 - d*(rvar0^2)|s|^2 = rvar0*(1-d*rvar0*|s|^2)
        var(q1) = rvar1 - d*(rvar1^2) = rvar1*(1-d*rvar1)
        """
        qvar0 = rvar0_rep * (1 - d * rvar0_rep * np.abs(s_rep)**2)
        qvar1 = rvar1_rep * (1 - d * rvar1_rep)
        qvar0 = np.mean(qvar0, axis=self.var_axes[0])
        qvar1 = np.mean(qvar1, axis=self.var_axes[1])
        """
        Compute the variance of z
        """
        zhatvar0 = ns / nz0 * qvar0 + (1 - ns / nz0) * rvar0
        zhatvar1 = ns / nz1 * qvar1 + (1 - ns / nz1) * rvar1 * self.wvar / (
            self.wvar + rvar1)
        zhatvar = []
        if 0 in ind_out:
            zhatvar.append(zhatvar0)
        if 1 in ind_out:
            zhatvar.append(zhatvar1)

        if not return_cost:
            return zhat, zhatvar
        """
        Compute costs from the first order terms:
        
        cost1_perp = min_{z1} ||(I-UU')*(b-z1)||^2/wvar + ||(I-UU')*(r-z1)||^2/rvar1
              = ||(I-UU')*(b-r1)||^2/(wvar+rvar1)
              = ||b-r1 - U*(bt-qbar1)||^2/(wvar+rvar1)
        costq = ||q1-s*q0-bt||^2/wvar              
        cost0_perp = 0        
        """
        e = (self.b - r1 - self.A.Usvd(self.bt - qbar1))
        cost1_perp = np.sum((np.abs(e)**2) / (wvar_rep + rvar1_rep))
        if np.all(self.wvar > 0):
            e = qhat1 - s_rep * qhat0 - self.bt
            costq = np.sum((np.abs(e)**2) / wvar_rep)
        else:
            costq = 0
        cost0 = np.sum((np.abs(qhat0 - qbar0)**2) / rvar0_rep)
        cost1 = np.sum((np.abs(qhat1 - qbar1)**2) / rvar1_rep)
        cost = cost1_perp + costq + cost0 + cost1
        """
        Compute the costs for the second-order terms.
        
        For the MAP case, cost = -nz1*log(2*pi*wvar)
        
        For the MMSE case, we compute the Gaussian entropies:
            H1p = H((I-UU')*z1) - (nz-ns)log(2*pi*wvar)
            Hq  = H(q) - ns*log(2*pi*wvar)
            H0q = H((I-VV')*z0) 
            cost = cost - H1p - Hq - H0q            
        """
        if self.is_complex:
            cscale = 1
        else:
            cscale = 2
        if self.map_est:
            if np.all(self.wvar > 0):
                cost += nz1 * np.mean(np.log(cscale * np.pi * self.wvar))
        else:
            a = cscale * np.pi
            H1p = (nz1 - ns) * np.mean(np.log(rvar1 / (rvar1 + self.wvar)))
            Hq = ns * np.mean(np.log(a * rvar1_rep * rvar0_rep * d))
            H0p = (nz0 - ns) * np.mean(np.log(a * rvar0))
            cost = cost - H1p - Hq - H0p

        # Scale by 2 for the real case
        cost /= cscale

        return zhat, zhatvar, cost
Ejemplo n.º 17
0
    def solve(self):
        """
        Runs the main GAMP algorithm
        
        The final estimates are saved in :code:`z0` and :code:`z1`
        along with variances :code:`zvar0` and :code:`zvar1`
        """

        # Check if cost is available for both estimators
        if not self.est0.cost_avail or not self.est1.cost_avail:
            self.comp_cost = False

        # Initial esitmate from the input node
        if self.comp_cost:
            z0, zvar0, cost0 = self.est0.est_init(return_cost=True)
        else:
            z0, zvar0 = self.est0.est_init(return_cost=False)
            cost0 = 0
        self.z0 = z0
        self.zvar0 = zvar0
        self.cost0 = cost0

        # Initialize other variables
        self.var_cost0 = 0
        self.var_cost1 = 0
        self.cost = 0
        self.s = np.zeros(self.shape1)

        for it in range(self.nit):

            # Forward transform to est1
            t0 = time.time()
            rvar1_new = self.A.var_dot(self.zvar0)
            rvar1_rep = common.repeat_axes(rvar1_new,self.shape1,\
                                              self.var_axes1,rep=False)
            z1_mult = self.A.dot(self.z0)
            r1_new = z1_mult - rvar1_rep * self.s

            # Damping
            if it > 0:
                self.r1 = (1 - self.step) * self.r1 + self.step * r1_new
                self.rvar1 = (1 -
                              self.step) * self.rvar1 + self.step * rvar1_new
            else:
                self.r1 = r1_new
                self.rvar1 = rvar1_new

            # Estimator 1
            if self.comp_cost:
                z1, zvar1, cost1 = self.est1.est(self.r1,
                                                 self.rvar1,
                                                 return_cost=True)
                if not self.map_est:
                    cost1 -= self.cost_adjust(self.r1,z1,self.rvar1,zvar1,\
                                                   self.shape1,self.var_axes1)
            else:
                z1, zvar1 = self.est1.est(self.r1,
                                          self.rvar1,
                                          return_cost=False)
                cost1 = 0
            self.z1 = z1
            self.zvar1 = zvar1
            self.cost1 = cost1
            con_new = np.mean(np.abs(z1 - z1_mult)**2)

            # Reverse nonlinear transform to est 0
            self.s = (self.z1 - self.r1) / rvar1_rep
            self.sprec = 1 / self.rvar1 * (1 - self.zvar1 / self.rvar1)
            t1 = time.time()
            self.time_est1 = t1 - t0

            # Reverse linear transform to est 0
            rvar0_new = 1 / self.A.var_dotH(self.sprec)
            rvar0_rep = common.repeat_axes(rvar0_new,self.shape0,\
                                              self.var_axes0,rep=False)
            r0_new = self.z0 + rvar0_rep * self.A.dotH(self.s)

            # Damping
            if it > 0:
                self.r0 = (1 - self.step) * self.r0 + self.step * r0_new
                self.rvar0 = (1 -
                              self.step) * self.rvar0 + self.step * rvar0_new
            else:
                self.r0 = r0_new
                self.rvar0 = rvar0_new

            # Estimator 0
            if self.comp_cost:
                z0, zvar0, cost0 = self.est0.est(self.r0,
                                                 self.rvar0,
                                                 return_cost=True)
                if not self.map_est:
                    cost0 -= self.cost_adjust(self.r0,z0,self.rvar0,zvar0,\
                                                   self.shape0,self.var_axes0)

            else:
                z0, zvar0 = self.est0.est(self.r0,
                                          self.rvar0,
                                          return_cost=False)
                cost0 = 0
            self.z0 = z0
            self.zvar0 = zvar0
            self.cost0 = cost0

            # Compute total cost and constraint
            cost_new = self.cost0 + self.cost1
            if not self.map_est:
                cost_new += self.cost_gauss()

            # Step size adaptation
            if (self.step_adapt) and (it > 0):
                if (con_new < self.con):
                    self.step = np.minimum(1, self.step_inc * self.step)
                else:
                    self.step = np.maximum(self.step_min,
                                           self.step_dec * self.step)
            self.cost = cost_new
            self.con = con_new

            t2 = time.time()
            self.time_est0 = t2 - t1
            self.time_iter = t2 - t0

            # Print progress
            if self.prt_period > 0:
                if (it % self.prt_period == 0):
                    if self.comp_cost:
                        print("it={0:4d} cost={1:12.4e} con={2:12.4e} step={3:12.4e}".format(\
                            it, self.cost, self.con, self.step))
                    else:
                        print("it={0:4d} con={1:12.4e}".format(\
                            it, self.con))

            # Save history
            self.save_hist()