def test_proximal_convconj_kl_simple_space(): """Test for proximal factory for the convex conjugate of KL divergence.""" # Image space space = odl.uniform_discr(0, 1, 10) # Element in image space where the proximal operator is evaluated x = space.element(np.arange(-5, 5)) # Data g = space.element(np.arange(10, 0, -1)) # Factory function returning the proximal operator lam = 2 prox_factory = proximal_convex_conj_kl(space, lam=lam, g=g) # Initialize the proximal operator of F^* sigma = 0.25 prox = prox_factory(sigma) assert isinstance(prox, odl.Operator) # Allocate an output element x_opt = space.element() # Apply the proximal operator returning its optimal point prox(x, x_opt) # Explicit computation: x_verify = (lam + x - np.sqrt((x - lam)**2 + 4 * lam * sigma * g)) / 2 assert all_almost_equal(x_opt, x_verify, HIGH_ACC)
def test_proximal_convconj_kl_product_space(): """Test for product spaces in proximal for conjugate of KL divergence""" # Product space for matrix of operators op_domain = odl.ProductSpace(odl.uniform_discr(0, 1, 10), 2) # Element in the product space where the proximal operator is evaluated x0_arr = np.arange(-5, 5) x1_arr = np.arange(10, 0, -1) x = op_domain.element([x0_arr, x1_arr]) # Element in the product space with given data g0_arr = x1_arr.copy() g1_arr = x0_arr.copy() g = op_domain.element([g0_arr, g1_arr]) # Factory function returning the proximal operator lam = 2 prox_factory = proximal_convex_conj_kl(op_domain, lam=lam, g=g) # Initialize the proximal operator sigma = 0.25 prox = prox_factory(sigma) assert isinstance(prox, odl.Operator) # Allocate an output element x_opt = op_domain.element() # Apply the proximal operator returning its optimal point prox(x, x_opt) # Explicit computation: x_verify = (lam + x - np.sqrt((x - lam)**2 + 4 * lam * sigma * g)) / 2 # Compare components assert all_almost_equal(x_verify, x_opt)