コード例 #1
0
ファイル: proximal_operator_test.py プロジェクト: jakobsj/odl
def test_proximal_convconj_kl_simple_space():
    """Test for proximal factory for the convex conjugate of KL divergence."""

    # Image space
    space = odl.uniform_discr(0, 1, 10)

    # Element in image space where the proximal operator is evaluated
    x = space.element(np.arange(-5, 5))

    # Data
    g = space.element(np.arange(10, 0, -1))

    # Factory function returning the proximal operator
    lam = 2
    prox_factory = proximal_convex_conj_kl(space, lam=lam, g=g)

    # Initialize the proximal operator of F^*
    sigma = 0.25
    prox = prox_factory(sigma)

    assert isinstance(prox, odl.Operator)

    # Allocate an output element
    x_opt = space.element()

    # Apply the proximal operator returning its optimal point
    prox(x, x_opt)

    # Explicit computation:
    x_verify = (lam + x - np.sqrt((x - lam)**2 + 4 * lam * sigma * g)) / 2

    assert all_almost_equal(x_opt, x_verify, HIGH_ACC)
コード例 #2
0
ファイル: proximal_operator_test.py プロジェクト: jakobsj/odl
def test_proximal_convconj_kl_product_space():
    """Test for product spaces in proximal for conjugate of KL divergence"""

    # Product space for matrix of operators
    op_domain = odl.ProductSpace(odl.uniform_discr(0, 1, 10), 2)

    # Element in the product space where the proximal operator is evaluated
    x0_arr = np.arange(-5, 5)
    x1_arr = np.arange(10, 0, -1)
    x = op_domain.element([x0_arr, x1_arr])

    # Element in the product space with given data
    g0_arr = x1_arr.copy()
    g1_arr = x0_arr.copy()
    g = op_domain.element([g0_arr, g1_arr])

    # Factory function returning the proximal operator
    lam = 2
    prox_factory = proximal_convex_conj_kl(op_domain, lam=lam, g=g)

    # Initialize the proximal operator
    sigma = 0.25
    prox = prox_factory(sigma)

    assert isinstance(prox, odl.Operator)

    # Allocate an output element
    x_opt = op_domain.element()

    # Apply the proximal operator returning its optimal point
    prox(x, x_opt)

    # Explicit computation:
    x_verify = (lam + x - np.sqrt((x - lam)**2 + 4 * lam * sigma * g)) / 2

    # Compare components
    assert all_almost_equal(x_verify, x_opt)