def __init__(self, est1, est2, msg_hdl=None, hist_list=[], nit=10, comp_cost=False, prt_period=0, is_complex=False, map_est=False): Solver.__init__(self, hist_list) self.est1 = est1 self.est2 = est2 if msg_hdl is None: msg_hdl = estim.MsgHdlSimp(is_complex=is_complex, map_est=map_est, shape=self.est1.shape) self.msg_hdl = msg_hdl self.nit = nit self.comp_cost = comp_cost self.prt_period = prt_period # Check if dimensions match if self.est1.shape != self.est2.shape: err_str = '%s shape %s does not match %s shape %s' %\ (self.est1.name, str(self.est1.shape), self.est2.name, str(self.est2.shape)) raise common.VpException(err_str) if self.est1.shape != self.msg_hdl.shape: err_str = '%s shape %s does not match msg_hdl shape %s' %\ (self.est1.name, str(self.est1.shape), str(self.msg_hdl.shape)) raise common.VpException(err_str) if self.est1.var_axes != self.est2.var_axes: err_str = '%s var_axes %s does not match %s var_axes %s' %\ (self.est1.name, str(self.est1.var_axes), self.est2.name, str(self.est2.var_axes)) raise common.VpException(err_str) if self.est1.var_axes != self.msg_hdl.var_axes: err_str = '%s var_axes %s does not match msg_hdl var_axes %s' %\ (self.est1.name, str(self.est1.var_axes), str(self.msg_hdl.var_axes))
def init_svd(self): """ Initialization for the SVD method """ # Compute the SVD terms # Take an SVD A=USV'. Then write p = SV'z + w, if not self.A.svd_avail: raise common.VpException("Transform must support an SVD") self.bt = self.A.UsvdH(self.b) srep_axes = self.A.srep_axes # Compute the norm of ||b-UU*(b)||^2/wvar if np.all(self.wvar > 0): bp = self.A.Usvd(self.bt) wvar_rep = common.repeat_axes(self.wvar, self.shape[1], self.wrep_axes, rep=False) err = np.abs(self.b - bp)**2 self.bpnorm = np.sum(err / wvar_rep) else: self.bpnorm = 0 # Check that all axes on which A operates are repeated ndim = len(self.shape[1]) for i in range(ndim): if not (i in self.var_axes[1]) and not (i in srep_axes): raise common.VpException( "Variance must be constant over output axis") if not (i in self.wrep_axes) and not (i in srep_axes): raise common.VpException( "Noise variance must be constant over output axis") if not (i in self.var_axes[0]) and not (i in srep_axes): raise common.VpException( "Variance must be constant over input axis")
def __init__(self,A,y,wvar=0,\ wrep_axes='all', var_axes=(0,),name=None,map_est=False,\ is_complex=False,rvar_init=1e5,tune_wvar=False): BaseEst.__init__(self, shape=A.shape0, var_axes=var_axes,\ dtype=A.dtype0, name=name,\ type_name='LinEstim', nvars=1, cost_avail=True) self.A = A self.y = y self.wvar = wvar self.map_est = map_est self.is_complex = is_complex self.cost_avail = True self.rvar_init = rvar_init self.tune_wvar = tune_wvar # Get the input and output shape self.zshape = A.shape0 self.yshape = A.shape1 # Set the repetition axes ndim = len(self.yshape) if wrep_axes == 'all': wrep_axes = tuple(range(ndim)) self.wrep_axes = wrep_axes # Compute the SVD terms # Take an SVD A=USV'. Then write p = SV'z + w, if not A.svd_avail: raise common.VpException("Transform must support an SVD") self.p = A.UsvdH(y) srep_axes = A.get_svd_diag()[2] # Compute the norm of ||y-UU*(y)||^2/wvar if np.all(self.wvar > 0): yp = A.Usvd(self.p) wvar1 = common.repeat_axes(wvar, self.yshape, self.wrep_axes, rep=False) err = np.abs(y - yp)**2 self.ypnorm = np.sum(err / wvar1) else: self.ypnorm = 0 # Check that all axes on which A operates are repeated for i in range(ndim): if not (i in self.wrep_axes) and not (i in srep_axes): raise common.VpException( "Variance must be constant over output axis") if not (i in self.var_axes) and not (i in srep_axes): raise common.VpException( "Variance must be constant over input axis")
def est(self, r, rvar, return_cost=False, ind_out=None, avg_var_cost=True): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Check parameters if ind_out is None: ind_out = [0, 1] if not avg_var_cost: raise ValueError( "disabling variance averaging not supported for MixEst") if self.est_meth == 'svd': return self.est_svd(r, rvar, return_cost, ind_out) elif self.est_meth == 'cg': return self.est_cg(r, rvar, return_cost, ind_out) else: raise common.VpException("Unknown estimation method {0:s}".format( self.est_meth))
def __init__(self,nrow=256,ncol=256,wavelet='db4',level=3,fwd_mode='recon',\ dtype=np.float64,name=None): # Save parameters self.wavelet = wavelet self.level = level shape0 = (nrow,ncol) shape1 = (nrow,ncol) dtype0 = dtype dtype1 = dtype if pywt.Wavelet(wavelet).orthogonal: svd_avail = True #SVD calculation assumes an orthogonal wavelet else: svd_avail = False BaseLinTrans.__init__(self, shape0, shape1, dtype0, dtype1,\ svd_avail=svd_avail,name=name) # Set the mode to periodic to make the wavelet orthogonal self.mode = 'periodization' # Send a zero image to get the coefficient slices im = np.zeros((nrow,ncol)) coeffs = pywt.wavedec2(im, wavelet=self.wavelet, level=self.level, \ mode=self.mode) _, self.coeff_slices = pywt.coeffs_to_array(coeffs) # Confirm that fwd_mode is valid if (fwd_mode != 'recon') and (fwd_mode != 'analysis'): raise common.VpException('fwd_mode must be recon or analysis') self.fwd_mode = fwd_mode
def __init__(self, nrow=256, ncol=256, wavelet='db4', level=3, fwd_mode='recon'): # Initialize the base class LinTrans.__init__(self) # Save parameters self.wavelet = wavelet self.level = level self.shape0 = (nrow, ncol) self.shape1 = (nrow, ncol) # Set the mode to periodic to make the wavelet orthogonal self.mode = 'periodization' # Send a zero image to get the coefficient slices im = np.zeros((nrow, ncol)) coeffs = pywt.wavedec2(im, wavelet=self.wavelet, level=self.level, \ mode=self.mode) _, self.coeff_slices = pywt.coeffs_to_array(coeffs) # Confirm that fwd_mode is valid if (fwd_mode != 'recon') and (fwd_mode != 'analysis'): raise common.VpException('fwd_mode must be recon or analysis') self.fwd_mode = fwd_mode
def __init__(self, est_list, w, name=None): self.est_list = est_list self.w = w shape = est_list[0].shape var_axes = est_list[0].var_axes dtype = est_list[0].dtype # Check that all estimators have cost available for est in est_list: if est.shape != shape: raise common.VpException('Estimators must have the same shape') if est.var_axes != var_axes: raise common.VpException( 'Estimators must have the same var_axes') if not est.cost_avail: raise common.VpException(\ "Estimators in a mixture distribution"+\ "must have cost_avail==True") BaseEst.__init__(self,shape=shape,var_axes=var_axes,dtype=dtype,\ name=name, type_name='Mixture', nvars=1, cost_avail=True)
def __init__(self, est_list, w): Estim.__init__(self) self.est_list = est_list self.w = w self.shape = est_list[0].shape self.var_axes = est_list[0].var_axes # Check that all estimators have cost available for est in est_list: if not est.cost_avail: raise common.VpException(\ "Estimators in a mixture distribution"+\ "must have cost_avail==True") self.cost_avail = True
def __init__(self,A,b,wvar=0,\ z1rep_axes=(0,), z0rep_axes=(0,),wrep_axes='all',\ map_est=False,is_complex=False,est_meth='svd',\ nit_cg=100, atol_cg=1e-3, save_stats=False): Estim.__init__(self) self.A = A self.b = b self.wvar = wvar self.map_est = map_est self.is_complex = is_complex self.cost_avail = True # Initial variance. This is large value since the quantities # are underdetermined self.zvar0_init = np.Inf self.zvar1_init = np.Inf # Get the input and output shape self.shape0 = A.shape0 self.shape1 = A.shape1 # Set the repetition axes ndim = len(self.shape1) if z0rep_axes == 'all': z0rep_axes = tuple(range(ndim)) if z1rep_axes == 'all': z1rep_axes = tuple(range(ndim)) if wrep_axes == 'all': wrep_axes = tuple(range(ndim)) self.z0rep_axes = z0rep_axes self.z1rep_axes = z1rep_axes self.wrep_axes = wrep_axes # Initialization depending on the estimation method self.est_meth = est_meth if self.est_meth == 'svd': self.init_svd() elif self.est_meth == 'cg': self.init_cg() else: raise common.VpException( "Unknown estimation method {0:s}".format(est_meth)) # CG parameters self.nit_cg = nit_cg self.atol_cg = atol_cg self.save_stats = save_stats self.init_hist_dict()
def __init__(self,A,b,wvar=0,var_axes=[(0,), (0,)], wrep_axes='all',\ name=None,map_est=False,is_complex=False,est_meth='svd',\ nit_cg=100, atol_cg=1e-3, save_stats=False): self.A = A self.b = b self.wvar = wvar self.map_est = map_est self.is_complex = is_complex self.cost_avail = True # Initial variance. This is large value since the quantities # are underdetermined self.zvar0_init = np.Inf self.zvar1_init = np.Inf # Set parameters of the base estimator shape = [A.shape0, A.shape1] dtype = [A.dtype0, A.dtype1] nvars = 2 for i in range(nvars): if var_axes[i] == 'all': ndim = len(shape[i]) var_axes[i] = tuple(range(ndim)) if wrep_axes == 'all': ndim = len(shape[1]) wrep_axes = tuple(range(ndim)) self.wrep_axes = wrep_axes BaseEst.__init__(self,shape=shape, var_axes=var_axes, dtype=dtype, name=name,\ type_name='LinEstTwo', nvars=nvars, cost_avail=True) # Initialization depending on the estimation method self.est_meth = est_meth if self.est_meth == 'svd': self.init_svd() elif self.est_meth == 'cg': self.init_cg() else: raise common.VpException( "Unknown estimation method {0:s}".format(est_meth)) # CG parameters self.nit_cg = nit_cg self.atol_cg = atol_cg self.save_stats = save_stats self.init_hist_dict()
def __init__(self, est_list, msg_hdl_list=[], hist_list=[], nit=10,\ comp_cost=False,prt_period=0): Solver.__init__(self, hist_list) self.est_list = est_list self.msg_hdl_list = msg_hdl_list self.nit = nit self.comp_cost = comp_cost self.prt_period = prt_period # Check if all estimators can compute the cost nlayers = len(self.est_list) for i in range(nlayers): esti = self.est_list[i] if self.comp_cost and not esti.cost_avail: errstr = "Requested cost computation, but cost_avail==False"\ + " for estimator " + str(i) raise common.VpException(errstr) self.comp_cost = self.comp_cost and esti.cost_avail
def est(self, r, rvar, return_cost=False): """ Estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ if self.est_meth == 'svd': return self.est_svd(r, rvar, return_cost) elif self.est_meth == 'cg': return self.est_cg(r, rvar, return_cost) else: raise common.VpException("Unknown estimation method {0:s}".format( self.est_meth))
def est_cg(self, r, rvar, return_cost=False, ind_out=[0, 1]): """ CG-based estimation function The proximal estimation function as described in the base class :class:`vampyre.estim.base.Estim` :param r: Proximal mean :param rvar: Proximal variance :param Boolean return_cost: Flag indicating if :code:`cost` is to be returned :returns: :code:`zhat, zhatvar, [cost]` which are the posterior mean, variance and optional cost. """ # Unpack the inputs r0, r1 = r rvar0, rvar1 = rvar # Infinite variance case if np.any(rvar1 == np.Inf): zhat0 = r0 zhatvar0 = rvar0 zhat1 = self.A.dot(r0) + self.b # Compute variance numerically. yvar = np.abs(self.A.dot(self.dr0))**2 zhatvar1 = np.mean(yvar, self.var_axes[1])*rvar0/\ np.mean(self.dr0_norm_sq) zhat = [] zhatvar = [] if 0 in ind_out: zhat.append(zhat0) zhatvar.append(zhatvar0) if 1 in ind_out: zhat.append(zhat1) zhatvar.append(zhatvar1) cost = 0 if return_cost: return zhat, zhatvar, cost else: return zhat, zhatvar elif np.any(rvar0 == np.Inf): raise common.VpException("Infinite variance case for rvar0 "+\ "is not yet implemented") # Get dimensions self.n0 = np.prod(self.shape[0]) self.n1 = np.prod(self.shape[1]) """ First-order terms """ # Create the LSQR transform for the problem # The VAMP problem is equivalent to minimizing ||F(z)-g||^2 F = LSQROp(self.A,self.b,rvar, self.wvar,\ self.var_axes[0], self.var_axes[1],self.wrep_axes,\ self.shape[0], self.shape[1], self.is_complex) g = F.get_tgt_vec(r) # Get the initial condition if self.zlast is None: zinit = F.pack(r0, r1) else: zinit = self.zlast g -= F.dot(zinit) # Run the LSQR optimization lsqr_out = scipy.sparse.linalg.lsqr(F, g, iter_lim=self.nit_cg, atol=self.atol_cg) zvec = lsqr_out[0] + zinit self.zlast = zvec zhat0, zhat1 = F.unpack(zvec) zhat = [] if 0 in ind_out: zhat.append(zhat0) if 1 in ind_out: zhat.append(zhat1) # Save stats if 'zhat_nit' in self.hist_list: self.hist_dict['zhat_nit'].append(lsqr_out[2]) """ Cost """ if return_cost: # Compute the cost cost = lsqr_out[3]**2 # Add the cost for the second order terms. # # We only consider the MAP case, where the second-order cost is # (1/2)*nz1*log(2*pi*wvar) if self.is_complex: cscale = 1 else: cscale = 2 cost /= cscale if F.wvar_pos: if np.all(self.wvar > 0): cost += (1 / cscale) * self.n1 * np.mean( np.log(cscale * np.pi * self.wvar)) """ Second-order terms These are computed via the numerical gradient along a random direction """ zhatvar = [] if 0 in ind_out: # Perturb r0 r0p = r0 + self.dr0 g0 = F.get_tgt_vec([r0p, r1]) # Get the initial condition if self.zvec0_last is None: zinit = F.pack(r0p, r1) else: zinit = self.zvec0_last g0 -= F.dot(zinit) # Run the LSQR optimization lsqr_out = scipy.sparse.linalg.lsqr(F, g0, iter_lim=self.nit_cg, atol=self.atol_cg) zvec0 = lsqr_out[0] + zinit self.zvec0_last = zvec0 dzvec = zvec0 - zvec dz0, dz1 = F.unpack(dzvec) if 'zvar0_nit' in self.hist_list: self.hist_dict['zvar0_nit'].append(lsqr_out[2]) # Compute the correlations alpha0 = np.mean(np.real(self.dr0.conj()*dz0),self.var_axes[0]) /\ self.dr0_norm_sq zhatvar0 = alpha0 * rvar0 zhatvar.append(zhatvar0) if 1 in ind_out: # Perturb r1 r1p = r1 + self.dr1 g1 = F.get_tgt_vec([r0, r1p]) # Get the initial condition if self.zvec1_last is None: zinit = F.pack(r0, r1p) else: zinit = self.zvec1_last g1 -= F.dot(zinit) # Run the LSQR optimization lsqr_out = scipy.sparse.linalg.lsqr(F, g1, iter_lim=self.nit_cg, atol=self.atol_cg) zvec1 = lsqr_out[0] + zinit self.zvec1_last = zvec1 dzvec = zvec1 - zvec dz0, dz1 = F.unpack(dzvec) if 'zvar1_nit' in self.hist_list: self.hist_dict['zvar1_nit'].append(lsqr_out[2]) # Compute the correlations alpha1 = np.mean(np.real(self.dr1.conj()*dz1),self.var_axes[1]) /\ self.dr1_norm_sq zhatvar1 = alpha1 * rvar1 zhatvar.append(zhatvar1) if return_cost: return zhat, zhatvar, cost else: return zhat, zhatvar
def __init__(self, est_list, msg_hdl_list=[], hist_list=[], nit=10,\ comp_cost=False,prt_period=0): Solver.__init__(self, hist_list) self.est_list = est_list self.msg_hdl_list = msg_hdl_list self.nit = nit self.comp_cost = comp_cost self.prt_period = prt_period self.time_iter = 0 # Computation time for last iteration # Check dimensions nlayers = len(self.est_list) if self.est_list[0].nvars != 1: raise ValueError('First estimator must take 1 variable') if self.est_list[-1].nvars != 1: raise ValueError('Last estimator must take 1 variable') for i in range(0, nlayers - 1): msg_hdl_shape = self.msg_hdl_list[i].shape msg_hdl_var_axes = self.msg_hdl_list[i].var_axes if i > 0: if (self.est_list[i].nvars != 2): errstr = 'Estimator %s must take 2 variables'\ % self.est_list[i].name raise ValueError(errstr) if i == 0: shape0 = self.est_list[i].shape var_axes0 = self.est_list[i].var_axes else: shape0 = self.est_list[i].shape[1] var_axes0 = self.est_list[i].var_axes[1] if i == nlayers - 2: shape1 = self.est_list[i + 1].shape var_axes1 = self.est_list[i + 1].var_axes else: shape1 = self.est_list[i + 1].shape[0] var_axes1 = self.est_list[i + 1].var_axes[0] if shape0 != shape1: errstr = 'Est %s shape %s does not match est %s shape %s'\ % (est_list[i].name, str(shape0), est_list[i+1].name, str(shape1)) raise ValueError(errstr) if shape0 != msg_hdl_shape: errstr = 'Est %s shape %s does not match msg_hdl shape %s'\ % (est_list[i].name, str(shape0), str(msg_hdl_shape)) raise ValueError(errstr) if var_axes0 != var_axes1: errstr = 'Est %s var_axes %s does not match est %s var_axes %s'\ % (est_list[i].name, str(var_axes0), est_list[i+1].name, str(var_axes1)) if var_axes0 != msg_hdl_var_axes: errstr = 'Est %s var_axes %s does not match msg_hdl var_axes %s'\ % (est_list[i].name, str(var_axes0), str(msg_hdl_var_axes)) raise ValueError(errstr) # Check if all estimators can compute the cost for i in range(nlayers): esti = self.est_list[i] if self.comp_cost and not esti.cost_avail: errstr = "Requested cost computation, but cost_avail==False"\ + " for estimator " + str(i) raise common.VpException(errstr) self.comp_cost = self.comp_cost and esti.cost_avail