def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, try_fast_norm=True, scaled=False, metric=None, convlog=None, verbose=0): # Can only have one omega function. assert len(omega_fns) <= 1 prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Select optimal parameters if wanted if lmb is None or mu is None: lmb, mu = est_params_lin_admm(K, lmb, verbose, scaled, try_fast_norm) # Initialize everything to zero. v = np.zeros(K.input_size) z = np.zeros(K.output_size) u = np.zeros(K.output_size) # Buffers. Kv = np.zeros(K.output_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) Kvzu = np.zeros(K.output_size) v_prev = np.zeros(K.input_size) z_prev = np.zeros(K.output_size) # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("LIN-ADMM iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() v_prev[:] = v z_prev[:] = z # Update v K.forward(v, Kv) Kvzu[:] = Kv - z + u K.adjoint(Kvzu, v) v[:] = v_prev - (mu / lmb) * v if len(omega_fns) > 0: v[:] = omega_fns[0].prox(1.0 / mu, v, x_init=v_prev.copy(), lin_solver=lin_solver, options=lin_solver_options) # Update z. K.forward(v, Kv) Kv_u = Kv + u offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape) # Apply and time prox. prox_log[fn].tic() z[slc] = fn.prox(1.0 / lmb, Kv_u_slc, i).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update u. u += Kv - z K.adjoint(u, KTu) # Check convergence. r = Kv - z K.adjoint((1.0 / lmb) * (z - z_prev), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(Kv), np.linalg.norm(z)]) eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm( KTu) / (1.0 / lmb) # Convergence log if convlog is not None: convlog.toc() K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(v) objstr = ", obj_val = %02.03e" % sum( [fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(v)) print( "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)) iter_timing.toc() if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break # Print out timings info. if verbose > 0: print(iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(v) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, tau=None, sigma=None, theta=None, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, try_fast_norm=False, scaled=True, metric=None, convlog=None, verbose=0): # Can only have one omega function. assert len(omega_fns) <= 1 prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) v = np.zeros(K.input_size) # Select optimal parameters if wanted if tau is None or sigma is None or theta is None: tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled, try_fast_norm) # Initialize x = np.zeros(K.input_size) y = np.zeros(K.output_size) xbar = np.zeros(K.input_size) u = np.zeros(K.output_size) z = np.zeros(K.output_size) if x0 is not None: x[:] = np.reshape(x0, K.input_size) K.forward(x, y) xbar[:] = x # Buffers. Kxbar = np.zeros(K.output_size) Kx = np.zeros(K.output_size) KTy = np.zeros(K.input_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) prev_x = x.copy() prev_Kx = Kx.copy() prev_z = z.copy() prev_u = u.copy() # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("PC iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() # Keep track of previous iterates np.copyto(prev_x, x) np.copyto(prev_z, z) np.copyto(prev_u, u) np.copyto(prev_Kx, Kx) # Compute z K.forward(xbar, Kxbar) z = y + sigma * Kxbar # Update y. offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) z_slc = np.reshape(z[slc], fn.lin_op.shape) # Moreau identity: apply and time prox. prox_log[fn].tic() y[slc] = (z_slc - sigma * fn.prox(sigma, z_slc / sigma, i)).flatten() prox_log[fn].toc() offset += fn.lin_op.size y[offset:] = 0 # Update x K.adjoint(y, KTy) x -= tau * KTy if len(omega_fns) > 0: xtmp = np.reshape(x, omega_fns[0].lin_op.shape) x[:] = omega_fns[0].prox(1.0 / tau, xtmp, x_init=prev_x, lin_solver=lin_solver, options=lin_solver_options).flatten() # Update xbar np.copyto(xbar, x) xbar += theta * (x - prev_x) # Convergence log if convlog is not None: convlog.toc() K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) """ Old convergence check #Very basic convergence check. r_x = np.linalg.norm(x - prev_x) r_xbar = np.linalg.norm(xbar - prev_xbar) r_ybar = np.linalg.norm(y - prev_y) error = r_x + r_xbar + r_ybar """ # Residual based convergence check K.forward(x, Kx) u = 1.0 / sigma * y + theta * (Kx - prev_Kx) z = prev_u + prev_Kx - 1.0 / sigma * y # Iteration order is different than # lin-admm (--> start checking at iteration 1) if i > 0: # Check convergence r = prev_Kx - z K.adjoint(sigma * (z - prev_z), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(prev_Kx), np.linalg.norm(z)]) K.adjoint(u, KTu) eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / sigma # Progress if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(x) objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns]) """ Old convergence check #Evaluate metric potentially metstr = '' if metric is None else ", {}".format( metric.message(x.copy()) ) print "iter [%04d]:" \ "||x - x_prev||_2 = %02.02e " \ "||xbar - xbar_prev||_2 = %02.02e " \ "||y - y_prev||_2 = %02.02e " \ "SUM = %02.02e (eps=%02.03e)%s%s" \ % (i, r_x, r_xbar, r_ybar, error, eps, objstr, metstr) """ # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(v)) print( "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr) ) iter_timing.toc() if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break else: iter_timing.toc() """ Old convergence check if error <= eps: break """ # Print out timings info. if verbose > 0: print iter_timing print "prox funcs:" print prox_log print "K forward ops:" print K.forward_log print "K adjoint ops:" print K.adjoint_log # Assign values to variables. K.update_vars(x) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, try_fast_norm=True, scaled=False, metric=None, convlog=None, verbose=0): # Can only have one omega function. assert len(omega_fns) <= 1 prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Select optimal parameters if wanted if lmb is None or mu is None: lmb, mu = est_params_lin_admm(K, lmb, verbose, scaled, try_fast_norm) # Initialize everything to zero. v = np.zeros(K.input_size) z = np.zeros(K.output_size) u = np.zeros(K.output_size) # Buffers. Kv = np.zeros(K.output_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) Kvzu = np.zeros(K.output_size) v_prev = np.zeros(K.input_size) z_prev = np.zeros(K.output_size) # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("LIN-ADMM iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() v_prev[:] = v z_prev[:] = z # Update v K.forward(v, Kv) Kvzu[:] = Kv - z + u K.adjoint(Kvzu, v) v[:] = v_prev - (mu / lmb) * v if len(omega_fns) > 0: v[:] = omega_fns[0].prox(1.0 / mu, v, x_init=v_prev.copy(), lin_solver=lin_solver, options=lin_solver_options) # Update z. K.forward(v, Kv) Kv_u = Kv + u offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape) # Apply and time prox. prox_log[fn].tic() z[slc] = fn.prox(1.0 / lmb, Kv_u_slc, i).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update u. u += Kv - z K.adjoint(u, KTu) # Check convergence. r = Kv - z K.adjoint((1.0 / lmb) * (z - z_prev), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(Kv), np.linalg.norm(z)]) eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / (1.0 / lmb) # Convergence log if convlog is not None: convlog.toc() K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(v) objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(v)) print "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % ( i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr) iter_timing.toc() if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break # Print out timings info. if verbose > 0: print iter_timing print "prox funcs:" print prox_log print "K forward ops:" print K.forward_log print "K adjoint ops:" print K.adjoint_log # Assign values to variables. K.update_vars(v) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, rho_0=1.0, rho_scale=math.sqrt(2.0) * 2.0, rho_max=2**8, max_iters=-1, max_inner_iters=100, x0=None, eps_rel=1e-3, eps_abs=1e-3, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, scaled=False, try_fast_norm=False, metric=None, convlog=None, verbose=0): prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Rescale so (1/2)||x - b||^2_2 rescaling = np.sqrt(2.) quad_ops = [] quad_weights = [] const_terms = [] for fn in omega_fns: fn = fn.absorb_params() quad_ops.append(scale(rescaling * fn.beta, fn.lin_op)) quad_weights.append(rescaling * fn.beta) const_terms.append(fn.b.flatten() * rescaling) # Get optimize inverse (tries spatial and frequency diagonalization) op_list = [func.lin_op for func in psi_fns] + quad_ops stacked_ops = vstack(op_list) x_update = get_least_squares_inverse(op_list, None, try_diagonalize, verbose) # Initialize if x0 is not None: x = np.reshape(x0, K.input_size) else: x = np.zeros(K.input_size) Kx = np.zeros(K.output_size) w = Kx.copy() # Temporary iteration counts x_prev = x.copy() # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("HQS iteration") inner_iter_timing = TimingsEntry("HQS inner iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(x) objval = sum([func.value for func in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) # Rho scedule rho = rho_0 i = 0 while rho < rho_max and i < max_iters: iter_timing.tic() if convlog is not None: convlog.tic() # Update rho for quadratics for idx, op in enumerate(quad_ops): op.scalar = quad_weights[idx] / np.sqrt(rho) x_update = get_least_squares_inverse(op_list, CompGraph(stacked_ops), try_diagonalize, verbose) for ii in range(max_inner_iters): inner_iter_timing.tic() # Update Kx. K.forward(x, Kx) # Prox update to get w. offset = 0 w_prev = w.copy() for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) # Apply and time prox. prox_log[fn].tic() w[slc] = fn.prox(rho, np.reshape(Kx[slc], fn.lin_op.shape), ii).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update x. x_prev[:] = x tmp = np.hstack([w] + [cterm / np.sqrt(rho) for cterm in const_terms]) x = x_update.solve(tmp, x_init=x, lin_solver=lin_solver, options=lin_solver_options) # Very basic convergence check. r_x = np.linalg.norm(x_prev - x) eps_x = eps_rel * np.prod(K.input_size) r_w = np.linalg.norm(w_prev - w) eps_w = eps_rel * np.prod(K.output_size) # Convergence log if convlog is not None: convlog.toc() K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(x) objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(x)) print("iter [%02d (rho=%2.1e) || %02d]:" "||w - w_prev||_2 = %02.02e (eps=%02.03e)" "||x - x_prev||_2 = %02.02e (eps=%02.03e)%s%s" % (i, rho, ii, r_x, eps_x, r_w, eps_w, objstr, metstr)) inner_iter_timing.toc() if r_x < eps_x and r_w < eps_w: break # Update rho rho = np.minimum(rho * rho_scale, rho_max) i += 1 iter_timing.toc() # Print out timings info. if verbose > 0: print(iter_timing) print(inner_iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(x) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, tau=None, sigma=None, theta=None, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None, lin_solver="cg", lin_solver_options=None, conv_check=100, try_diagonalize=True, try_fast_norm=False, scaled=True, metric=None, convlog=None, verbose=0): # Can only have one omega function. assert len(omega_fns) <= 1 prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) v = np.zeros(K.input_size) # Select optimal parameters if wanted if tau is None or sigma is None or theta is None: tau, sigma, theta = est_params_pc(K, tau, sigma, verbose, scaled, try_fast_norm) # Initialize x = np.zeros(K.input_size) y = np.zeros(K.output_size) xbar = np.zeros(K.input_size) u = np.zeros(K.output_size) z = np.zeros(K.output_size) if x0 is not None: x[:] = np.reshape(x0, K.input_size) K.forward(x, y) xbar[:] = x # Buffers. Kxbar = np.zeros(K.output_size) Kx = np.zeros(K.output_size) KTy = np.zeros(K.input_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) prev_x = x.copy() prev_Kx = Kx.copy() prev_z = z.copy() prev_u = u.copy() # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("PC iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() # Keep track of previous iterates np.copyto(prev_x, x) np.copyto(prev_z, z) np.copyto(prev_u, u) np.copyto(prev_Kx, Kx) # Compute z K.forward(xbar, Kxbar) z = y + sigma * Kxbar # Update y. offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) z_slc = np.reshape(z[slc], fn.lin_op.shape) # Moreau identity: apply and time prox. prox_log[fn].tic() y[slc] = (z_slc - sigma * fn.prox(sigma, z_slc / sigma, i)).flatten() prox_log[fn].toc() offset += fn.lin_op.size y[offset:] = 0 # Update x K.adjoint(y, KTy) x -= tau * KTy if len(omega_fns) > 0: xtmp = np.reshape(x, omega_fns[0].lin_op.shape) x[:] = omega_fns[0].prox(1.0 / tau, xtmp, x_init=prev_x, lin_solver=lin_solver, options=lin_solver_options).flatten() # Update xbar np.copyto(xbar, x) xbar += theta * (x - prev_x) # Convergence log if convlog is not None: convlog.toc() K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) """ Old convergence check #Very basic convergence check. r_x = np.linalg.norm(x - prev_x) r_xbar = np.linalg.norm(xbar - prev_xbar) r_ybar = np.linalg.norm(y - prev_y) error = r_x + r_xbar + r_ybar """ # Residual based convergence check K.forward(x, Kx) u = 1.0 / sigma * y + theta * (Kx - prev_Kx) z = prev_u + prev_Kx - 1.0 / sigma * y # Iteration order is different than # lin-admm (--> start checking at iteration 1) if i > 0 and i % conv_check == 0: # Check convergence r = prev_Kx - z K.adjoint(sigma * (z - prev_z), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(prev_Kx), np.linalg.norm(z)]) K.adjoint(u, KTu) eps_dual = np.sqrt( K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) / sigma # Progress if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(x) objstr = ", obj_val = %02.03e" % sum( [fn.value for fn in prox_fns]) """ Old convergence check #Evaluate metric potentially metstr = '' if metric is None else ", {}".format( metric.message(x.copy()) ) print "iter [%04d]:" \ "||x - x_prev||_2 = %02.02e " \ "||xbar - xbar_prev||_2 = %02.02e " \ "||y - y_prev||_2 = %02.02e " \ "SUM = %02.02e (eps=%02.03e)%s%s" \ % (i, r_x, r_xbar, r_ybar, error, eps, objstr, metstr) """ # Evaluate metric potentially metstr = '' if metric is None else ", {}".format( metric.message(v)) print( "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)) iter_timing.toc() if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break else: iter_timing.toc() """ Old convergence check if error <= eps: break """ # Print out timings info. if verbose > 0: print(iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(x) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, rho=1.0, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, try_fast_norm=False, scaled=True, conv_check=100, implem=None, metric=None, convlog=None, verbose=0): prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Rescale so (rho/2)||x - b||^2_2 rescaling = np.sqrt(2. / rho) quad_ops = [] const_terms = [] for fn in omega_fns: fn = fn.absorb_params() quad_ops.append(scale(rescaling * fn.beta, fn.lin_op)) const_terms.append(fn.b.flatten() * rescaling) # Check for fast inverse. op_list = [func.lin_op for func in psi_fns] + quad_ops stacked_ops = vstack(op_list) # Get optimize inverse (tries spatial and frequency diagonalization) v_update = get_least_squares_inverse(op_list, None, try_diagonalize, verbose) # Initialize everything to zero. v = np.zeros(K.input_size) z = np.zeros(K.output_size) u = np.zeros(K.output_size) # Initialize if x0 is not None: v[:] = np.reshape(x0, K.input_size) K.forward(v, z) # Buffers. Kv = np.zeros(K.output_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("ADMM iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(v) objval = sum([func.value for func in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() z_prev = z.copy() # Update v. tmp = np.hstack([z - u] + const_terms) v = v_update.solve(tmp, x_init=v, lin_solver=lin_solver, options=lin_solver_options) # Update z. K.forward(v, Kv) Kv_u = Kv + u offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape) # Apply and time prox. prox_log[fn].tic() z[slc] = fn.prox(rho, Kv_u_slc, i).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update u. ne.evaluate('u + Kv - z', out=u) # Check convergence. if i % conv_check == 0: r = Kv - z K.adjoint(u, KTu) K.adjoint(rho * (z - z_prev), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(Kv), np.linalg.norm(z)]) eps_dual = np.sqrt( K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) * rho # Convergence log if convlog is not None: convlog.toc() K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0 and i % conv_check == 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(v) objstr = ", obj_val = %02.03e" % sum( [fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(v)) print( "iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)) iter_timing.toc() # Exit if converged. if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break # Print out timings info. if verbose > 0: print(iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(v) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, rho_0=1.0, rho_scale=math.sqrt(2.0) * 2.0, rho_max=2**8, max_iters=-1, max_inner_iters=100, x0=None, eps_rel=1e-3, eps_abs=1e-3, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, scaled=False, try_fast_norm=False, metric=None, convlog=None, verbose=0): prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Rescale so (1/2)||x - b||^2_2 rescaling = np.sqrt(2.) quad_ops = [] quad_weights = [] const_terms = [] for fn in omega_fns: fn = fn.absorb_params() quad_ops.append(scale(rescaling * fn.beta, fn.lin_op)) quad_weights.append(rescaling * fn.beta) const_terms.append(fn.b.flatten() * rescaling) # Get optimize inverse (tries spatial and frequency diagonalization) op_list = [func.lin_op for func in psi_fns] + quad_ops stacked_ops = vstack(op_list) x_update = get_least_squares_inverse(op_list, None, try_diagonalize, verbose) # Initialize if x0 is not None: x = np.reshape(x0, K.input_size) else: x = np.zeros(K.input_size) Kx = np.zeros(K.output_size) w = Kx.copy() # Temporary iteration counts x_prev = x.copy() # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("HQS iteration") inner_iter_timing = TimingsEntry("HQS inner iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(x) objval = sum([func.value for func in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) # Rho scedule rho = rho_0 i = 0 while rho < rho_max and i < max_iters: iter_timing.tic() if convlog is not None: convlog.tic() # Update rho for quadratics for idx, op in enumerate(quad_ops): op.scalar = quad_weights[idx] / np.sqrt(rho) x_update = get_least_squares_inverse(op_list, CompGraph(stacked_ops), try_diagonalize, verbose) for ii in range(max_inner_iters): inner_iter_timing.tic() # Update Kx. K.forward(x, Kx) # Prox update to get w. offset = 0 w_prev = w.copy() for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) # Apply and time prox. prox_log[fn].tic() w[slc] = fn.prox(rho, np.reshape(Kx[slc], fn.lin_op.shape), ii).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update x. x_prev[:] = x tmp = np.hstack([w] + [cterm / np.sqrt(rho) for cterm in const_terms]) x = x_update.solve(tmp, x_init=x, lin_solver=lin_solver, options=lin_solver_options) # Very basic convergence check. r_x = np.linalg.norm(x_prev - x) eps_x = eps_rel * np.prod(K.input_size) r_w = np.linalg.norm(w_prev - w) eps_w = eps_rel * np.prod(K.output_size) # Convergence log if convlog is not None: convlog.toc() K.update_vars(x) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(x) objstr = ", obj_val = %02.03e" % sum( [fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format( metric.message(x)) print("iter [%02d (rho=%2.1e) || %02d]:" "||w - w_prev||_2 = %02.02e (eps=%02.03e)" "||x - x_prev||_2 = %02.02e (eps=%02.03e)%s%s" % (i, rho, ii, r_x, eps_x, r_w, eps_w, objstr, metstr)) inner_iter_timing.toc() if r_x < eps_x and r_w < eps_w: break # Update rho rho = np.minimum(rho * rho_scale, rho_max) i += 1 iter_timing.toc() # Print out timings info. if verbose > 0: print(iter_timing) print(inner_iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(x) # Return optimal value. return sum([fn.value for fn in prox_fns])
def solve(psi_fns, omega_fns, rho=1.0, max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, x0=None, lin_solver="cg", lin_solver_options=None, try_diagonalize=True, try_fast_norm=False, scaled=True, conv_check=100, metric=None, convlog=None, verbose=0): prox_fns = psi_fns + omega_fns stacked_ops = vstack([fn.lin_op for fn in psi_fns]) K = CompGraph(stacked_ops) # Rescale so (rho/2)||x - b||^2_2 rescaling = np.sqrt(2. / rho) quad_ops = [] const_terms = [] for fn in omega_fns: fn = fn.absorb_params() quad_ops.append(scale(rescaling * fn.beta, fn.lin_op)) const_terms.append(fn.b.flatten() * rescaling) # Check for fast inverse. op_list = [func.lin_op for func in psi_fns] + quad_ops stacked_ops = vstack(op_list) # Get optimize inverse (tries spatial and frequency diagonalization) v_update = get_least_squares_inverse(op_list, None, try_diagonalize, verbose) # Initialize everything to zero. v = np.zeros(K.input_size) z = np.zeros(K.output_size) u = np.zeros(K.output_size) # Initialize if x0 is not None: v[:] = np.reshape(x0, K.input_size) K.forward(v, z) # Buffers. Kv = np.zeros(K.output_size) KTu = np.zeros(K.input_size) s = np.zeros(K.input_size) # Log for prox ops. prox_log = TimingsLog(prox_fns) # Time iterations. iter_timing = TimingsEntry("ADMM iteration") # Convergence log for initial iterate if convlog is not None: K.update_vars(v) objval = sum([func.value for func in prox_fns]) convlog.record_objective(objval) convlog.record_timing(0.0) for i in range(max_iters): iter_timing.tic() if convlog is not None: convlog.tic() z_prev = z.copy() # Update v. tmp = np.hstack([z - u] + const_terms) v = v_update.solve(tmp, x_init=v, lin_solver=lin_solver, options=lin_solver_options) # Update z. K.forward(v, Kv) Kv_u = Kv + u offset = 0 for fn in psi_fns: slc = slice(offset, offset + fn.lin_op.size, None) Kv_u_slc = np.reshape(Kv_u[slc], fn.lin_op.shape) # Apply and time prox. prox_log[fn].tic() z[slc] = fn.prox(rho, Kv_u_slc, i).flatten() prox_log[fn].toc() offset += fn.lin_op.size # Update u. u += Kv - z # Check convergence. if i % conv_check == 0: r = Kv - z K.adjoint(u, KTu) K.adjoint(rho * (z - z_prev), s) eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ max([np.linalg.norm(Kv), np.linalg.norm(z)]) eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu) * rho # Convergence log if convlog is not None: convlog.toc() K.update_vars(v) objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) # Show progess if verbose > 0 and i % conv_check == 0: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(v) objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns]) # Evaluate metric potentially metstr = '' if metric is None else ", {}".format(metric.message(v)) print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % ( i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)) iter_timing.toc() # Exit if converged. if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: break # Print out timings info. if verbose > 0: print(iter_timing) print("prox funcs:") print(prox_log) print("K forward ops:") print(K.forward_log) print("K adjoint ops:") print(K.adjoint_log) # Assign values to variables. K.update_vars(v) # Return optimal value. return sum([fn.value for fn in prox_fns])