def est_map(self, r, rvar, return_cost, ind_out): """ MAP Estimation In this case, we wish to minimize cost = (z0-r0)^2/(2*rvar0) + (z1-r1)^2/(2*rvar1) where z1 = max(0,z0) """ # Unpack the terms r0, r1 = r rvar0, rvar1 = rvar # Clip variances rvar1 = np.minimum(1e8 * rvar0, rvar1) # Reshape the variances rvar0 = common.repeat_axes(rvar0, self.shape[0], self.var_axes[0]) rvar1 = common.repeat_axes(rvar1, self.shape[1], self.var_axes[1]) # Positive case: z0 >= 0 and hence z1=z0 z0p = np.maximum(0, (rvar0 * r1 + rvar1 * r0) / (rvar0 + rvar1)) z1p = z0p zvar0p = rvar0 * rvar1 / (rvar0 + rvar1) zvar1p = zvar0p costp = 0.5 * ((z0p - r0)**2 / rvar0 + (z1p - r1)**2 / rvar1) # Negative case: z0 <= 0 and hence z1 = 0 z0n = np.minimum(0, r0) z1n = 0 zvar0n = rvar0 zvar1n = 0 costn = 0.5 * ((z0n - r0)**2 / rvar0 + (z1n - r1)**2 / rvar1) # Find lower cost and select the correct choice for each element Ip = (costp < costn) zhat0 = z0p * Ip + z0n * (1 - Ip) zhat1 = z1p * Ip + z1n * (1 - Ip) zhatvar0 = zvar0p * Ip + zvar0n * (1 - Ip) zhatvar1 = zvar1p * Ip + zvar1n * (1 - Ip) cost = np.sum(costp * Ip + costn * (1 - Ip)) # Average the variance over the specified axes zhatvar0 = np.mean(zhatvar0, axis=self.var_axes[0]) zhatvar1 = np.mean(zhatvar1, axis=self.var_axes[1]) zhatvar = [zhatvar0, zhatvar1] # Pack the items zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) if not return_cost: return zhat, zhatvar else: return zhat, zhatvar, cost
def est_init(self, return_cost=False, ind_out=None,\ avg_var_cost=True): """ Initial estimator. See the base class :class:`vampyre.estim.base.Estim` for a complete description. :param boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zmean, zvar, [cost]` which are the prior mean and variance """ # Check parameters if (ind_out != [0]) and (ind_out != None): raise ValueError("ind_out must be either [0] or None") if not avg_var_cost: raise ValueError( "disabling variance averaging not supported for LinEst") # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() shape0 = self.A.shape0 # Reshape the variances to the transformed space s1 = common.repeat_axes(s, sshape, srep_axes) wvar1 = common.repeat_axes(self.wvar, sshape, self.wrep_axes, rep=False) # Compute the estimate within the transformed space q = (1 / s1) * self.p qvar = wvar1 / (np.abs(s1)**2) qvar_mean = np.mean(qvar, axis=self.var_axes) rdim = np.product(sshape) / np.product(shape0) zmean = self.A.Vsvd(q) zvar = rdim * qvar_mean + (1 - rdim) * self.rvar_init # Exit if cost does not need to be computed if not return_cost: return zmean, zvar # Computes the MAP output cost if np.all(self.wvar > 0): cost = self.ypnorm else: cost = 0 # Compute the output variance cost if np.all(self.wvar > 0) and self.map_est: clog = np.log(2 * np.pi * self.wvar) cost += common.repeat_sum(clog, self.zshape, self.wrep_axes) # Scale for real case if not self.is_complex: cost = 0.5 * cost return zmean, zvar, cost
def init_svd(self): """ Initialization for the SVD method """ # Compute the SVD terms # Take an SVD A=USV'. Then write p = SV'z + w, if not self.A.svd_avail: raise common.VpException("Transform must support an SVD") self.bt = self.A.UsvdH(self.b) srep_axes = self.A.srep_axes # Compute the norm of ||b-UU*(b)||^2/wvar if np.all(self.wvar > 0): bp = self.A.Usvd(self.bt) wvar_rep = common.repeat_axes(self.wvar, self.shape[1], self.wrep_axes, rep=False) err = np.abs(self.b - bp)**2 self.bpnorm = np.sum(err / wvar_rep) else: self.bpnorm = 0 # Check that all axes on which A operates are repeated ndim = len(self.shape[1]) for i in range(ndim): if not (i in self.var_axes[1]) and not (i in srep_axes): raise common.VpException( "Variance must be constant over output axis") if not (i in self.wrep_axes) and not (i in srep_axes): raise common.VpException( "Noise variance must be constant over output axis") if not (i in self.var_axes[0]) and not (i in srep_axes): raise common.VpException( "Variance must be constant over input axis")
def est(self, r, rvar, return_cost=False, avg_var_cost=True): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param boolean return_cost: Flag indicating if :code:`cost` is to be returned :param Boolean avg_var_cost: Average variance and cost. This should be disabled to obtain per element values. (Default=True) :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Infinite variance case if np.any(rvar == np.Inf): return self.est_init(return_cost, avg_var_cost) # Convert to 1D vectors r1 = r.ravel() rvar1 = common.repeat_axes(rvar, self.shape, self.var_axes) rvar1 = rvar1.ravel() # Compute the augmented penalty for each value faug = (np.abs(self.zval[None, :] - r1[:, None])**2) / rvar1[:, None] if not self.is_complex: faug *= 0.5 faug = faug + self.fz[None, :] # Compute the conditional probability of each value fmin = np.min(faug, axis=1) pzr = np.exp(-faug + fmin[:, None]) psum = np.sum(pzr, axis=1) pzr = pzr / psum[:, None] cost = -np.log(psum) + fmin zhat = pzr.dot(self.zval) zerr = np.abs(self.zval[None, :] - zhat[:, None])**2 zhatvar = np.sum(pzr * zerr, axis=1) # Reshape values cost = np.reshape(cost, self.shape) zhat = np.reshape(zhat, self.shape) zhatvar = np.reshape(zhatvar, self.shape) self.pzr = pzr # Average values if avg_var_cost: cost = np.sum(cost) zhatvar = np.mean(zhatvar, axis=self.var_axes) if return_cost: return zhat, zhatvar, cost else: return zhat, zhatvar
def compute_cost_terms(self,idir): """ Computes the Gaussian cost in belief propagation. See base class :class:`MsgHdl` for more details. :param z: Estimate :math:`z` :param r: Estimate :math:`r` :param zvar: Variance :math:`\\tau_z` :param rvar: Variance :math:`\\tau_r` :returns: :code:`cost` the cost :math:`c` defined above. """ # Skip update if if self.rvar_prev[idir] is None: return rvar_rep = common.repeat_axes(self.rvar_prev[idir],\ self.shape,self.var_axes,rep=False) # Computes the gradient z = self.zprev[idir] r = self.rprev[idir] self.grad[idir] = (z-r)/rvar_rep self.cost[idir] = np.sum((1/rvar_rep)*(np.abs(z-r)**2)) if not self.map_est: self.cost[idir] += np.prod(self.shape)*\ np.mean(self.zvar_prev[idir]/self.rvar_prev[idir]) if not self.is_complex: self.cost[idir] *= 0.5
def est(self, r, rvar, return_cost=False, ind_out=None, avg_var_cost=True): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Check parameters if (ind_out != [0]) and (ind_out != None): raise ValueError("ind_out must be either [0] or None") if not avg_var_cost: raise ValueError( "disabling variance averaging not supported for HardThreshEst") # Repeat the variance and reshape r and rvar to 1D vectors rvar1 = common.repeat_axes(rvar, self.shape, self.var_axes) r1 = r.ravel() rvar1 = rvar1.ravel() y1 = self.y.ravel() # Compute the values for Ai # A0i = \int_{-\intfy}^thresh z^i exp(-(z-r)^2/(2*rvar)) # A1i = \int_thresh^\infty z^i exp(-(z-r)^2/(2*rvar)) rsig = np.sqrt(rvar1) A00, A01, A02 = gauss_integral(-np.Inf, self.thresh, r1, rvar1) A10 = rsig - A00 A11 = rsig * r1 - A01 A12 = rsig * (rvar1 + r1**2) - A02 # Compute probability y==1 before flipping py1 = y1 * (1 - self.perr) + (1 - y1) * self.perr # Set Ai = A0i for y==0 and Ai=A1i for y==1 A0 = A10 * py1 + A00 * (1 - py1) A1 = A11 * py1 + A01 * (1 - py1) A2 = A12 * py1 + A02 * (1 - py1) # Compute posterior mean and variance zhat = A1 / A0 zhatvar = A2 / A0 - (zhat**2) cost = -np.sum(np.log(A0)) # Reshape and average values zhat = np.reshape(zhat, self.shape) zhatvar = np.reshape(zhatvar, self.shape) zhatvar = np.mean(zhatvar, axis=self.var_axes) if return_cost: return zhat, zhatvar, cost else: return zhat, zhatvar
def __init__(self,A,y,wvar=0,\ wrep_axes='all', var_axes=(0,),name=None,map_est=False,\ is_complex=False,rvar_init=1e5,tune_wvar=False): BaseEst.__init__(self, shape=A.shape0, var_axes=var_axes,\ dtype=A.dtype0, name=name,\ type_name='LinEstim', nvars=1, cost_avail=True) self.A = A self.y = y self.wvar = wvar self.map_est = map_est self.is_complex = is_complex self.cost_avail = True self.rvar_init = rvar_init self.tune_wvar = tune_wvar # Get the input and output shape self.zshape = A.shape0 self.yshape = A.shape1 # Set the repetition axes ndim = len(self.yshape) if wrep_axes == 'all': wrep_axes = tuple(range(ndim)) self.wrep_axes = wrep_axes # Compute the SVD terms # Take an SVD A=USV'. Then write p = SV'z + w, if not A.svd_avail: raise common.VpException("Transform must support an SVD") self.p = A.UsvdH(y) srep_axes = A.get_svd_diag()[2] # Compute the norm of ||y-UU*(y)||^2/wvar if np.all(self.wvar > 0): yp = A.Usvd(self.p) wvar1 = common.repeat_axes(wvar, self.yshape, self.wrep_axes, rep=False) err = np.abs(y - yp)**2 self.ypnorm = np.sum(err / wvar1) else: self.ypnorm = 0 # Check that all axes on which A operates are repeated for i in range(ndim): if not (i in self.wrep_axes) and not (i in srep_axes): raise common.VpException( "Variance must be constant over output axis") if not (i in self.var_axes) and not (i in srep_axes): raise common.VpException( "Variance must be constant over input axis")
def min_primal_avg(self): """ Minimizes the primal average variables, zhat """ nlayers = len(self.est_list) self.fgrad = np.zeros(nlayers - 1) for i in range(nlayers - 1): msg_hdl = self.msg_hdl_list[i] rvarrev = common.repeat_axes(self.rvarrev[i],msg_hdl.shape,\ msg_hdl.rep_axes,rep=False) rvarfwd = common.repeat_axes(self.rvarfwd[i],msg_hdl.shape,\ msg_hdl.rep_axes,rep=False) self.zhat[i] = (rvarrev*(self.zhatrev[i]-self.sfwd[i]) +\ rvarfwd*(self.zhatfwd[i]-self.srev[i]))/(rvarfwd + rvarrev) # Compute the gradients grad0 = (self.zhatrev[i] - self.rfwd[i]) / rvarfwd grad1 = (self.zhatfwd[i] - self.rrev[i]) / rvarrev self.fgrad[i] = np.mean((grad0 + grad1)**2)
def svd_dotH(self, s1, q1): """ Performs diagonal matrix multiplication conjugate Implements :math:`q_0 = \\mathrm{diag}(s_1)^* q_1`. :param s1: diagonal parameters :param q1: input to the diagonal multiplication :returns: :code:`q0` diagonal multiplication output """ srep = common.repeat_axes(np.conj(s1), self.shape0, (0, 1), rep=False) q0 = srep * q1 return q0
def svd_dot(self, s1, q0): """ Performs diagonal matrix multiplication. Implements :math:`q_1 = \\mathrm{diag}(s_1) q_0`. :param s1: diagonal parameters :param q0: input to the diagonal multiplication :returns: :code:`q1` diagonal multiplication output """ srep = common.repeat_axes(s1, self.shape0, (0, 1), rep=False) q1 = srep * q0 return q1
def __init__(self,A,b,rvar,wvar,z0rep_axes,z1rep_axes,wrep_axes,\ shape0,shape1,is_complex): self.A = A self.b = b self.shape0 = shape0 self.shape1 = shape1 self.n0 = np.prod(shape0) self.n1 = np.prod(shape1) # Compute scale factors rvar0, rvar1 = rvar self.rsqrt0 = common.repeat_axes(np.sqrt(rvar0), self.shape0, z0rep_axes, rep=False) self.rsqrt1 = common.repeat_axes(np.sqrt(rvar1), self.shape1, z1rep_axes, rep=False) self.wvar_pos = np.all(wvar > 0) if self.wvar_pos: self.wsqrt = common.repeat_axes(np.sqrt(wvar), self.shape1, wrep_axes, rep=False) # Compute dimensions of the transform F if self.wvar_pos: nin = self.n0 + self.n1 nout = self.n0 + 2 * self.n1 else: nin = self.n0 nout = self.n0 + self.n1 self.shape = (nout, nin) if is_complex: self.dtype = np.dtype(complex) else: self.dtype = np.dtype(float)
def cost_adjust(self, r, z, rvar, zvar, shape, var_axes): """ Computes the cost adjustment term for the Bethe Free Energy: J = beta*[log(2*pi*rvar) + ((z-r)**2 + xvar)/rvar] where beta = 1 for complex problems and 0 for real problems """ J0 = np.mean(np.log(2 * np.pi * rvar)) * np.product(shape) rvar_rep = common.repeat_axes(rvar,shape,\ var_axes,rep=False) J1 = np.sum(np.abs(r - z)**2 / rvar_rep) J2 = np.mean(zvar / rvar) * np.product(shape) J = J0 + J1 + J2 if not self.is_complex: J = J / 2 return J
def est_mmse(self, r, rvar, return_cost, ind_out): """ In the MMSE estimation case, we wish to estimate z0 and z1 with priors zi = N(ri,rvari) and z1=f(z0) Substituting in z1 = f(z0), we have the density of z0: p(z0) \propto qn(z0)1_{z0 < 0} + qp(z0)1_{z0 > 0} where qp(z0) = exp[-(z0-r0)^2/(2*rvar0) - (z0-r1)^2/(2*rvar1)] qn(z0) = exp[-(z0-r0)^2/(2*rvar0) - r1^2/(2*rvar1)] First, we complete the squares and write: qp(z0) = exp(Amax)*Cp*exp(-(z0-rp)^2/(2*zvarp))/sqrt(2*pi) qn(z0) = exp(Amax)*Cn*exp(-(z0-rn)^2/(2*zvarn))/sqrt(2*pi) """ # Unpack the terms r0, r1 = r rvar0, rvar1 = rvar # Reshape the variances rvar0 = common.repeat_axes(rvar0, self.shape[0], self.var_axes[0]) rvar1 = common.repeat_axes(rvar1, self.shape[1], self.var_axes[1]) if np.any(rvar1 == np.Inf): # Infinite variance case. zvarp = rvar0 zvarn = rvar0 rp = r0 rn = r0 Cp = 1 Cn = 1 Amax = 0 else: # Compute the MAP estimate zhat_map, zvar_map = self.est_map(r, rvar, return_cost=False, ind_out=[0, 1]) zhat0_map, zhat1_map = zhat_map zvar0_map, zvar1_map = zvar_map # Compute the conditional Gaussian terms for z > 0 and z < 0 zvarp = rvar0 * rvar1 / (rvar0 + rvar1) zvarn = rvar0 rp = (rvar1 * r0 + rvar0 * r1) / (rvar0 + rvar1) rn = r0 # Compute scaling constants for each region Ap = 0.5 * ((rp**2) / zvarp - (r0**2) / rvar0 - (r1**2) / rvar1) An = 0.5 * (-(r1**2) / rvar1) Amax = np.maximum(Ap, An) Ap = Ap - Amax An = An - Amax Cp = np.exp(Ap) Cn = np.exp(An) # Compute moments for each region zp = Cp * gauss_integral(0, np.Inf, rp, zvarp) zn = Cn * gauss_integral(-np.Inf, 0, rn, zvarn) # Find poorly conditioned points Ibad = (zp[0] + zn[0] < 1e-6) zpsum = zp[0] + zn[0] + Ibad # Compute mean zhat0 = (zp[1] + zn[1]) / zpsum zhat1 = zp[1] / zpsum # Compute the variance zhatvar0 = (zp[2] + zn[2]) / zpsum - zhat0**2 zhatvar1 = zp[2] / zpsum - zhat1**2 # Replace bad points with MAP estimate if 1: zhat0 = zhat0 * (1 - Ibad) + zhat0_map * Ibad zhat1 = zhat1 * (1 - Ibad) + zhat1_map * Ibad zhatvar0 = zhatvar0 * (1 - Ibad) + zvar0_map * Ibad zhatvar1 = zhatvar1 * (1 - Ibad) + zvar1_map * Ibad # Average the variance over the specified axes zhatvar0 = np.mean(zhatvar0, axis=self.var_axes[0]) zhatvar1 = np.mean(zhatvar1, axis=self.var_axes[1]) # Pack the items zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) if not return_cost: return zhat, zhatvar """ Compute the cost = -\log \int p(z_0) = -Amax - log(zp[0] + zn[0]) """ nz = np.prod(self.shape[0]) cost = -nz * np.mean(Amax - np.log(zpsum)) return zhat, zhatvar, cost
def msg_sub(self,z,zvar,idir=0): """ Variance subtraction for message passing See base class :class:`MsgHdl` for more details. :param z: Mean from the factor node :param zvar: Variance from the factor node :param idir: Index of the factor node for incoming message """ # If variance is fixed, then overwrite the variance to be consistent # with the prior variances. if self.damp_var == 0: zvar = self.rvar_prev[0]*self.rvar_prev[1]/(self.rvar_prev[0] + self.rvar_prev[1]) # Save the z value self.zprev[idir] = z self.zvar_prev[idir] = zvar # Get variance passed to the incoming node r0 = self.rprev[idir] rvar0 = self.rvar_prev[idir] if rvar0 is None: # Special case where there is no previous variance rvar1 = zvar r1 = z else: # Compute cost and gradient self.compute_cost_terms(idir) # Threshold the decrease in variance alpha = zvar/rvar0 alpha = np.maximum(self.alpha_min, alpha) alpha = np.minimum(self.alpha_max, alpha) # Compute output variance rvar1 = alpha/(1-alpha)*rvar0 # Compute the message alpha_rep = common.repeat_axes(alpha,self.shape,self.var_axes,rep=False) r1 = (z-alpha_rep*r0)/(1-alpha_rep) # Bound the variance rvar1 = np.maximum(rvar1, self.rvar_min) rvar1 = np.minimum(rvar1, self.rvar_max) # Get the last output mean and variance jdir = (idir+1)%2 rvar1_prev = self.rvar_prev[jdir] r1_prev = self.rprev[jdir] if not (r1_prev is None): r1 = self.damp*r1 + (1-self.damp)*r1_prev if not (rvar1_prev is None): gam1_prev = 1/rvar1_prev gam1 = self.damp_var/rvar1 + (1-self.damp_var)*gam1_prev rvar1 = 1/gam1 # Bound the variance rvar1 = np.maximum(rvar1, self.rvar_min) rvar1 = np.minimum(rvar1, self.rvar_max) # Save outgoing message self.rprev[jdir] = r1 self.rvar_prev[jdir] = rvar1 return r1, rvar1
def est(self,r,rvar,return_cost=False, ind_out=None,\ avg_var_cost=True): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Check parameters if (ind_out != [0]) and (ind_out != None): raise ValueError("ind_out must be either [0] or None") if not avg_var_cost: raise ValueError( "disabling variance averaging not supported for LinEst") # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() # Get dimensions nz = np.prod(self.zshape) ny = np.prod(self.yshape) ns = np.prod(sshape) # Reshape the variances to the transformed space s1 = common.repeat_axes(s, sshape, srep_axes, rep=False) rvar1 = common.repeat_axes(rvar, sshape, self.var_axes, rep=False) wvar1 = common.repeat_axes(self.wvar, sshape, self.wrep_axes, rep=False) # Compute the estimate within the transformed space qbar = self.A.VsvdH(r) d = 1 / (rvar1 * (np.abs(s1)**2) + wvar1) q = d * (rvar1 * s1.conj() * self.p + wvar1 * qbar) qvar = rvar1 * wvar1 * d qvar_mean = np.mean(qvar, axis=self.var_axes) zhat = self.A.Vsvd(q - qbar) + r zhatvar = ns / nz * qvar_mean + (1 - ns / nz) * rvar # Update the variance estimate if tuning is enabled if self.tune_wvar: yerr = np.abs(self.y - self.A.Usvd(s1 * q))**2 self.wvar = np.mean(yerr, self.wrep_axes) + np.mean( qvar * (np.abs(s1)**2), self.wrep_axes) # Exit if cost does not need to be computed if not return_cost: return zhat, zhatvar # Computes the MAP output cost if np.all(self.wvar > 0): err = np.abs(self.p - s1 * q)**2 cost = self.ypnorm + np.sum(err / wvar1) else: cost = 0 # Add the MAP input cost err = np.abs(q - qbar)**2 cost = cost + np.sum(err / rvar1) # Compute the variance cost. if self.map_est: # For the MAP case, this is log(2*pi*wvar) if np.all(self.wvar > 0): cost += ny * np.mean(np.log(2 * np.pi * self.wvar)) else: # For the MMSE case, this is 1 + log(2*pi*wvar) - H(b) # where b is the Gaussian with variance wvar*rvar*d cost += -ns*np.mean(np.log(rvar1*d)) -\ (nz-ns)*np.mean(np.log(2*np.pi*rvar1)) if np.all(self.wvar > 0): cost += (ny - ns) * np.mean(np.log(2 * np.pi * self.wvar)) # Scale for real case if not self.is_complex: cost = 0.5 * cost return zhat, zhatvar, cost
def est_svd(self, r, rvar, return_cost=False, ind_out=[0, 1]): """ SVD-based estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Unpack the variables r0, r1 = r rvar0, rvar1 = rvar # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() # Get dimensions nz1 = np.prod(self.shape[1]) nz0 = np.prod(self.shape[0]) ns = np.prod(sshape) # Reshape the variances to match the dimensions of the first-order # terms s_rep = common.repeat_axes(s, sshape, srep_axes, rep=False) rvar0_rep = common.repeat_axes(rvar0, self.shape[0], self.var_axes[0], rep=False) rvar1_rep = common.repeat_axes(rvar1, self.shape[1], self.var_axes[1], rep=False) wvar_rep = common.repeat_axes(self.wvar, sshape, self.wrep_axes, rep=False) """ To compute the estimates, we write: z0 = V*q0 + z0perp, z1 = U*q1 + z1perp. We separately compute the estimates of q=(q0,q1), z0perp and z1perp. First, we compute the estimates for q=(q0,q1). The joint density of q is given by p(q) \propto exp(-J(q)) where J(q) = ||q1-s*q0 - bt||^2/wvar + ||q1-q1bar||^2/rvar1 + ||q0-q0bar||^2/rvar0 where q0bar = V'(r0), q1bar = U'(z1), bt = U'(b). Now define c := [-s, 1] P^{-1} = diag(rvar0,rvar1) Hence J(q) = ||c'q-bt||^2/wvar + (q-qbar)'P^{-1}(q-qbar) = q'Q^{-1}q -2q'*g where Q = (P^{-1} - cc'/wvar)^{-1} = P - Pcc'P/(wvar + c'Pc) g = bt*c/wvar + P^{-1}qbar Now, we can verify the following: Q = cov(q) det(Q) = d = 1/(wvar + rvar1 + |s|^2*rvar0) Q*c/wvar = d*P^{-1}c = d*[rvar0*s, rvar1] Q*P^{-1}*qbar = qbar - Pc d*c'*qbar = qbar - [-rvar0*s,rvar1]*d*(qbar1-s*qbar0) Hence qhat = E(q) = qbar - [-rvar0*s,rvar1]*d*(qbar1-s*qbar0-bt) """ # Infinite variance case if np.any(rvar1 == np.Inf): zhat0 = r0 zhatvar0 = rvar0 zhat1 = self.A.dot(r0) + self.b qvar1 = np.mean( np.abs(s_rep)**2 * rvar0_rep + wvar_rep, self.var_axes[1]) zhatvar1 = ns / nz1 * qvar1 + (1 - ns / nz1) * self.wvar cost = 0 # FIX THIS zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) if return_cost: return zhat, zhatvar, cost else: return zhat, zhatvar # Compute the offset terms qbar0 = self.A.VsvdH(r0) qbar1 = self.A.UsvdH(r1) # Compute E(q) d = 1 / ((np.abs(s_rep)**2) * rvar0_rep + rvar1_rep + wvar_rep) e = d * (qbar1 - s_rep * qbar0 - self.bt) qhat0 = qbar0 + rvar0_rep * s_rep * e qhat1 = qbar1 - rvar1_rep * e """ Compute E(z) zhat0 = (I-VV')*r0 + V*qhat0 = r0 + V*(qhat0-qbar0) zhat1 = (I-UU')*(wvar*r1+rvvar1*b)/(wvar+rvar0) + U*qhat1 = """ # Compute E(z) z1p = (wvar_rep * r1 + rvar1_rep * self.b) / (wvar_rep + rvar1_rep) qbar1p = (wvar_rep * qbar1 + rvar1_rep * self.bt) / (wvar_rep + rvar1_rep) zhat0 = r0 + self.A.Vsvd(qhat0 - qbar0) zhat1 = z1p + self.A.Usvd(qhat1 - qbar1p) zhat = [] if 0 in ind_out: zhat.append(zhat0) if 1 in ind_out: zhat.append(zhat1) """ Compute the variance. From the above calcualtions, we have cov(q) = Q = P - Pcc'P/(wvar + c'Pc) var(q0) = rvar0 - d*(rvar0^2)|s|^2 = rvar0*(1-d*rvar0*|s|^2) var(q1) = rvar1 - d*(rvar1^2) = rvar1*(1-d*rvar1) """ qvar0 = rvar0_rep * (1 - d * rvar0_rep * np.abs(s_rep)**2) qvar1 = rvar1_rep * (1 - d * rvar1_rep) qvar0 = np.mean(qvar0, axis=self.var_axes[0]) qvar1 = np.mean(qvar1, axis=self.var_axes[1]) """ Compute the variance of z """ zhatvar0 = ns / nz0 * qvar0 + (1 - ns / nz0) * rvar0 zhatvar1 = ns / nz1 * qvar1 + (1 - ns / nz1) * rvar1 * self.wvar / ( self.wvar + rvar1) zhatvar = [] if 0 in ind_out: zhatvar.append(zhatvar0) if 1 in ind_out: zhatvar.append(zhatvar1) if not return_cost: return zhat, zhatvar """ Compute costs from the first order terms: cost1_perp = min_{z1} ||(I-UU')*(b-z1)||^2/wvar + ||(I-UU')*(r-z1)||^2/rvar1 = ||(I-UU')*(b-r1)||^2/(wvar+rvar1) = ||b-r1 - U*(bt-qbar1)||^2/(wvar+rvar1) costq = ||q1-s*q0-bt||^2/wvar cost0_perp = 0 """ e = (self.b - r1 - self.A.Usvd(self.bt - qbar1)) cost1_perp = np.sum((np.abs(e)**2) / (wvar_rep + rvar1_rep)) if np.all(self.wvar > 0): e = qhat1 - s_rep * qhat0 - self.bt costq = np.sum((np.abs(e)**2) / wvar_rep) else: costq = 0 cost0 = np.sum((np.abs(qhat0 - qbar0)**2) / rvar0_rep) cost1 = np.sum((np.abs(qhat1 - qbar1)**2) / rvar1_rep) cost = cost1_perp + costq + cost0 + cost1 """ Compute the costs for the second-order terms. For the MAP case, cost = -nz1*log(2*pi*wvar) For the MMSE case, we compute the Gaussian entropies: H1p = H((I-UU')*z1) - (nz-ns)log(2*pi*wvar) Hq = H(q) - ns*log(2*pi*wvar) H0q = H((I-VV')*z0) cost = cost - H1p - Hq - H0q """ if self.is_complex: cscale = 1 else: cscale = 2 if self.map_est: if np.all(self.wvar > 0): cost += nz1 * np.mean(np.log(cscale * np.pi * self.wvar)) else: a = cscale * np.pi H1p = (nz1 - ns) * np.mean(np.log(rvar1 / (rvar1 + self.wvar))) Hq = ns * np.mean(np.log(a * rvar1_rep * rvar0_rep * d)) H0p = (nz0 - ns) * np.mean(np.log(a * rvar0)) cost = cost - H1p - Hq - H0p # Scale by 2 for the real case cost /= cscale return zhat, zhatvar, cost
def solve(self): """ Runs the main GAMP algorithm The final estimates are saved in :code:`z0` and :code:`z1` along with variances :code:`zvar0` and :code:`zvar1` """ # Check if cost is available for both estimators if not self.est0.cost_avail or not self.est1.cost_avail: self.comp_cost = False # Initial esitmate from the input node if self.comp_cost: z0, zvar0, cost0 = self.est0.est_init(return_cost=True) else: z0, zvar0 = self.est0.est_init(return_cost=False) cost0 = 0 self.z0 = z0 self.zvar0 = zvar0 self.cost0 = cost0 # Initialize other variables self.var_cost0 = 0 self.var_cost1 = 0 self.cost = 0 self.s = np.zeros(self.shape1) for it in range(self.nit): # Forward transform to est1 t0 = time.time() rvar1_new = self.A.var_dot(self.zvar0) rvar1_rep = common.repeat_axes(rvar1_new,self.shape1,\ self.var_axes1,rep=False) z1_mult = self.A.dot(self.z0) r1_new = z1_mult - rvar1_rep * self.s # Damping if it > 0: self.r1 = (1 - self.step) * self.r1 + self.step * r1_new self.rvar1 = (1 - self.step) * self.rvar1 + self.step * rvar1_new else: self.r1 = r1_new self.rvar1 = rvar1_new # Estimator 1 if self.comp_cost: z1, zvar1, cost1 = self.est1.est(self.r1, self.rvar1, return_cost=True) if not self.map_est: cost1 -= self.cost_adjust(self.r1,z1,self.rvar1,zvar1,\ self.shape1,self.var_axes1) else: z1, zvar1 = self.est1.est(self.r1, self.rvar1, return_cost=False) cost1 = 0 self.z1 = z1 self.zvar1 = zvar1 self.cost1 = cost1 con_new = np.mean(np.abs(z1 - z1_mult)**2) # Reverse nonlinear transform to est 0 self.s = (self.z1 - self.r1) / rvar1_rep self.sprec = 1 / self.rvar1 * (1 - self.zvar1 / self.rvar1) t1 = time.time() self.time_est1 = t1 - t0 # Reverse linear transform to est 0 rvar0_new = 1 / self.A.var_dotH(self.sprec) rvar0_rep = common.repeat_axes(rvar0_new,self.shape0,\ self.var_axes0,rep=False) r0_new = self.z0 + rvar0_rep * self.A.dotH(self.s) # Damping if it > 0: self.r0 = (1 - self.step) * self.r0 + self.step * r0_new self.rvar0 = (1 - self.step) * self.rvar0 + self.step * rvar0_new else: self.r0 = r0_new self.rvar0 = rvar0_new # Estimator 0 if self.comp_cost: z0, zvar0, cost0 = self.est0.est(self.r0, self.rvar0, return_cost=True) if not self.map_est: cost0 -= self.cost_adjust(self.r0,z0,self.rvar0,zvar0,\ self.shape0,self.var_axes0) else: z0, zvar0 = self.est0.est(self.r0, self.rvar0, return_cost=False) cost0 = 0 self.z0 = z0 self.zvar0 = zvar0 self.cost0 = cost0 # Compute total cost and constraint cost_new = self.cost0 + self.cost1 if not self.map_est: cost_new += self.cost_gauss() # Step size adaptation if (self.step_adapt) and (it > 0): if (con_new < self.con): self.step = np.minimum(1, self.step_inc * self.step) else: self.step = np.maximum(self.step_min, self.step_dec * self.step) self.cost = cost_new self.con = con_new t2 = time.time() self.time_est0 = t2 - t1 self.time_iter = t2 - t0 # Print progress if self.prt_period > 0: if (it % self.prt_period == 0): if self.comp_cost: print("it={0:4d} cost={1:12.4e} con={2:12.4e} step={3:12.4e}".format(\ it, self.cost, self.con, self.step)) else: print("it={0:4d} con={1:12.4e}".format(\ it, self.con)) # Save history self.save_hist()