예제 #1
0
def test_fixed_templ_init():
    """Verify that the init method and checks work properly."""
    space = odl.uniform_discr(0, 1, 5)
    template = space.element(template_function)

    # Valid input
    print(LinDeformFixedTempl(template))

    # Invalid input
    with pytest.raises(TypeError):
        # template not a DiscreteLpElement
        LinDeformFixedTempl(template_function)
예제 #2
0
def test_fixed_templ_init():
    """Verify that the init method and checks work properly."""
    space = odl.uniform_discr(0, 1, 5)
    template = space.element(template_function)

    # Valid input
    op = LinDeformFixedTempl(template)
    assert repr(op) != ''
    op = LinDeformFixedTempl(template, domain=space.astype('float32')**1)
    assert repr(op) != ''

    # Invalid input
    with pytest.raises(TypeError):
        # template not a DiscreteLpElement
        LinDeformFixedTempl(template_function)
예제 #3
0
def test_fixed_templ_init():
    """Test init and props of linearized deformation with fixed template."""
    space = odl.uniform_discr(0, 1, 5)
    template = space.element(template_function)

    # Valid input
    op = LinDeformFixedTempl(template)
    assert repr(op) != ''
    op = LinDeformFixedTempl(template, domain=space.astype('float32')**1)
    assert repr(op) != ''

    # Invalid input
    with pytest.raises(TypeError):
        # template_function not a DiscretizedSpaceElement
        LinDeformFixedTempl(template_function)
예제 #4
0
def test_fixed_templ_deriv(space):
    if not space.is_rn:
        pytest.skip('derivative not implemented for complex dtypes')

    # Set up template and displacement field
    template = space.element(template_function)
    disp_field = disp_field_factory(space.ndim)
    vector_field = vector_field_factory(space.ndim)
    fixed_templ_op = LinDeformFixedTempl(template)

    # Calculate result
    fixed_templ_op_deriv = fixed_templ_op.derivative(disp_field)
    fixed_templ_deriv_comp = fixed_templ_op_deriv(vector_field)

    # Calculate the analytic result
    fixed_templ_deriv_exact = space.element(fixed_templ_deriv)

    # Verify that the result is within error limits
    error = (fixed_templ_deriv_exact - fixed_templ_deriv_comp).norm()
    rlt_err = error / fixed_templ_deriv_comp.norm()
    assert rlt_err < error_bound(space.interp)
예제 #5
0
def test_fixed_templ_deriv(space):
    if not space.is_rn:
        pytest.skip('derivative not implemented for complex dtypes')

    # Set up template and displacement field
    template = space.element(template_function)
    disp_field = disp_field_factory(space.ndim)
    vector_field = vector_field_factory(space.ndim)
    fixed_templ_op = LinDeformFixedTempl(template)

    # Calculate result
    fixed_templ_op_deriv = fixed_templ_op.derivative(disp_field)
    fixed_templ_deriv_comp = fixed_templ_op_deriv(vector_field)

    # Calculate the analytic result
    fixed_templ_deriv_exact = space.element(fixed_templ_deriv)

    # Verify that the result is within error limits
    error = (fixed_templ_deriv_exact - fixed_templ_deriv_comp).norm()
    rlt_err = error / fixed_templ_deriv_comp.norm()
    assert rlt_err < error_bound(space.interp)
예제 #6
0
def test_fixed_templ_call(space, interp):
    """Test call of linearized deformation with fixed template."""
    # Define the analytic template as the hat function and its gradient
    template = space.element(template_function)
    deform_op = LinDeformFixedTempl(template, interp=interp)

    # Calculate result and exact result
    true_deformed_templ = space.element(deformed_template)
    deformed_templ = deform_op(disp_field_factory(space.ndim))

    # Verify that the result is within error limits
    error = (true_deformed_templ - deformed_templ).norm()
    rlt_err = error / deformed_templ.norm()
    assert rlt_err < error_bound(interp)
예제 #7
0
# Compute Fourier trasform of the kernel function in data matching term
ft_kernel_fitting = fitting_kernel_ft(kernel)

# Compute Fourier trasform of the kernel function in shape regularization term
ft_kernel_shape = shape_kernel_ft(kernel)

# Create displacement operator
displacement_op = DisplacementOperator(vspace, cptsspace.grid, discr_space,
                                       ft_kernel_fitting)

# Compute the displacement at momenta
displ = displacement_op(momenta)

# Create linearized deformation operator
linear_deform_op = LinDeformFixedTempl(template)

# Compute the deformed template
deformed_template = linear_deform_op(displ)

# Create X-ray transform operator
proj_deformed_template = xray_trafo_op(deformed_template)

# Create L2 data matching (fitting) term
l2_data_fit_func = L2DataMatchingFunctional(xray_trafo_op.range,
                                            noise_proj_data)

# Composition of the L2 fitting term with three operators
# data_fitting_term = l2_data_fit_func * xray_trafo_op * linear_deform_op * displacement_op

# Compute the kernel matrix for the method without Fourier transform