예제 #1
0
class State(object):
    """
    Build a model state to keep the variables in, and specify parameters.

    :arg mesh: The :class:`Mesh` to use.
    :arg vertical_degree: integer, required for vertically extruded meshes.
    Specifies the degree for the pressure space in the vertical
    (the degrees for other spaces are inferred). Defaults to None.
    :arg horizontal_degree: integer, the degree for spaces in the horizontal
    (specifies the degree for the pressure space, other spaces are inferred)
    defaults to 1.
    :arg family: string, specifies the velocity space family to use.
    Options:
    "RT": The Raviart-Thomas family (default, recommended for quads)
    "BDM": The BDM family
    "BDFM": The BDFM family
    :arg Coriolis: (optional) Coriolis function.
    :arg sponge_function: (optional) Function specifying a sponge layer.
    :arg timestepping: class containing timestepping parameters
    :arg output: class containing output parameters
    :arg parameters: class containing physical parameters
    :arg diagnostics: class containing diagnostic methods
    :arg fieldlist: list of prognostic field names
    :arg diagnostic_fields: list of diagnostic field classes
    :arg u_bc_ids: a list containing the ids of boundaries with no normal
                   component of velocity. These ids are passed to `DirichletBC`s. For
                   extruded meshes, top and bottom are added automatically.
    """
    def __init__(self,
                 mesh,
                 vertical_degree=None,
                 horizontal_degree=1,
                 family="RT",
                 Coriolis=None,
                 sponge_function=None,
                 hydrostatic=None,
                 timestepping=None,
                 output=None,
                 parameters=None,
                 diagnostics=None,
                 fieldlist=None,
                 diagnostic_fields=None,
                 u_bc_ids=None):

        self.family = family
        self.vertical_degree = vertical_degree
        self.horizontal_degree = horizontal_degree
        self.Omega = Coriolis
        self.mu = sponge_function
        self.hydrostatic = hydrostatic
        self.timestepping = timestepping
        if output is None:
            raise RuntimeError(
                "You must provide a directory name for dumping results")
        else:
            self.output = output
        self.parameters = parameters
        if fieldlist is None:
            raise RuntimeError(
                "You must provide a fieldlist containing the names of the prognostic fields"
            )
        else:
            self.fieldlist = fieldlist
        if diagnostics is not None:
            self.diagnostics = diagnostics
        else:
            self.diagnostics = Diagnostics(*fieldlist)
        if diagnostic_fields is not None:
            self.diagnostic_fields = diagnostic_fields
        else:
            self.diagnostic_fields = []
        if u_bc_ids is not None:
            self.u_bc_ids = u_bc_ids
        else:
            self.u_bc_ids = []

        # The mesh
        self.mesh = mesh

        # Build the spaces
        self._build_spaces(mesh, vertical_degree, horizontal_degree, family)

        # Allocate state
        self._allocate_state()
        if self.output.dumplist is None:
            self.output.dumplist = fieldlist
        self.fields = FieldCreator(fieldlist, self.xn, self.output.dumplist)

        # set up bcs
        V = self.fields('u').function_space()
        self.bcs = []
        if V.extruded:
            self.bcs.append(DirichletBC(V, 0.0, "bottom"))
            self.bcs.append(DirichletBC(V, 0.0, "top"))
        for id in self.u_bc_ids:
            self.bcs.append(DirichletBC(V, 0.0, id))

        self.dumpfile = None

        # figure out if we're on a sphere
        try:
            self.on_sphere = (mesh._base_mesh.geometric_dimension() == 3
                              and mesh._base_mesh.topological_dimension() == 2)
        except AttributeError:
            self.on_sphere = (mesh.geometric_dimension() == 3
                              and mesh.topological_dimension() == 2)

        #  build the vertical normal and define perp for 2d geometries
        dim = mesh.topological_dimension()
        if self.on_sphere:
            x = SpatialCoordinate(mesh)
            R = sqrt(inner(x, x))
            self.k = interpolate(x / R, mesh.coordinates.function_space())
            if dim == 2:
                outward_normals = CellNormal(mesh)
                self.perp = lambda u: cross(outward_normals, u)
        else:
            kvec = [0.0] * dim
            kvec[dim - 1] = 1.0
            self.k = Constant(kvec)
            if dim == 2:
                self.perp = lambda u: as_vector([-u[1], u[0]])

        # project test function for hydrostatic case
        if self.hydrostatic:
            self.h_project = lambda u: u - self.k * inner(u, self.k)
        else:
            self.h_project = lambda u: u

        #  Constant to hold current time
        self.t = Constant(0.0)

        # setup logger
        logger.setLevel(output.log_level)
        set_log_handler(mesh.comm)
        logger.info("Timestepping parameters that take non-default values:")
        logger.info(", ".join("%s: %s" % item
                              for item in vars(timestepping).items()))
        if parameters is not None:
            logger.info("Physical parameters that take non-default values:")
            logger.info(", ".join("%s: %s" % item
                                  for item in vars(parameters).items()))

    def setup_diagnostics(self):
        """
        Add special case diagnostic fields
        """
        for name in self.output.perturbation_fields:
            f = Perturbation(name)
            self.diagnostic_fields.append(f)

        for name in self.output.steady_state_error_fields:
            f = SteadyStateError(self, name)
            self.diagnostic_fields.append(f)

        fields = set([f.name() for f in self.fields])
        field_deps = [(d, sorted(set(d.required_fields).difference(fields), ))
                      for d in self.diagnostic_fields]
        schedule = topo_sort(field_deps)
        self.diagnostic_fields = schedule
        for diagnostic in self.diagnostic_fields:
            diagnostic.setup(self)
            self.diagnostics.register(diagnostic.name)

    def setup_dump(self, t, tmax, pickup=False):
        """
        Setup dump files
        Check for existence of directory so as not to overwrite
        output files
        Setup checkpoint file

        :arg tmax: model stop time
        :arg pickup: recover state from the checkpointing file if true,
        otherwise dump and checkpoint to disk. (default is False).
        """

        if any([
                self.output.dump_vtus, self.output.dumplist_latlon,
                self.output.dump_diagnostics, self.output.point_data,
                self.output.checkpoint and not pickup
        ]):
            # setup output directory and check that it does not already exist
            self.dumpdir = path.join("results", self.output.dirname)
            running_tests = '--running-tests' in sys.argv or "pytest" in self.output.dirname
            if self.mesh.comm.rank == 0:
                if not running_tests and path.exists(
                        self.dumpdir) and not pickup:
                    raise IOError("results directory '%s' already exists" %
                                  self.dumpdir)
                else:
                    if not running_tests:
                        makedirs(self.dumpdir)

        if self.output.dump_vtus:

            # setup pvd output file
            outfile = path.join(self.dumpdir, "field_output.pvd")
            self.dumpfile = File(outfile,
                                 project_output=self.output.project_fields,
                                 comm=self.mesh.comm)

            # make list of fields to dump
            self.to_dump = [field for field in self.fields if field.dump]

            # make dump counter
            self.dumpcount = itertools.count()

        # if there are fields to be dumped in latlon coordinates,
        # setup the latlon coordinate mesh and make output file
        if len(self.output.dumplist_latlon) > 0:
            mesh_ll = get_latlon_mesh(self.mesh)
            outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd")
            self.dumpfile_ll = File(outfile_ll,
                                    project_output=self.output.project_fields,
                                    comm=self.mesh.comm)

            # make functions on latlon mesh, as specified by dumplist_latlon
            self.to_dump_latlon = []
            for name in self.output.dumplist_latlon:
                f = self.fields(name)
                field = Function(functionspaceimpl.WithGeometry(
                    f.function_space(), mesh_ll),
                                 val=f.topological,
                                 name=name + '_ll')
                self.to_dump_latlon.append(field)

        # we create new netcdf files to write to, unless pickup=True, in
        # which case we just need the filenames
        if self.output.dump_diagnostics:
            diagnostics_filename = self.dumpdir + "/diagnostics.nc"
            self.diagnostic_output = DiagnosticsOutput(diagnostics_filename,
                                                       self.diagnostics,
                                                       self.output.dirname,
                                                       self.mesh.comm,
                                                       create=not pickup)

        if len(self.output.point_data) > 0:
            pointdata_filename = self.dumpdir + "/point_data.nc"
            ndt = int(tmax / self.timestepping.dt)
            self.pointdata_output = PointDataOutput(pointdata_filename,
                                                    ndt,
                                                    self.output.point_data,
                                                    self.output.dirname,
                                                    self.fields,
                                                    self.mesh.comm,
                                                    create=not pickup)

        # if we want to checkpoint and are not picking up from a previous
        # checkpoint file, setup the dumb checkpointing
        if self.output.checkpoint and not pickup:
            self.chkpt = DumbCheckpoint(path.join(self.dumpdir, "chkpt"),
                                        mode=FILE_CREATE)
            # make list of fields to pickup (this doesn't include
            # diagnostic fields)
            self.to_pickup = [field for field in self.fields if field.pickup]

        # if we want to checkpoint then make a checkpoint counter
        if self.output.checkpoint:
            self.chkptcount = itertools.count()

        # dump initial fields
        self.dump(t)

    def pickup_from_checkpoint(self):
        """
        :arg t: the current model time (default is zero).
        """
        if self.output.checkpoint:
            # Open the checkpointing file for writing
            chkfile = path.join(self.dumpdir, "chkpt")
            with DumbCheckpoint(chkfile, mode=FILE_READ) as chk:
                # Recover all the fields from the checkpoint
                for field in self.to_pickup:
                    chk.load(field)
                t = chk.read_attribute("/", "time")
                next(self.dumpcount)
            # Setup new checkpoint
            self.chkpt = DumbCheckpoint(path.join(self.dumpdir, "chkpt"),
                                        mode=FILE_CREATE)
        else:
            raise ValueError("Must set checkpoint True if pickup")

        return t

    def dump(self, t):
        """
        Dump output
        """
        output = self.output

        # Diagnostics:
        # Compute diagnostic fields
        for field in self.diagnostic_fields:
            field(self)

        if output.dump_diagnostics:
            # Output diagnostic data
            self.diagnostic_output.dump(self, t)

        if len(output.point_data) > 0:
            # Output pointwise data
            self.pointdata_output.dump(self.fields, t)

        # Dump all the fields to the checkpointing file (backup version)
        if output.checkpoint and (next(self.chkptcount) %
                                  output.chkptfreq) == 0:
            for field in self.to_pickup:
                self.chkpt.store(field)
            self.chkpt.write_attribute("/", "time", t)

        if output.dump_vtus and (next(self.dumpcount) % output.dumpfreq) == 0:
            # dump fields
            self.dumpfile.write(*self.to_dump)

            # dump fields on latlon mesh
            if len(output.dumplist_latlon) > 0:
                self.dumpfile_ll.write(*self.to_dump_latlon)

    def initialise(self, initial_conditions):
        """
        Initialise state variables

        :arg initial_conditions: An iterable of pairs (field_name, pointwise_value)
        """
        for name, ic in initial_conditions:
            f_init = getattr(self.fields, name)
            f_init.assign(ic)
            f_init.rename(name)

    def set_reference_profiles(self, reference_profiles):
        """
        Initialise reference profiles

        :arg reference_profiles: An iterable of pairs (field_name, interpolatory_value)
        """
        for name, profile in reference_profiles:
            field = getattr(self.fields, name)
            ref = self.fields(name + 'bar', field.function_space(), False)
            ref.interpolate(profile)

    def _build_spaces(self, mesh, vertical_degree, horizontal_degree, family):
        """
        Build:
        velocity space self.V2,
        pressure space self.V3,
        temperature space self.Vt,
        mixed function space self.W = (V2,V3,Vt)
        """

        self.spaces = SpaceCreator()
        if vertical_degree is not None:
            # horizontal base spaces
            cell = mesh._base_mesh.ufl_cell().cellname()
            S1 = FiniteElement(family,
                               cell,
                               horizontal_degree + 1,
                               variant="equispaced")
            S2 = FiniteElement("DG",
                               cell,
                               horizontal_degree,
                               variant="equispaced")

            # vertical base spaces
            T0 = FiniteElement("CG",
                               interval,
                               vertical_degree + 1,
                               variant="equispaced")
            T1 = FiniteElement("DG",
                               interval,
                               vertical_degree,
                               variant="equispaced")

            # build spaces V2, V3, Vt
            V2h_elt = HDiv(TensorProductElement(S1, T1))
            V2t_elt = TensorProductElement(S2, T0)
            V3_elt = TensorProductElement(S2, T1)
            V2v_elt = HDiv(V2t_elt)
            V2_elt = V2h_elt + V2v_elt

            V0 = self.spaces("HDiv", mesh, V2_elt)
            V1 = self.spaces("DG", mesh, V3_elt)
            V2 = self.spaces("HDiv_v", mesh, V2t_elt)

            self.Vv = self.spaces("Vv", mesh, V2v_elt)

            DG1_hori_elt = FiniteElement("DG", cell, 1, variant="equispaced")
            DG1_vert_elt = FiniteElement("DG",
                                         interval,
                                         1,
                                         variant="equispaced")
            DG1_elt = TensorProductElement(DG1_hori_elt, DG1_vert_elt)
            self.DG1_space = self.spaces("DG1", mesh, DG1_elt)

            self.W = MixedFunctionSpace((V0, V1, V2))

        else:
            cell = mesh.ufl_cell().cellname()
            V1_elt = FiniteElement(family,
                                   cell,
                                   horizontal_degree + 1,
                                   variant="equispaced")
            DG_elt = FiniteElement("DG",
                                   cell,
                                   horizontal_degree,
                                   variant="equispaced")
            DG1_elt = FiniteElement("DG", cell, 1, variant="equispaced")

            V0 = self.spaces("HDiv", mesh, V1_elt)
            V1 = self.spaces("DG", mesh, DG_elt)
            self.DG1_space = self.spaces("DG1", mesh, DG1_elt)

            self.W = MixedFunctionSpace((V0, V1))

    def _allocate_state(self):
        """
        Construct Functions to store the state variables.
        """

        W = self.W
        self.xn = Function(W)
        self.xstar = Function(W)
        self.xp = Function(W)
        self.xnp1 = Function(W)
        self.xrhs = Function(W)
        self.xb = Function(W)  # store the old state for diagnostics
        self.dy = Function(W)
예제 #2
0
class State(object):
    """
    Build a model state to keep the variables in, and specify parameters.

    :arg mesh: The :class:`Mesh` to use.
    :arg dt: The time step as a :class:`Constant`. If a float or int is passed,
             it will be cast to a :class:`Constant`.
    :arg output: class containing output parameters
    :arg parameters: class containing physical parameters
    :arg diagnostics: class containing diagnostic methods
    :arg diagnostic_fields: list of diagnostic field classes
    """
    def __init__(self,
                 mesh,
                 dt,
                 output=None,
                 parameters=None,
                 diagnostics=None,
                 diagnostic_fields=None):

        if output is None:
            raise RuntimeError(
                "You must provide a directory name for dumping results")
        else:
            self.output = output
        self.parameters = parameters

        if diagnostics is not None:
            self.diagnostics = diagnostics
        else:
            self.diagnostics = Diagnostics()
        if diagnostic_fields is not None:
            self.diagnostic_fields = diagnostic_fields
        else:
            self.diagnostic_fields = []

        # The mesh
        self.mesh = mesh

        self.spaces = SpaceCreator(mesh)

        if self.output.dumplist is None:

            self.output.dumplist = []

        self.fields = StateFields(*self.output.dumplist)

        self.dumpdir = None
        self.dumpfile = None
        self.to_pickup = None

        # figure out if we're on a sphere
        try:
            self.on_sphere = (mesh._base_mesh.geometric_dimension() == 3
                              and mesh._base_mesh.topological_dimension() == 2)
        except AttributeError:
            self.on_sphere = (mesh.geometric_dimension() == 3
                              and mesh.topological_dimension() == 2)

        #  build the vertical normal and define perp for 2d geometries
        dim = mesh.topological_dimension()
        if self.on_sphere:
            x = SpatialCoordinate(mesh)
            R = sqrt(inner(x, x))
            self.k = interpolate(x / R, mesh.coordinates.function_space())
            if dim == 2:
                outward_normals = CellNormal(mesh)
                self.perp = lambda u: cross(outward_normals, u)
        else:
            kvec = [0.0] * dim
            kvec[dim - 1] = 1.0
            self.k = Constant(kvec)
            if dim == 2:
                self.perp = lambda u: as_vector([-u[1], u[0]])

        # setup logger
        logger.setLevel(output.log_level)
        set_log_handler(mesh.comm)
        if parameters is not None:
            logger.info("Physical parameters that take non-default values:")
            logger.info(", ".join("%s: %s" % (k, float(v))
                                  for (k, v) in vars(parameters).items()))

        #  Constant to hold current time
        self.t = Constant(0.0)
        if type(dt) is Constant:
            self.dt = dt
        elif type(dt) in (float, int):
            self.dt = Constant(dt)
        else:
            raise TypeError(
                f'dt must be a Constant, float or int, not {type(dt)}')

    def setup_diagnostics(self):
        """
        Add special case diagnostic fields
        """
        for name in self.output.perturbation_fields:
            f = Perturbation(name)
            self.diagnostic_fields.append(f)

        for name in self.output.steady_state_error_fields:
            f = SteadyStateError(self, name)
            self.diagnostic_fields.append(f)

        fields = set([f.name() for f in self.fields])
        field_deps = [(d, sorted(set(d.required_fields).difference(fields), ))
                      for d in self.diagnostic_fields]
        schedule = topo_sort(field_deps)
        self.diagnostic_fields = schedule
        for diagnostic in self.diagnostic_fields:
            diagnostic.setup(self)
            self.diagnostics.register(diagnostic.name)

    def setup_dump(self, t, tmax, pickup=False):
        """
        Setup dump files
        Check for existence of directory so as not to overwrite
        output files
        Setup checkpoint file

        :arg tmax: model stop time
        :arg pickup: recover state from the checkpointing file if true,
        otherwise dump and checkpoint to disk. (default is False).
        """

        if any([
                self.output.dump_vtus, self.output.dumplist_latlon,
                self.output.dump_diagnostics, self.output.point_data,
                self.output.checkpoint and not pickup
        ]):
            # setup output directory and check that it does not already exist
            self.dumpdir = path.join("results", self.output.dirname)
            running_tests = '--running-tests' in sys.argv or "pytest" in self.output.dirname
            if self.mesh.comm.rank == 0:
                if not running_tests and path.exists(
                        self.dumpdir) and not pickup:
                    raise IOError("results directory '%s' already exists" %
                                  self.dumpdir)
                else:
                    if not running_tests:
                        makedirs(self.dumpdir)

        if self.output.dump_vtus:

            # setup pvd output file
            outfile = path.join(self.dumpdir, "field_output.pvd")
            self.dumpfile = File(outfile,
                                 project_output=self.output.project_fields,
                                 comm=self.mesh.comm)

            # make list of fields to dump
            self.to_dump = [
                f for f in self.fields if f.name() in self.fields.to_dump
            ]

            # make dump counter
            self.dumpcount = itertools.count()

        # if there are fields to be dumped in latlon coordinates,
        # setup the latlon coordinate mesh and make output file
        if len(self.output.dumplist_latlon) > 0:
            mesh_ll = get_latlon_mesh(self.mesh)
            outfile_ll = path.join(self.dumpdir, "field_output_latlon.pvd")
            self.dumpfile_ll = File(outfile_ll,
                                    project_output=self.output.project_fields,
                                    comm=self.mesh.comm)

            # make functions on latlon mesh, as specified by dumplist_latlon
            self.to_dump_latlon = []
            for name in self.output.dumplist_latlon:
                f = self.fields(name)
                field = Function(functionspaceimpl.WithGeometry.create(
                    f.function_space(), mesh_ll),
                                 val=f.topological,
                                 name=name + '_ll')
                self.to_dump_latlon.append(field)

        # we create new netcdf files to write to, unless pickup=True, in
        # which case we just need the filenames
        if self.output.dump_diagnostics:
            diagnostics_filename = self.dumpdir + "/diagnostics.nc"
            self.diagnostic_output = DiagnosticsOutput(diagnostics_filename,
                                                       self.diagnostics,
                                                       self.output.dirname,
                                                       self.mesh.comm,
                                                       create=not pickup)

        if len(self.output.point_data) > 0:
            # set up point data output
            pointdata_filename = self.dumpdir + "/point_data.nc"
            ndt = int(tmax / float(self.dt))
            self.pointdata_output = PointDataOutput(pointdata_filename,
                                                    ndt,
                                                    self.output.point_data,
                                                    self.output.dirname,
                                                    self.fields,
                                                    self.mesh.comm,
                                                    self.output.tolerance,
                                                    create=not pickup)

            # make point data dump counter
            self.pddumpcount = itertools.count()

            # set frequency of point data output - defaults to
            # dumpfreq if not set by user
            if self.output.pddumpfreq is None:
                self.output.pddumpfreq = self.output.dumpfreq

        # if we want to checkpoint and are not picking up from a previous
        # checkpoint file, setup the checkpointing
        if self.output.checkpoint:
            if not pickup:
                self.chkpt = DumbCheckpoint(path.join(self.dumpdir, "chkpt"),
                                            mode=FILE_CREATE)
            # make list of fields to pickup (this doesn't include
            # diagnostic fields)
            self.to_pickup = [
                f for f in self.fields if f.name() in self.fields.to_pickup
            ]

        # if we want to checkpoint then make a checkpoint counter
        if self.output.checkpoint:
            self.chkptcount = itertools.count()

        # dump initial fields
        self.dump(t)

    def pickup_from_checkpoint(self):
        """
        :arg t: the current model time (default is zero).
        """
        # TODO: this duplicates some code from setup_dump. Can this be avoided?
        # It is because we don't know if we are picking up or setting dump first
        if self.to_pickup is None:
            self.to_pickup = [
                f for f in self.fields if f.name() in self.fields.to_pickup
            ]
        # Set dumpdir if has not been done already
        if self.dumpdir is None:
            self.dumpdir = path.join("results", self.output.dirname)

        if self.output.checkpoint:
            # Open the checkpointing file for writing
            if self.output.checkpoint_pickup_filename is not None:
                chkfile = self.output.checkpoint_pickup_filename
            else:
                chkfile = path.join(self.dumpdir, "chkpt")
            with DumbCheckpoint(chkfile, mode=FILE_READ) as chk:
                # Recover all the fields from the checkpoint
                for field in self.to_pickup:
                    chk.load(field)
                t = chk.read_attribute("/", "time")
            # Setup new checkpoint
            self.chkpt = DumbCheckpoint(path.join(self.dumpdir, "chkpt"),
                                        mode=FILE_CREATE)
        else:
            raise ValueError("Must set checkpoint True if pickup")

        return t

    def dump(self, t):
        """
        Dump output
        """
        output = self.output

        # Diagnostics:
        # Compute diagnostic fields
        for field in self.diagnostic_fields:
            field(self)

        if output.dump_diagnostics:
            # Output diagnostic data
            self.diagnostic_output.dump(self, t)

        if len(output.point_data) > 0 and (next(self.pddumpcount) %
                                           output.pddumpfreq) == 0:
            # Output pointwise data
            self.pointdata_output.dump(self.fields, t)

        # Dump all the fields to the checkpointing file (backup version)
        if output.checkpoint and (next(self.chkptcount) %
                                  output.chkptfreq) == 0:
            for field in self.to_pickup:
                self.chkpt.store(field)
            self.chkpt.write_attribute("/", "time", t)

        if output.dump_vtus and (next(self.dumpcount) % output.dumpfreq) == 0:
            # dump fields
            self.dumpfile.write(*self.to_dump)

            # dump fields on latlon mesh
            if len(output.dumplist_latlon) > 0:
                self.dumpfile_ll.write(*self.to_dump_latlon)

    def initialise(self, initial_conditions):
        """
        Initialise state variables

        :arg initial_conditions: An iterable of pairs (field_name, pointwise_value)
        """
        for name, ic in initial_conditions:
            f_init = getattr(self.fields, name)
            f_init.assign(ic)
            f_init.rename(name)

    def set_reference_profiles(self, reference_profiles):
        """
        Initialise reference profiles

        :arg reference_profiles: An iterable of pairs (field_name, interpolatory_value)
        """
        for name, profile in reference_profiles:
            if name + 'bar' in self.fields:
                # For reference profiles already added to state, allow
                # interpolation from expressions
                ref = self.fields(name + 'bar')
            elif isinstance(profile, Function):
                # Need to add reference profile to state so profile must be
                # a Function
                ref = self.fields(name + 'bar',
                                  space=profile.function_space(),
                                  dump=False)
            else:
                raise ValueError(
                    f'When initialising reference profile {name}' +
                    ' the passed profile must be a Function')
            ref.interpolate(profile)
예제 #3
0
class State(object):
    """
    Build a model state to keep the variables in, and specify parameters.

    :arg mesh: The :class:`Mesh` to use.
    :arg vertical_degree: integer, the degree for spaces in the vertical
    (specifies the degree for the pressure space, other spaces are inferred)
    defaults to 1.
    :arg horizontal_degree: integer, the degree for spaces in the horizontal
    (specifies the degree for the pressure space, other spaces are inferred)
    defaults to 1.
    :arg family: string, specifies the velocity space family to use.
    Options:
    "RT": The Raviart-Thomas family (default, recommended for quads)
    "BDM": The BDM family
    "BDFM": The BDFM family
    :arg timestepping: class containing timestepping parameters
    :arg output: class containing output parameters
    :arg parameters: class containing physical parameters
    :arg diagnostics: class containing diagnostic methods
    :arg fieldlist: list of prognostic field names

    """
    __metaclass__ = ABCMeta

    def __init__(self, mesh, vertical_degree=1, horizontal_degree=1,
                 family="RT", z=None, k=None, Omega=None, mu=None,
                 timestepping=None,
                 output=None,
                 parameters=None,
                 diagnostics=None,
                 fieldlist=None,
                 diagnostic_fields=[]):

        self.z = z
        self.k = k
        self.Omega = Omega
        self.mu = mu
        self.timestepping = timestepping
        self.output = output
        self.parameters = parameters
        if fieldlist is None:
            raise RuntimeError("You must provide a fieldlist containing the names of the prognostic fields")
        else:
            self.fieldlist = fieldlist
        if diagnostics is not None:
            self.diagnostics = diagnostics
        else:
            self.diagnostics = Diagnostics(*fieldlist)
        self.diagnostic_fields = diagnostic_fields

        # The mesh
        self.mesh = mesh

        # Build the spaces
        self._build_spaces(mesh, vertical_degree,
                           horizontal_degree, family)

        # Allocate state
        self._allocate_state()
        self.field_dict = {name: func for (name, func) in
                           zip(self.fieldlist, self.xn.split())}

        self.dumpfile = None

    def dump(self, t=0, pickup=False):
        """
        Dump output
        :arg t: the current model time (default is zero).
        :arg pickup: recover state from the checkpointing file if true,
        otherwise dump and checkpoint to disk. (default is False).
        """

        # default behaviour is to dump all prognostic fields
        if self.output.dumplist is None:
            self.output.dumplist = self.fieldlist

        # if there are fields to be dumped in latlon coordinates,
        # setup the latlon coordinate mesh
        if len(self.output.dumplist_latlon) > 0:
            field_dict_ll = {}
            mesh_ll = get_latlon_mesh(self.mesh)

        funcs = self.xn.split()
        field_dict = {name: func for (name, func) in zip(self.fieldlist, funcs)}
        to_dump = []  # fields to output to dump and checkpoint
        to_pickup = []  # fields to pick up from checkpoint
        for name, f in field_dict.iteritems():
            if name in self.output.dumplist:
                to_dump.append(f)
                to_pickup.append(f)
            f.rename(name=name)

        # append diagnostic fields for to_dump
        for diagnostic in self.diagnostic_fields:
            to_dump.append(diagnostic(self))

        # check if we are running a steady state simulation and if so
        # set up the error fields and save the
        # initial fields so that we can compute the error fields
        steady_state_dump_err = defaultdict(bool)
        steady_state_dump_err.update(self.output.steady_state_dump_err)
        for name, f, f_init in zip(self.fieldlist, funcs, self.x_init.split()):
            if steady_state_dump_err[name]:
                err = Function(f.function_space(), name=name+'err').assign(f-f_init)
                field_dict[name+"err"] = err
                self.diagnostics.register(name+"err")
                to_dump.append(err)
                f_init.rename(f.name()+"_init")
                to_dump.append(f_init)
                to_pickup.append(f_init)

        # check if we are dumping perturbation fields. If we are, the
        # meanfields are provided in a dictionary. Here we set up the
        # perturbation fields.
        meanfields = defaultdict(lambda: None)
        meanfields.update(self.output.meanfields)
        for name, meanfield in meanfields.iteritems():
            if meanfield is not None:
                field = field_dict[name]
                diff = Function(
                    field.function_space(),
                    name=field.name()+"_perturbation").assign(field - meanfield)
                self.diagnostics.register(name+"perturbation")
                field_dict[name+"perturbation"] = diff
                to_dump.append(diff)
            mean_name = field.name() + "_bar"
            meanfield.rename(name=mean_name)
            to_dump.append(meanfield)
            to_pickup.append(meanfield)

        # make functions on latlon mesh, as specified by dumplist_latlon
        to_dump_latlon = []
        for name in self.output.dumplist_latlon:
            f = field_dict[name]
            f_ll = Function(functionspaceimpl.WithGeometry(f.function_space(), mesh_ll), val=f.topological, name=name+'_ll')
            field_dict_ll[name] = f_ll
            to_dump_latlon.append(f_ll)

        self.dumpdir = path.join("results", self.output.dirname)
        outfile = path.join(self.dumpdir, "field_output.pvd")
        if self.dumpfile is None:
            if self.mesh.comm.rank == 0 and path.exists(self.dumpdir) and not pickup:
                exit("results directory '%s' already exists" % self.dumpdir)
            self.dumpcount = itertools.count()
            self.dumpfile = File(outfile, project_output=self.output.project_fields, comm=self.mesh.comm)
            self.diagnostic_data = defaultdict(partial(defaultdict, list))

            # make output file for fields on latlon mesh if required
            if len(self.output.dumplist_latlon) > 0:
                outfile_latlon = path.join(self.dumpdir, "field_output_latlon.pvd")
                self.dumpfile_latlon = File(outfile_latlon, project_output=self.output.project_fields,
                                            comm=self.mesh.comm)

        if(pickup):
            # Open the checkpointing file for writing
            chkfile = path.join(self.dumpdir, "chkpt")
            with DumbCheckpoint(chkfile, mode=FILE_READ) as chk:
                # Recover all the fields from the checkpoint
                for field in to_pickup:
                    chk.load(field)
                t = chk.read_attribute("/","time")
                next(self.dumpcount)

        elif (next(self.dumpcount) % self.output.dumpfreq) == 0:

            print "DBG dumping", t

            # dump fields
            self.dumpfile.write(*to_dump)

            # dump fields on latlon mesh
            if len(self.output.dumplist_latlon) > 0:
                self.dumpfile_latlon.write(*to_dump_latlon)

            # compute diagnostics
            for name in self.diagnostics.fields:
                data = self.diagnostics.l2(field_dict[name])
                self.diagnostic_data[name]["l2"].append(data)

            # Open the checkpointing file (backup version)
            files = ["chkptbk", "chkpt"]
            for file in files:
                chkfile = path.join(self.dumpdir, file)
                with DumbCheckpoint(chkfile, mode=FILE_CREATE) as chk:
                    # Dump all the fields to a checkpoint
                    for field in to_dump:
                        chk.store(field)
                    chk.write_attribute("/","time",t)

        return t

    def diagnostic_dump(self):
        """
        Dump diagnostics dictionary
        """

        with open(path.join(self.dumpdir, "diagnostics.json"), "w") as f:
            f.write(json.dumps(self.diagnostic_data, indent=4))

    def initialise(self, initial_conditions):
        """
        Initialise state variables
        """

        for x, ic in zip(self.x_init.split(), initial_conditions):
            x.assign(ic)

    @abstractmethod
    def _build_spaces(self, mesh, vertical_degree, horizontal_degree, family):

        """
        Build function spaces:
        """
        pass

    def _allocate_state(self):
        """
        Construct Functions to store the state variables.
        """

        W = self.W
        self.xn = Function(W)
        self.x_init = Function(W)
        self.xstar = Function(W)
        self.xp = Function(W)
        self.xnp1 = Function(W)
        self.xrhs = Function(W)
        self.dy = Function(W)