def test_constructor(self, in_dataset, model): invalid_ds = in_dataset.drop("clock") with pytest.raises(ValueError, match=r"Missing master clock.*"): XarraySimulationDriver(invalid_ds, model) invalid_ds = in_dataset.drop("init_profile__n_points") with pytest.raises(KeyError, match=r"Missing variables.*"): XarraySimulationDriver(invalid_ds, model)
def test_multi_index(self, in_dataset, model): # just check that multi-index pass through model run (reset -> zarr -> rebuilt) midx = pd.MultiIndex.from_tuples([(0, 1), (0, 2)], names=["a", "b"]) in_dataset["dummy"] = ("dummy", midx) driver = XarraySimulationDriver(in_dataset, model) driver.run_model() out_dataset = driver.get_results() pd.testing.assert_index_equal(out_dataset.indexes["dummy"], midx)
def test_static_var_as_scalar_coord(self, in_dataset, out_dataset, model): # test that a model input (static variable) given as a scalar coordinate # doesn't cause any trouble in_dataset.coords["init_profile__n_points"] = in_dataset[ "init_profile__n_points"] driver = XarraySimulationDriver(in_dataset, model) driver.run_model() out_ds = driver.get_results() xr.testing.assert_equal(out_ds.reset_coords(), out_dataset)
def test_constructor(self, in_dataset, model): store = {} out_store = InMemoryOutputStore() invalid_ds = in_dataset.drop('clock') with pytest.raises(ValueError) as excinfo: XarraySimulationDriver(invalid_ds, model, store, out_store) assert "Missing master clock" in str(excinfo.value) invalid_ds = in_dataset.drop('init_profile__n_points') with pytest.raises(KeyError) as excinfo: XarraySimulationDriver(invalid_ds, model, store, out_store) assert "Missing variables" in str(excinfo.value)
def test_runtime_context_in_model(in_dataset, model): @xs.process class P: @xs.runtime(args="not_a_runtime_arg") def run_step(self, arg): pass m = model.update_processes({"p": P}) driver = XarraySimulationDriver(in_dataset, m) with pytest.raises(KeyError, match="'not_a_runtime_arg'"): driver.run_model()
def test_finalize_always_called(): @xs.process class P: var = xs.variable(intent="out") def initialize(self): self.var = "initialized" raise RuntimeError() def finalize(self): self.var = "finalized" model = xs.Model({"p": P}) in_dataset = xs.create_setup(model=model, clocks={"clock": [0, 1]}) driver = XarraySimulationDriver(in_dataset, model) try: driver.run_model() except RuntimeError: pass assert model.state[("p", "var")] == "finalized"
def xarray_driver(in_dataset, model): store = {} out_store = InMemoryOutputStore() return XarraySimulationDriver(in_dataset, model, store, out_store)
def xarray_driver(in_dataset, model): return XarraySimulationDriver(in_dataset, model)