def __init__(self, A, y, proxg, eps, x=None, G=None, max_iter=100, tau=None, sigma=None, show_pbar=True): self.y = y self.x = x self.y_device = backend.get_device(y) if self.x is None: with self.y_device: self.x = self.y_device.xp.zeros(A.ishape, dtype=self.y.dtype) self.x_device = backend.get_device(self.x) if G is None: self.max_eig_app = MaxEig(A.H * A, dtype=self.x.dtype, device=self.x_device, show_pbar=show_pbar) proxfc = prox.Conj(prox.L2Proj(A.oshape, eps, y=y)) else: proxf1 = prox.L2Proj(A.oshape, eps, y=y) proxf2 = proxg proxfc = prox.Conj(prox.Stack([proxf1, proxf2])) proxg = prox.NoOp(A.ishape) A = linop.Vstack([A, G]) if tau is None or sigma is None: max_eig = MaxEig(A.H * A, dtype=self.x.dtype, device=self.x_device, show_pbar=show_pbar).run() tau = 1 sigma = 1 / max_eig with self.y_device: self.u = self.y_device.xp.zeros(A.oshape, dtype=self.y.dtype) alg = PrimalDualHybridGradient(proxfc, proxg, A, A.H, self.x, self.u, tau, sigma, max_iter=max_iter) super().__init__(alg, show_pbar=show_pbar)
def _get_PrimalDualHybridGradient(self): with self.y_device: y = -self.y A = self.A if self.proxg is None: proxg = prox.NoOp(self.x.shape) else: proxg = self.proxg if self.lamda > 0: def gradh(x): with backend.get_device(self.x): gradh_x = 0 if self.lamda > 0: if self.z is None: gradh_x += self.lamda * x else: gradh_x += self.lamda * (x - self.z) return gradh_x gamma_primal = self.lamda else: gradh = None gamma_primal = 0 if self.G is None: proxfc = prox.L2Reg(y.shape, 1, y=y) gamma_dual = 1 else: A = linop.Vstack([A, self.G]) proxf1c = prox.L2Reg(self.y.shape, 1, y=y) proxf2c = prox.Conj(self.proxg) proxfc = prox.Stack([proxf1c, proxf2c]) proxg = prox.NoOp(self.x.shape) gamma_dual = 0 if self.tau is None: if self.sigma is None: self.sigma = 1 S = linop.Multiply(A.oshape, self.sigma) AHA = A.H * S * A max_eig = MaxEig(AHA, dtype=self.x.dtype, device=self.x_device, max_iter=self.max_power_iter, show_pbar=self.show_pbar).run() self.tau = 1 / (max_eig + self.lamda) else: T = linop.Multiply(A.ishape, self.tau) AAH = A * T * A.H max_eig = MaxEig(AAH, dtype=self.x.dtype, device=self.x_device, max_iter=self.max_power_iter, show_pbar=self.show_pbar).run() self.sigma = 1 / max_eig with self.y_device: u = self.y_device.xp.zeros(A.oshape, dtype=self.y.dtype) self.alg = PrimalDualHybridGradient(proxfc, proxg, A, A.H, self.x, u, self.tau, self.sigma, gamma_primal=gamma_primal, gamma_dual=gamma_dual, gradh=gradh, max_iter=self.max_iter)
def _get_PrimalDualHybridGradient(self): with self.y_device: A = self.A if self.lamda > 0: gamma_primal = self.lamda proxg = prox.L2Reg(self.x.shape, self.lamda, y=self.z, proxh=self.proxg) else: gamma_primal = 0 if self.proxg is None: proxg = prox.NoOp(self.x.shape) else: proxg = self.proxg if self.G is None: proxfc = prox.L2Reg(self.y.shape, 1, y=-self.y) gamma_dual = 1 else: A = linop.Vstack([A, self.G]) proxf1c = prox.L2Reg(self.y.shape, 1, y=-self.y) proxf2c = prox.Conj(proxg) proxfc = prox.Stack([proxf1c, proxf2c]) proxg = prox.NoOp(self.x.shape) gamma_dual = 0 if self.tau is None: if self.sigma is None: self.sigma = 1 S = linop.Multiply(A.oshape, self.sigma) AHA = A.H * S * A max_eig = MaxEig(AHA, dtype=self.x.dtype, device=self.x_device, max_iter=self.max_power_iter, show_pbar=self.show_pbar).run() self.tau = 1 / max_eig elif self.sigma is None: T = linop.Multiply(A.ishape, self.tau) AAH = A * T * A.H max_eig = MaxEig(AAH, dtype=self.x.dtype, device=self.x_device, max_iter=self.max_power_iter, show_pbar=self.show_pbar).run() self.sigma = 1 / max_eig with self.y_device: u = self.y_device.xp.zeros(A.oshape, dtype=self.y.dtype) self.alg = PrimalDualHybridGradient(proxfc, proxg, A, A.H, self.x, u, self.tau, self.sigma, gamma_primal=gamma_primal, gamma_dual=gamma_dual, max_iter=self.max_iter)