Esempio n. 1
0
def gram_schmidt_basis_extension(basis, U, U_ind=None, product=None, copy_basis=True, copy_U=True):
    '''Extend basis using Gram-Schmidt orthonormalization.

    Parameters
    ----------
    basis
        The basis to extend.
    U
        The new basis vectors.
    U_ind
        Indices of the new basis vectors in U.
    product
        The scalar product w.r.t. which to orthonormalize; if None, the l2-scalar
        product on the coefficient vector is used.
    copy_basis
        If copy_basis is False, the old basis is extended in-place.
    copy_U
        If copy_U is False, the new basis vectors are removed from U.

    Returns
    -------
    The new basis.

    Raises
    ------
    ExtensionError
        Gram-Schmidt orthonormalization fails. Usually this is the case when U
        is not linearily independent from the basis. However this can also happen
        due to rounding errors ...
    '''
    if basis is None:
        basis = NumpyVectorArray(np.zeros((0, U.dim)))

    basis_length = len(basis)

    new_basis = basis.copy() if copy_basis else basis
    new_basis.append(U, o_ind=U_ind, remove_from_other=(not copy_U))
    gram_schmidt(new_basis, offset=len(basis), product=product)

    if len(new_basis) <= basis_length:
        raise ExtensionError

    return new_basis
Esempio n. 2
0
def trivial_basis_extension(basis, U, U_ind=None, copy_basis=True, copy_U=True):
    '''Trivially extend basis by just adding the new vector.

    We check that the new vector is not already contained in the basis, but we do
    not check for linear independence.

    Parameters
    ----------
    basis
        The basis to extend.
    U
        The new basis vector.
    U_ind
        Indices of the new basis vectors in U.
    copy_basis
        If copy_basis is False, the old basis is extended in-place.
    copy_U
        If copy_U is False, the new basis vectors are removed from U.

    Returns
    -------
    The new basis.

    Raises
    ------
    ExtensionError
        Is raised if U is already contained in basis.
    '''
    if basis is None:
        basis = NumpyVectorArray(np.zeros((0, U.dim)))

    if np.any(U.almost_equal(basis, ind=U_ind)):
        raise ExtensionError

    new_basis = basis.copy() if copy_basis else basis
    new_basis.append(U, o_ind=U_ind, remove_from_other=(not copy_U))

    return new_basis