Exemple #1
0
    def test_admm_two_prox_fn(self):
        """Test ADMM algorithm, two prox functions.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5)) * 1.

        prox_fns = [px.norm1(X), px.sum_squares(X, b=B)]
        sltn = admm.solve(prox_fns, [], 1.0, eps_rel=1e-5, eps_abs=1e-5)

        cvx_X = cvx.Variable((10, 5))
        cost = cvx.sum_squares(cvx_X - B) + cvx.norm(cvx_X, 1)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-1)
        self.assertAlmostEqual(sltn, prob.value)

        psi_fns, omega_fns = admm.partition(prox_fns)
        sltn = admm.solve(psi_fns, omega_fns, 1.0, eps_rel=1e-5, eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)

        prox_fns = [px.norm1(X)]
        quad_funcs = [px.sum_squares(X, b=B)]
        sltn = admm.solve(prox_fns,
                          quad_funcs,
                          1.0,
                          eps_rel=1e-5,
                          eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)
Exemple #2
0
    def test_admm_extra_param(self):
        ''' With parameters for px.sum_squares '''
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5)) * 1.

        prox_fns = [px.norm1(X)]
        quad_funcs = [px.sum_squares(X, b=B, alpha=0.1, beta=2., gamma=1, c=B)]
        sltn = admm.solve(prox_fns,
                          quad_funcs,
                          1.0,
                          eps_rel=1e-5,
                          eps_abs=1e-5)

        cvx_X = cvx.Variable((10, 5))
        cost = 0.1 * cvx.sum_squares(2 * cvx_X - B) + cvx.sum_squares(cvx_X) + \
            cvx.norm(cvx_X, 1) + cvx.trace(B.T @ cvx_X)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value, eps=1e-3)

        prox_fns = [px.norm1(X)]
        quad_funcs = [px.sum_squares(X - B, alpha=0.1, beta=2., gamma=1, c=B)]
        quad_funcs[0] = absorb_offset(quad_funcs[0])
        sltn = admm.solve(prox_fns,
                          quad_funcs,
                          1.0,
                          eps_rel=1e-5,
                          eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value, eps=1e-3)

        prox_fns = [px.norm1(X)]
        cvx_X = cvx.Variable((10, 5))
        # With linear operators.
        kernel = np.array([1, 2, 3])
        x = px.Variable(3)
        b = np.array([-41, 413, 2])
        prox_fns = [px.nonneg(x), px.sum_squares(px.conv(kernel, x), b=b)]
        sltn = admm.solve(prox_fns, [], 1.0, eps_abs=1e-5, eps_rel=1e-5)

        kernel_mat = np.matrix("2 1 3; 3 2 1; 1 3 2")
        cvx_X = cvx.Variable(3)
        cost = cvx.norm(kernel_mat * cvx_X - b)
        prob = cvx.Problem(cvx.Minimize(cost), [cvx_X >= 0])
        prob.solve()
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(np.sqrt(sltn), prob.value, eps=2e-2)

        prox_fns = [px.nonneg(x)]
        quad_funcs = [px.sum_squares(px.conv(kernel, x), b=b)]
        sltn = admm.solve(prox_fns,
                          quad_funcs,
                          1.0,
                          eps_abs=1e-5,
                          eps_rel=1e-5)
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(np.sqrt(sltn), prob.value, eps=2e-2)
Exemple #3
0
def testTotalVariation():
    im = scipy.misc.ascent()
    im = im[:, :, np.newaxis].astype(np.uint8)
    fn = proximal.norm1(proximal.grad(proximal.Variable(im.shape)))
    thresh = 6

    out = fn.prox(thresh, im)
Exemple #4
0
    def test_lin_admm_two_prox_fn(self):
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.norm1(X), px.sum_squares(X, b=B)]
        sltn = ladmm.solve(prox_fns, [],
                           0.1,
                           max_iters=500,
                           eps_rel=1e-5,
                           eps_abs=1e-5)

        cvx_X = cvx.Variable((10, 5))
        cost = cvx.sum_squares(cvx_X - B) + cvx.norm(cvx_X, 1)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)

        psi_fns, omega_fns = ladmm.partition(prox_fns)
        sltn = ladmm.solve(psi_fns,
                           omega_fns,
                           0.1,
                           max_iters=500,
                           eps_rel=1e-5,
                           eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)

        # With linear operators.
        kernel = np.array([1, 2, 3])
        kernel_mat = np.matrix("2 1 3; 3 2 1; 1 3 2")
        x = px.Variable(3)
        b = np.array([-41, 413, 2])
        prox_fns = [px.nonneg(x), px.sum_squares(px.conv(kernel, x), b=b)]
        sltn = ladmm.solve(prox_fns, [],
                           0.1,
                           max_iters=3000,
                           eps_abs=1e-5,
                           eps_rel=1e-5)

        cvx_X = cvx.Variable(3)
        cost = cvx.norm(kernel_mat * cvx_X - b)
        prob = cvx.Problem(cvx.Minimize(cost), [cvx_X >= 0])
        prob.solve()
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)

        psi_fns, omega_fns = ladmm.partition(prox_fns)
        sltn = ladmm.solve(psi_fns,
                           omega_fns,
                           0.1,
                           max_iters=3000,
                           eps_abs=1e-5,
                           eps_rel=1e-5)
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)
Exemple #5
0
def testL1NormProx():
    f = FlexISP()

    v = np.arange(start=-12, stop=12, dtype=float)
    thresh = 6

    f_out = f.L1_norm_prox(v, thresh)

    prox_fn = proximal.norm1(proximal.Variable(v.shape))
    prox_out = prox_fn.prox(1 / thresh, v)

    np.testing.assert_almost_equal(prox_out, f_out)
Exemple #6
0
    def test_admm_single_prox_fn(self):
        """Test ADMM algorithm, single prox function only.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5)) * 1.

        prox_fns = [px.sum_squares(X, b=B)]
        sltn = admm.solve(prox_fns, [], 1.0, eps_abs=1e-4, eps_rel=1e-4)
        self.assertItemsAlmostEqual(X.value, B, eps=2e-2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = admm.solve(prox_fns, [], 1.0)
        self.assertItemsAlmostEqual(X.value, B / 2., eps=2e-2)
        self.assertAlmostEqual(sltn, 0)
def solver(f, kernel_img, metric, cnn_func, elemental):
    """
    Solves the deblurring problem for the given input and kernel image.

    :param f: Corrupted input image
    :type f: np.ndarray
    :param kernel_img: Blur kernel
    :type kernel_img: np.ndarray
    :param metric: Preinitialized metric
    :type metric: proximal.utils.metrics
    :param cnn_func: Preinitialized deployment CNN
    :type cnn_func: function
    :param elemental: General experiment configuration parameters
    :type elemental: Dict

    :returns: Reconstructed output image
    :rtype: np.ndarray
    """
    # pylint:disable=no-value-for-parameter
    options = px.cg_options(tol=1e-4, num_iters=100, verbose=True)

    u = px.Variable(f.shape)

    alpha_sumsquare = elemental['alpha_data'] / 2.0
    A_u = px.conv(kernel_img, u)

    prox_fns = px.sum_squares(A_u - f, alpha=alpha_sumsquare)
    if elemental['alpha_tv'] > 0.0:
        prox_fns += px.norm1(elemental['alpha_tv'] * px.grad(u))
    prox_fns += init_denoising_prior(u,
                                     cnn_func,
                                     sigma=elemental['sigma'],
                                     sigma_scale=elemental['sigma_scale'])

    prob = init_problem(prox_fns)
    solve_problem(prob,
                  x0=f.copy(),
                  metric=metric,
                  sigma=elemental['sigma'],
                  lin_solver_options=options)

    return np.clip(u.value, 0.0, 1.0)
Exemple #8
0
    def test_half_quadratic_splitting_single_prox_fn(self):
        """Test half quadratic splitting.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.sum_squares(X, b=B)]
        sltn = hqs.solve(prox_fns, [],
                         eps_rel=1e-4,
                         max_iters=100,
                         max_inner_iters=50)
        self.assertItemsAlmostEqual(X.value, B, eps=2e-2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = hqs.solve(prox_fns, [],
                         eps_rel=1e-4,
                         max_iters=100,
                         max_inner_iters=50)
        self.assertItemsAlmostEqual(X.value, B / 2., eps=2e-2)
        self.assertAlmostEqual(sltn, 0)
Exemple #9
0
    def test_lin_admm_single_prox_fn(self):
        """Test linearized admm. algorithm.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.sum_squares(X, b=B)]
        sltn = ladmm.solve(prox_fns, [],
                           0.1,
                           max_iters=500,
                           eps_rel=1e-5,
                           eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, B, eps=2e-2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = ladmm.solve(prox_fns, [],
                           0.1,
                           max_iters=500,
                           eps_rel=1e-5,
                           eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, B / 2., eps=2e-2)
        self.assertAlmostEqual(sltn, 0, eps=2e-2)
Exemple #10
0
space = odl.uniform_discr([0, 0], [100, 100], [100, 100])

# Create ODL operator for the Laplacian
laplacian = odl.Laplacian(space)

# Create right hand side
phantom = odl.phantom.shepp_logan(space, modified=True)
phantom.show('original image')
rhs = laplacian(phantom)
rhs += odl.phantom.white_noise(space) * np.std(rhs) * 0.1
rhs.show('rhs')

# Convert laplacian to ProxImaL operator
proximal_lang_laplacian = odl.as_proximal_lang_operator(laplacian)

# Convert to array
rhs_arr = rhs.asarray()

# Set up optimization problem
x = proximal.Variable(space.shape)
funcs = [10 * proximal.sum_squares(proximal_lang_laplacian(x) - rhs_arr),
         proximal.norm1(proximal.grad(x))]

# Solve the problem using ProxImaL
prob = proximal.Problem(funcs)
prob.solve(verbose=True)

# Convert back to odl and display result
result_odl = space.element(x.value)
result_odl.show('result from ProxImaL')
Exemple #11
0
# Create ODL operator for the Laplacian
laplacian = odl.Laplacian(space)

# Create right hand side
phantom = odl.phantom.shepp_logan(space, modified=True)
phantom.show('original image')
rhs = laplacian(phantom)
rhs += odl.phantom.white_noise(space) * np.std(rhs) * 0.1
rhs.show('rhs')

# Convert laplacian to ProxImaL operator
proximal_lang_laplacian = odl.as_proximal_lang_operator(laplacian)

# Convert to array
rhs_arr = rhs.asarray()

# Set up optimization problem
x = proximal.Variable(space.shape)
funcs = [
    10 * proximal.sum_squares(proximal_lang_laplacian(x) - rhs_arr),
    proximal.norm1(proximal.grad(x))
]

# Solve the problem using ProxImaL
prob = proximal.Problem(funcs)
prob.solve(verbose=True)

# Convert back to odl and display result
result_odl = space.element(x.value)
result_odl.show('result from ProxImaL')
def solver(f, x0, metric, cnn_func, elemental):
    """
    Solves the demosaicking problem for the given input.

    :param f: Corrupted input image
    :type f: np.ndarray
    :param x0: Predemosaicked initialization image
    :type x0: np.ndarray
    :param metric: Preinitialized metric
    :type metric: proximal.utils.metrics
    :param cnn_func: Preinitialized deployment CNN
    :type cnn_func: function
    :param elemental: General experiment configuration parameters
    :type elemental: Dict

    :returns: Reconstructed output image
    :rtype: np.ndarray
    """
    # pylint:disable=no-value-for-parameter
    options = px.cg_options(tol=1e-4, num_iters=100, verbose=True)

    u = px.Variable(f.shape)
    A = bayer_mask(f.shape)
    A_u = px.mul_elemwise(A, u)

    alpha_sumsquare = elemental['alpha_data'] / 2.0
    data = px.sum_squares(A_u - f, alpha=alpha_sumsquare)

    prox_fns = data
    if elemental['alpha_tv'] > 0.0:
        prox_fns += px.norm1(elemental['alpha_tv'] * px.grad(u, dims=2))

    if elemental['alpha_cross'] > 0.0:
        grad_u = px.grad(u, dims=2)
        grad_x0 = px.grad(x0, dims=2).value
        x0_stacked = np.array([x0, x0]).reshape(x0.shape + (2, ))
        u_stacked = px.reshape(px.hstack([u, u]), x0.shape + (2, ))
        cross_1 = px.vstack([
            px.mul_elemwise(np.roll(x0_stacked, 1, 2), grad_u),
            px.mul_elemwise(np.roll(x0_stacked, 2, 2), grad_u)
        ])
        cross_2 = px.vstack([
            px.mul_elemwise(np.roll(grad_x0, 1, 2), u_stacked),
            px.mul_elemwise(np.roll(grad_x0, 2, 2), u_stacked)
        ])

        prox_fns += px.norm1(0.5 * elemental['alpha_cross'] *
                             (cross_1 - cross_2))

    prox_fns += init_denoising_prior(u,
                                     cnn_func,
                                     sigma=elemental['sigma'],
                                     sigma_scale=elemental['sigma_scale'])

    prob = init_problem(prox_fns)
    solve_problem(prob,
                  x0=x0,
                  metric=metric,
                  sigma=elemental['sigma'],
                  lin_solver_options=options)
    return np.clip(u.value, 0.0, 1.0)
Exemple #13
0
    def test_half_quadratic_splitting(self):
        """Test half quadratic splitting.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.sum_squares(X, b=B)]
        sltn = hqs.solve(prox_fns, [],
                         eps_rel=1e-4,
                         max_iters=100,
                         max_inner_iters=50)
        self.assertItemsAlmostEqual(X.value, B, places=2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = hqs.solve(prox_fns, [],
                         eps_rel=1e-4,
                         max_iters=100,
                         max_inner_iters=50)
        self.assertItemsAlmostEqual(X.value, B / 2., places=2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X), px.sum_squares(X, b=B)]
        sltn = hqs.solve(prox_fns, [],
                         eps_rel=1e-7,
                         rho_0=1.0,
                         rho_scale=np.sqrt(2.0) * 2.0,
                         rho_max=2**16,
                         max_iters=20,
                         max_inner_iters=500)

        cvx_X = cvx.Variable(10, 5)
        cost = cvx.sum_squares(cvx_X - B) + cvx.norm(cvx_X, 1)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, places=2)
        self.assertAlmostEqual(sltn, prob.value, places=3)

        psi_fns, omega_fns = hqs.partition(prox_fns)
        sltn = hqs.solve(psi_fns,
                         omega_fns,
                         eps_rel=1e-7,
                         rho_0=1.0,
                         rho_scale=np.sqrt(2.0) * 2.0,
                         rho_max=2**16,
                         max_iters=20,
                         max_inner_iters=500)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, places=2)
        self.assertAlmostEqual(sltn, prob.value, places=3)

        # With linear operators.
        kernel = np.array([1, 2, 3])
        kernel_mat = np.matrix("2 1 3; 3 2 1; 1 3 2")
        x = px.Variable(3)
        b = np.array([-41, 413, 2])
        prox_fns = [px.nonneg(x), px.sum_squares(px.conv(kernel, x), b=b)]
        hqs.solve(prox_fns, [],
                  eps_rel=1e-9,
                  rho_0=4,
                  rho_scale=np.sqrt(2.0) * 1.0,
                  rho_max=2**16,
                  max_iters=30,
                  max_inner_iters=500)

        cvx_X = cvx.Variable(3)
        cost = cvx.norm(kernel_mat * cvx_X - b)
        prob = cvx.Problem(cvx.Minimize(cost), [cvx_X >= 0])
        prob.solve()
        self.assertItemsAlmostEqual(x.value, cvx_X.value, places=0)

        psi_fns, omega_fns = hqs.partition(prox_fns)
        hqs.solve(psi_fns,
                  omega_fns,
                  eps_rel=1e-9,
                  rho_0=4,
                  rho_scale=np.sqrt(2.0) * 1.0,
                  rho_max=2**16,
                  max_iters=30,
                  max_inner_iters=500)
        self.assertItemsAlmostEqual(x.value, cvx_X.value, places=0)
Exemple #14
0
    def test_pock_chambolle(self):
        """Test pock chambolle algorithm.
        """
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.sum_squares(X, b=B)]
        sltn = pc.solve(prox_fns, [],
                        1.0,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, B, places=2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = pc.solve(prox_fns, [],
                        1.0,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5)
        self.assertItemsAlmostEqual(X.value, B / 2., places=2)
        self.assertAlmostEqual(sltn, 0, places=2)

        prox_fns = [px.norm1(X), px.sum_squares(X, b=B)]
        sltn = pc.solve(prox_fns, [],
                        0.5,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5)

        cvx_X = cvx.Variable(10, 5)
        cost = cvx.sum_squares(cvx_X - B) + cvx.norm(cvx_X, 1)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, places=2)
        self.assertAlmostEqual(sltn, prob.value)

        psi_fns, omega_fns = pc.partition(prox_fns)
        sltn = pc.solve(psi_fns,
                        omega_fns,
                        0.5,
                        1.0,
                        1.0,
                        eps_abs=1e-5,
                        eps_rel=1e-5)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, places=2)
        self.assertAlmostEqual(sltn, prob.value)

        # With linear operators.
        kernel = np.array([1, 2, 3])
        kernel_mat = np.matrix("2 1 3; 3 2 1; 1 3 2")
        x = px.Variable(3)
        b = np.array([-41, 413, 2])
        prox_fns = [px.nonneg(x), px.sum_squares(px.conv(kernel, x), b=b)]
        sltn = pc.solve(prox_fns, [],
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5)
        cvx_X = cvx.Variable(3)
        cost = cvx.norm(kernel_mat * cvx_X - b)
        prob = cvx.Problem(cvx.Minimize(cost), [cvx_X >= 0])
        prob.solve()
        self.assertItemsAlmostEqual(x.value, cvx_X.value, places=2)

        psi_fns, omega_fns = pc.partition(prox_fns)
        sltn = pc.solve(psi_fns,
                        omega_fns,
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5)
        self.assertItemsAlmostEqual(x.value, cvx_X.value, places=2)
Exemple #15
0
    def _test_pock_chambolle(self, impl):
        """Test pock chambolle algorithm.
        """
        #print()
        #print("----------------------",impl,"-------------------------")
        if impl == 'pycuda':
            kw = {'adapter': PyCudaAdapter()}
        else:
            kw = {}
        X = px.Variable((10, 5))
        B = np.reshape(np.arange(50), (10, 5))
        prox_fns = [px.sum_squares(X, b=B)]
        sltn = pc.solve(prox_fns, [],
                        1.0,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5,
                        **kw)
        self.assertItemsAlmostEqual(X.value, B, eps=2e-2)
        self.assertAlmostEqual(sltn, 0)

        prox_fns = [px.norm1(X, b=B, beta=2)]
        sltn = pc.solve(prox_fns, [],
                        1.0,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5,
                        **kw)
        self.assertItemsAlmostEqual(X.value, B / 2., eps=2e-2)
        self.assertAlmostEqual(sltn, 0, eps=2e-2)

        prox_fns = [px.norm1(X), px.sum_squares(X, b=B)]
        #print("----------------------------------------------------")
        #print("----------------------------------------------------")
        sltn = pc.solve(prox_fns, [],
                        0.5,
                        1.0,
                        1.0,
                        eps_rel=1e-5,
                        eps_abs=1e-5,
                        conv_check=1,
                        **kw)

        cvx_X = cvx.Variable((10, 5))
        cost = cvx.sum_squares(cvx_X - B) + cvx.norm(cvx_X, 1)
        prob = cvx.Problem(cvx.Minimize(cost))
        prob.solve()
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)

        psi_fns, omega_fns = pc.partition(prox_fns)
        sltn = pc.solve(psi_fns,
                        omega_fns,
                        0.5,
                        1.0,
                        1.0,
                        eps_abs=1e-5,
                        eps_rel=1e-5,
                        **kw)
        self.assertItemsAlmostEqual(X.value, cvx_X.value, eps=2e-2)
        self.assertAlmostEqual(sltn, prob.value)

        # With linear operators.
        kernel = np.array([1, 2, 3])
        kernel_mat = np.matrix("2 1 3; 3 2 1; 1 3 2")
        x = px.Variable(3)
        b = np.array([-41, 413, 2])
        prox_fns = [px.nonneg(x), px.sum_squares(px.conv(kernel, x), b=b)]
        sltn = pc.solve(prox_fns, [],
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5,
                        **kw)
        cvx_X = cvx.Variable(3)
        cost = cvx.norm(kernel_mat * cvx_X - b)
        prob = cvx.Problem(cvx.Minimize(cost), [cvx_X >= 0])
        prob.solve()
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)

        psi_fns, omega_fns = pc.partition(prox_fns)
        sltn = pc.solve(psi_fns,
                        omega_fns,
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5,
                        **kw)
        self.assertItemsAlmostEqual(x.value, cvx_X.value, eps=2e-2)

        # TODO
        # Multiple variables.
        x = px.Variable(1)
        y = px.Variable(1)
        prox_fns = [
            px.nonneg(x),
            px.sum_squares(vstack([x, y]), b=np.arange(2))
        ]
        sltn = pc.solve(prox_fns, [prox_fns[-1]],
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5,
                        try_diagonalize=False)
        self.assertItemsAlmostEqual(x.value, [0])
        self.assertItemsAlmostEqual(y.value, [1])

        sltn = pc.solve(prox_fns, [prox_fns[-1]],
                        0.1,
                        0.1,
                        1.0,
                        max_iters=3000,
                        eps_abs=1e-5,
                        eps_rel=1e-5,
                        try_diagonalize=True)
        self.assertItemsAlmostEqual(x.value, [0])
        self.assertItemsAlmostEqual(y.value, [1])
# Convert ray transform to proximal language operator
proximal_lang_ray_trafo = odl.as_proximal_lang_operator(ray_trafo)

# Create sinogram of forward projected phantom with noise
phantom = odl.phantom.shepp_logan(reco_space, modified=True)
phantom.show('phantom')
data = ray_trafo(phantom)
data += odl.phantom.white_noise(ray_trafo.range) * np.mean(data) * 0.1
data.show('noisy data')

# Convert to array for ProxImaL
rhs_arr = data.asarray()

# Set up optimization problem
# Note that proximal is not aware of the underlying space and only works with
# matrices. Hence the norm in proximal does not match the norm in the ODL space
# exactly.
x = proximal.Variable(reco_space.shape)
funcs = [proximal.sum_squares(proximal_lang_ray_trafo(x) - rhs_arr),
         0.2 * proximal.norm1(proximal.grad(x)),
         proximal.nonneg(x),
         proximal.nonneg(1 - x)]

# Solve the problem using ProxImaL
prob = proximal.Problem(funcs)
prob.solve(verbose=True)

# Convert back to odl and display result
result_odl = reco_space.element(x.value)
result_odl.show('ProxImaL result')