def project(self, other, annotate=None, *args, **kwargs): '''To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the Firedrake project call.''' to_annotate = utils.to_annotate(annotate) if not to_annotate: flag = misc.pause_annotation() res = firedrake_project(self, other, *args, **kwargs) if not to_annotate: misc.continue_annotation(flag) return res
def solve(*args, **kwargs): """This solve routine wraps the real Dolfin solve call. Its purpose is to annotate the model, recording what solves occur and what forms are involved, so that the adjoint and tangent linear models may be constructed automatically by libadjoint. To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the Dolfin solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic for the purposes of the adjoint computation (such as projecting fields to other function spaces for the purposes of visualisation).""" # First, decide if we should annotate or not. to_annotate = utils.to_annotate(kwargs.pop("annotate", None)) if to_annotate: linear = annotate(*args, **kwargs) # Avoid recursive annotation flag = misc.pause_annotation() try: ret = backend.solve(*args, **kwargs) except: raise finally: misc.continue_annotation(flag) if to_annotate: # Finally, if we want to record all of the solutions of the real forward model # (for comparison with a libadjoint replay later), # then we should record the value of the variable we just solved for. if backend.parameters["adjoint"]["record_all"]: if isinstance(args[0], ufl.classes.Equation): unpacked_args = compatibility._extract_args(*args, **kwargs) u = unpacked_args[1] adjglobals.adjointer.record_variable( adjglobals.adj_variables[u], libadjoint.MemoryStorage(adjlinalg.Vector(u)) ) elif isinstance(args[0], compatibility.matrix_types()): u = args[1].function adjglobals.adjointer.record_variable( adjglobals.adj_variables[u], libadjoint.MemoryStorage(adjlinalg.Vector(u)) ) else: raise libadjoint.exceptions.LibadjointErrorInvalidInputs("Don't know how to record, sorry") return ret
def project_firedrake(*args, **kwargs): try: annotate = kwargs["annotate"] kwargs.pop("annotate") except KeyError: annotate = None to_annotate = utils.to_annotate(annotate) if isinstance(args[0], backend.Expression) and (annotate is not True): to_annotate = False if isinstance(args[0], backend.Constant) and (annotate is not True): to_annotate = False if to_annotate: result = backend.project(*args, **kwargs) else: flag = misc.pause_annotation() result = backend.project(*args, **kwargs) misc.continue_annotation(flag) return result
def __call__(self, m_dot, project=False): flag = misc.pause_annotation() hess_action_timer = backend.Timer("Hessian action") m_p = self.m.set_perturbation(m_dot) last_timestep = adjglobals.adjointer.timestep_count m_dot = enlist(m_dot) Hm = [] for m_dot_cmp in m_dot: if hasattr(m_dot_cmp, 'function_space'): Hm.append(backend.Function(m_dot_cmp.function_space())) elif isinstance(m_dot_cmp, float): Hm.append(0.0) else: raise NotImplementedError("Sorry, don't know how to handle this") tlm_timer = backend.Timer("Hessian action (TLM)") # run the tangent linear model for (tlm, tlm_var) in compute_tlm(m_p, forget=None): pass tlm_timer.stop() # run the adjoint and second-order adjoint equations. for i in range(adjglobals.adjointer.equation_count)[::-1]: adj_var = adjglobals.adjointer.get_forward_variable(i).to_adjoint(self.J) # Only recompute the adjoint variable if we do not have it yet try: adj = adjglobals.adjointer.get_variable_value(adj_var) except (libadjoint.exceptions.LibadjointErrorHashFailed, libadjoint.exceptions.LibadjointErrorNeedValue): adj_timer = backend.Timer("Hessian action (ADM)") adj = adjglobals.adjointer.get_adjoint_solution(i, self.J)[1] adj_timer.stop() storage = libadjoint.MemoryStorage(adj) adjglobals.adjointer.record_variable(adj_var, storage) adj = adj.data soa_timer = backend.Timer("Hessian action (SOA)") (soa_var, soa_vec) = adjglobals.adjointer.get_soa_solution(i, self.J, m_p) soa_timer.stop() soa = soa_vec.data def hess_inner(Hm, out): assert len(out) == len(Hm) for i in range(len(out)): if out[i] is not None: if isinstance(Hm[i], backend.Function): Hm[i].vector().axpy(1.0, out[i].vector()) elif isinstance(Hm[i], float): Hm[i] += out[i] else: raise ValueError, "Do not know what to do with this" return Hm func_timer = backend.Timer("Hessian action (derivative formula)") # now implement the Hessian action formula. out = self.m.equation_partial_derivative(adjglobals.adjointer, soa, i, soa_var.to_forward()) Hm = hess_inner(Hm, out) out = self.m.equation_partial_second_derivative(adjglobals.adjointer, adj, i, soa_var.to_forward(), m_dot) Hm = hess_inner(Hm, out) if last_timestep > adj_var.timestep: # We have hit a new timestep, and need to compute this timesteps' \partial^2 J/\partial m^2 contribution last_timestep = adj_var.timestep out = self.m.functional_partial_second_derivative(adjglobals.adjointer, self.J, adj_var.timestep, m_dot) Hm = hess_inner(Hm, out) func_timer.stop() storage = libadjoint.MemoryStorage(soa_vec) storage.set_overwrite(True) adjglobals.adjointer.record_variable(soa_var, storage) for Hm_cmp in Hm: if isinstance(Hm_cmp, backend.Function): Hm_cmp.rename("d^2(%s)/d(%s)^2" % (str(self.J), str(self.m)), "a Function from dolfin-adjoint") misc.continue_annotation(flag) return postprocess(Hm, project, list_type=self.enlisted_controls)
def compute_gradient(J, param, forget=True, ignore=[], callback=lambda var, output: None, project=False): if not isinstance(J, Functional): raise ValueError, "J must be of type dolfin_adjoint.Functional." flag = misc.pause_annotation() enlisted_controls = enlist(param) param = ListControl(enlisted_controls) if backend.parameters["adjoint"]["allow_zero_derivatives"]: dJ_init = [] for c in enlisted_controls: if isinstance(c.data(), backend.Constant): dJ_init.append(backend.Constant(0)) elif isinstance(c.data(), backend.Function): space = c.data().function_space() dJ_init.append(backend.Function(space)) else: dJ_init = [None] * len(enlisted_controls) dJdparam = enlisted_controls.__class__(dJ_init) last_timestep = adjglobals.adjointer.timestep_count ignorelist = [] for fn in ignore: if isinstance(fn, backend.Function): ignorelist.append(adjglobals.adj_variables[fn]) elif isinstance(fn, str): ignorelist.append(libadjoint.Variable(fn, 0, 0)) else: ignorelist.append(fn) for i in range(adjglobals.adjointer.timestep_count): adjglobals.adjointer.set_functional_dependencies(J, i) for i in range(adjglobals.adjointer.equation_count)[::-1]: fwd_var = adjglobals.adjointer.get_forward_variable(i) if fwd_var in ignorelist: info("Ignoring the adjoint equation for %s" % fwd_var) continue (adj_var, output) = adjglobals.adjointer.get_adjoint_solution(i, J) callback(adj_var, output.data) storage = libadjoint.MemoryStorage(output) storage.set_overwrite(True) adjglobals.adjointer.record_variable(adj_var, storage) fwd_var = libadjoint.Variable(adj_var.name, adj_var.timestep, adj_var.iteration) out = param.equation_partial_derivative(adjglobals.adjointer, output.data, i, fwd_var) dJdparam = _add(dJdparam, out) if last_timestep > adj_var.timestep: # We have hit a new timestep, and need to compute this timesteps' \partial J/\partial m contribution out = param.functional_partial_derivative(adjglobals.adjointer, J, adj_var.timestep) dJdparam = _add(dJdparam, out) last_timestep = adj_var.timestep if forget is None: pass elif forget: adjglobals.adjointer.forget_adjoint_equation(i) else: adjglobals.adjointer.forget_adjoint_values(i) rename(J, dJdparam, param) misc.continue_annotation(flag) return postprocess(dJdparam, project, list_type=enlisted_controls)