Example #1
0
def test_trackers():
    """ test whether simple trackers can be used """
    times = []

    def store_time(state, t):
        times.append(t)

    def get_data(state):
        return {"integral": state.integral}

    devnull = open(os.devnull, "w")
    data = trackers.DataTracker(get_data, interval=0.1)
    tracker_list = [
        trackers.PrintTracker(interval=0.1, stream=devnull),
        trackers.CallbackTracker(store_time, interval=0.1),
        None,  # should be ignored
        data,
    ]
    if module_available("matplotlib"):
        tracker_list.append(trackers.PlotTracker(interval=0.1, show=False))

    grid = UnitGrid([16, 16])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    pde.solve(state, t_range=1, dt=0.005, tracker=tracker_list)

    devnull.close()

    assert times == data.times
    if module_available("pandas"):
        df = data.dataframe
        np.testing.assert_allclose(df["time"], times)
        np.testing.assert_allclose(df["integral"], state.integral)
Example #2
0
def test_steady_state_tracker():
    """ test the SteadyStateTracker """
    storage = MemoryStorage()
    c0 = ScalarField.random_uniform(UnitGrid([5]))
    pde = DiffusionPDE()
    tracker = trackers.SteadyStateTracker(atol=1e-2, rtol=1e-2, progress=True)
    pde.solve(c0, 1e3, dt=0.1, tracker=[tracker, storage.tracker(interval=1e2)])
    assert len(storage) < 9  # finished early
Example #3
0
def test_small_tracker_dt():
    """test the case where the dt of the tracker is smaller than the dt
    of the simulation"""
    storage = MemoryStorage()
    pde = DiffusionPDE()
    c0 = ScalarField.random_uniform(UnitGrid([4, 4]), 0.1, 0.2)
    pde.solve(
        c0, 1e-2, dt=1e-3, method="explicit", tracker=storage.tracker(interval=1e-4)
    )
    assert len(storage) == 11
Example #4
0
def test_plot_movie_tracker(tmp_path):
    """ test whether the plot tracker creates files without errors """
    output_file = tmp_path / "movie.mov"

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid)
    pde = DiffusionPDE()
    tracker = trackers.PlotTracker(movie=output_file, interval=0.1, show=False)

    pde.solve(state, t_range=0.5, dt=0.005, tracker=tracker, backend="numpy")

    assert output_file.stat().st_size > 0
Example #5
0
def test_runtime_tracker():
    """ test the RuntimeTracker """
    s = ScalarField.random_uniform(UnitGrid([128]))
    tracker = trackers.RuntimeTracker("0:01")
    sol = ExplicitSolver(DiffusionPDE())
    con = Controller(sol, t_range=1e4, tracker=["progress", tracker])
    con.run(s, dt=1e-3)
Example #6
0
def test_consistency_tracker():
    """ test the ConsistencyTracker """
    s = ScalarField.random_uniform(UnitGrid([128]))
    sol = ExplicitSolver(DiffusionPDE(1e3))
    con = Controller(sol, t_range=1e5, tracker=["consistency"])
    with np.errstate(all="ignore"):
        con.run(s, dt=1)
    assert con.info["t_final"] < con.info["t_end"]
Example #7
0
def test_plot_tracker(tmp_path):
    """ test whether the plot tracker creates files without errors """
    output_file = tmp_path / "img.png"

    def get_title(state, t):
        return f"{state.integral:g} at {t:g}"

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid)
    pde = DiffusionPDE()
    tracker = trackers.PlotTracker(
        output_file=output_file, title=get_title, interval=0.1, show=False
    )

    pde.solve(state, t_range=0.5, dt=0.005, tracker=tracker, backend="numpy")

    assert output_file.stat().st_size > 0
Example #8
0
def test_data_tracker(tmp_path):
    """ test the DataTracker """
    field = ScalarField(UnitGrid([4, 4]))
    eq = DiffusionPDE()

    path = tmp_path / "test_data_tracker.pickle"
    data1 = trackers.DataTracker(lambda f: f.average, filename=path)
    data2 = trackers.DataTracker(lambda f: {"avg": f.average, "int": f.integral})
    eq.solve(field, 10, tracker=[data1, data2])

    with path.open("br") as fp:
        time, data = pickle.load(fp)
    np.testing.assert_allclose(time, np.arange(11))
    assert isinstance(data, list)
    assert len(data) == 11

    assert path.stat().st_size > 0
Example #9
0
def test_movie_scalar(movie_func, tmp_path):
    """test Movie class"""

    # create some data
    state = ScalarField.random_uniform(UnitGrid([4, 4]))
    eq = DiffusionPDE()
    storage = MemoryStorage()
    tracker = storage.tracker(interval=1)
    eq.solve(state, t_range=2, dt=1e-2, backend="numpy", tracker=tracker)

    # check creating the movie
    path = tmp_path / "test_movie.mov"

    try:
        movie_func(storage, filename=path, progress=False)
    except RuntimeError:
        pass  # can happen when ffmpeg is not installed
    else:
        assert path.stat().st_size > 0
Example #10
0
def test_callback_tracker():
    """ test trackers that support a callback """
    data = []

    def store_mean_data(state):
        data.append(state.average)

    def get_mean_data(state):
        return state.average

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    data_tracker = trackers.DataTracker(get_mean_data, interval=0.1)
    callback_tracker = trackers.CallbackTracker(store_mean_data, interval=0.1)
    tracker_list = [data_tracker, callback_tracker]
    pde.solve(state,
              t_range=0.5,
              dt=0.005,
              tracker=tracker_list,
              backend="numpy")

    np.testing.assert_array_equal(data, data_tracker.data)

    data = []

    def store_time(state, t):
        data.append(t)

    def get_time(state, t):
        return t

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    data_tracker = trackers.DataTracker(get_time, interval=0.1)
    tracker_list = [
        trackers.CallbackTracker(store_time, interval=0.1), data_tracker
    ]
    pde.solve(state,
              t_range=0.5,
              dt=0.005,
              tracker=tracker_list,
              backend="numpy")

    ts = np.arange(0, 0.55, 0.1)
    np.testing.assert_allclose(data, ts, atol=1e-2)
    np.testing.assert_allclose(data_tracker.data, ts, atol=1e-2)