def OfflineOnlineExpansionStorageSize(problem_name): if problem_name not in _offline_online_expansion_storage_size_cache: _OfflineOnlineExpansionStorageSize_Base = OfflineOnlineSwitch( problem_name) class _OfflineOnlineExpansionStorageSize( _OfflineOnlineExpansionStorageSize_Base): def __init__(self): _OfflineOnlineExpansionStorageSize_Base.__init__(self) self._content = {"offline": dict(), "online": dict()} def __getitem__(self, term): return self._content[_OfflineOnlineExpansionStorageSize_Base. _current_stage][term] def __setitem__(self, term, size): self._content[_OfflineOnlineExpansionStorageSize_Base. _current_stage][term] = size def __contains__(self, term): return term in self._content[ _OfflineOnlineExpansionStorageSize_Base._current_stage] _offline_online_expansion_storage_size_cache[ problem_name] = _OfflineOnlineExpansionStorageSize return _offline_online_expansion_storage_size_cache[problem_name]
def OfflineOnlineClassMethod(problem_name): _OfflineOnlineClassMethod_Base = OfflineOnlineSwitch(problem_name) class _OfflineOnlineClassMethod(_OfflineOnlineClassMethod_Base): def __init__(self, problem, original_class_method_name): _OfflineOnlineClassMethod_Base.__init__(self) assert hasattr(problem, original_class_method_name) self._original_class_method = getattr(problem, original_class_method_name) self._replacement_condition = dict() def __call__(self, term): if self._replacement_condition[ _OfflineOnlineClassMethod_Base._current_stage](term): return self._content[ _OfflineOnlineClassMethod_Base._current_stage](term) else: return self._original_class_method(term) def attach(self, replaced_class_method, replacement_condition): if _OfflineOnlineClassMethod_Base._current_stage not in self._content: assert _OfflineOnlineClassMethod_Base._current_stage not in self._replacement_condition self._content[_OfflineOnlineClassMethod_Base. _current_stage] = replaced_class_method self._replacement_condition[ _OfflineOnlineClassMethod_Base. _current_stage] = replacement_condition else: assert replaced_class_method == self._content[ _OfflineOnlineClassMethod_Base._current_stage] # assert replacement_condition == self._replacement_condition[_OfflineOnlineClassMethod_Base._current_stage] # disabled because cannot easily compare lambda functions return _OfflineOnlineClassMethod
def OfflineOnlineRieszSolver(problem_name): _OfflineOnlineRieszSolver_Base = OfflineOnlineSwitch(problem_name) class _OfflineOnlineRieszSolver(_OfflineOnlineRieszSolver_Base): def __call__(self, problem): return _OfflineOnlineRieszSolver._RieszSolver( problem, self._content[_OfflineOnlineRieszSolver_Base._current_stage]) def set_is_affine(self, is_affine): assert isinstance(is_affine, bool) if is_affine: delay = False else: delay = True if _OfflineOnlineRieszSolver_Base._current_stage not in self._content: self._content[ _OfflineOnlineRieszSolver_Base._current_stage] = delay else: assert delay is self._content[ _OfflineOnlineRieszSolver_Base._current_stage] def unset_is_affine(self): pass class _RieszSolver(object): def __init__(self, problem, delay): self.problem = problem self.delay = delay @overload def solve(self, rhs: object): problem = self.problem args = (problem._riesz_solve_inner_product, problem._riesz_solve_storage, rhs, problem._riesz_solve_homogeneous_dirichlet_bc) if not self.delay: solver = LinearSolver(*args) solver.set_parameters(problem._linear_solver_parameters) solver.solve() return problem._riesz_solve_storage else: solver = DelayedLinearSolver(*args) solver.set_parameters(problem._linear_solver_parameters) return solver @overload def solve(self, coef: Number, matrix: object, basis_function: object): if not self.delay: rhs = coef * matrix * basis_function else: rhs = DelayedProduct(coef) rhs *= matrix rhs *= basis_function return self.solve(rhs) return _OfflineOnlineRieszSolver
def OfflineOnlineBackend(problem_name): return types.SimpleNamespace( OfflineOnlineClassMethod=OfflineOnlineClassMethod(problem_name), OfflineOnlineExpansionStorage=OfflineOnlineExpansionStorage( problem_name), OfflineOnlineExpansionStorageSize=OfflineOnlineExpansionStorageSize( problem_name), OfflineOnlineRieszSolver=OfflineOnlineRieszSolver(problem_name), OfflineOnlineSwitch=OfflineOnlineSwitch(problem_name))
def OfflineOnlineBackend(problem_name): if problem_name not in _offline_online_backend_cache: _offline_online_backend_cache[problem_name] = types.SimpleNamespace( OfflineOnlineClassMethod=OfflineOnlineClassMethod(problem_name), OfflineOnlineExpansionStorage=OfflineOnlineExpansionStorage( problem_name), OfflineOnlineExpansionStorageSize=OfflineOnlineExpansionStorageSize( problem_name), OfflineOnlineRieszSolver=OfflineOnlineRieszSolver(problem_name), OfflineOnlineSwitch=OfflineOnlineSwitch(problem_name)) return _offline_online_backend_cache[problem_name]
def OfflineOnlineExpansionStorage(problem_name): _OfflineOnlineExpansionStorage_Base = OfflineOnlineSwitch(problem_name) class _OfflineOnlineExpansionStorage(_OfflineOnlineExpansionStorage_Base): def __init__(self, problem, expansion_storage_type_attribute): _OfflineOnlineExpansionStorage_Base.__init__(self) self._content = {"offline": dict(), "online": dict()} self._problem = problem self._expansion_storage_type_attribute = expansion_storage_type_attribute setattr(problem, expansion_storage_type_attribute, None) def set_is_affine(self, is_affine): assert isinstance(is_affine, bool) if is_affine: setattr(self._problem, self._expansion_storage_type_attribute, AffineExpansionStorage) else: setattr(self._problem, self._expansion_storage_type_attribute, NonAffineExpansionStorage) def unset_is_affine(self): setattr(self._problem, self._expansion_storage_type_attribute, None) def __getitem__(self, term): return self._content[ _OfflineOnlineExpansionStorage_Base._current_stage][term] def __setitem__(self, term, expansion_storage): def patch_save_load(expansion_storage): def _patch_save_load(expansion_storage, method): if not hasattr(expansion_storage, method + "_patched"): original_method = getattr(expansion_storage, method) def patched_method(self, directory, filename): # Get full directory name full_directory = Folders.Folder( os.path.join( str(directory), _OfflineOnlineExpansionStorage_Base. _current_stage)) full_directory.create() # Call original implementation return original_method(full_directory, filename) PatchInstanceMethod(expansion_storage, method, patched_method).patch() setattr(expansion_storage, method + "_patched", True) assert (hasattr(expansion_storage, "save") == hasattr(expansion_storage, "load")) if hasattr(expansion_storage, "save"): for method in ("save", "load"): _patch_save_load(expansion_storage, method) patch_save_load(expansion_storage) self._content[_OfflineOnlineExpansionStorage_Base. _current_stage][term] = expansion_storage def __contains__(self, term): return term in self._content[ _OfflineOnlineExpansionStorage_Base._current_stage] def __len__(self): return len(self._content[ _OfflineOnlineExpansionStorage_Base._current_stage]) def items(self): return self._content[ _OfflineOnlineExpansionStorage_Base._current_stage].items() return _OfflineOnlineExpansionStorage