def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if not self.linear and self.func == block_variable.output: # We are not able to calculate derivatives wrt initial guess. return None F_form = prepared["form"] adj_sol = prepared["adj_sol"] adj_sol_bdy = prepared["adj_sol_bdy"] c = block_variable.output c_rep = block_variable.saved_output if isinstance(c, firedrake.Function): trial_function = firedrake.TrialFunction(c.function_space()) elif isinstance(c, firedrake.Constant): mesh = self.compat.extract_mesh_from_form(F_form) trial_function = firedrake.TrialFunction( c._ad_function_space(mesh)) elif isinstance(c, firedrake.DirichletBC): tmp_bc = self.compat.create_bc( c, value=self.compat.extract_subfunction(adj_sol_bdy, c.function_space())) return [tmp_bc] elif isinstance(c, self.compat.MeshType): # Using CoordianteDerivative requires us to do action before # differentiating, might change in the future. F_form_tmp = firedrake.action(F_form, adj_sol) X = firedrake.SpatialCoordinate(c_rep) dFdm = firedrake.derivative( -F_form_tmp, X, firedrake.TestFunction(c._ad_function_space())) dFdm = self.compat.assemble_adjoint_value(dFdm, **self.assemble_kwargs) return dFdm # dFdm_cache works with original variables, not block saved outputs. if c in self._dFdm_cache: dFdm = self._dFdm_cache[c] else: dFdm = -firedrake.derivative(self.lhs, c, trial_function) dFdm = firedrake.adjoint(dFdm) self._dFdm_cache[c] = dFdm # Replace the form coefficients with checkpointed values. replace_map = self._replace_map(dFdm) replace_map[self.func] = self.get_outputs()[0].saved_output dFdm = replace(dFdm, replace_map) dFdm = dFdm * adj_sol dFdm = self.compat.assemble_adjoint_value(dFdm, **self.assemble_kwargs) return dFdm
def update_adjoint_state(self): r"""Update the adjoint state for new values of the observable state and parameters so that we can calculate derivatives""" λ = self.adjoint_state L = adjoint(self._dF_du) firedrake.solve(L == -self._dE, λ, self._bc, solver_parameters=self._solver_params, form_compiler_parameters=self._fc_params)
def gauss_newton_mult(self, q): """Multiply a field by the Gauss-Newton operator""" u, p = self.state, self.parameter dE = derivative(self._E, u) dR = derivative(self._R, p) dF_du, dF_dp = self._dF_du, derivative(self._F, p) w = firedrake.Function(u.function_space()) firedrake.solve(dF_du == action(dF_dp, q), w, self._bc, solver_parameters=self._solver_params, form_compiler_parameters=self._fc_params) v = firedrake.Function(u.function_space()) firedrake.solve(adjoint(dF_du) == derivative(dE, u, w), v, self._bc, solver_parameters=self._solver_params, form_compiler_parameters=self._fc_params) return action(adjoint(dF_dp), v) + derivative(dR, p, q)
def _setup(self, problem, callback=(lambda s: None)): self._problem = problem self._callback = callback self._p = problem.parameter.copy(deepcopy=True) self._u = problem.state.copy(deepcopy=True) self._model_args = dict(**problem.model_args, dirichlet_ids=problem.dirichlet_ids) u_name, p_name = problem.state_name, problem.parameter_name args = dict(**self._model_args, **{u_name: self._u, p_name: self._p}) # Make the form compiler use a reasonable number of quadrature points degree = problem.model.quadrature_degree(**args) self._fc_params = {'quadrature_degree': degree} # Create the error, regularization, and barrier functionals self._E = problem.objective(self._u) self._R = problem.regularization(self._p) self._J = self._E + self._R # Create the weak form of the forward model, the adjoint state, and # the derivative of the objective functional self._F = derivative(problem.model.action(**args), self._u) self._dF_du = derivative(self._F, self._u) # Create a search direction dR = derivative(self._R, self._p) self._solver_params = {'ksp_type': 'preonly', 'pc_type': 'lu'} Q = self._p.function_space() self._q = firedrake.Function(Q) # Create the adjoint state variable V = self.state.function_space() self._λ = firedrake.Function(V) dF_dp = derivative(self._F, self._p) # Create Dirichlet BCs where they apply for the adjoint solve rank = self._λ.ufl_element().num_sub_elements() if rank == 0: zero = firedrake.Constant(0) else: zero = firedrake.as_vector((0, ) * rank) self._bc = firedrake.DirichletBC(V, zero, problem.dirichlet_ids) # Create the derivative of the objective functional self._dE = derivative(self._E, self._u) dR = derivative(self._R, self._p) self._dJ = (action(adjoint(dF_dp), self._λ) + dR)
def wrapper(self, *args, **kwargs): from firedrake import derivative, adjoint, TrialFunction init(self, *args, **kwargs) self._ad_F = self.F self._ad_u = self.u self._ad_bcs = self.bcs self._ad_J = self.J try: # Some forms (e.g. SLATE tensors) are not currently # differentiable. dFdu = derivative(self.F, self.u, TrialFunction(self.u.function_space())) self._ad_adj_F = adjoint(dFdu) except TypeError: self._ad_adj_F = None self._ad_kwargs = { 'Jp': self.Jp, 'form_compiler_parameters': self.form_compiler_parameters, 'is_linear': self.is_linear } self._ad_count_map = {}
def eval_dJdw(self): u = self.u v = self.v J = self.J F = self.F X = self.X w = self.w V = self.V params = self.params solve(self.F == 0, u, bcs=self.bc, solver_parameters=params) bil_form = adjoint(derivative(F, u)) rhs = -derivative(J, u) u_adj = Function(V) solve(assemble(bil_form), u_adj, assemble(rhs), bcs=self.bc, solver_parameters=params) L = J + replace(self.F, {v: u_adj}) self.L = L self.bil_form = bil_form return assemble(derivative(L, X, w))
def __init__(self, solver): r"""State machine for solving the Gauss-Newton subproblem via the preconditioned conjugate gradient method""" self._assemble = solver._assemble u = solver.state p = solver.parameter E = solver._E dE = derivative(E, u) R = solver._R dR = derivative(R, p) F = solver._F dF_du = derivative(F, u) dF_dp = derivative(F, p) # TODO: Make this an arbitrary RHS -- the solver can set it to the # gradient if we want dJ = solver.gradient bc = solver._bc V = u.function_space() Q = p.function_space() # Create the preconditioned residual and solver z = firedrake.Function(Q) s = firedrake.Function(Q) φ, ψ = firedrake.TestFunction(Q), firedrake.TrialFunction(Q) M = φ * ψ * dx + derivative(dR, p) residual_problem = firedrake.LinearVariationalProblem( M, -dJ, z, form_compiler_parameters=solver._fc_params, constant_jacobian=False) residual_solver = firedrake.LinearVariationalSolver( residual_problem, solver_parameters=solver._solver_params) self._preconditioner = M self._residual = z self._search_direction = s self._residual_solver = residual_solver # Create a variable to store the current solution of the Gauss-Newton # problem and the solutions of the auxiliary tangent sub-problems q = firedrake.Function(Q) v = firedrake.Function(V) w = firedrake.Function(V) # Create linear problem and solver objects for the auxiliary tangent # sub-problems tangent_linear_problem = firedrake.LinearVariationalProblem( dF_du, action(dF_dp, s), w, bc, form_compiler_parameters=solver._fc_params, constant_jacobian=False) tangent_linear_solver = firedrake.LinearVariationalSolver( tangent_linear_problem, solver_parameters=solver._solver_params) adjoint_tangent_linear_problem = firedrake.LinearVariationalProblem( adjoint(dF_du), derivative(dE, u, w), v, bc, form_compiler_parameters=solver._fc_params, constant_jacobian=False) adjoint_tangent_linear_solver = firedrake.LinearVariationalSolver( adjoint_tangent_linear_problem, solver_parameters=solver._solver_params) self._rhs = dJ self._solution = q self._tangent_linear_solution = w self._tangent_linear_solver = tangent_linear_solver self._adjoint_tangent_linear_solution = v self._adjoint_tangent_linear_solver = adjoint_tangent_linear_solver self._product = action(adjoint(dF_dp), v) + derivative(dR, p, s) # Create the update to the residual and the associated solver δz = firedrake.Function(Q) Gs = self._product delta_residual_problem = firedrake.LinearVariationalProblem( M, Gs, δz, form_compiler_parameters=solver._fc_params, constant_jacobian=False) delta_residual_solver = firedrake.LinearVariationalSolver( delta_residual_problem, solver_parameters=solver._solver_params) self._delta_residual = δz self._delta_residual_solver = delta_residual_solver self._residual_energy = 0. self._search_direction_energy = 0. self.reinit()
def _setup(self, problem, callback=(lambda s: None)): self._problem = problem self._callback = callback self._p = problem.parameter.copy(deepcopy=True) self._u = problem.state.copy(deepcopy=True) self._solver = self.problem.solver_type(self.problem.model, **self.problem.solver_kwargs) u_name, p_name = problem.state_name, problem.parameter_name solve_kwargs = dict(**problem.diagnostic_solve_kwargs, **{ u_name: self._u, p_name: self._p }) # Make the form compiler use a reasonable number of quadrature points degree = problem.model.quadrature_degree(**solve_kwargs) self._fc_params = {'quadrature_degree': degree} # Create the error, regularization, and barrier functionals self._E = problem.objective(self._u) self._R = problem.regularization(self._p) self._J = self._E + self._R # Create the weak form of the forward model, the adjoint state, and # the derivative of the objective functional A = problem.model.action(**solve_kwargs) self._F = derivative(A, self._u) self._dF_du = derivative(self._F, self._u) # Create a search direction dR = derivative(self._R, self._p) # TODO: Make this customizable self._solver_params = default_solver_parameters Q = self._p.function_space() self._q = firedrake.Function(Q) # Create the adjoint state variable V = self.state.function_space() self._λ = firedrake.Function(V) dF_dp = derivative(self._F, self._p) # Create Dirichlet BCs where they apply for the adjoint solve rank = self._λ.ufl_element().num_sub_elements() if rank == 0: zero = Constant(0) else: zero = firedrake.as_vector((0, ) * rank) self._bc = firedrake.DirichletBC(V, zero, problem.dirichlet_ids) # Create the derivative of the objective functional self._dE = derivative(self._E, self._u) dR = derivative(self._R, self._p) self._dJ = (action(adjoint(dF_dp), self._λ) + dR) # Create problem and solver objects for the adjoint state L = adjoint(self._dF_du) adjoint_problem = firedrake.LinearVariationalProblem( L, -self._dE, self._λ, self._bc, form_compiler_parameters=self._fc_params, constant_jacobian=False) self._adjoint_solver = firedrake.LinearVariationalSolver( adjoint_problem, solver_parameters=self._solver_params)