def OfflineOnlineExpansionStorageSize(problem_name):
    if problem_name not in _offline_online_expansion_storage_size_cache:
        _OfflineOnlineExpansionStorageSize_Base = OfflineOnlineSwitch(
            problem_name)

        class _OfflineOnlineExpansionStorageSize(
                _OfflineOnlineExpansionStorageSize_Base):
            def __init__(self):
                _OfflineOnlineExpansionStorageSize_Base.__init__(self)
                self._content = {"offline": dict(), "online": dict()}

            def __getitem__(self, term):
                return self._content[_OfflineOnlineExpansionStorageSize_Base.
                                     _current_stage][term]

            def __setitem__(self, term, size):
                self._content[_OfflineOnlineExpansionStorageSize_Base.
                              _current_stage][term] = size

            def __contains__(self, term):
                return term in self._content[
                    _OfflineOnlineExpansionStorageSize_Base._current_stage]

        _offline_online_expansion_storage_size_cache[
            problem_name] = _OfflineOnlineExpansionStorageSize

    return _offline_online_expansion_storage_size_cache[problem_name]
예제 #2
0
def OfflineOnlineClassMethod(problem_name):
    _OfflineOnlineClassMethod_Base = OfflineOnlineSwitch(problem_name)

    class _OfflineOnlineClassMethod(_OfflineOnlineClassMethod_Base):
        def __init__(self, problem, original_class_method_name):
            _OfflineOnlineClassMethod_Base.__init__(self)
            assert hasattr(problem, original_class_method_name)
            self._original_class_method = getattr(problem,
                                                  original_class_method_name)
            self._replacement_condition = dict()

        def __call__(self, term):
            if self._replacement_condition[
                    _OfflineOnlineClassMethod_Base._current_stage](term):
                return self._content[
                    _OfflineOnlineClassMethod_Base._current_stage](term)
            else:
                return self._original_class_method(term)

        def attach(self, replaced_class_method, replacement_condition):
            if _OfflineOnlineClassMethod_Base._current_stage not in self._content:
                assert _OfflineOnlineClassMethod_Base._current_stage not in self._replacement_condition
                self._content[_OfflineOnlineClassMethod_Base.
                              _current_stage] = replaced_class_method
                self._replacement_condition[
                    _OfflineOnlineClassMethod_Base.
                    _current_stage] = replacement_condition
            else:
                assert replaced_class_method == self._content[
                    _OfflineOnlineClassMethod_Base._current_stage]
                # assert replacement_condition == self._replacement_condition[_OfflineOnlineClassMethod_Base._current_stage] # disabled because cannot easily compare lambda functions

    return _OfflineOnlineClassMethod
예제 #3
0
def OfflineOnlineRieszSolver(problem_name):

    _OfflineOnlineRieszSolver_Base = OfflineOnlineSwitch(problem_name)

    class _OfflineOnlineRieszSolver(_OfflineOnlineRieszSolver_Base):
        def __call__(self, problem):
            return _OfflineOnlineRieszSolver._RieszSolver(
                problem,
                self._content[_OfflineOnlineRieszSolver_Base._current_stage])

        def set_is_affine(self, is_affine):
            assert isinstance(is_affine, bool)
            if is_affine:
                delay = False
            else:
                delay = True
            if _OfflineOnlineRieszSolver_Base._current_stage not in self._content:
                self._content[
                    _OfflineOnlineRieszSolver_Base._current_stage] = delay
            else:
                assert delay is self._content[
                    _OfflineOnlineRieszSolver_Base._current_stage]

        def unset_is_affine(self):
            pass

        class _RieszSolver(object):
            def __init__(self, problem, delay):
                self.problem = problem
                self.delay = delay

            @overload
            def solve(self, rhs: object):
                problem = self.problem
                args = (problem._riesz_solve_inner_product,
                        problem._riesz_solve_storage, rhs,
                        problem._riesz_solve_homogeneous_dirichlet_bc)
                if not self.delay:
                    solver = LinearSolver(*args)
                    solver.set_parameters(problem._linear_solver_parameters)
                    solver.solve()
                    return problem._riesz_solve_storage
                else:
                    solver = DelayedLinearSolver(*args)
                    solver.set_parameters(problem._linear_solver_parameters)
                    return solver

            @overload
            def solve(self, coef: Number, matrix: object,
                      basis_function: object):
                if not self.delay:
                    rhs = coef * matrix * basis_function
                else:
                    rhs = DelayedProduct(coef)
                    rhs *= matrix
                    rhs *= basis_function
                return self.solve(rhs)

    return _OfflineOnlineRieszSolver
예제 #4
0
def OfflineOnlineBackend(problem_name):
    return types.SimpleNamespace(
        OfflineOnlineClassMethod=OfflineOnlineClassMethod(problem_name),
        OfflineOnlineExpansionStorage=OfflineOnlineExpansionStorage(
            problem_name),
        OfflineOnlineExpansionStorageSize=OfflineOnlineExpansionStorageSize(
            problem_name),
        OfflineOnlineRieszSolver=OfflineOnlineRieszSolver(problem_name),
        OfflineOnlineSwitch=OfflineOnlineSwitch(problem_name))
예제 #5
0
def OfflineOnlineBackend(problem_name):
    if problem_name not in _offline_online_backend_cache:
        _offline_online_backend_cache[problem_name] = types.SimpleNamespace(
            OfflineOnlineClassMethod=OfflineOnlineClassMethod(problem_name),
            OfflineOnlineExpansionStorage=OfflineOnlineExpansionStorage(
                problem_name),
            OfflineOnlineExpansionStorageSize=OfflineOnlineExpansionStorageSize(
                problem_name),
            OfflineOnlineRieszSolver=OfflineOnlineRieszSolver(problem_name),
            OfflineOnlineSwitch=OfflineOnlineSwitch(problem_name))

    return _offline_online_backend_cache[problem_name]
def OfflineOnlineExpansionStorage(problem_name):
    _OfflineOnlineExpansionStorage_Base = OfflineOnlineSwitch(problem_name)

    class _OfflineOnlineExpansionStorage(_OfflineOnlineExpansionStorage_Base):
        def __init__(self, problem, expansion_storage_type_attribute):
            _OfflineOnlineExpansionStorage_Base.__init__(self)
            self._content = {"offline": dict(), "online": dict()}
            self._problem = problem
            self._expansion_storage_type_attribute = expansion_storage_type_attribute
            setattr(problem, expansion_storage_type_attribute, None)

        def set_is_affine(self, is_affine):
            assert isinstance(is_affine, bool)
            if is_affine:
                setattr(self._problem, self._expansion_storage_type_attribute,
                        AffineExpansionStorage)
            else:
                setattr(self._problem, self._expansion_storage_type_attribute,
                        NonAffineExpansionStorage)

        def unset_is_affine(self):
            setattr(self._problem, self._expansion_storage_type_attribute,
                    None)

        def __getitem__(self, term):
            return self._content[
                _OfflineOnlineExpansionStorage_Base._current_stage][term]

        def __setitem__(self, term, expansion_storage):
            def patch_save_load(expansion_storage):
                def _patch_save_load(expansion_storage, method):
                    if not hasattr(expansion_storage, method + "_patched"):
                        original_method = getattr(expansion_storage, method)

                        def patched_method(self, directory, filename):
                            # Get full directory name
                            full_directory = Folders.Folder(
                                os.path.join(
                                    str(directory),
                                    _OfflineOnlineExpansionStorage_Base.
                                    _current_stage))
                            full_directory.create()
                            # Call original implementation
                            return original_method(full_directory, filename)

                        PatchInstanceMethod(expansion_storage, method,
                                            patched_method).patch()
                        setattr(expansion_storage, method + "_patched", True)

                assert (hasattr(expansion_storage,
                                "save") == hasattr(expansion_storage, "load"))
                if hasattr(expansion_storage, "save"):
                    for method in ("save", "load"):
                        _patch_save_load(expansion_storage, method)

            patch_save_load(expansion_storage)
            self._content[_OfflineOnlineExpansionStorage_Base.
                          _current_stage][term] = expansion_storage

        def __contains__(self, term):
            return term in self._content[
                _OfflineOnlineExpansionStorage_Base._current_stage]

        def __len__(self):
            return len(self._content[
                _OfflineOnlineExpansionStorage_Base._current_stage])

        def items(self):
            return self._content[
                _OfflineOnlineExpansionStorage_Base._current_stage].items()

    return _OfflineOnlineExpansionStorage