コード例 #1
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_get_casedir(casedir):
    pp = PostProcessor(dict(casedir=casedir))

    assert os.path.isdir(pp.get_casedir())
    assert os.path.samefile(pp.get_casedir(), casedir)

    pp.update_all({}, 0.0, 0)

    assert len(os.listdir(pp.get_casedir())) == 1
    pp.clean_casedir()
    assert len(os.listdir(pp.get_casedir())) == 0
コード例 #2
0
ファイル: test_plotter.py プロジェクト: jakobes/cbcpost-fork
def test_dolfinplot(mesh):
    pp = PostProcessor()
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1, 0)
    V = spacepool.get_space(1, 1)

    pp.add_field(MockFunctionField(Q, dict(plot=True)))
    pp.add_field(MockVectorFunctionField(V, dict(plot=True)))
    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.6, 2)
    pp.update_all({}, 1.6, 3)
コード例 #3
0
ファイル: test_plotter.py プロジェクト: jakobes/cbcpost-fork
def test_pyplot():
    pp = PostProcessor()

    pp.add_field(MockScalarField(dict(plot=True)))
    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.6, 2)
    pp.update_all({}, 1.6, 3)
コード例 #4
0
ファイル: Poiseuille3D.py プロジェクト: fmoabreu/fluxo
def main():
    problem = Poiseuille3D({"refinement_level": 0})
    scheme = IPCS({"u_degree": 1})

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
        ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #5
0
ファイル: PipeAneurysm.py プロジェクト: fmoabreu/fluxo
def main():
    problem = PipeAneurysm()
    scheme = IPCS()

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
        ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #6
0
ファイル: LidDrivenCavity.py プロジェクト: fmoabreu/fluxo
def main():
    problem = LidDrivenCavity({"refinement_level": 1})
    scheme = IPCS({"u_degree": 2, "solver_p_neumann": ("cg", "ilu")}) # Displays pressure oscillations

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
        ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #7
0
def main():
    problem = Channel()
    scheme = IPCS()

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
    ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #8
0
def main():
    problem = Womersley3D({"refinement_level": 2})
    scheme = IPCS()

    casedir = "results_demo_%s_%s_%d" % (problem.shortname(), scheme.shortname(), problem.params.refinement_level)
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
        ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #9
0
def main():
    problem = Poiseuille3D({"refinement_level": 0})
    scheme = IPCS({"u_degree": 1})

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
    ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #10
0
ファイル: Poiseuille2D.py プロジェクト: fmoabreu/fluxo
def main():
    set_log_level(100)
    problem = Poiseuille2D({"dt": 1e-3, "T": 1e-1, "num_periods": None, "refinement_level": 1})
    scheme = IPCS_Naive({"u_degree": 2})

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
        ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #11
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_store_mesh(casedir):
    pp = PostProcessor(dict(casedir=casedir))

    from dolfin import (UnitSquareMesh, CellFunction, FacetFunction,
                        AutoSubDomain, Mesh, HDF5File, assemble, Expression,
                        ds, dx)

    # Store mesh
    mesh = UnitSquareMesh(6, 6)
    celldomains = CellFunction("size_t", mesh)
    celldomains.set_all(0)
    AutoSubDomain(lambda x: x[0] < 0.5).mark(celldomains, 1)

    facetdomains = FacetFunction("size_t", mesh)
    AutoSubDomain(lambda x, on_boundary: x[0] < 0.5 and on_boundary).mark(
        facetdomains, 1)

    pp.store_mesh(mesh, celldomains, facetdomains)

    # Read mesh back
    mesh2 = Mesh()
    f = HDF5File(mpi_comm_world(), os.path.join(pp.get_casedir(), "mesh.hdf5"),
                 'r')
    f.read(mesh2, "Mesh", False)

    celldomains2 = CellFunction("size_t", mesh2)
    f.read(celldomains2, "CellDomains")
    facetdomains2 = FacetFunction("size_t", mesh2)
    f.read(facetdomains2, "FacetDomains")

    e = Expression("1+x[1]", degree=1)

    dx1 = dx(1, domain=mesh, subdomain_data=celldomains)
    dx2 = dx(1, domain=mesh2, subdomain_data=celldomains2)
    C1 = assemble(e * dx1)
    C2 = assemble(e * dx2)
    assert abs(C1 - C2) < 1e-10

    ds1 = ds(1, domain=mesh, subdomain_data=facetdomains)
    ds2 = ds(1, domain=mesh2, subdomain_data=facetdomains2)
    F1 = assemble(e * ds1)
    F2 = assemble(e * ds2)
    assert abs(F1 - F2) < 1e-10
コード例 #12
0
ファイル: Poiseuille2D.py プロジェクト: fmoabreu/fluxo
def main():
    set_log_level(100)
    problem = Poiseuille2D({
        "dt": 1e-3,
        "T": 1e-1,
        "num_periods": None,
        "refinement_level": 1
    })
    scheme = IPCS_Naive({"u_degree": 2})

    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    plot_and_save = dict(plot=True, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
    ]
    postproc = PostProcessor({"casedir": casedir})
    postproc.add_fields(fields)

    solver = NSSolver(problem, scheme, postproc)
    solver.solve()
コード例 #13
0
ファイル: FlowAroundCylinder.py プロジェクト: vegarvi/CSF
def main():
    # Create problem and scheme instances
    problem = FlowAroundCylinder({"refinement_level": 2})
    scheme = IPCS_Stable()

    # Create postprocessor instance pointing to a case directory
    casedir = "results_demo_%s_%s" % (problem.shortname(), scheme.shortname())
    postprocessor = PostProcessor({"casedir": casedir})

    # Creating fields to plot and save
    plot_and_save = dict(plot=False, save=True)
    fields = [
        Pressure(plot_and_save),
        Velocity(plot_and_save),
    ]  #StreamFunction(plot_and_save),
    #]

    # Add fields to postprocessor
    postprocessor.add_fields(fields)

    # Create NSSolver instance and solve problem
    solver = NSSolver(problem, scheme, postprocessor)
    solver.solve()
コード例 #14
0
def test_get():
    pp = PostProcessor()
    velocity = MockVelocity()
    pp.add_field(velocity)

    # Check that compute is triggered
    assert velocity.touched == 0
    assert pp.get("MockVelocity") == "u"
    assert velocity.touched == 1

    # Check that get doesn't trigger second compute count
    pp.get("MockVelocity")
    assert velocity.touched == 1
コード例 #15
0
def test_finalize_all(casedir):
    pp = PostProcessor(dict(casedir=casedir))

    velocity = MockVelocity(dict(finalize=True))
    pressure = MockPressure()
    pp.add_fields([velocity, pressure])

    pp.get("MockVelocity")
    pp.get("MockPressure")

    # Nothing finalized yet
    assert pp._finalized == {}
    assert velocity.finalized is False

    # finalize_all should finalize velocity only
    pp.finalize_all()
    assert pp._finalized == {"MockVelocity": "u"}
    assert velocity.finalized is True

    # Still able to get it
    assert pp.get("MockVelocity") == "u"
コード例 #16
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_store_params(casedir):
    pp = PostProcessor(dict(casedir=casedir))
    params = ParamDict(Field=Field.default_params(),
                       PostProcessor=PostProcessor.default_params())

    pp.store_params(params)

    # Read back params
    params2 = None
    with open(os.path.join(pp.get_casedir(), "params.pickle"), 'r') as f:
        params2 = pickle.load(f)
    assert params2 == params

    str_params2 = open(os.path.join(pp.get_casedir(), "params.txt"),
                       'r').read()
    assert str_params2 == str(params)
コード例 #17
0
def test_add_field():
    pp = PostProcessor()

    pp.add_field(SolutionField("foo"))
    assert "foo" in pp._fields.keys()

    pp += SolutionField("bar")
    assert set(["foo", "bar"]) == set(pp._fields.keys())

    pp += [SolutionField("a"), SolutionField("b")]
    assert set(["foo", "bar", "a", "b"]) == set(pp._fields.keys())

    pp.add_fields([
        MetaField("foo"),
        MetaField2("foo", "bar"),
    ])

    assert set(["foo", "bar", "a", "b", "MetaField_foo",
                "MetaField2_foo_bar"]) == set(pp._fields.keys())
コード例 #18
0
def test_compute_calls():
    pressure = MockPressure()
    velocity = MockVelocity()
    Du = MockVelocityGradient()
    epsilon = MockStrain()
    sigma = MockStress()

    # Add fields to postprocessor
    pp = PostProcessor()
    pp.add_fields([pressure, velocity, Du, epsilon, sigma])

    # Nothing has been computed yet
    assert velocity.touched == 0
    assert Du.touched == 0
    assert epsilon.touched == 0
    assert pressure.touched == 0
    assert sigma.touched == 0

    # Get strain twice
    for i in range(2):
        strain = pp.get("MockStrain")
        # Check value
        assert strain == "epsilon(grad(u))"
        # Check the right things are computed but only the first time
        assert velocity.touched == 1  # Only increased first iteration!
        assert Du.touched == 1  # ...
        assert epsilon.touched == 1  # ...
        assert pressure.touched == 0  # Not computed!
        assert sigma.touched == 0  # ...

    # Get stress twice
    for i in range(2):
        stress = pp.get("MockStress")
        # Check value
        assert stress == "sigma(epsilon(grad(u)), p)"
        # Check the right things are computed but only the first time
        assert velocity.touched == 1  # Not recomputed!
        assert Du.touched == 1  # ...
        assert epsilon.touched == 1  # ...
        assert pressure.touched == 1  # Only increased first iteration!
        assert sigma.touched == 1  # ...
コード例 #19
0
def test_update_all():
    pressure = SolutionField("MockPressure")  #MockPressure()
    velocity = SolutionField("MockVelocity")  #MockVelocity()
    Du = MockVelocityGradient()
    epsilon = MockStrain(dict(start_timestep=3))
    sigma = MockStress(dict(start_time=0.5, end_time=0.8))

    # Add fields to postprocessor
    pp = PostProcessor()
    pp.add_fields([pressure, velocity, Du, epsilon, sigma])

    N = 11
    T = [(i, float(i) / (N - 1)) for i in xrange(N)]

    for timestep, t in T:
        pp.update_all(
            {
                "MockPressure": lambda: "p" + str(timestep),
                "MockVelocity": lambda: "u" + str(timestep)
            }, t, timestep)

        assert Du.touched == timestep + 1

        assert pp._cache[0]["MockPressure"] == "p%d" % timestep
        assert pp._cache[0]["MockVelocity"] == "u%d" % timestep
        assert pp._cache[0]["MockVelocityGradient"] == "grad(u%d)" % timestep

        if timestep >= 3:
            assert pp._cache[0][
                "MockStrain"] == "epsilon(grad(u%d))" % timestep
        else:
            assert "MockStrain" not in pp._cache[0]

        if 0.5 <= t <= 0.8:
            assert pp._cache[0][
                "MockStress"] == "sigma(epsilon(grad(u%d)), p%d)" % (timestep,
                                                                     timestep)
        else:
            assert "MockStress" not in pp._cache[0]
コード例 #20
0
    def replay(self):
        "Replay problem with given postprocessor."
        # Backup play log
        self.backup_playlog()

        # Set up for replay
        replay_plan = self._fetch_history()
        postprocessors = []
        for fieldname, field in self.postproc._fields.items():
            if not (field.params.save or field.params.plot):
                continue

            # Check timesteps covered by current field
            keys = self._check_field_coverage(replay_plan, fieldname)

            # Get the time dependency for the field
            t_dep = min(
                [dep[1]
                 for dep in self.postproc._dependencies[fieldname]] + [0])

            dep_fields = []
            for dep in self.postproc._full_dependencies[fieldname]:
                if dep[0] in ["t", "timestep"]:
                    continue

                if dep[0] in dep_fields:
                    continue

                # Copy dependency and set save/plot to False. If dependency should be
                # plotted/saved, this field will be added separately.
                dependency = self.postproc._fields[dep[0]]
                dependency = copy.copy(dependency)
                dependency.params.save = False
                dependency.params.plot = False
                dependency.params.safe = False

                dep_fields.append(dependency)

            added_to_postprocessor = False
            for i, (ppkeys, ppt_dep, pp) in enumerate(postprocessors):
                if t_dep == ppt_dep and set(keys) == set(ppkeys):
                    pp.add_fields(dep_fields, exists_reaction="ignore")
                    pp.add_field(field, exists_reaction="replace")

                    added_to_postprocessor = True
                    break
                else:
                    continue

            # Create new postprocessor if no suitable postprocessor found
            if not added_to_postprocessor:
                pp = PostProcessor(self.postproc.params, self.postproc._timer)
                pp.add_fields(dep_fields, exists_reaction="ignore")
                pp.add_field(field, exists_reaction="replace")

                postprocessors.append([keys, t_dep, pp])

        postprocessors = sorted(postprocessors,
                                key=itemgetter(1),
                                reverse=True)

        t_independent_fields = []
        for fieldname in self.postproc._fields:
            if self.postproc._full_dependencies[fieldname] == []:
                t_independent_fields.append(fieldname)
            elif min(t for dep, t in
                     self.postproc._full_dependencies[fieldname]) == 0:
                t_independent_fields.append(fieldname)

        # Run replay
        sorted_keys = sorted(replay_plan.keys())
        N = max(sorted_keys)
        for timestep in sorted_keys:
            cbc_print("Processing timestep %d of %d. %.3f%% complete." %
                      (timestep, N, 100.0 * (timestep) / N))

            # Load solution at this timestep (all available fields)
            solution = replay_plan[timestep]
            t = solution.pop("t")

            # Cycle through postprocessors and update if required
            for ppkeys, ppt_dep, pp in postprocessors:
                if timestep in ppkeys:
                    # Add dummy solutions to avoid error when handling dependencies
                    # We know this should work, because it has already been established that
                    # the fields to be computed at this timestep can be computed from stored
                    # solutions.
                    for field in pp._sorted_fields_keys:
                        for dep in reversed(pp._dependencies[field]):
                            if not have_necessary_deps(solution, pp, dep[0]):
                                solution[dep[0]] = lambda: None
                    pp.update_all(solution, t, timestep)

                    # Clear None-objects from solution
                    [
                        solution.pop(k) for k in solution.keys()
                        if not solution[k]
                    ]

                    # Update solution to avoid re-computing data
                    for fieldname, value in pp._cache[0].items():
                        if fieldname in t_independent_fields:
                            value = pp._cache[0][fieldname]
                            #solution[fieldname] = lambda value=value: value # Memory leak!
                            solution[fieldname] = MiniCallable(value)

            self.timer.increment()
            if self.params.check_memory_frequency != 0 and timestep % self.params.check_memory_frequency == 0:
                cbc_print('Memory usage is: %s' %
                          MPI.sum(mpi_comm_world(), get_memory_usage()))

            # Clean up solution: Required to avoid memory leak for some reason...
            for f, v in solution.items():
                if isinstance(v, MiniCallable):
                    v.value = None
                    del v
                    solution.pop(f)

        for ppkeys, ppt_dep, pp in postprocessors:
            pp.finalize_all()
コード例 #21
0
ファイル: test_replay.py プロジェクト: jakobes/cbcpost-fork
def test_basic_replay(mesh, casedir):
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1,0)
    V = spacepool.get_space(1,1)

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([
        MockFunctionField(Q, dict(save=True)),
        MockVectorFunctionField(V, dict(save=True))
    ])

    replay_fields = lambda save: [Norm("MockFunctionField", dict(save=save)),
                             Norm("MockVectorFunctionField", dict(save=save)),
                             TimeIntegral("Norm_MockFunctionField", dict(save=save)),]
    rf_names = [f.name for f in replay_fields(False)]

    # Add fields, but don't save (for testing)
    pp.add_fields(replay_fields(False))

    # Solutions to check against
    checks = {}
    pp.update_all({}, 0.0, 0)
    checks[0] = dict([(name, pp.get(name)) for name in rf_names])
    pp.update_all({}, 0.1, 1)
    checks[1] = dict([(name, pp.get(name)) for name in rf_names])
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()
    checks[2] = dict([(name, pp.get(name)) for name in rf_names])

    # Make sure that nothing is saved yet
    for name in rf_names:
        assert not os.path.isfile(os.path.join(pp.get_savedir(name), name+".db"))
    # ----------- Replay -----------------
    pp = PostProcessor(dict(casedir=casedir))

    pp.add_fields([
        MockFunctionField(Q),
        MockVectorFunctionField(V),
    ])

    # This time, save the fields
    pp.add_fields(replay_fields(True))

    replayer = Replay(pp)
    replayer.replay()

    # Test that replayed solution is the same as computed in the original "solve"
    for name in rf_names:
        data = shelve.open(os.path.join(pp.get_savedir(name), name+".db"), 'r')

        for i in range(3):
            assert data.get(str(i), None) == checks[i][name] or abs(data.get(str(i), None) - checks[i][name]) < 1e-8
        data.close()
コード例 #22
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_shelve_save(mesh, casedir):
    mtf = MockTupleField(dict(save=True, save_as="shelve"))
    msf = MockScalarField(dict(save=True, save_as="shelve"))

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([mtf, msf])

    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()

    for mf in [mtf, msf]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".db"))

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'shelve' in md["0"]
        assert md['saveformats'] == ['shelve']
        md.close()

        # Read back
        data = shelve.open(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".db"), 'r')
        for i in ["0", "1", "2"]:
            d = data[i]
        data.close()

        assert d == pp.get(mf.name)
コード例 #23
0
def test_run_problem(problem_factory, scheme_factory, refinement_level, dt):

    problem = problem_factory(refinement_level, dt)
    scheme = scheme_factory()

    print
    print "."*100
    print "** Problem: %16s ** Solver: %16s" % (problem.shortname(), scheme.shortname())
    print "** Refinement level: %2d ** dt: %.8f" % (refinement_level, dt)
    print "**"

    pp = PostProcessor({"casedir": "test"})
    test_fields = problem.test_fields()
    pp.add_fields(test_fields)

    solver = NSSolver(problem, scheme, pp)

    # Define variables
    values = {f.name: 1e16 for f in test_fields}
    num_dofs = 0
    T = 0

    # Disable printing from solve
    # TODO: Move to NSSolver/set option
    original_stdout = sys.stdout
    sys.stdout = NoOutput()

    try:
        t1 = time.time()
        ns = solver.solve()
        t2 = time.time()
        T = t2-t1

        spaces = ns["spaces"]
        t = float(ns["t"])
        num_dofs = spaces.V.dim()+spaces.Q.dim()

        references = problem.test_references(spaces, t)
        if references:
            assert len(references) == len(test_fields)
            for field, ref in zip(test_fields, references):
                value = pp.get(field.name)
                values[field.name] = l2norm(value, ref)
        else:
            for field in test_fields:
                value = float(pp.get(field.name)) # Only support scalar values in reference data
                values[field.name] = value

    except RuntimeError as re:
        print re.message

    # Enable printing again, and print values
    sys.stdout = original_stdout

    print "** dofs: %d Time spent: %f" % (num_dofs , T)

    #assert values, "No values calculated. Solver most likely failed."
    if all([v==1e16 for v in values.values()]):
        print "No values calculated. Solver most likely failed."

    for tfname, err in values.items():
        print "**** Fieldname: %20s ** Error: %.8e" % (tfname, err)

    # Store solve metadata
    metadata = {}
    metadata["scheme"] = {}
    metadata["scheme"]["name"] = scheme.shortname()
    metadata["scheme"]["params"] = scheme.params

    metadata["problem"] = {}
    metadata["problem"]["name"] = problem.shortname()
    metadata["problem"]["params"] = problem.params

    metadata["num_dofs"] = num_dofs
    metadata["time"] = T

    # Find hash from problem and scheme name+parameters
    hash = sha1()
    hash.update(str(metadata["scheme"]))
    hash.update(str(metadata["problem"]))
    filename = hash.hexdigest() + ".db"

    # Always write to output
    write_output_data(filename, metadata, values)

    # Read reference data values
    ref_metadata, ref_values = read_reference_data(filename)
    if ref_values == {}:
        print "WARNING: Found no reference data"
        return
    assert ref_values != {}, "Found no reference data!"

    # Check each value against reference
    for key in values:
        if key in ref_values:

            # Compute absolute and relative errors
            abs_err = abs(values[key] - ref_values[key])
            if abs(ref_values[key]) > 1e-12:
                err = abs_err / abs(ref_values[key])
            else:
                err = abs_err

            # TODO: Find necessary condition of this check!

            # This one should be chosen such that it always passes when nothing has changed
            strict_tolerance = 1e-8
            if err > strict_tolerance:
                msg = "Error not matching reference with tolerance %e:\n    key=%s,  error=%e,  ref_error=%e  diff=%e,  relative=%e" % (
                    strict_tolerance, key, values[key], ref_values[key], abs_err, err)
                print msg

            # This one should be chosen so that it passes when we're happy
            loose_tolerance = 1e-3
            assert err < loose_tolerance

    # After comparing what we can, check that we have references for everything we computed
    assert set(values.keys()) == set(ref_values.keys()), "Value keys computed and in references are different."
コード例 #24
0
ファイル: restart.py プロジェクト: jakobes/cbcpost-fork
    def get_restart_conditions(self, function_spaces="default"):
        """ Return restart conditions as requested.

        :param dict function_spaces: A dict of dolfin.FunctionSpace on which to return the restart conditions with solution name as key.

        """
        self._pp = PostProcessor(
            dict(casedir=self.params.casedir, clean_casedir=False))

        playlog = self._pp.get_playlog('r')
        assert playlog != {}, "Playlog is empty! Unable to find restart data."

        loadable_solutions = find_solution_presence(self._pp, playlog,
                                                    self.params.solution_names)
        loadables = find_restart_items(self.params.restart_times,
                                       loadable_solutions)

        if function_spaces != "default":
            assert isinstance(
                function_spaces,
                dict), "Expecting function_spaces kwarg to be a dict"
            assert set(loadables.values()[0].keys()) == set(
                function_spaces.keys(
                )), "Expecting a function space for each solution variable"

        def restart_conditions(spaces, loadables):
            # loadables[restart_time0][solution_name] = [(t0, Lt0)] # will load Lt0
            # loadables[restart_time0][solution_name] = [(t0, Lt0), (t1, Lt1)] # will interpolate to restart_time
            functions = {}
            for t in loadables:
                functions[t] = dict()
                for solution_name in loadables[t]:
                    assert len(loadables[t][solution_name]) in [1, 2]

                    if len(loadables[t][solution_name]) == 1:
                        f = loadables[t][solution_name][0][1]()
                    elif len(loadables[t][solution_name]) == 2:
                        # Interpolate
                        t0, Lt0 = loadables[t][solution_name][0]
                        t1, Lt1 = loadables[t][solution_name][1]

                        assert t0 <= t <= t1
                        if Lt0.function is not None:

                            # The copy-function raise a PETSc-error in parallel
                            #f = Function(Lt0())
                            f0 = Lt0()
                            f = Function(f0.function_space())
                            f.vector().axpy(1.0, f0.vector())
                            del f0

                            df = Lt1().vector()
                            df.axpy(-1.0, f.vector())
                            f.vector().axpy((t - t0) / (t1 - t0), df)
                        else:
                            f0 = Lt0()
                            f1 = Lt1()
                            datatype = type(f0)
                            if not issubclass(datatype, Iterable):
                                f0 = [f0]
                                f1 = [f1]

                            f = []
                            for _f0, _f1 in zip(f0, f1):
                                val = _f0 + (t - t0) / (t1 - t0) * (_f1 - _f0)
                                f.append(val)

                            if not issubclass(datatype, Iterable):
                                f = f[0]
                            else:
                                f = datatype(f)

                    if solution_name in spaces:
                        space = spaces[solution_name]
                        if space != f.function_space():
                            #from fenicstools import interpolate_nonmatching_mesh
                            #f = interpolate_nonmatching_mesh(f, space)
                            try:
                                f = interpolate(f, space)
                            except:
                                f = project(f, space)

                    functions[t][solution_name] = f

            return functions

        if function_spaces == "default":
            function_spaces = {}
            for fieldname in loadables.values()[0]:
                try:
                    function_spaces[fieldname] = loadables.values(
                    )[0][fieldname][0][1].function.function_space()
                except AttributeError:
                    # This was not a function field
                    pass

        result = restart_conditions(function_spaces, loadables)

        ts = 0
        while playlog[str(ts)]["t"] < max(loadables) - 1e-14:
            ts += 1
        self.restart_timestep = ts
        playlog.close()
        MPI.barrier(mpi_comm_world())
        if self.params.rollback_casedir:
            self._correct_postprocessing(ts)

        return result
コード例 #25
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_pvd_save(mesh, casedir):
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1, 0)
    V = spacepool.get_space(1, 1)

    mff = MockFunctionField(Q, dict(save=True, save_as="pvd"))
    mvff = MockVectorFunctionField(V, dict(save=True, save_as="pvd"))

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([mff, mvff])

    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()

    for mf, FS in [(mff, Q), (mvff, V)]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".pvd"))
        if MPI.size(mpi_comm_world()) == 1:
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.vtu" % 0))
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.vtu" % 1))
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.vtu" % 2))
        else:
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.pvtu" % 0))
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.pvtu" % 1))
            assert os.path.isfile(
                os.path.join(pp.get_savedir(mf.name),
                             mf.name + "%0.6d.pvtu" % 2))

            for i in range(MPI.size(mpi_comm_world())):
                assert os.path.isfile(
                    os.path.join(pp.get_savedir(mf.name),
                                 mf.name + "_p%d_%0.6d.vtu" % (i, 0)))
                assert os.path.isfile(
                    os.path.join(pp.get_savedir(mf.name),
                                 mf.name + "_p%d_%0.6d.vtu" % (i, 1)))
                assert os.path.isfile(
                    os.path.join(pp.get_savedir(mf.name),
                                 mf.name + "_p%d_%0.6d.vtu" % (i, 2)))

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'pvd' in md["0"]
        assert 'pvd' in md["1"]
        assert 'pvd' in md["2"]
        assert md['saveformats'] == ['pvd']
        md.close()

        assert len(os.listdir(pp.get_savedir(mf.name))) == 1 + 1 + 3 + int(
            MPI.size(mpi_comm_world()) != 1) * MPI.size(mpi_comm_world()) * 3
コード例 #26
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_playlog(casedir):
    pp = PostProcessor(dict(casedir=casedir))

    # Test playlog
    assert not os.path.isfile(os.path.join(casedir, 'play.db'))
    MPI.barrier(mpi_comm_world())

    pp.update_all({}, 0.0, 0)
    pp.finalize_all()

    playlog = pp.get_playlog('r')
    assert playlog == {"0": {"t": 0.0}}
    playlog.close()

    pp.update_all({}, 0.1, 1)
    pp.finalize_all()
    playlog = pp.get_playlog('r')
    assert playlog == {"0": {"t": 0.0}, "1": {"t": 0.1}}
    playlog.close()
コード例 #27
0
ファイル: test_nssolver.py プロジェクト: fmoabreu/fluxo
 def default_params(cls):
     params = PostProcessor.default_params()
     params.update(pp=2)
     return params
コード例 #28
0
 def default_params(cls):
     params = PostProcessor.default_params()
     params.update(pp=2)
     return params
コード例 #29
0
def test_run_problem(problem_factory, scheme_factory, refinement_level, dt):

    problem = problem_factory(refinement_level, dt)
    scheme = scheme_factory()

    print
    print "." * 100
    print "** Problem: %16s ** Solver: %16s" % (problem.shortname(),
                                                scheme.shortname())
    print "** Refinement level: %2d ** dt: %.8f" % (refinement_level, dt)
    print "**"

    pp = PostProcessor({"casedir": "test"})
    test_fields = problem.test_fields()
    pp.add_fields(test_fields)

    solver = NSSolver(problem, scheme, pp)

    # Define variables
    values = {f.name: 1e16 for f in test_fields}
    num_dofs = 0
    T = 0

    # Disable printing from solve
    # TODO: Move to NSSolver/set option
    original_stdout = sys.stdout
    sys.stdout = NoOutput()

    try:
        t1 = time.time()
        ns = solver.solve()
        t2 = time.time()
        T = t2 - t1

        spaces = ns["spaces"]
        t = float(ns["t"])
        num_dofs = spaces.V.dim() + spaces.Q.dim()

        references = problem.test_references(spaces, t)
        if references:
            assert len(references) == len(test_fields)
            for field, ref in zip(test_fields, references):
                value = pp.get(field.name)
                values[field.name] = l2norm(value, ref)
        else:
            for field in test_fields:
                value = float(pp.get(field.name)
                              )  # Only support scalar values in reference data
                values[field.name] = value

    except RuntimeError as re:
        print re.message

    # Enable printing again, and print values
    sys.stdout = original_stdout

    print "** dofs: %d Time spent: %f" % (num_dofs, T)

    #assert values, "No values calculated. Solver most likely failed."
    if all([v == 1e16 for v in values.values()]):
        print "No values calculated. Solver most likely failed."

    for tfname, err in values.items():
        print "**** Fieldname: %20s ** Error: %.8e" % (tfname, err)

    # Store solve metadata
    metadata = {}
    metadata["scheme"] = {}
    metadata["scheme"]["name"] = scheme.shortname()
    metadata["scheme"]["params"] = scheme.params

    metadata["problem"] = {}
    metadata["problem"]["name"] = problem.shortname()
    metadata["problem"]["params"] = problem.params

    metadata["num_dofs"] = num_dofs
    metadata["time"] = T

    # Find hash from problem and scheme name+parameters
    hash = sha1()
    hash.update(str(metadata["scheme"]))
    hash.update(str(metadata["problem"]))
    filename = hash.hexdigest() + ".db"

    # Always write to output
    write_output_data(filename, metadata, values)

    # Read reference data values
    ref_metadata, ref_values = read_reference_data(filename)
    if ref_values == {}:
        print "WARNING: Found no reference data"
        return
    assert ref_values != {}, "Found no reference data!"

    # Check each value against reference
    for key in values:
        if key in ref_values:

            # Compute absolute and relative errors
            abs_err = abs(values[key] - ref_values[key])
            if abs(ref_values[key]) > 1e-12:
                err = abs_err / abs(ref_values[key])
            else:
                err = abs_err

            # TODO: Find necessary condition of this check!

            # This one should be chosen such that it always passes when nothing has changed
            strict_tolerance = 1e-8
            if err > strict_tolerance:
                msg = "Error not matching reference with tolerance %e:\n    key=%s,  error=%e,  ref_error=%e  diff=%e,  relative=%e" % (
                    strict_tolerance, key, values[key], ref_values[key],
                    abs_err, err)
                print msg

            # This one should be chosen so that it passes when we're happy
            loose_tolerance = 1e-3
            assert err < loose_tolerance

    # After comparing what we can, check that we have references for everything we computed
    assert set(values.keys()) == set(ref_values.keys(
    )), "Value keys computed and in references are different."
コード例 #30
0
ファイル: monodomain.py プロジェクト: daveb-dev/xalbrain
def run_splitting_solver(domain, dt, T):

    # Create cardiac model  problem description
    cell_model = Tentusscher_panfilov_2006_epi_cell()
    heart = setup_model(cell_model, domain)

    # Customize and create monodomain solver
    ps = SplittingSolver.default_parameters()
    ps["pde_solver"] = "monodomain"
    ps["apply_stimulus_current_to_pde"] = True

    # 2nd order splitting scheme
    ps["theta"] = 0.5

    # Use explicit first-order Rush-Larsen scheme for the ODEs
    ps["ode_solver_choice"] = "CardiacODESolver"
    ps["CardiacODESolver"]["scheme"] = "RL1"

    # Crank-Nicolson discretization for PDEs in time:
    ps["MonodomainSolver"]["theta"] = 0.5
    ps["MonodomainSolver"]["linear_solver_type"] = "iterative"
    ps["MonodomainSolver"]["algorithm"] = "cg"
    ps["MonodomainSolver"]["preconditioner"] = "petsc_amg"

    # Create solver
    solver = SplittingSolver(heart, params=ps)

    # Extract the solution fields and set the initial conditions
    (vs_, vs, vur) = solver.solution_fields()
    vs_.assign(cell_model.initial_conditions())
    solutions = solver.solve((0, T), dt)


    postprocessor = PostProcessor(dict(casedir="test", clean_casedir=True))
    postprocessor.store_mesh(heart.domain())

    field_params = dict(
        save=True,
        save_as=["hdf5", "xdmf"],
        plot=False,
        start_timestep=-1,
        stride_timestep=1
    )

    postprocessor.add_field(SolutionField("v", field_params))
    theta = ps["theta"]

    # Solve
    total = Timer("XXX Total cbcbeat solver time")
    for i, (timestep, (vs_, vs, vur)) in enumerate(solutions):

        t0, t1 = timestep
        current_t = t0 + theta*(t1 - t0)    
        postprocessor.update_all({"v": lambda: vur}, current_t, i)
        print("Solving on %s" % str(timestep))

        # Print memory usage (just for the fun of it)
        print(memory_usage())

    total.stop()

    # Plot result (as sanity check)
    #plot(vs[0], interactive=True)

    # Stop timer and list timings
    if MPI.rank(mpi_comm_world()) == 0:
        list_timings(TimingClear_keep, [TimingType_wall])
コード例 #31
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_hdf5_save(mesh, casedir):
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1, 0)
    V = spacepool.get_space(1, 1)

    mff = MockFunctionField(Q, dict(save=True, save_as="hdf5"))
    mvff = MockVectorFunctionField(V, dict(save=True, save_as="hdf5"))

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([mff, mvff])

    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()

    for mf, FS in [(mff, Q), (mvff, V)]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".hdf5"))

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'hdf5' in md["0"]
        assert 'hdf5' in md["1"]
        assert 'hdf5' in md["2"]
        assert 'hdf5' in md['saveformats']

        assert md['saveformats'] == ['hdf5']
        md.close()

        assert len(os.listdir(pp.get_savedir(mf.name))) == 2

        # Read back
        hdf5file = HDF5File(
            mpi_comm_world(),
            os.path.join(pp.get_savedir(mf.name), mf.name + ".hdf5"), 'r')
        f = Function(FS)
        for i in ["0", "1", "2"]:
            hdf5file.read(f, mf.name + i)

        assert norm(f) == norm(pp.get(mf.name))
コード例 #32
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_default_save(mesh, casedir):
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1, 0)
    V = spacepool.get_space(1, 1)

    mff = MockFunctionField(Q, dict(save=True))
    mvff = MockVectorFunctionField(V, dict(save=True))
    mtf = MockTupleField(dict(save=True))
    msf = MockScalarField(dict(save=True))

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([mff, mvff, mtf, msf])

    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()

    for mf in [mff, mvff]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".hdf5"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".h5"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".xdmf"))

        assert len(os.listdir(pp.get_savedir(mf.name))) == 4

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'hdf5' in md["0"]
        assert 'hdf5' in md['saveformats']
        assert 'xdmf' in md["0"]
        assert 'xdmf' in md['saveformats']
        assert set(md['saveformats']) == set(['hdf5', 'xdmf'])
        md.close()

    for mf in [mtf, msf]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + ".txt"))

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'txt' in md["0"]
        assert 'txt' in md['saveformats']
        assert 'shelve' in md["0"]
        assert 'shelve' in md['saveformats']
        assert set(md['saveformats']) == set(['txt', 'shelve'])
        md.close()
コード例 #33
0
ファイル: test_saver.py プロジェクト: jakobes/cbcpost-fork
def test_xmlgz_save(mesh, casedir):
    spacepool = SpacePool(mesh)
    Q = spacepool.get_space(1, 0)
    V = spacepool.get_space(1, 1)

    mff = MockFunctionField(Q, dict(save=True, save_as="xml.gz"))
    mvff = MockVectorFunctionField(V, dict(save=True, save_as="xml.gz"))

    pp = PostProcessor(dict(casedir=casedir))
    pp.add_fields([mff, mvff])

    pp.update_all({}, 0.0, 0)
    pp.update_all({}, 0.1, 1)
    pp.update_all({}, 0.2, 2)
    pp.finalize_all()

    for mf, FS in [(mff, Q), (mvff, V)]:
        assert os.path.isdir(pp.get_savedir(mf.name))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "metadata.db"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), "mesh.hdf5"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + "0.xml.gz"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + "1.xml.gz"))
        assert os.path.isfile(
            os.path.join(pp.get_savedir(mf.name), mf.name + "2.xml.gz"))

        md = shelve.open(os.path.join(pp.get_savedir(mf.name), "metadata.db"),
                         'r')
        assert 'xml.gz' in md["0"]
        assert 'xml.gz' in md["1"]
        assert 'xml.gz' in md["2"]
        assert 'xml.gz' in md['saveformats']

        assert md['saveformats'] == ['xml.gz']
        md.close()

        assert len(os.listdir(pp.get_savedir(mf.name))) == 1 + 1 + 3

        # Read back
        for i in ["0", "1", "2"]:
            f = Function(
                FS,
                os.path.join(pp.get_savedir(mf.name), mf.name + i + ".xml.gz"))

        assert norm(f) == norm(pp.get(mf.name))
コード例 #34
0
ファイル: restart.py プロジェクト: jakobes/cbcpost-fork
class Restart(Parameterized):
    """Class to fetch restart conditions through."""
    #def __init__(self, params=None):
    #    Parameterized.__init__(self, params)

    @classmethod
    def default_params(cls):
        """
        Default parameters are:

        +----------------------+-----------------------+-------------------------------------------------------------------+
        |Key                   | Default value         |  Description                                                      |
        +======================+=======================+===================================================================+
        | casedir              | '.'                   | Case directory - relative path to read solutions from             |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | restart_times        | -1                    | float or list of floats to find restart times from. If -1,        |
        |                      |                       | restart from last available time.                                 |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | solution_names       | 'default'             | Solution names to look for. If 'default', will fetch all          |
        |                      |                       | fields stored as SolutionField.                                   |
        +----------------------+-----------------------+-------------------------------------------------------------------+
        | rollback_casedir     | False                 | Rollback case directory by removing all items stored after        |
        |                      |                       | largest restart time. This allows for saving data from a          |
        |                      |                       | restarted simulation in the same case directory.                  |
        +----------------------+-----------------------+-------------------------------------------------------------------+

        """
        params = ParamDict(
            casedir='.',
            restart_times=-1,
            #restart_timesteps=-1,
            solution_names="default",
            rollback_casedir=False,
            #interpolate=True,
            #dt=None,
        )
        return params

    def get_restart_conditions(self, function_spaces="default"):
        """ Return restart conditions as requested.

        :param dict function_spaces: A dict of dolfin.FunctionSpace on which to return the restart conditions with solution name as key.

        """
        self._pp = PostProcessor(
            dict(casedir=self.params.casedir, clean_casedir=False))

        playlog = self._pp.get_playlog('r')
        assert playlog != {}, "Playlog is empty! Unable to find restart data."

        loadable_solutions = find_solution_presence(self._pp, playlog,
                                                    self.params.solution_names)
        loadables = find_restart_items(self.params.restart_times,
                                       loadable_solutions)

        if function_spaces != "default":
            assert isinstance(
                function_spaces,
                dict), "Expecting function_spaces kwarg to be a dict"
            assert set(loadables.values()[0].keys()) == set(
                function_spaces.keys(
                )), "Expecting a function space for each solution variable"

        def restart_conditions(spaces, loadables):
            # loadables[restart_time0][solution_name] = [(t0, Lt0)] # will load Lt0
            # loadables[restart_time0][solution_name] = [(t0, Lt0), (t1, Lt1)] # will interpolate to restart_time
            functions = {}
            for t in loadables:
                functions[t] = dict()
                for solution_name in loadables[t]:
                    assert len(loadables[t][solution_name]) in [1, 2]

                    if len(loadables[t][solution_name]) == 1:
                        f = loadables[t][solution_name][0][1]()
                    elif len(loadables[t][solution_name]) == 2:
                        # Interpolate
                        t0, Lt0 = loadables[t][solution_name][0]
                        t1, Lt1 = loadables[t][solution_name][1]

                        assert t0 <= t <= t1
                        if Lt0.function is not None:

                            # The copy-function raise a PETSc-error in parallel
                            #f = Function(Lt0())
                            f0 = Lt0()
                            f = Function(f0.function_space())
                            f.vector().axpy(1.0, f0.vector())
                            del f0

                            df = Lt1().vector()
                            df.axpy(-1.0, f.vector())
                            f.vector().axpy((t - t0) / (t1 - t0), df)
                        else:
                            f0 = Lt0()
                            f1 = Lt1()
                            datatype = type(f0)
                            if not issubclass(datatype, Iterable):
                                f0 = [f0]
                                f1 = [f1]

                            f = []
                            for _f0, _f1 in zip(f0, f1):
                                val = _f0 + (t - t0) / (t1 - t0) * (_f1 - _f0)
                                f.append(val)

                            if not issubclass(datatype, Iterable):
                                f = f[0]
                            else:
                                f = datatype(f)

                    if solution_name in spaces:
                        space = spaces[solution_name]
                        if space != f.function_space():
                            #from fenicstools import interpolate_nonmatching_mesh
                            #f = interpolate_nonmatching_mesh(f, space)
                            try:
                                f = interpolate(f, space)
                            except:
                                f = project(f, space)

                    functions[t][solution_name] = f

            return functions

        if function_spaces == "default":
            function_spaces = {}
            for fieldname in loadables.values()[0]:
                try:
                    function_spaces[fieldname] = loadables.values(
                    )[0][fieldname][0][1].function.function_space()
                except AttributeError:
                    # This was not a function field
                    pass

        result = restart_conditions(function_spaces, loadables)

        ts = 0
        while playlog[str(ts)]["t"] < max(loadables) - 1e-14:
            ts += 1
        self.restart_timestep = ts
        playlog.close()
        MPI.barrier(mpi_comm_world())
        if self.params.rollback_casedir:
            self._correct_postprocessing(ts)

        return result

    def _correct_postprocessing(self, restart_timestep):
        "Removes data from casedir found at timestep>restart_timestep."
        playlog = self._pp.get_playlog('r')
        playlog_to_remove = {}
        for k, v in playlog.items():
            if int(k) >= restart_timestep:
                #playlog_to_remove[k] = playlog.pop(k)
                playlog_to_remove[k] = playlog[k]
        playlog.close()

        MPI.barrier(mpi_comm_world())
        if on_master_process():
            playlog = self._pp.get_playlog()
            [playlog.pop(k) for k in playlog_to_remove.keys()]
            playlog.close()

        MPI.barrier(mpi_comm_world())
        all_fields_to_clean = []

        for k, v in playlog_to_remove.items():
            if "fields" not in v:
                continue
            else:
                all_fields_to_clean += v["fields"].keys()
        all_fields_to_clean = list(set(all_fields_to_clean))
        for fieldname in all_fields_to_clean:
            self._clean_field(fieldname, restart_timestep)

    def _clean_field(self, fieldname, restart_timestep):
        "Deletes data from field found at timestep>restart_timestep."
        metadata = shelve.open(
            os.path.join(self._pp.get_savedir(fieldname), 'metadata.db'), 'r')
        metadata_to_remove = {}
        for k in metadata.keys():
            #MPI.barrier(mpi_comm_world())
            try:
                k = int(k)
            except:
                continue
            if k >= restart_timestep:
                #metadata_to_remove[str(k)] = metadata.pop(str(k))
                metadata_to_remove[str(k)] = metadata[str(k)]
        metadata.close()
        MPI.barrier(mpi_comm_world())
        if on_master_process():
            metadata = shelve.open(
                os.path.join(self._pp.get_savedir(fieldname), 'metadata.db'),
                'w')
            [metadata.pop(key) for key in metadata_to_remove.keys()]
            metadata.close()
        MPI.barrier(mpi_comm_world())

        # Remove files and data for all save formats
        self._clean_hdf5(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())
        self._clean_files(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_txt(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_shelve(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_xdmf(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

        self._clean_pvd(fieldname, metadata_to_remove)
        MPI.barrier(mpi_comm_world())

    def _clean_hdf5(self, fieldname, del_metadata):
        delete_from_hdf5_file = '''
        namespace dolfin {
            #include <hdf5.h>
            void delete_from_hdf5_file(const MPI_Comm comm,
                                       const std::string hdf5_filename,
                                       const std::string dataset,
                                       const bool use_mpiio)
            {
                //const hid_t plist_id = H5Pcreate(H5P_FILE_ACCESS);
                // Open file existing file for append
                //hid_t file_id = H5Fopen(filename.c_str(), H5F_ACC_RDWR, plist_id);
                hid_t hdf5_file_id = HDF5Interface::open_file(comm, hdf5_filename, "a", use_mpiio);

                H5Ldelete(hdf5_file_id, dataset.c_str(), H5P_DEFAULT);
                HDF5Interface::close_file(hdf5_file_id);
            }
        }
        '''
        cpp_module = compile_extension_module(
            delete_from_hdf5_file,
            additional_system_headers=["dolfin/io/HDF5Interface.h"])

        hdf5filename = os.path.join(self._pp.get_savedir(fieldname),
                                    fieldname + '.hdf5')

        if not os.path.isfile(hdf5filename):
            return

        for k, v in del_metadata.items():
            if 'hdf5' not in v:
                continue
            else:
                cpp_module.delete_from_hdf5_file(
                    mpi_comm_world(), hdf5filename, v['hdf5']['dataset'],
                    MPI.size(mpi_comm_world()) > 1)

        hdf5tmpfilename = os.path.join(self._pp.get_savedir(fieldname),
                                       fieldname + '_tmp.hdf5')
        #import ipdb; ipdb.set_trace()
        MPI.barrier(mpi_comm_world())
        if on_master_process():
            # status, result = getstatusoutput("h5repack -V")
            status, result = -1, -1
            if status != 0:
                cbc_warning(
                    "Unable to run h5repack. Will not repack hdf5-files before replay, which may cause bloated hdf5-files."
                )
            else:
                subprocess.call("h5repack %s %s" %
                                (hdf5filename, hdf5tmpfilename),
                                shell=True)
                os.remove(hdf5filename)
                os.rename(hdf5tmpfilename, hdf5filename)
        MPI.barrier(mpi_comm_world())

    def _clean_files(self, fieldname, del_metadata):
        for k, v in del_metadata.items():
            for i in v.values():
                MPI.barrier(mpi_comm_world())
                try:
                    i["filename"]
                except:
                    continue

                fullpath = os.path.join(self._pp.get_savedir(fieldname),
                                        i['filename'])

                if on_master_process():
                    os.remove(fullpath)
                MPI.barrier(mpi_comm_world())
            """
            #print k,v
            if 'filename' not in v:
                continue
            else:
                fullpath = os.path.join(self.postprocesor.get_savedir(fieldname), v['filename'])
                os.remove(fullpath)
            """

    def _clean_txt(self, fieldname, del_metadata):
        txtfilename = os.path.join(self._pp.get_savedir(fieldname),
                                   fieldname + ".txt")
        if on_master_process() and os.path.isfile(txtfilename):
            txtfile = open(txtfilename, 'r')
            txtfilelines = txtfile.readlines()
            txtfile.close()

            num_lines_to_strp = ['txt' in v
                                 for v in del_metadata.values()].count(True)

            txtfile = open(txtfilename, 'w')
            [txtfile.write(l) for l in txtfilelines[:-num_lines_to_strp]]

            txtfile.close()

    def _clean_shelve(self, fieldname, del_metadata):
        shelvefilename = os.path.join(self._pp.get_savedir(fieldname),
                                      fieldname + ".db")
        if on_master_process():
            if os.path.isfile(shelvefilename):
                shelvefile = shelve.open(shelvefilename, 'c')
                for k, v in del_metadata.items():
                    if 'shelve' in v:
                        shelvefile.pop(str(k))
                shelvefile.close()
        MPI.barrier(mpi_comm_world())

    def _clean_xdmf(self, fieldname, del_metadata):
        basename = os.path.join(self._pp.get_savedir(fieldname), fieldname)
        if os.path.isfile(basename + ".xdmf"):
            MPI.barrier(mpi_comm_world())

            i = 0
            while True:
                h5_filename = basename + "_RS" + str(i) + ".h5"
                if not os.path.isfile(h5_filename):
                    break
                i = i + 1

            xdmf_filename = basename + "_RS" + str(i) + ".xdmf"
            MPI.barrier(mpi_comm_world())

            if on_master_process():
                os.rename(basename + ".h5", h5_filename)
                os.rename(basename + ".xdmf", xdmf_filename)

                f = open(xdmf_filename, 'r').read()

                new_f = open(xdmf_filename, 'w')
                new_f.write(
                    f.replace(
                        os.path.split(basename)[1] + ".h5",
                        os.path.split(h5_filename)[1]))
                new_f.close()
        MPI.barrier(mpi_comm_world())

    def _clean_pvd(self, fieldname, del_metadata):
        if os.path.isfile(
                os.path.join(self._pp.get_savedir(fieldname),
                             fieldname + '.pvd')):
            cbc_warning(
                "No functionality for cleaning pvd-files for restart. Will overwrite."
            )
コード例 #35
0
 def __init__(self, params):
     PostProcessor.__init__(self, params)
コード例 #36
0
ファイル: test_nssolver.py プロジェクト: fmoabreu/fluxo
 def __init__(self, params):
     PostProcessor.__init__(self, params)