예제 #1
0
def test_metric():
    H = odl.rn(2)
    v11 = H.element([1, 2])
    v12 = H.element([5, 3])

    v21 = H.element([1, 2])
    v22 = H.element([8, 9])

    # 1-norm
    HxH = odl.ProductSpace(H, H, exponent=1.0)
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert almost_equal(HxH.dist(w1, w2),
                        H.dist(v11, v21) + H.dist(v12, v22))

    # 2-norm
    HxH = odl.ProductSpace(H, H, exponent=2.0)
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert almost_equal(
        HxH.dist(w1, w2),
        (H.dist(v11, v21) ** 2 + H.dist(v12, v22) ** 2) ** (1 / 2.0))

    # inf norm
    HxH = odl.ProductSpace(H, H, exponent=float('inf'))
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert almost_equal(
        HxH.dist(w1, w2),
        max(H.dist(v11, v21), H.dist(v12, v22)))
예제 #2
0
def test_element_getitem_multi():
    """Test element access with multiple indices."""
    pspace = odl.ProductSpace(odl.rn(1), odl.rn(2))
    pspace2 = odl.ProductSpace(pspace, 3)
    pspace3 = odl.ProductSpace(pspace2, 2)
    z = pspace3.element(
        [[[[1],
           [2, 3]],
          [[4],
           [5, 6]],
          [[7],
           [8, 9]]],
         [[[10],
           [12, 13]],
          [[14],
           [15, 16]],
          [[17],
           [18, 19]]]
         ]
    )

    assert pspace3.shape == (2, 3, 2)
    assert z[0, 0, 0, 0] == 1
    assert all_equal(z[0, 0, 1], [2, 3])
    assert all_equal(z[0, 0], [[1], [2, 3]])
    assert all_equal(z[0, 1:], [[[4],
                                 [5, 6]],
                                [[7],
                                 [8, 9]]])
    assert all_equal(z[0, 1:, 1], [[5, 6],
                                   [8, 9]])
    assert all_equal(z[0, 1:, :, 0], [[[4],
                                       [5]],
                                      [[7],
                                       [8]]])
예제 #3
0
def functional(request, linear_offset, quadratic_offset, dual):
    """Return functional whose proximal should be tested."""
    name = request.param.strip()

    space = odl.uniform_discr(0, 1, 2)

    if name == 'l1':
        func = odl.solvers.L1Norm(space)
    elif name == 'l2':
        func = odl.solvers.L2Norm(space)
    elif name == 'l2^2':
        func = odl.solvers.L2NormSquared(space)
    elif name == 'kl':
        func = odl.solvers.KullbackLeibler(space)
    elif name == 'kl_cross_ent':
        func = odl.solvers.KullbackLeiblerCrossEntropy(space)
    elif name == 'const':
        func = odl.solvers.ConstantFunctional(space, constant=2)
    elif name.startswith('groupl1'):
        exponent = float(name.split('-')[1])
        space = odl.ProductSpace(space, 2)
        func = odl.solvers.GroupL1Norm(space, exponent=exponent)
    elif name.startswith('nuclearnorm'):
        outer_exp = float(name.split('-')[1])
        singular_vector_exp = float(name.split('-')[2])

        space = odl.ProductSpace(odl.ProductSpace(space, 2), 3)
        func = odl.solvers.NuclearNorm(space,
                                       outer_exp=outer_exp,
                                       singular_vector_exp=singular_vector_exp)
    elif name == 'quadratic':
        func = odl.solvers.QuadraticForm(operator=odl.IdentityOperator(space),
                                         vector=space.one(),
                                         constant=0.623)
    elif name == 'linear':
        func = odl.solvers.QuadraticForm(vector=space.one(), constant=0.623)
    else:
        assert False

    if quadratic_offset:
        if linear_offset:
            g = noise_element(space)
            if name.startswith('kl'):
                g = np.abs(g)
        else:
            g = None

        quadratic_coeff = 1.32
        func = odl.solvers.FunctionalQuadraticPerturb(
            func, quadratic_coeff=quadratic_coeff, linear_term=g)
    elif linear_offset:
        g = noise_element(space)
        if name.startswith('kl'):
            g = np.abs(g)
        func = func.translated(g)

    if dual:
        func = func.convex_conj

    return func
예제 #4
0
def test_fixed_disp_init():
    """Verify that the init method and checks work properly."""
    space = odl.uniform_discr(0, 1, 5)
    disp_field = space.tangent_bundle.element(disp_field_factory(space.ndim))

    # Valid input
    print(LinDeformFixedDisp(disp_field))
    print(LinDeformFixedDisp(disp_field, templ_space=space))

    # Non-valid input
    with pytest.raises(TypeError):  # displacement not ProductSpaceElement
        LinDeformFixedDisp(space.one())
    with pytest.raises(TypeError):  # templ_space not DiscreteLp
        LinDeformFixedDisp(disp_field, space.tangent_bundle)
    with pytest.raises(TypeError):  # templ_space not a power space
        bad_pspace = odl.ProductSpace(space, odl.rn(3))
        LinDeformFixedDisp(disp_field, bad_pspace)
    with pytest.raises(TypeError):  # templ_space not based on DiscreteLp
        bad_pspace = odl.ProductSpace(odl.rn(2), 1)
        LinDeformFixedDisp(disp_field, bad_pspace)
    with pytest.raises(TypeError):  # wrong dtype on templ_space
        wrong_dtype = odl.ProductSpace(space.astype(complex), 1)
        LinDeformFixedDisp(disp_field, wrong_dtype)
    with pytest.raises(ValueError):  # vector field spaces don't match
        bad_space = odl.uniform_discr(0, 1, 10)
        LinDeformFixedDisp(disp_field, bad_space)
예제 #5
0
    def __init__(self, DomainField, Ntrans, kernel, partialderdim1kernel,
                 partialder2kernel):
        """Initialize a new instance.
        DomainField : space on wich vector fields will be defined
        Ntrans : number of translations
        Kernel : kernel
        partialderdim1kernel :  partial derivative with respect to one componenet of
        the 1st component, 
         to be used as partialder2kernel(x, y, d) where d is in [|0, dim-1|]
        partialder2kernel : partial derivative with respect to 2nd component, 
         to be used as partialder2kernel(x, y, u) with x and y points, and u 
         vectors for differentiation (same number of vectors as the number of 
         points for y)
        """

        self.Ntrans = Ntrans
        self.kernel = kernel
        self.partialderdim1kernel = partialderdim1kernel
        self.partialder2kernel = partialder2kernel
        self.dim = DomainField.ndim
        self.get_unstructured_op = usefun.get_from_structured_to_unstructured(
            DomainField, kernel)
        self.gradient_op = Gradient(DomainField)
        GDspace = odl.ProductSpace(odl.space.rn(self.dim), self.Ntrans)
        Contspace = odl.ProductSpace(odl.space.rn(self.dim), self.Ntrans)
        self.dimCont = self.Ntrans * self.dim

        super().__init__(GDspace, Contspace, DomainField)
예제 #6
0
def test_element_getitem_int():
    """Test indexing of product space elements with one or several integers."""
    pspace = odl.ProductSpace(odl.rn(1), odl.rn(2))

    # One level of product space
    x0 = pspace[0].element([0])
    x1 = pspace[1].element([1, 2])
    x = pspace.element([x0, x1])

    assert x[0] is x0
    assert x[1] is x1
    assert x[-2] is x0
    assert x[-1] is x1
    with pytest.raises(IndexError):
        x[-3]
        x[2]
    assert x[0, 0] == 0
    assert x[1, 0] == 1

    # Two levels of product spaces
    pspace2 = odl.ProductSpace(pspace, 3)
    z = pspace2.element([x, x, x])
    assert z[0] is x
    assert z[1, 0] is x0
    assert z[1, 1, 1] == 2
예제 #7
0
def test_metric():
    H = odl.rn(2)
    v11 = H.element([1, 2])
    v12 = H.element([5, 3])

    v21 = H.element([1, 2])
    v22 = H.element([8, 9])

    # 1-norm
    HxH = odl.ProductSpace(H, H, exponent=1.0)
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert (HxH.dist(w1,
                     w2) == pytest.approx(H.dist(v11, v21) + H.dist(v12, v22)))

    # 2-norm
    HxH = odl.ProductSpace(H, H, exponent=2.0)
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert (HxH.dist(w1, w2) == pytest.approx(
        (H.dist(v11, v21)**2 + H.dist(v12, v22)**2)**0.5))

    # inf norm
    HxH = odl.ProductSpace(H, H, exponent=float('inf'))
    w1 = HxH.element([v11, v12])
    w2 = HxH.element([v21, v22])
    assert (HxH.dist(w1, w2) == pytest.approx(
        max(H.dist(v11, v21), H.dist(v12, v22))))
예제 #8
0
def test_fixed_disp_init():
    """Test init and props of lin. deformation with fixed displacement."""
    space = odl.uniform_discr(0, 1, 5)
    disp_field = space.tangent_bundle.element(disp_field_factory(space.ndim))

    # Valid input
    op = LinDeformFixedDisp(disp_field)
    assert repr(op) != ''
    op = LinDeformFixedDisp(disp_field, templ_space=space)
    assert repr(op) != ''

    # Non-valid input
    with pytest.raises(TypeError):  # displacement not ProductSpaceElement
        LinDeformFixedDisp(space.one())
    with pytest.raises(TypeError):  # templ_space not DiscretizedSpace
        LinDeformFixedDisp(disp_field, space.tangent_bundle)
    with pytest.raises(TypeError):  # templ_space not a power space
        bad_pspace = odl.ProductSpace(space, odl.rn(3))
        LinDeformFixedDisp(disp_field, bad_pspace)
    with pytest.raises(TypeError):  # templ_space not based on DiscretizedSpace
        bad_pspace = odl.ProductSpace(odl.rn(2), 1)
        LinDeformFixedDisp(disp_field, bad_pspace)
    with pytest.raises(TypeError):  # wrong dtype on templ_space
        wrong_dtype = odl.ProductSpace(space.astype(complex), 1)
        LinDeformFixedDisp(disp_field, wrong_dtype)
    with pytest.raises(ValueError):  # vector field spaces don't match
        bad_space = odl.uniform_discr(0, 1, 10)
        LinDeformFixedDisp(disp_field, bad_space)
예제 #9
0
def test_is_power_space():
    r2 = odl.Rn(2)
    r2x3 = odl.ProductSpace(r2, 3)
    assert r2x3.is_power_space

    r2r2r2 = odl.ProductSpace(r2, r2, r2)
    assert r2x3 == r2r2r2
예제 #10
0
def test_getitem_fancy():
    r1 = odl.rn(1)
    r2 = odl.rn(2)
    r3 = odl.rn(3)
    H = odl.ProductSpace(r1, r2, r3)

    assert H[[0, 2]] == odl.ProductSpace(r1, r3)
    assert H[[0, 2]][0] is r1
    assert H[[0, 2]][1] is r3
예제 #11
0
def test_emptyproduct():
    with pytest.raises(ValueError):
        odl.ProductSpace()

    reals = odl.RealNumbers()
    spc = odl.ProductSpace(field=reals)
    assert spc.field == reals
    assert spc.size == 0

    with pytest.raises(IndexError):
        spc[0]
예제 #12
0
def test_is_power_space():
    r2 = odl.rn(2)
    r2x3 = odl.ProductSpace(r2, 3)
    assert len(r2x3) == 3
    assert r2x3.is_power_space
    assert r2x3.spaces[0] is r2
    assert r2x3.spaces[1] is r2
    assert r2x3.spaces[2] is r2

    r2r2r2 = odl.ProductSpace(r2, r2, r2)
    assert r2x3 == r2r2r2
예제 #13
0
def test_getitem_slice():
    r1 = odl.rn(1)
    r2 = odl.rn(2)
    r3 = odl.rn(3)
    H = odl.ProductSpace(r1, r2, r3)

    assert H[:2] == odl.ProductSpace(r1, r2)
    assert H[:2][0] is r1
    assert H[:2][1] is r2

    assert H[3:] == odl.ProductSpace(field=r1.field)
예제 #14
0
def space(request):
    name = request.param.strip()

    if name == 'product_space':
        space = odl.ProductSpace(odl.uniform_discr(0, 1, 3, dtype=complex),
                                 odl.cn(2))
    elif name == 'power_space':
        space = odl.ProductSpace(odl.uniform_discr(0, 1, 3, dtype=complex), 2)
    else:
        raise ValueError('undefined space')

    return space
    def ConvolveIntegrate(self,grad_S_init,H,j0,vector_field_list,zeta_list):
            dim = self.image_domain.ndim

            k_j=self.k_j_list[j0]
            h=odl.ProductSpace(self.image_domain.tangent_bundle,k_j+1).zero()
            eta=odl.ProductSpace(self.image_domain,k_j+1).zero()

            grad_op = Gradient(domain=self.image_domain, method='forward',
                   pad_mode='symmetric')
            # Create the divergence op
            div_op = -grad_op.adjoint
            delta0=self.data_time_points[j0] -(k_j/self.N)

            detDphi=self.image_domain.element(
                                      1+delta0 *
                                      div_op(vector_field_list[k_j])).copy()
            grad_S=self.image_domain.element(
                                   linear_deform(grad_S_init,
                                   delta0 * vector_field_list[k_j])).copy()

            if (not delta0==0):
                tmp=H[k_j].copy()
                tmp1=(grad_S * detDphi).copy()
                for d in range(dim):
                    tmp[d] *= tmp1
                tmp3=(2 * np.pi) ** (dim / 2.0) * self.vectorial_ft_fit_op.inverse(self.vectorial_ft_fit_op(tmp) * self.ft_kernel_fitting)
                h[k_j]+=tmp3.copy()
                eta[k_j]+=detDphi*grad_S

            delta_t= self.inv_N
            for u in range(k_j):
                k=k_j -u-1
                detDphi=self.image_domain.element(
                                linear_deform(detDphi,
                                   delta_t*vector_field_list[k])).copy()
                detDphi=self.image_domain.element(detDphi*
                                self.image_domain.element(1+delta_t *
                                div_op(vector_field_list[k]))).copy()
                grad_S=self.image_domain.element(
                                   linear_deform(grad_S,
                                   delta_t * vector_field_list[k])).copy()

                tmp=H[k].copy()
                tmp1=(grad_S * detDphi).copy()
                for d in range(dim):
                    tmp[d] *= tmp1.copy()
                tmp3=(2 * np.pi) ** (dim / 2.0) * self.vectorial_ft_fit_op.inverse(self.vectorial_ft_fit_op(tmp) * self.ft_kernel_fitting)
                h[k]+=tmp3.copy()
                eta[k]+=detDphi*grad_S

            return [h,eta]
예제 #16
0
def test_equals_space(exponent):
    r2 = odl.rn(2)
    r2x3_1 = odl.ProductSpace(r2, 3, exponent=exponent)
    r2x3_2 = odl.ProductSpace(r2, 3, exponent=exponent)
    r2x4 = odl.ProductSpace(r2, 4, exponent=exponent)

    assert r2x3_1 is r2x3_1
    assert r2x3_1 is not r2x3_2
    assert r2x3_1 is not r2x4
    assert r2x3_1 == r2x3_1
    assert r2x3_1 == r2x3_2
    assert r2x3_1 != r2x4
    assert hash(r2x3_1) == hash(r2x3_2)
    assert hash(r2x3_1) != hash(r2x4)
예제 #17
0
def test_vector_setitem_single():
    H = odl.ProductSpace(odl.Rn(1), odl.Rn(2))

    x1 = H[0].element([0])
    x2 = H[1].element([1, 2])
    x = H.element([x1, x2])

    x1_1 = H[0].element([1])
    x[-2] = x1_1
    assert x[-2] is x1_1

    x2_1 = H[1].element([3, 4])
    x[-1] = x2_1
    assert x[-1] is x2_1

    x1_2 = H[0].element([5])
    x[0] = x1_2

    x2_2 = H[1].element([3, 4])
    x[1] = x2_2
    assert x[1] is x2_2

    with pytest.raises(IndexError):
        x[-3] = x2
        x[2] = x1
 def __init__(self, space, energies, spectrum):
     self.energies = np.array(energies)
     self.spectrum = np.array(spectrum)
     self.spectrum = self.spectrum / self.spectrum.sum()
     super().__init__(odl.ProductSpace(space, 2),
                      space,
                      False)
예제 #19
0
def test_reductions():
    H = odl.ProductSpace(odl.rn(1), odl.rn(2))
    x = H.element([[1], [2, 3]])
    assert x.ufuncs.sum() == 6.0
    assert x.ufuncs.prod() == 6.0
    assert x.ufuncs.min() == 1.0
    assert x.ufuncs.max() == 3.0
예제 #20
0
def test_operators(arithmetic_op):
    # Test of the operators `+`, `-`, etc work as expected by numpy

    space = odl.rn(3)
    pspace = odl.ProductSpace(space, 2)

    # Interactions with scalars

    for scalar in [-31.2, -1, 0, 1, 2.13]:

        # Left op
        x_arr, x = noise_elements(pspace)
        if scalar == 0 and arithmetic_op in [operator.truediv,
                                             operator.itruediv]:
            # Check for correct zero division behaviour
            with pytest.raises(ZeroDivisionError):
                y = arithmetic_op(x, scalar)
        else:
            y_arr = arithmetic_op(x_arr, scalar)
            y = arithmetic_op(x, scalar)

            assert all_almost_equal([x, y], [x_arr, y_arr])

        # Right op
        x_arr, x = noise_elements(pspace)

        y_arr = arithmetic_op(scalar, x_arr)
        y = arithmetic_op(scalar, x)

        assert all_almost_equal([x, y], [x_arr, y_arr])

    # Verify that the statement z=op(x, y) gives equivalent results to NumPy
    x_arr, x = noise_elements(space, 1)
    y_arr, y = noise_elements(pspace, 1)

    # non-aliased left
    if arithmetic_op in [operator.iadd,
                         operator.isub,
                         operator.itruediv,
                         operator.imul]:
        # Check for correct error since in-place op is not possible here
        with pytest.raises(TypeError):
            z = arithmetic_op(x, y)
    else:
        z_arr = arithmetic_op(x_arr, y_arr)
        z = arithmetic_op(x, y)

        assert all_almost_equal([x, y, z], [x_arr, y_arr, z_arr])

    # non-aliased right
    z_arr = arithmetic_op(y_arr, x_arr)
    z = arithmetic_op(y, x)

    assert all_almost_equal([x, y, z], [x_arr, y_arr, z_arr])

    # aliased operation
    z_arr = arithmetic_op(y_arr, y_arr)
    z = arithmetic_op(y, y)

    assert all_almost_equal([x, y, z], [x_arr, y_arr, z_arr])
예제 #21
0
def test_ufuncs():
    # Cannot use fixture due to bug in pytest
    H = odl.ProductSpace(odl.rn(1), odl.rn(2))

    # one arg
    x = H.element([[-1], [-2, -3]])

    z = x.ufuncs.absolute()
    assert all_almost_equal(z, [[1], [2, 3]])

    # one arg with out
    x = H.element([[-1], [-2, -3]])
    y = H.element()

    z = x.ufuncs.absolute(out=y)
    assert y is z
    assert all_almost_equal(z, [[1], [2, 3]])

    # Two args
    x = H.element([[1], [2, 3]])
    y = H.element([[4], [5, 6]])
    w = H.element()

    z = x.ufuncs.add(y)
    assert all_almost_equal(z, [[5], [7, 9]])

    # Two args with out
    x = H.element([[1], [2, 3]])
    y = H.element([[4], [5, 6]])
    w = H.element()

    z = x.ufuncs.add(y, out=w)
    assert w is z
    assert all_almost_equal(z, [[5], [7, 9]])
예제 #22
0
def test_element_setitem_fancy():
    """Test assignment of pspace parts with lists."""
    pspace = odl.ProductSpace(odl.rn(1), odl.rn(2), odl.rn(3))

    x0 = pspace[0].element([0])
    x1 = pspace[1].element([1, 2])
    x2 = pspace[2].element([3, 4, 5])
    x = pspace.element([x0, x1, x2])
    old_x0 = x[0]
    old_x2 = x[2]

    # Check that values are set, but identity is preserved
    new_x0 = pspace[0].element([6])
    new_x2 = pspace[2].element([7, 8, 9])
    x[[0, 2]] = pspace[[0, 2]].element([new_x0, new_x2])
    assert x[[0, 2]][0] is old_x0
    assert x[[0, 2]][0] == new_x0
    assert x[[0, 2]][1] is old_x2
    assert x[[0, 2]][1] == new_x2

    # Set values with sequences of scalars
    x[[0, 2]] = [-1, -2]
    assert x[[0, 2]][0] is old_x0
    assert all_equal(x[[0, 2]][0], [-1])
    assert x[[0, 2]][1] is old_x2
    assert all_equal(x[[0, 2]][1], [-2, -2, -2])
예제 #23
0
def test_element_setitem_single():
    """Test assignment of pspace parts with single indices."""
    pspace = odl.ProductSpace(odl.rn(1), odl.rn(2))

    x0 = pspace[0].element([0])
    x1 = pspace[1].element([1, 2])
    x = pspace.element([x0, x1])
    old_x0 = x[0]
    old_x1 = x[1]

    # Check that values are set, but identity is preserved
    new_x0 = pspace[0].element([1])
    x[-2] = new_x0
    assert x[-2] == new_x0
    assert x[-2] is old_x0

    new_x1 = pspace[1].element([3, 4])
    x[-1] = new_x1
    assert x[-1] == new_x1
    assert x[-1] is old_x1

    # Set values with scalars
    x[1] = -1
    assert all_equal(x[1], [-1, -1])
    assert x[1] is old_x1

    # Check that out-of-bounds indices raise IndexError
    with pytest.raises(IndexError):
        x[-3] = x1
    with pytest.raises(IndexError):
        x[2] = x0
예제 #24
0
def test_equals_vec(exponent):
    r2 = odl.rn(2)
    r2x3 = odl.ProductSpace(r2, 3, exponent=exponent)
    r2x4 = odl.ProductSpace(r2, 4, exponent=exponent)

    x1 = r2x3.zero()
    x2 = r2x3.zero()
    y = r2x3.one()
    z = r2x4.zero()

    assert x1 is x1
    assert x1 is not x2
    assert x1 is not y
    assert x1 == x1
    assert x1 == x2
    assert x1 != y
    assert x1 != z
예제 #25
0
def test_pspace_op_weighted_init():

    r3 = odl.rn(3)
    ran = odl.ProductSpace(r3, 2, weighting=[1, 2])
    I = odl.IdentityOperator(r3)

    with pytest.raises(NotImplementedError):
        odl.ProductSpaceOperator([[I], [0]], range=ran)
예제 #26
0
def functional(request, space):
    name = request.param.strip()

    if name == 'l1':
        func = odl.solvers.functional.L1Norm(space)
    elif name == 'l2':
        func = odl.solvers.functional.L2Norm(space)
    elif name == 'l2^2':
        func = odl.solvers.functional.L2NormSquared(space)
    elif name == 'constant':
        func = odl.solvers.functional.ConstantFunctional(space, 2)
    elif name == 'zero':
        func = odl.solvers.functional.ZeroFunctional(space)
    elif name == 'ind_unit_ball_1':
        func = odl.solvers.functional.IndicatorLpUnitBall(space, 1)
    elif name == 'ind_unit_ball_2':
        func = odl.solvers.functional.IndicatorLpUnitBall(space, 2)
    elif name == 'ind_unit_ball_pi':
        func = odl.solvers.functional.IndicatorLpUnitBall(space, np.pi)
    elif name == 'ind_unit_ball_inf':
        func = odl.solvers.functional.IndicatorLpUnitBall(space, np.inf)
    elif name == 'product':
        left = odl.solvers.functional.L2Norm(space)
        right = odl.solvers.functional.ConstantFunctional(space, 2)
        func = odl.solvers.functional.FunctionalProduct(left, right)
    elif name == 'quotient':
        dividend = odl.solvers.functional.L2Norm(space)
        divisor = odl.solvers.functional.ConstantFunctional(space, 2)
        func = odl.solvers.functional.FunctionalQuotient(dividend, divisor)
    elif name == 'kl':
        func = odl.solvers.functional.KullbackLeibler(space)
    elif name == 'kl_cc':
        func = odl.solvers.KullbackLeibler(space).convex_conj
    elif name == 'kl_cross_ent':
        func = odl.solvers.functional.KullbackLeiblerCrossEntropy(space)
    elif name == 'kl_cc_cross_ent':
        func = odl.solvers.KullbackLeiblerCrossEntropy(space).convex_conj
    elif name == 'huber':
        func = odl.solvers.Huber(space, gamma=0.1)
    elif name == 'groupl1':
        if isinstance(space, odl.ProductSpace):
            pytest.skip("The `GroupL1Norm` is not supported on `ProductSpace`")
        space = odl.ProductSpace(space, 3)
        func = odl.solvers.GroupL1Norm(space)
    elif name == 'bregman_l2squared':
        point = noise_element(space)
        l2_squared = odl.solvers.L2NormSquared(space)
        subgrad = l2_squared.gradient(point)
        func = odl.solvers.BregmanDistance(l2_squared, point, subgrad)
    elif name == 'bregman_l1':
        point = noise_element(space)
        l1 = odl.solvers.L1Norm(space)
        subgrad = l1.gradient(point)
        func = odl.solvers.BregmanDistance(l1, point, subgrad)
    else:
        assert False

    return func
예제 #27
0
def test_array_wrap_method():
    """Verify that the __array_wrap__ method for NumPy works."""
    space = odl.ProductSpace(odl.uniform_discr(0, 1, 10), 2)
    x_arr, x = noise_elements(space)
    y_arr = np.sin(x_arr)
    y = np.sin(x)  # Should yield again an ODL product space element

    assert y in space
    assert all_equal(y, y_arr)
    def __init__(self, ModulesList):
        self.ModulesList = ModulesList
        domain = ModulesList[0].DomainField
        self.Nmod = len(ModulesList)
        for i in range(1, self.Nmod):
            if not (ModulesList[i].DomainField == domain):
                print('Problem domains')

        GDspace = odl.ProductSpace(
            *[ModulesList[i].GDspace for i in range(self.Nmod)])
        Contspace = odl.ProductSpace(
            *[ModulesList[i].Contspace for i in range(self.Nmod)])

        self.dim = ModulesList[0].DomainField.ndim

        self.dimCont = sum([Modi.dimCont for Modi in ModulesList])

        super().__init__(GDspace, Contspace, domain)
예제 #29
0
def test_comp_proj_indices():
    r3 = odl.rn(3)
    r33 = odl.ProductSpace(r3, 3)

    x = r33.element([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

    proj = odl.ComponentProjection(r33, [0, 2])
    assert x[[0, 2]] == proj(x)
    assert x[[0, 2]] == proj(x, out=proj.range.element())
예제 #30
0
def test_comp_proj_slice():
    r3 = odl.Rn(3)
    r33 = odl.ProductSpace(r3, 3)

    x = r33.element([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

    proj = odl.ComponentProjection(r33, slice(0, 2))
    assert x[0:2] == proj(x)
    assert x[0:2] == proj(x, out=proj.range.element())