Esempio n. 1
0
    def project(self, other, annotate=None, *args, **kwargs):
        '''To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
        Firedrake project call.'''

        to_annotate = utils.to_annotate(annotate)

        if not to_annotate:
            flag = misc.pause_annotation()

        res = firedrake_project(self, other, *args, **kwargs)

        if not to_annotate:
            misc.continue_annotation(flag)

        return res
Esempio n. 2
0
def solve(*args, **kwargs):
    """This solve routine wraps the real Dolfin solve call. Its purpose is to annotate the model,
    recording what solves occur and what forms are involved, so that the adjoint and tangent linear models may be
    constructed automatically by libadjoint.

    To disable the annotation, just pass :py:data:`annotate=False` to this routine, and it acts exactly like the
    Dolfin solve call. This is useful in cases where the solve is known to be irrelevant or diagnostic
    for the purposes of the adjoint computation (such as projecting fields to other function spaces
    for the purposes of visualisation)."""

    # First, decide if we should annotate or not.
    to_annotate = utils.to_annotate(kwargs.pop("annotate", None))
    if to_annotate:
        linear = annotate(*args, **kwargs)

    # Avoid recursive annotation
    flag = misc.pause_annotation()
    try:
        ret = backend.solve(*args, **kwargs)
    except:
        raise
    finally:
        misc.continue_annotation(flag)

    if to_annotate:
        # Finally, if we want to record all of the solutions of the real forward model
        # (for comparison with a libadjoint replay later),
        # then we should record the value of the variable we just solved for.
        if backend.parameters["adjoint"]["record_all"]:
            if isinstance(args[0], ufl.classes.Equation):
                unpacked_args = compatibility._extract_args(*args, **kwargs)
                u = unpacked_args[1]
                adjglobals.adjointer.record_variable(
                    adjglobals.adj_variables[u], libadjoint.MemoryStorage(adjlinalg.Vector(u))
                )
            elif isinstance(args[0], compatibility.matrix_types()):
                u = args[1].function
                adjglobals.adjointer.record_variable(
                    adjglobals.adj_variables[u], libadjoint.MemoryStorage(adjlinalg.Vector(u))
                )
            else:
                raise libadjoint.exceptions.LibadjointErrorInvalidInputs("Don't know how to record, sorry")

    return ret
Esempio n. 3
0
def project_firedrake(*args, **kwargs):

  try:
    annotate = kwargs["annotate"]
    kwargs.pop("annotate")
  except KeyError:
    annotate = None

  to_annotate = utils.to_annotate(annotate)

  if isinstance(args[0], backend.Expression) and (annotate is not True):
    to_annotate = False

  if isinstance(args[0], backend.Constant) and (annotate is not True):
    to_annotate = False

  if to_annotate:
    result = backend.project(*args, **kwargs)
  else:
    flag = misc.pause_annotation()
    result = backend.project(*args, **kwargs)
    misc.continue_annotation(flag)

  return result
Esempio n. 4
0
    def __call__(self, m_dot, project=False):
        flag = misc.pause_annotation()
        hess_action_timer = backend.Timer("Hessian action")

        m_p = self.m.set_perturbation(m_dot)
        last_timestep = adjglobals.adjointer.timestep_count

        m_dot = enlist(m_dot)
        Hm = []
        for m_dot_cmp in m_dot:
            if hasattr(m_dot_cmp, 'function_space'):
                Hm.append(backend.Function(m_dot_cmp.function_space()))
            elif isinstance(m_dot_cmp, float):
                Hm.append(0.0)
            else:
                raise NotImplementedError("Sorry, don't know how to handle this")

        tlm_timer = backend.Timer("Hessian action (TLM)")
        # run the tangent linear model
        for (tlm, tlm_var) in compute_tlm(m_p, forget=None):
            pass

        tlm_timer.stop()

        # run the adjoint and second-order adjoint equations.
        for i in range(adjglobals.adjointer.equation_count)[::-1]:
            adj_var = adjglobals.adjointer.get_forward_variable(i).to_adjoint(self.J)
            # Only recompute the adjoint variable if we do not have it yet
            try:
                adj = adjglobals.adjointer.get_variable_value(adj_var)
            except (libadjoint.exceptions.LibadjointErrorHashFailed, libadjoint.exceptions.LibadjointErrorNeedValue):
                adj_timer = backend.Timer("Hessian action (ADM)")
                adj = adjglobals.adjointer.get_adjoint_solution(i, self.J)[1]
                adj_timer.stop()

                storage = libadjoint.MemoryStorage(adj)
                adjglobals.adjointer.record_variable(adj_var, storage)

            adj = adj.data

            soa_timer = backend.Timer("Hessian action (SOA)")
            (soa_var, soa_vec) = adjglobals.adjointer.get_soa_solution(i, self.J, m_p)
            soa_timer.stop()
            soa = soa_vec.data

            def hess_inner(Hm, out):
                assert len(out) == len(Hm)
                for i in range(len(out)):
                    if out[i] is not None:
                        if isinstance(Hm[i], backend.Function):
                            Hm[i].vector().axpy(1.0, out[i].vector())
                        elif isinstance(Hm[i], float):
                            Hm[i] += out[i]
                        else:
                            raise ValueError, "Do not know what to do with this"
                return Hm

            func_timer = backend.Timer("Hessian action (derivative formula)")
            # now implement the Hessian action formula.
            out = self.m.equation_partial_derivative(adjglobals.adjointer, soa, i, soa_var.to_forward())
            Hm = hess_inner(Hm, out)

            out = self.m.equation_partial_second_derivative(adjglobals.adjointer, adj, i, soa_var.to_forward(), m_dot)
            Hm = hess_inner(Hm, out)

            if last_timestep > adj_var.timestep:
                # We have hit a new timestep, and need to compute this timesteps' \partial^2 J/\partial m^2 contribution
                last_timestep = adj_var.timestep
                out = self.m.functional_partial_second_derivative(adjglobals.adjointer, self.J, adj_var.timestep, m_dot)
                Hm = hess_inner(Hm, out)

            func_timer.stop()

            storage = libadjoint.MemoryStorage(soa_vec)
            storage.set_overwrite(True)
            adjglobals.adjointer.record_variable(soa_var, storage)

        for Hm_cmp in Hm:
            if isinstance(Hm_cmp, backend.Function):
                Hm_cmp.rename("d^2(%s)/d(%s)^2" % (str(self.J), str(self.m)), "a Function from dolfin-adjoint")

        misc.continue_annotation(flag)
        return postprocess(Hm, project, list_type=self.enlisted_controls)
Esempio n. 5
0
def compute_gradient(J, param, forget=True, ignore=[], callback=lambda var, output: None, project=False):
    if not isinstance(J, Functional):
        raise ValueError, "J must be of type dolfin_adjoint.Functional."

    flag = misc.pause_annotation()

    enlisted_controls = enlist(param)
    param = ListControl(enlisted_controls)

    if backend.parameters["adjoint"]["allow_zero_derivatives"]:
        dJ_init = []
        for c in enlisted_controls:
            if isinstance(c.data(), backend.Constant):
                dJ_init.append(backend.Constant(0))
            elif isinstance(c.data(), backend.Function):
                space = c.data().function_space()
                dJ_init.append(backend.Function(space))

    else:
        dJ_init = [None] * len(enlisted_controls)

    dJdparam = enlisted_controls.__class__(dJ_init)

    last_timestep = adjglobals.adjointer.timestep_count

    ignorelist = []
    for fn in ignore:
        if isinstance(fn, backend.Function):
            ignorelist.append(adjglobals.adj_variables[fn])
        elif isinstance(fn, str):
            ignorelist.append(libadjoint.Variable(fn, 0, 0))
        else:
            ignorelist.append(fn)

    for i in range(adjglobals.adjointer.timestep_count):
        adjglobals.adjointer.set_functional_dependencies(J, i)

    for i in range(adjglobals.adjointer.equation_count)[::-1]:
        fwd_var = adjglobals.adjointer.get_forward_variable(i)
        if fwd_var in ignorelist:
            info("Ignoring the adjoint equation for %s" % fwd_var)
            continue

        (adj_var, output) = adjglobals.adjointer.get_adjoint_solution(i, J)

        callback(adj_var, output.data)

        storage = libadjoint.MemoryStorage(output)
        storage.set_overwrite(True)
        adjglobals.adjointer.record_variable(adj_var, storage)
        fwd_var = libadjoint.Variable(adj_var.name, adj_var.timestep, adj_var.iteration)

        out = param.equation_partial_derivative(adjglobals.adjointer, output.data, i, fwd_var)
        dJdparam = _add(dJdparam, out)

        if last_timestep > adj_var.timestep:
            # We have hit a new timestep, and need to compute this timesteps' \partial J/\partial m contribution
            out = param.functional_partial_derivative(adjglobals.adjointer, J, adj_var.timestep)
            dJdparam = _add(dJdparam, out)

        last_timestep = adj_var.timestep

        if forget is None:
            pass
        elif forget:
            adjglobals.adjointer.forget_adjoint_equation(i)
        else:
            adjglobals.adjointer.forget_adjoint_values(i)

    rename(J, dJdparam, param)

    misc.continue_annotation(flag)

    return postprocess(dJdparam, project, list_type=enlisted_controls)