Example #1
0
def _abs(matrix: Matrix.Type()):
    abs_matrix = numpy_abs(matrix)
    (i_max, j_max) = unravel_index(argmax(abs_matrix), abs_matrix.shape)
    (i_max, j_max) = (
        int(i_max), int(j_max)
    )  # numpy.intXX types are not subclasses of int, but can be converted to int
    return AbsOutput(matrix[(i_max, j_max)], (i_max, j_max))
Example #2
0
File: abs.py Project: mfkiwl/RBniCS
def _abs(matrix: Matrix.Type()):
    abs_matrix = numpy_abs(matrix)
    (i_max, j_max) = unravel_index(argmax(abs_matrix), abs_matrix.shape)
    # i_max and j_max are of type numpy.intXX which is not a subclass of int, but can be converted to int
    (i_max, j_max) = (int(i_max), int(j_max))
    return AbsOutput(matrix[(i_max, j_max)], (i_max, j_max))
Example #3
0
File: abs.py Project: mfkiwl/RBniCS
# Copyright (C) 2015-2021 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from numpy import argmax, abs as numpy_abs, unravel_index
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import backend_for, overload


# abs function to compute maximum absolute value of an expression, matrix or vector (for EIM).
# To be used in combination with max, even though here we actually carry out both the max and the abs!
@backend_for("numpy", inputs=((Matrix.Type(), Vector.Type()), ))
def abs(expression):
    return _abs(expression)


@overload
def _abs(matrix: Matrix.Type()):
    abs_matrix = numpy_abs(matrix)
    (i_max, j_max) = unravel_index(argmax(abs_matrix), abs_matrix.shape)
    # i_max and j_max are of type numpy.intXX which is not a subclass of int, but can be converted to int
    (i_max, j_max) = (int(i_max), int(j_max))
    return AbsOutput(matrix[(i_max, j_max)], (i_max, j_max))


@overload
def _abs(vector: Vector.Type()):
    abs_vector = numpy_abs(vector)
Example #4
0
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from numpy.linalg import solve
from rbnics.backends.abstract import LinearProblemWrapper
from rbnics.backends.online.basic import LinearSolver as BasicLinearSolver
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.transpose import DelayedTransposeWithArithmetic
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ModuleWrapper, ThetaType

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping = ModuleWrapper(DelayedTransposeWithArithmetic=DelayedTransposeWithArithmetic)
LinearSolver_Base = BasicLinearSolver(backend, wrapping)


@BackendFor("numpy", inputs=((Matrix.Type(), DelayedTransposeWithArithmetic, LinearProblemWrapper),
                             Function.Type(),
                             (Vector.Type(), DelayedTransposeWithArithmetic, None),
                             ThetaType + DictOfThetaType + (None,)))
class LinearSolver(LinearSolver_Base):
    def set_parameters(self, parameters):
        assert len(parameters) == 0, "NumPy linear solver does not accept parameters yet"

    def solve(self):
        solution = solve(self.lhs, self.rhs)
        self.solution.vector()[:] = solution
        if self.monitor is not None:
            self.monitor(self.solution)
Example #5
0
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from numpy import real, imag
from scipy.linalg import eig, eigh
from rbnics.backends.abstract import FunctionsList as AbstractFunctionsList
from rbnics.backends.abstract import EigenSolver as AbstractEigenSolver
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ThetaType

@BackendFor("numpy", inputs=((AbstractFunctionsList, None), Matrix.Type(), (Matrix.Type(), None), ThetaType + DictOfThetaType + (None,)))
class EigenSolver(AbstractEigenSolver):
    def __init__(self, basis_functions, A, B=None, bcs=None):
        assert A.N == A.M
        if B is not None:
            assert B.N == B.M
            assert A.N == B.M
        
        self.A = A
        self.B = B
        self.parameters = dict()
        self.eigs = None
        self.eigv = None
        assert bcs is None # the case bcs != None has not been implemented yet
        
    def set_parameters(self, parameters):
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.abstract import FunctionsList as AbstractFunctionsList
from rbnics.backends.abstract import ProperOrthogonalDecomposition as AbstractProperOrthogonalDecomposition
from rbnics.backends.basic import ProperOrthogonalDecompositionBase as BasicProperOrthogonalDecomposition
from rbnics.backends.online.numpy.eigen_solver import EigenSolver
from rbnics.backends.online.numpy.functions_list import FunctionsList
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.snapshots_matrix import SnapshotsMatrix
from rbnics.backends.online.numpy.transpose import transpose
from rbnics.backends.online.numpy.wrapping import get_mpi_comm
from rbnics.utils.decorators import BackendFor, ModuleWrapper

backend = ModuleWrapper(transpose)
wrapping = ModuleWrapper(get_mpi_comm)
online_backend = ModuleWrapper(OnlineEigenSolver=EigenSolver)
online_wrapping = ModuleWrapper()
ProperOrthogonalDecomposition_Base = BasicProperOrthogonalDecomposition(
    backend, wrapping, online_backend, online_wrapping, AbstractProperOrthogonalDecomposition, SnapshotsMatrix,
    FunctionsList)


@BackendFor("numpy", inputs=(AbstractFunctionsList, Matrix.Type(), (str, None)))
class ProperOrthogonalDecomposition(ProperOrthogonalDecomposition_Base):
    def __init__(self, basis_functions, inner_product, component=None):
        ProperOrthogonalDecomposition_Base.__init__(self, basis_functions, inner_product, component)

    def store_snapshot(self, snapshot, component=None, weight=None):
        self.snapshots_matrix.enrich(snapshot, component, weight)
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from rbnics.backends.online.basic import AffineExpansionStorage as BasicAffineExpansionStorage
from rbnics.backends.online.numpy.copy import function_copy, tensor_copy
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.wrapping import function_load, function_save, tensor_load, tensor_save
from rbnics.utils.decorators import BackendFor, ModuleWrapper, tuple_of

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping = ModuleWrapper(function_load,
                         function_save,
                         tensor_load,
                         tensor_save,
                         function_copy=function_copy,
                         tensor_copy=tensor_copy)
AffineExpansionStorage_Base = BasicAffineExpansionStorage(backend, wrapping)


@BackendFor("numpy",
            inputs=((int, tuple_of(Matrix.Type()), tuple_of(Vector.Type())),
                    (int, None)))
class AffineExpansionStorage(AffineExpansionStorage_Base):
    def __init__(self, arg1, arg2=None):
        AffineExpansionStorage_Base.__init__(self, arg1, arg2)
# Copyright (C) 2015-2020 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.online.basic import NonAffineExpansionStorage as BasicNonAffineExpansionStorage
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, tuple_of

NonAffineExpansionStorage_Base = BasicNonAffineExpansionStorage


@BackendFor("numpy", inputs=((int, tuple_of(Matrix.Type()), tuple_of(Vector.Type())), (int, None)))
class NonAffineExpansionStorage(NonAffineExpansionStorage_Base):
    pass
Example #9
0
# Copyright (C) 2015-2020 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.basic import export as basic_export
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.wrapping import function_save, tensor_save
from rbnics.utils.decorators import backend_for, ModuleWrapper
from rbnics.utils.io import Folders

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping = ModuleWrapper(function_save, tensor_save)
export_base = basic_export(backend, wrapping)


# Export a solution to file
@backend_for("numpy", inputs=((Function.Type(), Matrix.Type(), Vector.Type()), (Folders.Folder, str),
                              str, (int, None), (int, str, None)))
def export(solution, directory, filename, suffix=None, component=None):
    export_base(solution, directory, filename, suffix, component)
Example #10
0
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from numpy import real, imag
from scipy.linalg import eig, eigh
from rbnics.backends.abstract import FunctionsList as AbstractFunctionsList
from rbnics.backends.abstract import EigenSolver as AbstractEigenSolver
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ThetaType


@BackendFor("numpy",
            inputs=((AbstractFunctionsList,
                     None), Matrix.Type(), (Matrix.Type(), None),
                    ThetaType + DictOfThetaType + (None, )))
class EigenSolver(AbstractEigenSolver):
    def __init__(self, basis_functions, A, B=None, bcs=None):
        assert A.N == A.M
        if B is not None:
            assert B.N == B.M
            assert A.N == B.M

        self.A = A
        self.B = B
        self.parameters = dict()
        self.eigs = None
        self.eigv = None
        assert bcs is None  # the case bcs != None has not been implemented yet
Example #11
0
# Copyright (C) 2015-2021 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.online.basic.assign import assign as basic_assign
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import backend_for, list_of, ModuleWrapper

backend = ModuleWrapper(Function, Matrix, Vector)
assign_base = basic_assign(backend)


@backend_for("numpy", inputs=((Function.Type(), list_of(Function.Type()), Matrix.Type(), Vector.Type()),
                              (Function.Type(), list_of(Function.Type()), Matrix.Type(), Vector.Type())))
def assign(object_to, object_from):
    assign_base(object_to, object_from)
Example #12
0
def _adjoint(arg: tuple_of(Matrix.Type())):
    return tuple(_adjoint(a) for a in arg)
Example #13
0
def _adjoint(arg: Matrix.Type()):
    return Matrix.Type()(arg.M, arg.N, arg.content.T)
Example #14
0
# This file is part of RBniCS.
#
# RBniCS is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RBniCS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.utils.decorators import backend_for, overload, tuple_of

@backend_for("numpy", inputs=((Matrix.Type(), tuple_of(Matrix.Type())), ))
def adjoint(arg):
    return _adjoint(arg)
    
@overload
def _adjoint(arg: Matrix.Type()):
    return Matrix.Type()(arg.M, arg.N, arg.content.T)
    
@overload
def _adjoint(arg: tuple_of(Matrix.Type())):
    return tuple(_adjoint(a) for a in arg)
Example #15
0
#
# This file is part of RBniCS.
#
# RBniCS is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RBniCS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from rbnics.backends.basic import GramSchmidt as BasicGramSchmidt
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.transpose import transpose
from rbnics.backends.online.numpy.wrapping import gram_schmidt_projection_step
from rbnics.utils.decorators import BackendFor, ModuleWrapper

backend = ModuleWrapper(transpose)
wrapping = ModuleWrapper(gram_schmidt_projection_step)
GramSchmidt_Base = BasicGramSchmidt(backend, wrapping)

@BackendFor("numpy", inputs=(Matrix.Type(), ))
class GramSchmidt(GramSchmidt_Base):
    pass
Example #16
0
# SPDX-License-Identifier: LGPL-3.0-or-later

# from rbnics.backends.online.basic import evaluate as basic_evaluate
# from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
# from rbnics.backends.online.numpy.parametrized_expression_factory import ParametrizedExpressionFactory
# from rbnics.backends.online.numpy.parametrized_tensor_factory import ParametrizedTensorFactory
# from rbnics.backends.online.numpy.reduced_mesh import ReducedMesh
# from rbnics.backends.online.numpy.reduced_vertices import ReducedVertices
# from rbnics.backends.online.numpy.tensors_list import TensorsList
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import backend_for, tuple_of

# backend = ModuleWrapper(Function, FunctionsList, Matrix, ParametrizedExpressionFactory, ParametrizedTensorFactory,
#                         ReducedMesh, ReducedVertices, TensorsList, Vector)
# wrapping = ModuleWrapper(evaluate_and_vectorize_sparse_matrix_at_dofs, evaluate_sparse_function_at_dofs,
#                          evaluate_sparse_vector_at_dofs, expression_on_reduced_mesh, expression_on_truth_mesh,
#                          form_on_reduced_function_space, form_on_truth_function_space)
# online_backend = ModuleWrapper(OnlineFunction=Function, OnlineMatrix=Matrix, OnlineVector=Vector)
# online_wrapping = ModuleWrapper()
# evaluate_base = basic_evaluate(backend, wrapping, online_backend, online_wrapping)
evaluate_base = None  # TODO


# Evaluate a parametrized expression, possibly at a specific location
@backend_for("numpy",
             inputs=((Matrix.Type(), Vector.Type()),
                     (tuple_of(int), tuple_of(tuple_of(int)), None)))
def evaluate(expression, at=None):
    return evaluate_base(expression, at)
Example #17
0
from numpy.linalg import solve
from rbnics.backends.abstract import LinearProblemWrapper
from rbnics.backends.online.basic import LinearSolver as BasicLinearSolver
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.transpose import DelayedTransposeWithArithmetic
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ModuleWrapper, ThetaType

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping = ModuleWrapper(
    DelayedTransposeWithArithmetic=DelayedTransposeWithArithmetic)
LinearSolver_Base = BasicLinearSolver(backend, wrapping)


@BackendFor("numpy",
            inputs=((Matrix.Type(), DelayedTransposeWithArithmetic,
                     LinearProblemWrapper), Function.Type(),
                    (Vector.Type(), DelayedTransposeWithArithmetic,
                     None), ThetaType + DictOfThetaType + (None, )))
class LinearSolver(LinearSolver_Base):
    def set_parameters(self, parameters):
        assert len(parameters
                   ) == 0, "NumPy linear solver does not accept parameters yet"

    def solve(self):
        solution = solve(self.lhs, self.rhs)
        self.solution.vector()[:] = solution
        if self.monitor is not None:
            self.monitor(self.solution)
Example #18
0
from rbnics.backends.abstract import FunctionsList as AbstractFunctionsList
from rbnics.backends.abstract import ProperOrthogonalDecomposition as AbstractProperOrthogonalDecomposition
from rbnics.backends.basic import ProperOrthogonalDecompositionBase as BasicProperOrthogonalDecomposition
from rbnics.backends.online.numpy.eigen_solver import EigenSolver
from rbnics.backends.online.numpy.functions_list import FunctionsList
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.snapshots_matrix import SnapshotsMatrix
from rbnics.backends.online.numpy.transpose import transpose
from rbnics.backends.online.numpy.wrapping import get_mpi_comm
from rbnics.utils.decorators import BackendFor, ModuleWrapper

backend = ModuleWrapper(transpose)
wrapping = ModuleWrapper(get_mpi_comm)
online_backend = ModuleWrapper(OnlineEigenSolver=EigenSolver)
online_wrapping = ModuleWrapper()
ProperOrthogonalDecomposition_Base = BasicProperOrthogonalDecomposition(
    backend, wrapping, online_backend, online_wrapping,
    AbstractProperOrthogonalDecomposition, SnapshotsMatrix, FunctionsList)


@BackendFor("numpy",
            inputs=(AbstractFunctionsList, Matrix.Type(), (str, None)))
class ProperOrthogonalDecomposition(ProperOrthogonalDecomposition_Base):
    def __init__(self, basis_functions, inner_product, component=None):
        ProperOrthogonalDecomposition_Base.__init__(self, basis_functions,
                                                    inner_product, component)

    def store_snapshot(self, snapshot, component=None, weight=None):
        self.snapshots_matrix.enrich(snapshot, component, weight)
Example #19
0
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from numpy import real, imag
from scipy.linalg import eig, eigh
from rbnics.backends.abstract import FunctionsList as AbstractFunctionsList
from rbnics.backends.abstract import EigenSolver as AbstractEigenSolver
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ThetaType


@BackendFor("numpy", inputs=((AbstractFunctionsList, None), Matrix.Type(), (Matrix.Type(), None),
                             ThetaType + DictOfThetaType + (None,)))
class EigenSolver(AbstractEigenSolver):
    def __init__(self, basis_functions, A, B=None, bcs=None):
        assert A.N == A.M
        if B is not None:
            assert B.N == B.M
            assert A.N == B.M

        self.A = A
        self.B = B
        self.parameters = dict()
        self.eigs = None
        self.eigv = None
        assert bcs is None  # the case bcs != None has not been implemented yet
Example #20
0
# Copyright (C) 2015-2021 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.basic import copy as basic_copy
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.wrapping.function_copy import basic_function_copy
from rbnics.backends.online.numpy.wrapping.tensor_copy import basic_tensor_copy
from rbnics.utils.decorators import backend_for, list_of, ModuleWrapper

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping_for_wrapping = ModuleWrapper()
function_copy = basic_function_copy(backend, wrapping_for_wrapping)
tensor_copy = basic_tensor_copy(backend, wrapping_for_wrapping)
wrapping = ModuleWrapper(function_copy=function_copy, tensor_copy=tensor_copy)
copy_base = basic_copy(backend, wrapping)


@backend_for("numpy",
             inputs=((Function.Type(), list_of(Function.Type()), Matrix.Type(),
                      Vector.Type()), ))
def copy(arg):
    return copy_base(arg)
Example #21
0
# RBniCS is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from numpy.linalg import solve
from rbnics.backends.online.basic import LinearSolver as BasicLinearSolver
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.product import DelayedTransposeWithArithmetic
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.function import Function
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ModuleWrapper, ThetaType

backend = ModuleWrapper(Matrix, Vector)
wrapping = ModuleWrapper(DelayedTransposeWithArithmetic=DelayedTransposeWithArithmetic)
LinearSolver_Base = BasicLinearSolver(backend, wrapping)

@BackendFor("numpy", inputs=(Matrix.Type(), Function.Type(), Vector.Type(), ThetaType + DictOfThetaType + (None,)))
class LinearSolver(LinearSolver_Base):
    def set_parameters(self, parameters):
        assert len(parameters) == 0, "NumPy linear solver does not accept parameters yet"
        
    def solve(self):
        solution = solve(self.lhs, self.rhs)
        self.solution.vector()[:] = solution
        return self.solution
Example #22
0
# Copyright (C) 2015-2021 by the RBniCS authors
#
# This file is part of RBniCS.
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from rbnics.backends.basic import export as basic_export
from rbnics.backends.online.numpy.function import Function
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.wrapping import function_save, tensor_save
from rbnics.utils.decorators import backend_for, ModuleWrapper
from rbnics.utils.io import Folders

backend = ModuleWrapper(Function, Matrix, Vector)
wrapping = ModuleWrapper(function_save, tensor_save)
export_base = basic_export(backend, wrapping)


# Export a solution to file
@backend_for("numpy",
             inputs=((Function.Type(), Matrix.Type(), Vector.Type()),
                     (Folders.Folder, str), str, (int, None), (int, str, None))
             )
def export(solution, directory, filename, suffix=None, component=None):
    export_base(solution, directory, filename, suffix, component)
Example #23
0
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RBniCS. If not, see <http://www.gnu.org/licenses/>.
#

from numpy.linalg import solve
from rbnics.backends.online.basic import LinearSolver as BasicLinearSolver
from rbnics.backends.online.numpy.matrix import Matrix
from rbnics.backends.online.numpy.vector import Vector
from rbnics.backends.online.numpy.function import Function
from rbnics.utils.decorators import BackendFor, DictOfThetaType, ThetaType

LinearSolver_Base = BasicLinearSolver


@BackendFor("numpy",
            inputs=(Matrix.Type(), Function.Type(), Vector.Type(),
                    ThetaType + DictOfThetaType + (None, )))
class LinearSolver(LinearSolver_Base):
    def set_parameters(self, parameters):
        assert len(parameters
                   ) == 0, "NumPy linear solver does not accept parameters yet"

    def solve(self):
        solution = solve(self.lhs, self.rhs)
        self.solution.vector()[:] = solution
        return self.solution