示例#1
0
def test_playlog(casedir):
    pp = PostProcessor(dict(casedir=casedir))

    # Test playlog
    assert not os.path.isfile(os.path.join(casedir, 'play.db'))
    MPI.barrier(mpi_comm_world())

    pp.update_all({}, 0.0, 0)
    pp.finalize_all()

    playlog = pp.get_playlog('r')
    assert playlog == {"0": {"t": 0.0}}
    playlog.close()

    pp.update_all({}, 0.1, 1)
    pp.finalize_all()
    playlog = pp.get_playlog('r')
    assert playlog == {"0": {"t": 0.0}, "1": {"t": 0.1}}
    playlog.close()
示例#2
0
class Restart(Parameterized):
    """Class to fetch restart conditions through."""
    #def __init__(self, params=None):
    #    Parameterized.__init__(self, params)

    @classmethod
    def default_params(cls):
        """
        Default parameters are:

        +----------------------+-----------------------+-------------------------------------------------------------------+
        |Key                   | Default value         |  Description                                                      |
        +======================+=======================+===================================================================+
        | casedir              | '.'                   | Case directory - relative path to read solutions from             |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | restart_times        | -1                    | float or list of floats to find restart times from. If -1,        |
        |                      |                       | restart from last available time.                                 |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | solution_names       | 'default'             | Solution names to look for. If 'default', will fetch all          |
        |                      |                       | fields stored as SolutionField.                                   |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | rollback_casedir     | False                 | Rollback case directory by removing all items stored after        |
        |                      |                       | largest restart time. This allows for saving data from a          |
        |                      |                       | restarted simulation in the same case directory.                  |
        +----------------------+-----------------------+-------------------------------------------------------------------+

        """
        params = ParamDict(
            casedir='.',
            restart_times=-1,
            #restart_timesteps=-1,
            solution_names="default",
            rollback_casedir=False,
            #interpolate=True,
            #dt=None,
        )
        return params

    def get_restart_conditions(self, function_spaces="default"):
        """ Return restart conditions as requested.

        :param dict function_spaces: A dict of dolfin.FunctionSpace on which to return the restart conditions with solution name as key.

        """
        self._pp = PostProcessor(
            dict(casedir=self.params.casedir, clean_casedir=False))

        playlog = self._pp.get_playlog('r')
        assert playlog != {}, "Playlog is empty! Unable to find restart data."

        loadable_solutions = find_solution_presence(self._pp, playlog,
                                                    self.params.solution_names)
        loadables = find_restart_items(self.params.restart_times,
                                       loadable_solutions)

        if function_spaces != "default":
            assert isinstance(
                function_spaces,
                dict), "Expecting function_spaces kwarg to be a dict"
            assert set(loadables.values()[0].keys()) == set(
                function_spaces.keys(
                )), "Expecting a function space for each solution variable"

        def restart_conditions(spaces, loadables):
            # loadables[restart_time0][solution_name] = [(t0, Lt0)] # will load Lt0
            # loadables[restart_time0][solution_name] = [(t0, Lt0), (t1, Lt1)] # will interpolate to restart_time
            functions = {}
            for t in loadables:
                functions[t] = dict()
                for solution_name in loadables[t]:
                    assert len(loadables[t][solution_name]) in [1, 2]

                    if len(loadables[t][solution_name]) == 1:
                        f = loadables[t][solution_name][0][1]()
                    elif len(loadables[t][solution_name]) == 2:
                        # Interpolate
                        t0, Lt0 = loadables[t][solution_name][0]
                        t1, Lt1 = loadables[t][solution_name][1]

                        assert t0 <= t <= t1
                        if Lt0.function is not None:

                            # The copy-function raise a PETSc-error in parallel
                            #f = Function(Lt0())
                            f0 = Lt0()
                            f = Function(f0.function_space())
                            f.vector().axpy(1.0, f0.vector())
                            del f0

                            df = Lt1().vector()
                            df.axpy(-1.0, f.vector())
                            f.vector().axpy((t - t0) / (t1 - t0), df)
                        else:
                            f0 = Lt0()
                            f1 = Lt1()
                            datatype = type(f0)
                            if not issubclass(datatype, Iterable):
                                f0 = [f0]
                                f1 = [f1]

                            f = []
                            for _f0, _f1 in zip(f0, f1):
                                val = _f0 + (t - t0) / (t1 - t0) * (_f1 - _f0)
                                f.append(val)

                            if not issubclass(datatype, Iterable):
                                f = f[0]
                            else:
                                f = datatype(f)

                    if solution_name in spaces:
                        space = spaces[solution_name]
                        if space != f.function_space():
                            #from fenicstools import interpolate_nonmatching_mesh
                            #f = interpolate_nonmatching_mesh(f, space)
                            try:
                                f = interpolate(f, space)
                            except:
                                f = project(f, space)

                    functions[t][solution_name] = f

            return functions

        if function_spaces == "default":
            function_spaces = {}
            for fieldname in loadables.values()[0]:
                try:
                    function_spaces[fieldname] = loadables.values(
                    )[0][fieldname][0][1].function.function_space()
                except AttributeError:
                    # This was not a function field
                    pass

        result = restart_conditions(function_spaces, loadables)

        ts = 0
        while playlog[str(ts)]["t"] < max(loadables) - 1e-14:
            ts += 1
        self.restart_timestep = ts
        playlog.close()
        MPI.barrier(mpi_comm_world())
        if self.params.rollback_casedir:
            self._correct_postprocessing(ts)

        return result

    def _correct_postprocessing(self, restart_timestep):
        "Removes data from casedir found at timestep>restart_timestep."
        playlog = self._pp.get_playlog('r')
        playlog_to_remove = {}
        for k, v in playlog.items():
            if int(k) >= restart_timestep:
                #playlog_to_remove[k] = playlog.pop(k)
                playlog_to_remove[k] = playlog[k]
        playlog.close()

        MPI.barrier(mpi_comm_world())
        if on_master_process():
            playlog = self._pp.get_playlog()
            [playlog.pop(k) for k in playlog_to_remove.keys()]
            playlog.close()

        MPI.barrier(mpi_comm_world())
        all_fields_to_clean = []

        for k, v in playlog_to_remove.items():
            if "fields" not in v:
                continue
            else:
                all_fields_to_clean += v["fields"].keys()
        all_fields_to_clean = list(set(all_fields_to_clean))
        for fieldname in all_fields_to_clean:
            self._clean_field(fieldname, restart_timestep)

    def _clean_field(self, fieldname, restart_timestep):
        "Deletes data from field found at timestep>restart_timestep."
        metadata = shelve.open(
            os.path.join(self._pp.get_savedir(fieldname), 'metadata.db'), 'r')
        metadata_to_remove = {}
        for k in metadata.keys():
            #MPI.barrier(mpi_comm_world())
            try:
                k = int(k)
            except:
                continue
            if k >= restart_timestep:
                #metadata_to_remove[str(k)] = metadata.pop(str(k))
                metadata_to_remove[str(k)] = metadata[str(k)]
        metadata.close()
        MPI.barrier(mpi_comm_world())
        if on_master_process():
            metadata = shelve.open(
                os.path.join(self._pp.get_savedir(fieldname), 'metadata.db'),
                'w')
            [metadata.pop(key) for key in metadata_to_remove.keys()]
            metadata.close()
        MPI.barrier(mpi_comm_world())

        # Remove files and data for all save formats
        self._clean_hdf5(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())
        self._clean_files(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_txt(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_shelve(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_xdmf(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_pvd(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

    def _clean_hdf5(self, fieldname, del_metadata):
        delete_from_hdf5_file = '''
        namespace dolfin {
            #include <hdf5.h>
            void delete_from_hdf5_file(const MPI_Comm comm,
                                       const std::string hdf5_filename,
                                       const std::string dataset,
                                       const bool use_mpiio)
            {
                //const hid_t plist_id = H5Pcreate(H5P_FILE_ACCESS);
                // Open file existing file for append
                //hid_t file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, plist_id);
                hid_t hdf5_file_id = HDF5Interface::open_file(comm, hdf5_filename, "a", use_mpiio);

                H5Ldelete(hdf5_file_id, dataset.c_str(), H5P_DEFAULT);
                HDF5Interface::close_file(hdf5_file_id);
            }
        }
        '''
        cpp_module = compile_extension_module(
            delete_from_hdf5_file,
            additional_system_headers=["dolfin/io/HDF5Interface.h"])

        hdf5filename = os.path.join(self._pp.get_savedir(fieldname),
                                    fieldname + '.hdf5')

        if not os.path.isfile(hdf5filename):
            return

        for k, v in del_metadata.items():
            if 'hdf5' not in v:
                continue
            else:
                cpp_module.delete_from_hdf5_file(
                    mpi_comm_world(), hdf5filename, v['hdf5']['dataset'],
                    MPI.size(mpi_comm_world()) > 1)

        hdf5tmpfilename = os.path.join(self._pp.get_savedir(fieldname),
                                       fieldname + '_tmp.hdf5')
        #import ipdb; ipdb.set_trace()
        MPI.barrier(mpi_comm_world())
        if on_master_process():
            # status, result = getstatusoutput("h5repack -V")
            status, result = -1, -1
            if status != 0:
                cbc_warning(
                    "Unable to run h5repack. Will not repack hdf5-files before replay, which may cause bloated hdf5-files."
                )
            else:
                subprocess.call("h5repack %s %s" %
                                (hdf5filename, hdf5tmpfilename),
                                shell=True)
                os.remove(hdf5filename)
                os.rename(hdf5tmpfilename, hdf5filename)
        MPI.barrier(mpi_comm_world())

    def _clean_files(self, fieldname, del_metadata):
        for k, v in del_metadata.items():
            for i in v.values():
                MPI.barrier(mpi_comm_world())
                try:
                    i["filename"]
                except:
                    continue

                fullpath = os.path.join(self._pp.get_savedir(fieldname),
                                        i['filename'])

                if on_master_process():
                    os.remove(fullpath)
                MPI.barrier(mpi_comm_world())
            """
            #print k,v
            if 'filename' not in v:
                continue
            else:
                fullpath = os.path.join(self.postprocesor.get_savedir(fieldname), v['filename'])
                os.remove(fullpath)
            """

    def _clean_txt(self, fieldname, del_metadata):
        txtfilename = os.path.join(self._pp.get_savedir(fieldname),
                                   fieldname + ".txt")
        if on_master_process() and os.path.isfile(txtfilename):
            txtfile = open(txtfilename, 'r')
            txtfilelines = txtfile.readlines()
            txtfile.close()

            num_lines_to_strp = ['txt' in v
                                 for v in del_metadata.values()].count(True)

            txtfile = open(txtfilename, 'w')
            [txtfile.write(l) for l in txtfilelines[:-num_lines_to_strp]]

            txtfile.close()

    def _clean_shelve(self, fieldname, del_metadata):
        shelvefilename = os.path.join(self._pp.get_savedir(fieldname),
                                      fieldname + ".db")
        if on_master_process():
            if os.path.isfile(shelvefilename):
                shelvefile = shelve.open(shelvefilename, 'c')
                for k, v in del_metadata.items():
                    if 'shelve' in v:
                        shelvefile.pop(str(k))
                shelvefile.close()
        MPI.barrier(mpi_comm_world())

    def _clean_xdmf(self, fieldname, del_metadata):
        basename = os.path.join(self._pp.get_savedir(fieldname), fieldname)
        if os.path.isfile(basename + ".xdmf"):
            MPI.barrier(mpi_comm_world())

            i = 0
            while True:
                h5_filename = basename + "_RS" + str(i) + ".h5"
                if not os.path.isfile(h5_filename):
                    break
                i = i + 1

            xdmf_filename = basename + "_RS" + str(i) + ".xdmf"
            MPI.barrier(mpi_comm_world())

            if on_master_process():
                os.rename(basename + ".h5", h5_filename)
                os.rename(basename + ".xdmf", xdmf_filename)

                f = open(xdmf_filename, 'r').read()

                new_f = open(xdmf_filename, 'w')
                new_f.write(
                    f.replace(
                        os.path.split(basename)[1] + ".h5",
                        os.path.split(h5_filename)[1]))
                new_f.close()
        MPI.barrier(mpi_comm_world())

    def _clean_pvd(self, fieldname, del_metadata):
        if os.path.isfile(
                os.path.join(self._pp.get_savedir(fieldname),
                             fieldname + '.pvd')):
            cbc_warning(
                "No functionality for cleaning pvd-files for restart. Will overwrite."
            )