def test_pointwise_norm_real(exponent):
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    pwnorm = PointwiseNorm(vfspace, exponent)

    testarr = np.array([[[1, 2], [3, 4]]])

    true_norm = np.linalg.norm(testarr, ord=exponent, axis=0)

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    pwnorm = PointwiseNorm(vfspace, exponent)

    testarr = np.array([[[1, 2], [3, 4]], [[0, -1], [0, 1]], [[1, 1], [1, 1]]])

    true_norm = np.linalg.norm(testarr, ord=exponent, axis=0)

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))
def test_pointwise_inner_adjoint_weighted():
    # Weighted product space only
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3, weighting=[2, 4, 6])
    array = np.array([[[-1 - 1j, -3], [2, 2j]], [[-1j, 0], [0, 1]],
                      [[-1, 1 + 2j], [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2], [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array  # same as unweighted case

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj, true_inner_adj)

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj)

    # Using different weighting in the inner product
    pwinner = PointwiseInner(vfspace, vecfield=array, weighting=[4, 8, 12])

    testarr = np.array([[1 + 1j, 2], [3, 4 - 2j]])

    true_inner_adj = 2 * testarr[None, :, :] * array  # w / v = (2, 2, 2)

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj, true_inner_adj)

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj)
def test_pointwise_norm_gradient_real(exponent):
    # The operator is not differentiable for exponent 'inf'
    if exponent == float('inf'):
        fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
        vfspace = ProductSpace(fspace, 1)
        pwnorm = PointwiseNorm(vfspace, exponent)
        point = vfspace.one()
        with pytest.raises(NotImplementedError):

    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    pwnorm = PointwiseNorm(vfspace, exponent)

    point = noise_element(vfspace)
    direction = noise_element(vfspace)

    # Computing expected result
    tmp = pwnorm(point).ufuncs.power(1 - exponent)
    v_field = vfspace.element()
    for i in range(len(v_field)):
        v_field[i] = tmp * point[i] * np.abs(point[i])**(exponent - 2)
    pwinner = odl.PointwiseInner(vfspace, v_field)
    expected_result = pwinner(direction)

    func_pwnorm = pwnorm.derivative(point)

    assert all_almost_equal(func_pwnorm(direction), expected_result)

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    pwnorm = PointwiseNorm(vfspace, exponent)

    point = noise_element(vfspace)
    direction = noise_element(vfspace)

    # Computing expected result
    tmp = pwnorm(point).ufuncs.power(1 - exponent)
    v_field = vfspace.element()
    for i in range(len(v_field)):
        v_field[i] = tmp * point[i] * np.abs(point[i])**(exponent - 2)
    pwinner = odl.PointwiseInner(vfspace, v_field)
    expected_result = pwinner(direction)

    func_pwnorm = pwnorm.derivative(point)
    assert all_almost_equal(func_pwnorm(direction), expected_result)
def test_pointwise_inner_real():
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    array = np.array([[[-1, -3],
                       [2, 0]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1, 2],
                         [3, 4]]])

    true_inner = np.sum(testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1, -3],
                       [2, 0]],
                      [[0, 0],
                       [0, 1]],
                      [[-1, 1],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    true_inner = np.sum(testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_inner_real():
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    array = np.array([[[-1, -3],
                       [2, 0]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1, 2],
                         [3, 4]]])

    true_inner = np.sum(testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1, -3],
                       [2, 0]],
                      [[0, 0],
                       [0, 1]],
                      [[-1, 1],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    true_inner = np.sum(testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_norm_weighted(exponent):
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    weight = np.array([1.0, 2.0, 3.0])
    pwnorm = PointwiseNorm(vfspace, exponent, weighting=weight)

    testarr = np.array([[[1, 2], [3, 4]], [[0, -1], [0, 1]], [[1, 1], [1, 1]]])

    if exponent in (1.0, float('inf')):
        true_norm = np.linalg.norm(weight[:, None, None] * testarr,
        true_norm = np.linalg.norm(weight[:, None, None]**(1 / exponent) *

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))
def test_pointwise_inner_complex():
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1 - 1j, -3],
                       [2, 2j]],
                      [[-1j, 0],
                       [0, 1]],
                      [[-1, 1 + 2j],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1 + 1j, 2],
                         [3, 4 - 2j]],
                        [[0, -1],
                         [0, 1]],
                        [[1j, 1j],
                         [1j, 1j]]])

    true_inner = np.sum(testarr * array.conj(), axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_inner_weighted():
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1, -3],
                       [2, 0]],
                      [[0, 0],
                       [0, 1]],
                      [[-1, 1],
                       [1, 1]]])

    weight = np.array([1.0, 2.0, 3.0])
    pwinner = PointwiseInner(vfspace, vecfield=array, weighting=weight)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    true_inner = np.sum(weight[:, None, None] * testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_inner_complex():
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1 - 1j, -3],
                       [2, 2j]],
                      [[-1j, 0],
                       [0, 1]],
                      [[-1, 1 + 2j],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[[1 + 1j, 2],
                         [3, 4 - 2j]],
                        [[0, -1],
                         [0, 1]],
                        [[1j, 1j],
                         [1j, 1j]]])

    true_inner = np.sum(testarr * array.conj(), axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_norm_weighted(exponent):
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    weight = np.array([1.0, 2.0, 3.0])
    pwnorm = PointwiseNorm(vfspace, exponent, weighting=weight)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    if exponent in (1.0, float('inf')):
        true_norm = np.linalg.norm(weight[:, None, None] * testarr,
                                   ord=exponent, axis=0)
        true_norm = np.linalg.norm(
            weight[:, None, None] ** (1 / exponent) * testarr, ord=exponent,

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))
def test_pointwise_inner_weighted():
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1, -3],
                       [2, 0]],
                      [[0, 0],
                       [0, 1]],
                      [[-1, 1],
                       [1, 1]]])

    weight = np.array([1.0, 2.0, 3.0])
    pwinner = PointwiseInner(vfspace, vecfield=array, weighting=weight)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    true_inner = np.sum(weight[:, None, None] * testarr * array, axis=0)

    func = vfspace.element(testarr)
    func_pwinner = pwinner(func)
    assert all_almost_equal(func_pwinner, true_inner.reshape(-1))

    out = fspace.element()
    pwinner(func, out=out)
    assert all_almost_equal(out, true_inner.reshape(-1))
def test_pointwise_norm_gradient_real_with_zeros(exponent):
    # The gradient is only well-defined in points with zeros if the exponent is
    # >= 2 and < inf
    if exponent < 2 or exponent == float('inf'):
        pytest.skip('differential of operator has singularity for this '

    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    pwnorm = PointwiseNorm(vfspace, exponent)

    test_point = np.array([[[0, 0],  # This makes the point singular for p < 2
                            [1, 2]]])
    test_direction = np.array([[[1, 2],
                                [4, 5]]])

    point = vfspace.element(test_point)
    direction = vfspace.element(test_direction)
    func_pwnorm = pwnorm.derivative(point)

    assert not np.any(np.isnan(func_pwnorm(direction)))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    pwnorm = PointwiseNorm(vfspace, exponent)

    test_point = np.array([[[0, 0],  # This makes the point singular for p < 2
                            [1, 2]],
                           [[3, 4],
                            [0, 0]],  # This makes the point singular for p < 2
                           [[5, 6],
                            [7, 8]]])
    test_direction = np.array([[[0, 1],
                                [2, 3]],
                               [[4, 5],
                                [6, 7]],
                               [[8, 9],
                                [0, 1]]])

    point = vfspace.element(test_point)
    direction = vfspace.element(test_direction)
    func_pwnorm = pwnorm.derivative(point)

    assert not np.any(np.isnan(func_pwnorm(direction)))
def test_pointwise_inner_adjoint():
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 1)
    array = np.array([[[-1, -3],
                       [2, 0]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([1, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([1, -1]))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1 - 1j, -3],
                       [2, 2j]],
                      [[-1j, 0],
                       [0, 1]],
                      [[-1, 1 + 2j],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([3, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([3, -1]))
def test_pointwise_inner_adjoint():
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 1)
    array = np.array([[[-1, -3],
                       [2, 0]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([1, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([1, -1]))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3)
    array = np.array([[[-1 - 1j, -3],
                       [2, 2j]],
                      [[-1j, 0],
                       [0, 1]],
                      [[-1, 1 + 2j],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([3, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([3, -1]))
def test_pointwise_inner_adjoint_weighted():
    # Weighted product space only
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3, weighting=[2, 4, 6])
    array = np.array([[[-1 - 1j, -3],
                       [2, 2j]],
                      [[-1j, 0],
                       [0, 1]],
                      [[-1, 1 + 2j],
                       [1, 1]]])
    pwinner = PointwiseInner(vfspace, vecfield=array)

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = testarr[None, :, :] * array  # same as unweighted case

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([3, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([3, -1]))

    # Using different weighting in the inner product
    pwinner = PointwiseInner(vfspace, vecfield=array, weighting=[4, 8, 12])

    testarr = np.array([[1 + 1j, 2],
                        [3, 4 - 2j]])

    true_inner_adj = 2 * testarr[None, :, :] * array  # w / v = (2, 2, 2)

    testfunc = fspace.element(testarr)
    testfunc_pwinner_adj = pwinner.adjoint(testfunc)
    assert all_almost_equal(testfunc_pwinner_adj,
                            true_inner_adj.reshape([3, -1]))

    out = vfspace.element()
    pwinner.adjoint(testfunc, out=out)
    assert all_almost_equal(out, true_inner_adj.reshape([3, -1]))
def test_pointwise_norm_real(exponent):
    # 1d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 1)
    pwnorm = PointwiseNorm(vfspace, exponent)

    testarr = np.array([[[1, 2],
                         [3, 4]]])

    true_norm = np.linalg.norm(testarr, ord=exponent, axis=0)

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))

    # 3d
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2))
    vfspace = ProductSpace(fspace, 3)
    pwnorm = PointwiseNorm(vfspace, exponent)

    testarr = np.array([[[1, 2],
                         [3, 4]],
                        [[0, -1],
                         [0, 1]],
                        [[1, 1],
                         [1, 1]]])

    true_norm = np.linalg.norm(testarr, ord=exponent, axis=0)

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm.reshape(-1))

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm.reshape(-1))
def test_pointwise_norm_complex(exponent):
    fspace = odl.uniform_discr([0, 0], [1, 1], (2, 2), dtype=complex)
    vfspace = ProductSpace(fspace, 3)
    pwnorm = PointwiseNorm(vfspace, exponent)

    testarr = np.array([[[1 + 1j, 2], [3, 4 - 2j]], [[0, -1], [0, 1]],
                        [[1j, 1j], [1j, 1j]]])

    true_norm = np.linalg.norm(testarr, ord=exponent, axis=0)

    func = vfspace.element(testarr)
    func_pwnorm = pwnorm(func)
    assert all_almost_equal(func_pwnorm, true_norm)

    out = fspace.element()
    pwnorm(func, out=out)
    assert all_almost_equal(out, true_norm)
class DataFitL2DispAffRest(Functional):
    def __init__(self, space, data, forward=None):
        self.space = space
        self.image_space = self.space[0]
        self.affine_space = self.space[1]
        self.rest_space = self.space[2]
        self.deformation_space = ProductSpace(self.affine_space,
        self.data = data
        if forward is None:
            self.forward = IdentityOperator(self.image_space)
            self.forward = forward

        self.datafit = 0.5 * L2NormSquared(self.image_space).translated(
        self.embedding_affine_rest = ops.Embedding_Affine_Rest(
            self.deformation_space, self.image_space.tangent_bundle)
        self.embedding_affine = ops.Embedding_Affine(
            self.affine_space, self.image_space.tangent_bundle)

        super(DataFitL2DispAffRest, self).__init__(space=space,

    def __call__(self, x):
        xim = x[0]
        xaff = x[1]
        xrest = x[2]
        xdeform = self.deformation_space.element([xaff, xrest])
        transl_operator = self.transl_op_fixed_vf(xdeform)
        fctl = self.datafit * self.forward * transl_operator
        return fctl(xim)

    def transl_op_fixed_im_aff(self, im, aff):
        affine_deform = defm.LinDeformFixedDisp(self.embedding_affine(aff))
        deform_op = defm.LinDeformFixedTempl(affine_deform(im))
        transl_operator = deform_op
        return transl_operator

    def transl_op_fixed_im_rest(self, im, rest):
        rest_deform = defm.LinDeformFixedDisp(rest)
        deformed_im = rest_deform(im)
        transl_operator = defm.LinDeformFixedTempl(
            deformed_im) * self.embedding_affine
        return transl_operator

    def transl_op_fixed_vf(self, disp):
        deform_op = defm.LinDeformFixedDisp(self.embedding_affine_rest(disp))
        transl_operator = deform_op
        return transl_operator

    def partial_gradient(self, i):
        if i == 0:
            functional = self

            class auxOperator(Operator):
                def __init__(self):
                    super(auxOperator, self).__init__(functional.space,

                def _call(self, x, out):
                    xim = x[0]
                    xaff = x[1]
                    xrest = x[2]
                    xdeform = functional.deformation_space.element(
                        [xaff, xrest])
                    transl_operator = functional.transl_op_fixed_vf(xdeform)
                    func = functional.datafit * functional.forward * transl_operator
                    grad = func.gradient

            return auxOperator()
        elif i == 1:
            functional = self

            class auxOperator(Operator):
                def __init__(self):
                    super(auxOperator, self).__init__(functional.space,

                def _call(self, x, out):
                    xim = x[0]
                    xaff = x[1]
                    xrest = x[2]
                    transl_operator = functional.transl_op_fixed_im_rest(
                        xim, xrest)
                    func = functional.datafit * functional.forward * transl_operator
                    grad = func.gradient

            return auxOperator()
        elif i == 2:
            functional = self

            class auxOperator(Operator):
                def __init__(self):
                    super(auxOperator, self).__init__(functional.space,

                def _call(self, x, out):
                    xim = x[0]
                    xaff = x[1]
                    xrest = x[2]
                    transl_operator = functional.transl_op_fixed_im_aff(
                        xim, xaff)
                    func = functional.datafit * functional.forward * transl_operator
                    grad = func.gradient

            return auxOperator()
            raise ValueError('No gradient defined for this variable')

    def gradient(self):
        return BroadcastOperator(*[self.partial_gradient(i) for i in range(3)])