Example #1
0
 def repl(t):
     if len(t.form.arguments()) != 2:
         raise TypeError(
             'Trying to replace trial function of a form that is not linear'
         )
     trial = t.form.arguments()[1]
     new_form = ufl.replace(t.form, {trial: new})
     return Term(new_form, t.labels)
Example #2
0
 def replace_transporting_velocity(self, uadv):
     # replace the transporting velocity in any terms that contain it
     if any([t.has_label(transporting_velocity) for t in self.residual]):
         assert uadv is not None
         if uadv == "prognostic":
             self.residual = self.residual.label_map(
                 lambda t: t.has_label(transporting_velocity),
                 map_if_true=lambda t: Term(ufl.replace(
                     t.form, {t.get(transporting_velocity): split(t.get(subject))[0]}), t.labels)
             )
         else:
             self.residual = self.residual.label_map(
                 lambda t: t.has_label(transporting_velocity),
                 map_if_true=lambda t: Term(ufl.replace(
                     t.form, {t.get(transporting_velocity): uadv}), t.labels)
             )
         self.residual = transporting_velocity.update_value(self.residual, uadv)
Example #3
0
    def linearise_equation_set(self):
        """
        Linearises the whole equation set, replacing all the equations with
        the complete linearisation. Terms without linearisations are dropped.
        All labels are carried over, and the original linearisations containing
        the trial function are kept as labels to the new terms.
        """

        # Replace all terms with their linearisations, drop terms without
        self.residual = self.residual.label_map(
            lambda t: t.has_label(linearisation),
            map_if_true=lambda t: Term(t.get(linearisation).form, t.labels),
            map_if_false=drop)

        # Replace trial functions with the prognostics
        self.residual = self.residual.label_map(all_terms,
                                                replace_trial_function(self.X))
Example #4
0
    def __init__(self, equation, alpha):

        residual = equation.residual.label_map(
            lambda t: t.has_label(linearisation),
            lambda t: Term(t.get(linearisation).form, t.labels), drop)

        dt = equation.state.dt
        W = equation.function_space
        beta = dt * alpha

        # Split up the rhs vector (symbolically)
        self.xrhs = Function(W)

        aeqn = residual.label_map(
            lambda t:
            (t.has_label(time_derivative) and t.has_label(linearisation)),
            map_if_false=lambda t: beta * t)
        Leqn = residual.label_map(
            lambda t:
            (t.has_label(time_derivative) and t.has_label(linearisation)),
            map_if_false=drop)

        # Place to put result of solver
        self.dy = Function(W)

        # Solver
        bcs = equation.bcs['u']
        problem = LinearVariationalProblem(aeqn.form,
                                           action(Leqn.form, self.xrhs),
                                           self.dy,
                                           bcs=bcs)

        self.solver = LinearVariationalSolver(
            problem,
            solver_parameters=self.solver_parameters,
            options_prefix='linear_solver')
Example #5
0
    def __init__(
            self,
            state,
            family,
            degree,
            fexpr=None,
            bexpr=None,
            terms_to_linearise={
                'D': [time_derivative, transport],
                'u': [time_derivative, pressure_gradient, coriolis]
            },
            u_transport_option="vector_invariant_form",
            no_normal_flow_bc_ids=None,
            active_tracers=None):

        super().__init__(state,
                         family,
                         degree,
                         fexpr=fexpr,
                         bexpr=bexpr,
                         terms_to_linearise=terms_to_linearise,
                         u_transport_option=u_transport_option,
                         no_normal_flow_bc_ids=no_normal_flow_bc_ids,
                         active_tracers=active_tracers)

        # Use the underlying routine to do a first linearisation of the equations
        self.linearise_equation_set()

        # D transport term is a special case -- add facet term
        _, D = split(self.X)
        _, phi = self.tests
        D_adv = prognostic(
            linear_continuity_form(state, phi, D, facet_term=True), "D")
        self.residual = self.residual.label_map(
            lambda t: t.has_label(transport) and t.get(prognostic) == "D",
            map_if_true=lambda t: Term(D_adv.form, t.labels))
Example #6
0
    def repl(t):
        """
        Function returned by replace_subject to return a new :class:`Term` with
        the subject replaced by the variable `new`. It is built around the ufl
        replace routine.

        Returns a new :class:`Term`.

        :arg t: the original :class:`Term`.
        """

        subj = t.get(subject)

        # Build a dictionary to pass to the ufl.replace routine
        # The dictionary matches variables in the old term with those in the new
        replace_dict = {}

        # Consider cases that subj is normal Function or MixedFunction
        # vs cases of new being Function vs MixedFunction vs tuple
        # Ideally catch all cases or fail gracefully
        if type(subj.ufl_element()) is MixedElement:
            if type(new) == tuple:
                assert len(new) == len(subj.function_space())
                for k, v in zip(split(subj), new):
                    replace_dict[k] = v

            elif type(new) == ufl.algebra.Sum:
                replace_dict[subj] = new

            # Otherwise fail if new is not a function
            elif not isinstance(new, Function):
                raise ValueError(
                    f'new must be a tuple or Function, not type {type(new)}')

            # Now handle MixedElements separately as these need indexing
            elif type(new.ufl_element()) is MixedElement:
                assert len(new.function_space()) == len(subj.function_space())
                # If idx specified, replace only that component
                if idx is not None:
                    replace_dict[split(subj)[idx]] = split(new)[idx]
                # Otherwise replace all components
                else:
                    for k, v in zip(split(subj), split(new)):
                        replace_dict[k] = v

            # Otherwise 'new' is a normal Function
            else:
                replace_dict[split(subj)[idx]] = new

        # subj is a normal Function
        else:
            if type(new) is tuple:
                if idx is None:
                    raise ValueError(
                        'idx must be specified to replace_subject' +
                        ' when new is a tuple')
                replace_dict[subj] = new[idx]
            elif not isinstance(new, Function):
                raise ValueError(
                    f'new must be a Function, not type {type(new)}')
            elif type(new.ufl_element()) == MixedElement:
                if idx is None:
                    raise ValueError(
                        'idx must be specified to replace_subject' +
                        ' when new is a tuple')
                replace_dict[subj] = split(new)[idx]
            else:
                replace_dict[subj] = new

        new_form = ufl.replace(t.form, replace_dict)

        return Term(new_form, t.labels)
Example #7
0
 def repl(t):
     test = t.form.arguments()[0]
     new_form = ufl.replace(t.form, {test: new_test})
     return Term(new_form, t.labels)
Example #8
0
    def setup(self, equation, uadv=None, apply_bcs=True, *active_labels):

        self.residual = equation.residual

        if self.field_name is not None:
            self.idx = equation.field_names.index(self.field_name)
            self.fs = self.state.fields(self.field_name).function_space()
            self.residual = self.residual.label_map(
                lambda t: t.get(prognostic) == self.field_name,
                lambda t: Term(
                    split_form(t.form)[self.idx].form,
                    t.labels),
                drop)
            bcs = equation.bcs[self.field_name]

        else:
            self.field_name = equation.field_name
            self.fs = equation.function_space
            self.idx = None
            if type(self.fs.ufl_element()) is MixedElement:
                bcs = [bc for _, bcs in equation.bcs.items() for bc in bcs]
            else:
                bcs = equation.bcs[self.field_name]

        if len(active_labels) > 0:
            self.residual = self.residual.label_map(
                lambda t: any(t.has_label(time_derivative, *active_labels)),
                map_if_false=drop)

        options = self.options

        # -------------------------------------------------------------------- #
        # Routines relating to transport
        # -------------------------------------------------------------------- #

        if hasattr(self.options, 'ibp'):
            self.replace_transport_term()
        self.replace_transporting_velocity(uadv)

        # -------------------------------------------------------------------- #
        # Wrappers for embedded / recovery methods
        # -------------------------------------------------------------------- #

        if self.discretisation_option in ["embedded_dg", "recovered"]:
            # construct the embedding space if not specified
            if options.embedding_space is None:
                V_elt = BrokenElement(self.fs.ufl_element())
                self.fs = FunctionSpace(self.state.mesh, V_elt)
            else:
                self.fs = options.embedding_space
            self.xdg_in = Function(self.fs)
            self.xdg_out = Function(self.fs)
            if self.idx is None:
                self.x_projected = Function(equation.function_space)
            else:
                self.x_projected = Function(self.state.fields(self.field_name).function_space())
            new_test = TestFunction(self.fs)
            parameters = {'ksp_type': 'cg',
                          'pc_type': 'bjacobi',
                          'sub_pc_type': 'ilu'}

        # -------------------------------------------------------------------- #
        # Make boundary conditions
        # -------------------------------------------------------------------- #

        if not apply_bcs:
            self.bcs = None
        elif self.discretisation_option in ["embedded_dg", "recovered"]:
            # Transfer boundary conditions onto test function space
            self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain) for bc in bcs]
        else:
            self.bcs = bcs

        # -------------------------------------------------------------------- #
        # Modify test function for SUPG methods
        # -------------------------------------------------------------------- #

        if self.discretisation_option == "supg":
            # construct tau, if it is not specified
            dim = self.state.mesh.topological_dimension()
            if options.tau is not None:
                # if tau is provided, check that is has the right size
                tau = options.tau
                assert as_ufl(tau).ufl_shape == (dim, dim), "Provided tau has incorrect shape!"
            else:
                # create tuple of default values of size dim
                default_vals = [options.default*self.dt]*dim
                # check for directions is which the space is discontinuous
                # so that we don't apply supg in that direction
                if is_cg(self.fs):
                    vals = default_vals
                else:
                    space = self.fs.ufl_element().sobolev_space()
                    if space.name in ["HDiv", "DirectionalH"]:
                        vals = [default_vals[i] if space[i].name == "H1"
                                else 0. for i in range(dim)]
                    else:
                        raise ValueError("I don't know what to do with space %s" % space)
                tau = Constant(tuple([
                    tuple(
                        [vals[j] if i == j else 0. for i, v in enumerate(vals)]
                    ) for j in range(dim)])
                )
                self.solver_parameters = {'ksp_type': 'gmres',
                                          'pc_type': 'bjacobi',
                                          'sub_pc_type': 'ilu'}

            test = TestFunction(self.fs)
            new_test = test + dot(dot(uadv, tau), grad(test))

        if self.discretisation_option is not None:
            # replace the original test function with one defined on
            # the embedding space, as this is the space where the
            # the problem will be solved
            self.residual = self.residual.label_map(
                all_terms,
                map_if_true=replace_test_function(new_test))

        if self.discretisation_option == "embedded_dg":
            if self.limiter is None:
                self.x_out_projector = Projector(self.xdg_out, self.x_projected,
                                                 solver_parameters=parameters)
            else:
                self.x_out_projector = Recoverer(self.xdg_out, self.x_projected)

        if self.discretisation_option == "recovered":
            # set up the necessary functions
            self.x_in = Function(self.state.fields(self.field_name).function_space())
            x_rec = Function(options.recovered_space)
            x_brok = Function(options.broken_space)

            # set up interpolators and projectors
            self.x_rec_projector = Recoverer(self.x_in, x_rec, VDG=self.fs, boundary_method=options.boundary_method)  # recovered function
            self.x_brok_projector = Projector(x_rec, x_brok)  # function projected back
            self.xdg_interpolator = Interpolator(self.x_in + x_rec - x_brok, self.xdg_in)
            if self.limiter is not None:
                self.x_brok_interpolator = Interpolator(self.xdg_out, x_brok)
                self.x_out_projector = Recoverer(x_brok, self.x_projected)
            else:
                self.x_out_projector = Projector(self.xdg_out, self.x_projected)

        # setup required functions
        self.dq = Function(self.fs)
        self.q1 = Function(self.fs)
Example #9
0
    def __init__(
            self,
            state,
            family,
            degree,
            Omega=None,
            terms_to_linearise={
                'u': [time_derivative],
                'p': [time_derivative],
                'b': [time_derivative, transport]
            },
            u_transport_option="vector_invariant_form",
            no_normal_flow_bc_ids=None,
            active_tracers=None):

        field_names = ['u', 'p', 'b']

        if active_tracers is not None:
            raise NotImplementedError(
                'Tracers not implemented for Boussinesq equations')

        if active_tracers is None:
            active_tracers = []

        super().__init__(field_names,
                         state,
                         family,
                         degree,
                         terms_to_linearise=terms_to_linearise,
                         no_normal_flow_bc_ids=no_normal_flow_bc_ids,
                         active_tracers=active_tracers)

        w, phi, gamma = self.tests[0:3]
        u, p, b = split(self.X)
        bbar = state.fields("bbar", space=state.spaces("theta"), dump=False)
        bbar = state.fields("pbar", space=state.spaces("DG"), dump=False)

        # -------------------------------------------------------------------- #
        # Time Derivative Terms
        # -------------------------------------------------------------------- #
        mass_form = self.generate_mass_terms()

        # -------------------------------------------------------------------- #
        # Transport Terms
        # -------------------------------------------------------------------- #
        # Velocity transport term -- depends on formulation
        if u_transport_option == "vector_invariant_form":
            u_adv = prognostic(vector_invariant_form(state, w, u), "u")
        elif u_transport_option == "vector_advection_form":
            u_adv = prognostic(advection_form(state, w, u), "u")
        elif u_transport_option == "vector_manifold_advection_form":
            u_adv = prognostic(vector_manifold_advection_form(state, w, u),
                               "u")
        elif u_transport_option == "circulation_form":
            ke_form = kinetic_energy_form(state, w, u)
            ke_form = transport.remove(ke_form)
            ke_form = ke_form.label_map(
                lambda t: t.has_label(transporting_velocity), lambda t: Term(
                    ufl.replace(t.form, {t.get(transporting_velocity): u}), t.
                    labels))
            ke_form = transporting_velocity.remove(ke_form)
            u_adv = advection_equation_circulation_form(state, w, u) + ke_form
        else:
            raise ValueError("Invalid u_transport_option: %s" %
                             u_transport_option)

        # Buoyancy transport
        b_adv = prognostic(advection_form(state, gamma, b), "b")
        linear_b_adv = linear_advection_form(state, gamma, bbar).label_map(
            lambda t: t.has_label(transporting_velocity), lambda t: Term(
                ufl.replace(t.form, {
                    t.get(transporting_velocity): self.trials[0]
                }), t.labels))
        b_adv = linearisation(b_adv, linear_b_adv)

        adv_form = subject(u_adv + b_adv, self.X)

        # Add transport of tracers
        if len(active_tracers) > 0:
            adv_form += self.generate_tracer_transport_terms(
                state, active_tracers)

        # -------------------------------------------------------------------- #
        # Pressure Gradient Term
        # -------------------------------------------------------------------- #
        pressure_gradient_form = subject(prognostic(-div(w) * p * dx, "u"),
                                         self.X)

        # -------------------------------------------------------------------- #
        # Gravitational Term
        # -------------------------------------------------------------------- #
        gravity_form = subject(prognostic(-b * inner(w, state.k) * dx, "u"),
                               self.X)

        # -------------------------------------------------------------------- #
        # Divergence Term
        # -------------------------------------------------------------------- #
        # This enforces that div(u) = 0
        # The p features here so that the div(u) evaluated in the "forcing" step
        # replaces the whole pressure field, rather than merely providing an
        # increment to it.
        divergence_form = name(
            subject(prognostic(phi * (p - div(u)) * dx, "p"), self.X),
            "incompressibility")

        residual = (mass_form + adv_form + divergence_form +
                    pressure_gradient_form + gravity_form)

        # -------------------------------------------------------------------- #
        # Extra Terms (Coriolis)
        # -------------------------------------------------------------------- #
        if Omega is not None:
            residual += subject(
                prognostic(inner(w, cross(2 * Omega, u)) * dx, "u"), self.X)

        # -------------------------------------------------------------------- #
        # Linearise equations
        # -------------------------------------------------------------------- #
        # TODO: add linearisation states for variables
        # Add linearisations to equations
        self.residual = self.generate_linear_terms(residual,
                                                   self.terms_to_linearise)
Example #10
0
    def __init__(
            self,
            state,
            family,
            degree,
            Omega=None,
            sponge=None,
            extra_terms=None,
            terms_to_linearise={
                'u': [time_derivative],
                'rho': [time_derivative, transport],
                'theta': [time_derivative, transport]
            },
            u_transport_option="vector_invariant_form",
            diffusion_options=None,
            no_normal_flow_bc_ids=None,
            active_tracers=None):

        field_names = ['u', 'rho', 'theta']

        if active_tracers is None:
            active_tracers = []

        super().__init__(field_names,
                         state,
                         family,
                         degree,
                         terms_to_linearise=terms_to_linearise,
                         no_normal_flow_bc_ids=no_normal_flow_bc_ids,
                         active_tracers=active_tracers)

        g = state.parameters.g
        cp = state.parameters.cp

        w, phi, gamma = self.tests[0:3]
        u, rho, theta = split(self.X)[0:3]
        rhobar = state.fields("rhobar", space=state.spaces("DG"), dump=False)
        thetabar = state.fields("thetabar",
                                space=state.spaces("theta"),
                                dump=False)
        zero_expr = Constant(0.0) * theta
        exner = exner_pressure(state.parameters, rho, theta)
        n = FacetNormal(state.mesh)

        # -------------------------------------------------------------------- #
        # Time Derivative Terms
        # -------------------------------------------------------------------- #
        mass_form = self.generate_mass_terms()

        # -------------------------------------------------------------------- #
        # Transport Terms
        # -------------------------------------------------------------------- #
        # Velocity transport term -- depends on formulation
        if u_transport_option == "vector_invariant_form":
            u_adv = prognostic(vector_invariant_form(state, w, u), "u")
        elif u_transport_option == "vector_advection_form":
            u_adv = prognostic(advection_form(state, w, u), "u")
        elif u_transport_option == "vector_manifold_advection_form":
            u_adv = prognostic(vector_manifold_advection_form(state, w, u),
                               "u")
        elif u_transport_option == "circulation_form":
            ke_form = kinetic_energy_form(state, w, u)
            ke_form = transport.remove(ke_form)
            ke_form = ke_form.label_map(
                lambda t: t.has_label(transporting_velocity), lambda t: Term(
                    ufl.replace(t.form, {t.get(transporting_velocity): u}), t.
                    labels))
            ke_form = transporting_velocity.remove(ke_form)
            u_adv = advection_equation_circulation_form(state, w, u) + ke_form
        else:
            raise ValueError("Invalid u_transport_option: %s" %
                             u_transport_option)

        # Density transport (conservative form)
        rho_adv = prognostic(continuity_form(state, phi, rho), "rho")
        # Transport term needs special linearisation
        if transport in terms_to_linearise['rho']:
            linear_rho_adv = linear_continuity_form(
                state, phi, rhobar).label_map(
                    lambda t: t.has_label(transporting_velocity),
                    lambda t: Term(
                        ufl.replace(t.form, {
                            t.get(transporting_velocity):
                            self.trials[0]
                        }), t.labels))
            rho_adv = linearisation(rho_adv, linear_rho_adv)

        # Potential temperature transport (advective form)
        theta_adv = prognostic(advection_form(state, gamma, theta), "theta")
        # Transport term needs special linearisation
        if transport in terms_to_linearise['theta']:
            linear_theta_adv = linear_advection_form(
                state, gamma, thetabar).label_map(
                    lambda t: t.has_label(transporting_velocity),
                    lambda t: Term(
                        ufl.replace(t.form, {
                            t.get(transporting_velocity):
                            self.trials[0]
                        }), t.labels))
            theta_adv = linearisation(theta_adv, linear_theta_adv)

        adv_form = subject(u_adv + rho_adv + theta_adv, self.X)

        # Add transport of tracers
        if len(active_tracers) > 0:
            adv_form += self.generate_tracer_transport_terms(
                state, active_tracers)

        # -------------------------------------------------------------------- #
        # Pressure Gradient Term
        # -------------------------------------------------------------------- #
        # First get total mass of tracers
        tracer_mr_total = zero_expr
        for tracer in active_tracers:
            if tracer.variable_type == TracerVariableType.mixing_ratio:
                idx = self.field_names.index(tracer.name)
                tracer_mr_total += split(self.X)[idx]
            else:
                raise NotImplementedError(
                    'Only mixing ratio tracers are implemented')
        theta_v = theta / (Constant(1.0) + tracer_mr_total)

        pressure_gradient_form = name(
            subject(
                prognostic(
                    cp * (-div(theta_v * w) * exner * dx +
                          jump(theta_v * w, n) * avg(exner) * dS_v), "u"),
                self.X), "pressure_gradient")

        # -------------------------------------------------------------------- #
        # Gravitational Term
        # -------------------------------------------------------------------- #
        gravity_form = subject(
            prognostic(Term(g * inner(state.k, w) * dx), "u"), self.X)

        residual = (mass_form + adv_form + pressure_gradient_form +
                    gravity_form)

        # -------------------------------------------------------------------- #
        # Moist Thermodynamic Divergence Term
        # -------------------------------------------------------------------- #
        if len(active_tracers) > 0:
            cv = state.parameters.cv
            c_vv = state.parameters.c_vv
            c_pv = state.parameters.c_pv
            c_pl = state.parameters.c_pl
            R_d = state.parameters.R_d
            R_v = state.parameters.R_v

            # Get gas and liquid moisture mixing ratios
            mr_l = zero_expr
            mr_v = zero_expr

            for tracer in active_tracers:
                if tracer.is_moisture:
                    if tracer.variable_type == TracerVariableType.mixing_ratio:
                        idx = self.field_names.index(tracer.name)
                        if tracer.phase == Phases.gas:
                            mr_v += split(self.X)[idx]
                        elif tracer.phase == Phases.liquid:
                            mr_l += split(self.X)[idx]
                    else:
                        raise NotImplementedError(
                            'Only mixing ratio tracers are implemented')

            c_vml = cv + mr_v * c_vv + mr_l * c_pl
            c_pml = cp + mr_v * c_pv + mr_l * c_pl
            R_m = R_d + mr_v * R_v

            residual += subject(
                prognostic(
                    gamma * theta * div(u) * (R_m / c_vml - (R_d * c_pml) /
                                              (cp * c_vml)) * dx, "theta"),
                self.X)

        # -------------------------------------------------------------------- #
        # Extra Terms (Coriolis, Sponge, Diffusion and others)
        # -------------------------------------------------------------------- #
        if Omega is not None:
            residual += subject(
                prognostic(inner(w, cross(2 * Omega, u)) * dx, "u"), self.X)

        if sponge is not None:
            W_DG = FunctionSpace(state.mesh, "DG", 2)
            x = SpatialCoordinate(state.mesh)
            z = x[len(x) - 1]
            H = sponge.H
            zc = sponge.z_level
            assert float(zc) < float(
                H
            ), "you have set the sponge level above the height of your domain"
            mubar = sponge.mubar
            muexpr = conditional(
                z <= zc, 0.0,
                mubar * sin((pi / 2.) * (z - zc) / (H - zc))**2)
            self.mu = Function(W_DG).interpolate(muexpr)

            residual += name(
                subject(
                    prognostic(
                        self.mu * inner(w, state.k) * inner(u, state.k) * dx,
                        "u"), self.X), "sponge")

        if diffusion_options is not None:
            for field, diffusion in diffusion_options:
                idx = self.field_names.index(field)
                test = self.tests[idx]
                fn = split(self.X)[idx]
                residual += subject(
                    prognostic(
                        interior_penalty_diffusion_form(
                            state, test, fn, diffusion), field), self.X)

        if extra_terms is not None:
            for field, term in extra_terms:
                idx = self.field_names.index(field)
                test = self.tests[idx]
                residual += subject(prognostic(inner(test, term) * dx, field),
                                    self.X)

        # -------------------------------------------------------------------- #
        # Linearise equations
        # -------------------------------------------------------------------- #
        # TODO: add linearisation states for variables
        # Add linearisations to equations
        self.residual = self.generate_linear_terms(residual,
                                                   self.terms_to_linearise)
Example #11
0
    def __init__(
            self,
            state,
            family,
            degree,
            fexpr=None,
            bexpr=None,
            terms_to_linearise={
                'D': [time_derivative, transport],
                'u': [time_derivative, pressure_gradient]
            },
            u_transport_option="vector_invariant_form",
            no_normal_flow_bc_ids=None,
            active_tracers=None):

        field_names = ["u", "D"]

        if active_tracers is not None:
            raise NotImplementedError(
                'Tracers not implemented for shallow water equations')

        if active_tracers is None:
            active_tracers = []

        super().__init__(field_names,
                         state,
                         family,
                         degree,
                         terms_to_linearise=terms_to_linearise,
                         no_normal_flow_bc_ids=no_normal_flow_bc_ids,
                         active_tracers=active_tracers)

        g = state.parameters.g
        H = state.parameters.H

        w, phi = self.tests
        u, D = split(self.X)

        # -------------------------------------------------------------------- #
        # Time Derivative Terms
        # -------------------------------------------------------------------- #
        mass_form = self.generate_mass_terms()

        # -------------------------------------------------------------------- #
        # Transport Terms
        # -------------------------------------------------------------------- #
        # Velocity transport term -- depends on formulation
        if u_transport_option == "vector_invariant_form":
            u_adv = prognostic(vector_invariant_form(state, w, u), "u")
        elif u_transport_option == "vector_advection_form":
            u_adv = prognostic(advection_form(state, w, u), "u")
        elif u_transport_option == "vector_manifold_advection_form":
            u_adv = prognostic(vector_manifold_advection_form(state, w, u),
                               "u")
        elif u_transport_option == "circulation_form":
            ke_form = kinetic_energy_form(state, w, u)
            ke_form = transport.remove(ke_form)
            ke_form = ke_form.label_map(
                lambda t: t.has_label(transporting_velocity), lambda t: Term(
                    ufl.replace(t.form, {t.get(transporting_velocity): u}), t.
                    labels))
            ke_form = transporting_velocity.remove(ke_form)
            u_adv = advection_equation_circulation_form(state, w, u) + ke_form
        else:
            raise ValueError("Invalid u_transport_option: %s" %
                             u_transport_option)

        # Depth transport term
        D_adv = prognostic(continuity_form(state, phi, D), "D")
        # Transport term needs special linearisation
        if transport in terms_to_linearise['D']:
            linear_D_adv = linear_continuity_form(state, phi, H).label_map(
                lambda t: t.has_label(transporting_velocity), lambda t: Term(
                    ufl.replace(t.form, {
                        t.get(transporting_velocity): self.trials[0]
                    }), t.labels))
            # Add linearisation to D_adv
            D_adv = linearisation(D_adv, linear_D_adv)

        adv_form = subject(u_adv + D_adv, self.X)

        # Add transport of tracers
        if len(active_tracers) > 0:
            adv_form += self.generate_tracer_transport_terms(
                state, active_tracers)

        # -------------------------------------------------------------------- #
        # Pressure Gradient Term
        # -------------------------------------------------------------------- #
        pressure_gradient_form = pressure_gradient(
            subject(prognostic(-g * div(w) * D * dx, "u"), self.X))

        residual = (mass_form + adv_form + pressure_gradient_form)

        # -------------------------------------------------------------------- #
        # Extra Terms (Coriolis and Topography)
        # -------------------------------------------------------------------- #
        if fexpr is not None:
            V = FunctionSpace(state.mesh, "CG", 1)
            f = state.fields("coriolis", space=V)
            f.interpolate(fexpr)
            coriolis_form = coriolis(
                subject(prognostic(f * inner(state.perp(u), w) * dx, "u"),
                        self.X))
            residual += coriolis_form

        if bexpr is not None:
            b = state.fields("topography", state.spaces("DG"))
            b.interpolate(bexpr)
            topography_form = subject(prognostic(-g * div(w) * b * dx, "u"),
                                      self.X)
            residual += topography_form

        # -------------------------------------------------------------------- #
        # Linearise equations
        # -------------------------------------------------------------------- #
        u, D = self.X.split()
        # Linearise about D = H
        # TODO: add linearisation state for u
        D.assign(Constant(H))
        u, D = split(self.X)

        # Add linearisations to equations
        self.residual = self.generate_linear_terms(residual,
                                                   self.terms_to_linearise)