def test_callback_tracker(): """ test trackers that support a callback """ data = [] def store_mean_data(state): data.append(state.average) def get_mean_data(state): return state.average grid = UnitGrid([4, 4]) state = ScalarField.random_uniform(grid, 0.2, 0.3) pde = DiffusionPDE() data_tracker = trackers.DataTracker(get_mean_data, interval=0.1) callback_tracker = trackers.CallbackTracker(store_mean_data, interval=0.1) tracker_list = [data_tracker, callback_tracker] pde.solve(state, t_range=0.5, dt=0.005, tracker=tracker_list, backend="numpy") np.testing.assert_array_equal(data, data_tracker.data) data = [] def store_time(state, t): data.append(t) def get_time(state, t): return t grid = UnitGrid([4, 4]) state = ScalarField.random_uniform(grid, 0.2, 0.3) pde = DiffusionPDE() data_tracker = trackers.DataTracker(get_time, interval=0.1) tracker_list = [ trackers.CallbackTracker(store_time, interval=0.1), data_tracker ] pde.solve(state, t_range=0.5, dt=0.005, tracker=tracker_list, backend="numpy") ts = np.arange(0, 0.55, 0.1) np.testing.assert_allclose(data, ts, atol=1e-2) np.testing.assert_allclose(data_tracker.data, ts, atol=1e-2)
def test_trackers(): """ test whether simple trackers can be used """ times = [] def store_time(state, t): times.append(t) def get_data(state): return {"integral": state.integral} devnull = open(os.devnull, "w") data = trackers.DataTracker(get_data, interval=0.1) tracker_list = [ trackers.PrintTracker(interval=0.1, stream=devnull), trackers.CallbackTracker(store_time, interval=0.1), None, # should be ignored data, ] if module_available("matplotlib"): tracker_list.append(trackers.PlotTracker(interval=0.1, show=False)) grid = UnitGrid([16, 16]) state = ScalarField.random_uniform(grid, 0.2, 0.3) pde = DiffusionPDE() pde.solve(state, t_range=1, dt=0.005, tracker=tracker_list) devnull.close() assert times == data.times if module_available("pandas"): df = data.dataframe np.testing.assert_allclose(df["time"], times) np.testing.assert_allclose(df["integral"], state.integral)
def test_data_tracker(tmp_path): """ test the DataTracker """ field = ScalarField(UnitGrid([4, 4])) eq = DiffusionPDE() path = tmp_path / "test_data_tracker.pickle" data1 = trackers.DataTracker(lambda f: f.average, filename=path) data2 = trackers.DataTracker(lambda f: {"avg": f.average, "int": f.integral}) eq.solve(field, 10, tracker=[data1, data2]) with path.open("br") as fp: time, data = pickle.load(fp) np.testing.assert_allclose(time, np.arange(11)) assert isinstance(data, list) assert len(data) == 11 assert path.stat().st_size > 0