示例#1
0
def test_callback_tracker():
    """ test trackers that support a callback """
    data = []

    def store_mean_data(state):
        data.append(state.average)

    def get_mean_data(state):
        return state.average

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    data_tracker = trackers.DataTracker(get_mean_data, interval=0.1)
    callback_tracker = trackers.CallbackTracker(store_mean_data, interval=0.1)
    tracker_list = [data_tracker, callback_tracker]
    pde.solve(state,
              t_range=0.5,
              dt=0.005,
              tracker=tracker_list,
              backend="numpy")

    np.testing.assert_array_equal(data, data_tracker.data)

    data = []

    def store_time(state, t):
        data.append(t)

    def get_time(state, t):
        return t

    grid = UnitGrid([4, 4])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    data_tracker = trackers.DataTracker(get_time, interval=0.1)
    tracker_list = [
        trackers.CallbackTracker(store_time, interval=0.1), data_tracker
    ]
    pde.solve(state,
              t_range=0.5,
              dt=0.005,
              tracker=tracker_list,
              backend="numpy")

    ts = np.arange(0, 0.55, 0.1)
    np.testing.assert_allclose(data, ts, atol=1e-2)
    np.testing.assert_allclose(data_tracker.data, ts, atol=1e-2)
示例#2
0
def test_trackers():
    """ test whether simple trackers can be used """
    times = []

    def store_time(state, t):
        times.append(t)

    def get_data(state):
        return {"integral": state.integral}

    devnull = open(os.devnull, "w")
    data = trackers.DataTracker(get_data, interval=0.1)
    tracker_list = [
        trackers.PrintTracker(interval=0.1, stream=devnull),
        trackers.CallbackTracker(store_time, interval=0.1),
        None,  # should be ignored
        data,
    ]
    if module_available("matplotlib"):
        tracker_list.append(trackers.PlotTracker(interval=0.1, show=False))

    grid = UnitGrid([16, 16])
    state = ScalarField.random_uniform(grid, 0.2, 0.3)
    pde = DiffusionPDE()
    pde.solve(state, t_range=1, dt=0.005, tracker=tracker_list)

    devnull.close()

    assert times == data.times
    if module_available("pandas"):
        df = data.dataframe
        np.testing.assert_allclose(df["time"], times)
        np.testing.assert_allclose(df["integral"], state.integral)
示例#3
0
def test_data_tracker(tmp_path):
    """ test the DataTracker """
    field = ScalarField(UnitGrid([4, 4]))
    eq = DiffusionPDE()

    path = tmp_path / "test_data_tracker.pickle"
    data1 = trackers.DataTracker(lambda f: f.average, filename=path)
    data2 = trackers.DataTracker(lambda f: {"avg": f.average, "int": f.integral})
    eq.solve(field, 10, tracker=[data1, data2])

    with path.open("br") as fp:
        time, data = pickle.load(fp)
    np.testing.assert_allclose(time, np.arange(11))
    assert isinstance(data, list)
    assert len(data) == 11

    assert path.stat().st_size > 0