Ejemplo n.º 1
0
    def assemble_data(self):
        assert not isinstance(self.data, IdentityMatrix)
        if backend.__name__ == "firedrake":
            # Firedrake specifies assembled matrix type as part of the
            # solver parameters.
            mat_type = self.solver_parameters.get("mat_type")
            assemble = lambda x: backend.assemble(self.data, mat_type=mat_type)
        else:
            assemble = backend.assemble
        if not self.cache:
            if hasattr(self.data.arguments()[0], '_V_multi'):
                return backend.assemble_multimesh(self.data)
            else:
                return backend.assemble(self.data)

        else:
            if self.data in caching.assembled_adj_forms:
                if backend.parameters["adjoint"]["debug_cache"]:
                    backend.info_green("Got an assembly cache hit")
                return caching.assembled_adj_forms[self.data]
            else:
                if backend.parameters["adjoint"]["debug_cache"]:
                    backend.info_red("Got an assembly cache miss")

                if hasattr(self.data.arguments()[0], '_V_multi'):
                    M = backend.assemble_multimesh(self.data)
                else:
                    M = backend.assemble(self.data)

                caching.assembled_adj_forms[self.data] = M
                return M
Ejemplo n.º 2
0
    def __call__(self, adjointer, timestep, dependencies, values):

        functional_value = self._substitute_form(adjointer, timestep, dependencies, values)

        if functional_value is not None:
            args = ufl.algorithms.extract_arguments(functional_value)
            if len(args) > 0:
                backend.info_red("The form passed into Functional must be rank-0 (a scalar)! You have passed in a rank-%s form." % len(args))
                raise libadjoint.exceptions.LibadjointErrorInvalidInputs

            if hasattr(functional_value.coefficients()[0], '_V'):
                return backend.assemble_multimesh(functional_value)
            else:
                return backend.assemble(functional_value)
        else:
            return 0.0
Ejemplo n.º 3
0
    def __call__(self, adjointer, timestep, dependencies, values):

        functional_value = self._substitute_form(adjointer, timestep,
                                                 dependencies, values)

        if functional_value is not None:
            args = ufl.algorithms.extract_arguments(functional_value)
            if len(args) > 0:
                backend.info_red(
                    "The form passed into Functional must be rank-0 (a scalar)! You have passed in a rank-%s form."
                    % len(args))
                raise libadjoint.exceptions.LibadjointErrorInvalidInputs

            from .utils import _has_multimesh
            if _has_multimesh(functional_value):
                return backend.assemble_multimesh(functional_value)
            else:
                return backend.assemble(functional_value)
        else:
            return 0.0
Ejemplo n.º 4
0
def wrap_assemble(form, test):
    '''If you do
       F = inner(grad(TrialFunction(V), grad(TestFunction(V))))
       a = lhs(F); L = rhs(F)
       solve(a == L, ...)

       it works, even though L is empty. But if you try to assemble(L) as we do here,
       you get a crash.

       This function wraps assemble to catch that crash and return an empty RHS instead.
    '''

    try:
        if hasattr(form.arguments()[0], '_V_multi'):
            b = backend.assemble_multimesh(form)
        else:
            b = backend.assemble(form)
    except RuntimeError:
        assert len(form.integrals()) == 0
        b = backend.Function(test.function_space()).vector()

    return b
Ejemplo n.º 5
0
    def axpy(self, alpha, x):

        if hasattr(x, 'nonlinear_form'):
            self.nonlinear_form = x.nonlinear_form
            self.nonlinear_u = x.nonlinear_u
            self.nonlinear_bcs = x.nonlinear_bcs
            self.nonlinear_J = x.nonlinear_J

        if x.zero:
            return

        if (self.data is None):
            # self is an empty form.
            if isinstance(x.data, backend.Function):
                self.data = x.data.copy(deepcopy=True)
                self.data.vector()._scale(alpha)
            if isinstance(x.data, backend.MultiMeshFunction):
                self.data = backend.MultiMeshFunction(x.data.function_space(),
                        x.data.vector())
                self.data.vector()._scale(alpha)
            else:
                self.data=alpha*x.data

        elif x.data is None:
            pass
        elif isinstance(self.data, backend.Coefficient):
            if isinstance(x.data, backend.Coefficient):
                try:
                    self.data.vector().axpy(alpha, x.data.vector())
                except:
                    # Handle subfunctions
                    # Fixme: use FunctionAssigner instead of a projection
                    #assigner = backend.FunctionAssigner(self.data.function_space,
                    #        x.data.function_space()
                    x = backend.project(x.data, self.data.function_space())
                    self.data.vector().axpy(alpha, x.vector())
            else:
                # This occurs when adding a RHS derivative to an adjoint equation
                # corresponding to the initial conditions.
                if ((len(x.data.coefficients())>0) and
                    hasattr(x.data.coefficients()[0], '_V')):
                    self.data.vector().axpy(alpha,
                                            backend.assemble_multimesh(x.data))
                else:
                    self.data.vector().axpy(alpha, backend.assemble(x.data))
                self.data.form = alpha * x.data
        elif isinstance(x.data, ufl.form.Form) and isinstance(self.data, ufl.form.Form):

            # Let's do a bit of argument shuffling, shall we?
            xargs = ufl.algorithms.extract_arguments(x.data)
            sargs = ufl.algorithms.extract_arguments(self.data)

            if xargs != sargs:
                # OK, let's check that all of the function spaces are happy and so on.
                for i in range(len(xargs)):
                    assert xargs[i].element() == sargs[i].element()
                    assert xargs[i].function_space() == sargs[i].function_space()

                # Now that we are happy, let's replace the xargs with the sargs ones.
                x_form = backend.replace(x.data, dict(zip(xargs, sargs)))
            else:
                x_form = x.data

            self.data+=alpha*x_form
        elif isinstance(self.data, ufl.form.Form) and isinstance(x.data, backend.Function):
            #print "axpy assembling FormFunc. self.data is a %s; x.data is a %s" % (self.data.__class__, x.data.__class__)
            x_vec = x.data.vector().copy()
            self_vec = backend.assemble(self.data)
            self_vec.axpy(alpha, x_vec)
            new_fn = backend.Function(x.data.function_space())
            new_fn.vector()[:] = self_vec
            self.data = new_fn
            self.fn_space = self.data.function_space()
        elif isinstance(self.data, backend.MultiMeshFunction):
            raise NotImplementedError

        else:
            print "self.data.__class__: ", self.data.__class__
            print "x.data.__class__: ", x.data.__class__
            assert False

        self.zero = False