Beispiel #1
0
class TotalVerticalMotion:
    """Sum up all vertical motions of bedrock and topographic surface,
    respectively.

    Vertical motions may result from external forcing, erosion and/or
    feedback of erosion on tectonics (isostasy).

    """
    #TODO: remove any_upward_vars
    # see https://github.com/benbovy/xarray-simlab/issues/64
    any_upward_vars = xs.group('any_upward')
    bedrock_upward_vars = xs.group('bedrock_upward')
    surface_upward_vars = xs.group('surface_upward')
    surface_downward_vars = xs.group('surface_downward')

    bedrock_upward = xs.variable(
        dims=('y', 'x'),
        intent='out',
        description='bedrock motion in upward direction')
    surface_upward = xs.variable(
        dims=('y', 'x'),
        intent='out',
        description='topographic surface motion in upward direction')

    def run_step(self):
        sum_any = sum(self.any_upward_vars)

        self.bedrock_upward = sum_any + sum(self.bedrock_upward_vars)

        self.surface_upward = (sum_any + sum(self.surface_upward_vars) -
                               sum(self.surface_downward_vars))
Beispiel #2
0
class TotalVerticalMotion:
    """Sum up all vertical motions of bedrock and topographic surface,
    respectively.

    Vertical motions may result from external forcing, erosion and/or
    feedback of erosion on tectonics (isostasy).

    """
    bedrock_upward_vars = xs.group('bedrock_upward')
    surface_upward_vars = xs.group('surface_upward')
    surface_downward_vars = xs.group('surface_downward')

    bedrock_upward = xs.variable(
        dims=('y', 'x'),
        intent='out',
        description='bedrock motion in upward direction'
    )
    surface_upward = xs.variable(
        dims=('y', 'x'),
        intent='out',
        description='topographic surface motion in upward direction'
    )

    def run_step(self):
        self.bedrock_upward = sum(self.bedrock_upward_vars)

        self.surface_upward = (sum(self.surface_upward_vars) -
                               sum(self.surface_downward_vars))
Beispiel #3
0
class ThirdInit(Context):
    """ Inherited by all other model components to access backend"""
    firstinit = xs.group('FirstInit')
    secondinit = xs.group('SecondInit')
    group = xs.variable(intent='out', groups='ThirdInit')

    def initialize(self):
        super(ThirdInit, self).initialize()
        self.group = 3
class Profile:
    u = xs.variable(
        dims="x",
        description="quantity u",
        intent="inout",
        encoding={"fill_value": np.nan},
    )
    u_diffs = xs.group("diff")
    u_opp = xs.on_demand(dims="x")

    def initialize(self):
        self.u_change = np.zeros_like(self.u)

    def run_step(self):
        self.u_change[:] = sum((d for d in self.u_diffs))

    def finalize_step(self):
        self.u += self.u_change

    def finalize(self):
        self.u[:] = 0.0

    @u_opp.compute
    def _get_u_opposite(self):
        return -self.u
Beispiel #5
0
def _make_phydra_flux(label, variable):
    """ """
    xs_var_dict = defaultdict()
    xs_var_dict[label + '_value'] = _convert_2_xsimlabvar(
        var=variable,
        intent='out',
        value_store=True,
        description_label='output of flux value / ')
    group = variable.metadata.get('group')
    group_to_arg = variable.metadata.get('group_to_arg')

    #if group and group_to_arg:
    #    raise Exception("A flux can be either added to group or take a group as argument, not both.")

    if group:
        xs_var_dict[label + '_label'] = _convert_2_xsimlabvar(
            var=variable,
            intent='out',
            groups=group,
            var_dims=(),
            description_label='label reference with group / ',
            attrs=False)
    if group_to_arg:
        xs_var_dict[group_to_arg] = xs.group(group_to_arg)

    return xs_var_dict
class ProfileZ:
    """Compute the evolution of the elevation (z) profile"""

    h_vars = xs.group(
        "h_vars"
    )  #allows for multiple processes influencing; say diffusion and subsidence

    z = xs.variable(dims="x",
                    intent="inout",
                    description="surface elevation z",
                    attrs={"units": "m"})
    br = xs.variable(dims=[(), "x"],
                     intent="in",
                     description="bedrock_elevation",
                     attrs={"units": "m"})
    h = xs.variable(dims="x",
                    intent="inout",
                    description="sed_thickness",
                    attrs={"units": "m"})

    def run_step(self):
        #self._delta_br = sum((br for br in self.br_vars))
        self._delta_h = sum((h for h in self.h_vars))

    def finalize_step(self):
        #self.br += self._delta_br #update bedrock surface
        self.h += self._delta_h  #update sediment thickness
        self.z = self.br + self.h  #add sediment to bedrock to get topo elev.
class ExampleProcess:
    """A process with complete interface for testing."""

    in_var = xs.variable(dims=["x", ("x", "y")], description="input variable")
    out_var = xs.variable(groups="example_group", intent="out")
    inout_var = xs.variable(intent="inout", converter=int)
    od_var = xs.on_demand()
    obj_var = xs.any_object(description="arbitrary object")

    in_foreign_var = xs.foreign(SomeProcess, "some_var")
    in_foreign_var2 = xs.foreign(AnotherProcess, "some_var")
    out_foreign_var = xs.foreign(AnotherProcess, "another_var", intent="out")
    in_foreign_od_var = xs.foreign(SomeProcess, "some_od_var")

    in_global_var = xs.global_ref("some_global_var")
    out_global_var = xs.global_ref("another_global_var", intent="out")

    group_var = xs.group("some_group")
    group_dict_var = xs.group_dict("some_group")

    other_attrib = attr.attrib(init=False, repr=False)
    other_attr = "this is not a xsimlab variable attribute"

    @od_var.compute
    def compute_od_var(self):
        return 0
Beispiel #8
0
class Solver(Context):
    """ Solver process executed last """
    firstinit = xs.group('FirstInit')
    secondinit = xs.group('SecondInit')
    thirdinit = xs.group('ThirdInit')
    fourthinit = xs.group('FourthInit')
    fifthinit = xs.group('FifthInit')

    def initialize(self):
        """"""
        print("assembling model")
        print("SOLVER :", self.m.Solver)
        self.m.assemble()

    @xs.runtime(args="step_delta")
    def run_step(self, dt):
        self.m.solve(dt)
Beispiel #9
0
class TectonicForcing:
    """Sum up all tectonic forcing processes and their effect on the
    vertical motion of the bedrock surface and the topographic
    surface, respectively.

    """
    #TODO: remove any_forcing_vars
    # see https://github.com/benbovy/xarray-simlab/issues/64
    any_forcing_vars = xs.group('any_forcing_upward')
    bedrock_forcing_vars = xs.group('bedrock_forcing_upward')
    surface_forcing_vars = xs.group('surface_forcing_upward')

    bedrock_upward = xs.variable(
        dims=[(), ('y', 'x')],
        intent='out',
        group='bedrock_upward',
        description='imposed vertical motion of bedrock surface')

    surface_upward = xs.variable(
        dims=[(), ('y', 'x')],
        intent='out',
        group='surface_upward',
        description='imposed vertical motion of topographic surface')

    grid_area = xs.foreign(UniformRectilinearGrid2D, 'area')

    domain_rate = xs.on_demand(
        description='domain-integrated volumetric tectonic rate')

    @xs.runtime(args='step_delta')
    def run_step(self, dt):
        self._dt = dt

        sum_any = sum(self.any_forcing_vars)

        self.bedrock_upward = sum_any + sum(self.bedrock_forcing_vars)
        self.surface_upward = sum_any + sum(self.surface_forcing_vars)

    @domain_rate.compute
    def _domain_rate(self):
        return np.sum(self.surface_upward) * self.grid_area / self._dt
class ProfileU:
    """Compute the evolution of the profile of quantity `u`."""

    u_vars = xs.group("u_vars")
    u = xs.variable(
        dims="x", intent="inout", description="quantity u", attrs={"units": "m"}
    )

    def run_step(self):
        self._delta_u = sum((v for v in self.u_vars))

    def finalize_step(self):
        self.u += self._delta_u
Beispiel #11
0
class ProfileU(object):
    """Compute the evolution of the profile of quantity `u`."""

    u_vars = xs.group('u_vars')
    u = xs.variable(dims='x',
                    intent='inout',
                    description='quantity u',
                    attrs={'units': 'm'})

    def run_step(self, *args):
        self._delta_u = sum((v for v in self.u_vars))

    def finalize_step(self):
        self.u += self._delta_u
Beispiel #12
0
class Profile(object):
    u = xs.variable(dims='x', description='quantity u', intent='inout')
    u_diffs = xs.group('diff')
    u_opp = xs.on_demand(dims='x')

    def initialize(self):
        self.u_change = np.zeros_like(self.u)

    def run_step(self, *args):
        self.u_change[:] = np.sum((d for d in self.u_diffs))

    def finalize_step(self):
        self.u += self.u_change

    def finalize(self):
        self.u[:] = 0.

    @u_opp.compute
    def _get_u_opposite(self):
        return -self.u
Beispiel #13
0
class ExampleProcess(object):
    """A process with complete interface for testing."""
    in_var = xs.variable(dims=['x', ('x', 'y')], description='input variable')
    out_var = xs.variable(group='example_group', intent='out')
    inout_var = xs.variable(intent='inout')
    od_var = xs.on_demand()

    in_foreign_var = xs.foreign(SomeProcess, 'some_var')
    in_foreign_var2 = xs.foreign(AnotherProcess, 'some_var')
    out_foreign_var = xs.foreign(AnotherProcess, 'another_var', intent='out')
    in_foreign_od_var = xs.foreign(SomeProcess, 'some_od_var')

    group_var = xs.group('some_group')

    other_attrib = attr.attrib(init=False, cmp=False, repr=False)
    other_attr = "this is not a xsimlab variable attribute"

    @od_var.compute
    def compute_od_var(self):
        return 0
Beispiel #14
0
class TotalErosion:
    """Sum up all erosion processes."""

    erosion_vars = xs.group('erosion')

    cumulative_height = xs.variable(
        dims=[(), ('y', 'x')],
        intent='inout',
        description='erosion height accumulated over time')

    height = xs.variable(dims=[(), ('y', 'x')],
                         intent='out',
                         description='total erosion height at current step',
                         groups='surface_downward')

    rate = xs.on_demand(dims=[(), ('y', 'x')],
                        description='total erosion rate at current step')

    grid_area = xs.foreign(UniformRectilinearGrid2D, 'area')

    domain_rate = xs.on_demand(
        description='domain-integrated volumetric erosion rate')

    @xs.runtime(args='step_delta')
    def run_step(self, dt):
        self._dt = dt

        self.height = sum(self.erosion_vars)
        self.cumulative_height += self.height

    @rate.compute
    def _rate(self):
        return self.height / self._dt

    @domain_rate.compute
    def _domain_rate(self):
        return np.sum(self.height) * self.grid_area / self._dt
Beispiel #15
0
 class B:
     g1 = xs.group("g1")
     g2 = xs.group("g2")
 class Foo:
     bar = xs.group("g")
Beispiel #17
0
class TestSimlabAccessor:

    _clock_key = xr_accessor.SimlabAccessor._clock_key
    _main_clock_key = xr_accessor.SimlabAccessor._main_clock_key
    _output_vars_key = xr_accessor.SimlabAccessor._output_vars_key

    def test_clock_coords(self):
        ds = xr.Dataset(
            coords={
                "mclock": (
                    "mclock",
                    [0, 1, 2],
                    {
                        self._clock_key: 1,
                        self._main_clock_key: 1
                    },
                ),
                "sclock": ("sclock", [0, 2], {
                    self._clock_key: 1
                }),
                "no_clock": ("no_clock", [3, 4]),
            })
        assert set(ds.xsimlab.clock_coords) == {"mclock", "sclock"}

    def test_main_clock_coords(self):
        ds = xr.Dataset(
            coords={
                "mclock": (
                    "mclock",
                    [0, 1, 2],
                    {
                        self._clock_key: 1,
                        self._main_clock_key: 1
                    },
                ),
                "sclock": ("sclock", [0, 2], {
                    self._clock_key: 1
                }),
                "no_clock": ("no_clock", [3, 4]),
            })
        xr.testing.assert_equal(ds.xsimlab.main_clock_coord, ds.mclock)

    def test_master_clock_coords_warning(self):
        ds = xr.Dataset(
            coords={
                "mclock": (
                    "mclock",
                    [0, 1, 2],
                    {
                        self._clock_key: 1,
                        self._main_clock_key: 1
                    },
                ),
                "sclock": ("sclock", [0, 2], {
                    self._clock_key: 1
                }),
                "no_clock": ("no_clock", [3, 4]),
            })
        with pytest.warns(
                FutureWarning,
                match=
                "master_clock_coord is to be deprecated in favour of main_clock",
        ):
            ds.xsimlab.master_clock_coord

    def test_clock_sizes(self):
        ds = xr.Dataset(
            coords={
                "clock1": ("clock1", [0, 1, 2], {
                    self._clock_key: 1
                }),
                "clock2": ("clock2", [0, 2], {
                    self._clock_key: 1
                }),
                "no_clock": ("no_clock", [3, 4]),
            })

        assert ds.xsimlab.clock_sizes == {"clock1": 3, "clock2": 2}

    def test_main_clock_dim(self):
        attrs = {self._clock_key: 1, self._main_clock_key: 1}
        ds = xr.Dataset(coords={"clock": ("clock", [1, 2], attrs)})

        assert ds.xsimlab.main_clock_dim == "clock"
        assert ds.xsimlab._main_clock_dim == "clock"  # cache
        assert ds.xsimlab.main_clock_dim == "clock"  # get cached value

        ds = xr.Dataset()
        assert ds.xsimlab.main_clock_dim is None

    def test_master_clock_dim_warning(self):
        attrs = {self._clock_key: 1, self._main_clock_key: 1}
        ds = xr.Dataset(coords={"clock": ("clock", [1, 2], attrs)})

        with pytest.warns(
                FutureWarning,
                match=
                "master_clock is to be deprecated in favour of main_clock",
        ):
            assert ds.xsimlab.master_clock_dim == "clock"
        # internally, _main_clock_dim is used
        assert ds.xsimlab._main_clock_dim == "clock"  # cache
        with pytest.warns(
                FutureWarning,
                match=
                "master_clock is to be deprecated in favour of main_clock",
        ):
            assert ds.xsimlab.master_clock_dim == "clock"  # get cached value

        ds = xr.Dataset()
        with pytest.warns(
                FutureWarning,
                match=
                "master_clock is to be deprecated in favour of main_clock",
        ):
            assert ds.xsimlab.master_clock_dim is None

    def test_nsteps(self):
        attrs = {self._clock_key: 1, self._main_clock_key: 1}
        ds = xr.Dataset(coords={"clock": ("clock", [1, 2, 3], attrs)})

        assert ds.xsimlab.nsteps == 2

        ds = xr.Dataset()
        assert ds.xsimlab.nsteps == 0

    def test_get_output_save_steps(self):
        attrs = {self._clock_key: 1, self._main_clock_key: 1}
        ds = xr.Dataset(
            coords={
                "clock": ("clock", [0, 1, 2, 3, 4], attrs),
                "clock1": ("clock1", [0, 2, 4], {
                    self._clock_key: 1
                }),
                "clock2": ("clock2", [0, 4], {
                    self._clock_key: 1
                }),
            })

        expected = xr.Dataset(
            coords={"clock": ("clock", [0, 1, 2, 3, 4], attrs)},
            data_vars={
                "clock1": ("clock", [True, False, True, False, True]),
                "clock2": ("clock", [True, False, False, False, True]),
            },
        )

        xr.testing.assert_identical(ds.xsimlab.get_output_save_steps(),
                                    expected)

    def test_set_input_vars(self, model, in_dataset):
        in_vars = {
            ("init_profile", "n_points"): 5,
            ("roll", "shift"): 1,
            ("add", "offset"): ("clock", [1, 2, 3, 4, 5]),
        }

        ds = xr.Dataset(coords={"clock": [0, 2, 4, 6, 8]})
        ds.xsimlab._set_input_vars(model, in_vars)

        for vname in ("init_profile__n_points", "roll__shift", "add__offset"):
            # xr.testing.assert_identical also checks attrs of coordinates
            # (not needed here)
            xr.testing.assert_equal(ds[vname], in_dataset[vname])
            assert ds[vname].attrs == in_dataset[vname].attrs

        # test errors
        in_vars[("not_an", "input_var")] = None

        with pytest.raises(KeyError) as excinfo:
            ds.xsimlab._set_input_vars(model, in_vars)
        assert "not valid key(s)" in str(excinfo.value)

        # test implicit dimension label
        in_vars = {("add", "offset"): [1, 2, 3, 4, 5]}
        ds.xsimlab._set_input_vars(model, in_vars)

        assert ds["add__offset"].dims == ("x", )

        # test implicit dimension label error
        in_vars = {("roll", "shift"): [1, 2]}

        with pytest.raises(TypeError,
                           match=r"Could not get dimension labels.*"):
            ds.xsimlab._set_input_vars(model, in_vars)

    def test_update_clocks(self, model):
        ds = xr.Dataset()
        with pytest.raises(ValueError, match="Cannot determine which clock.*"):
            ds.xsimlab.update_clocks(model=model, clocks={})

        ds = xr.Dataset()
        with pytest.raises(ValueError, match="Cannot determine which clock.*"):
            ds.xsimlab.update_clocks(model=model,
                                     clocks={
                                         "clock": [0, 1, 2],
                                         "out": [0, 2]
                                     })

        ds = xr.Dataset()
        with pytest.raises(KeyError, match="Main clock dimension name.*"):
            ds.xsimlab.update_clocks(
                model=model,
                clocks={"clock": [0, 1, 2]},
                main_clock="non_existing_clock_dim",
            )

        ds = xr.Dataset()
        with pytest.raises(ValueError, match="Invalid dimension.*"):
            ds.xsimlab.update_clocks(
                model=model,
                clocks={"clock": ("x", [0, 1, 2])},
            )

        ds = xr.Dataset()
        with pytest.raises(ValueError, match=".*not synchronized.*"):
            ds.xsimlab.update_clocks(
                model=model,
                clocks={
                    "clock": [0, 1, 2],
                    "out": [0, 0.5, 2]
                },
                main_clock="clock",
            )

        ds = xr.Dataset()
        ds = ds.xsimlab.update_clocks(model=model, clocks={"clock": [0, 1, 2]})
        assert ds.xsimlab.main_clock_dim == "clock"

        ds.clock.attrs[self._output_vars_key] = "profile__u"

        ds = ds.xsimlab.update_clocks(
            model=model,
            clocks={"clock": [0, 1, 2]},
            main_clock={
                "dim": "clock",
                "units": "days since 1-1-1 0:0:0",
                "calendar": "365_days",
            },
        )
        np.testing.assert_array_equal(ds.clock.values, [0, 1, 2])
        assert "units" in ds.clock.attrs
        assert "calendar" in ds.clock.attrs
        assert ds.clock.attrs[self._output_vars_key] == "profile__u"

        new_ds = ds.xsimlab.update_clocks(
            model=model,
            clocks={"clock2": [0, 0.5, 1, 1.5, 2]},
            main_clock="clock2",
        )
        assert new_ds.xsimlab.main_clock_dim == "clock2"

        new_ds = ds.xsimlab.update_clocks(model=model, clocks={"out2": [0, 2]})
        assert new_ds.xsimlab.main_clock_dim == "clock"

        new_ds = ds.xsimlab.update_clocks(model=model,
                                          clocks={"clock": [0, 2, 4]})
        assert new_ds.xsimlab.main_clock_dim == "clock"
        np.testing.assert_array_equal(new_ds.clock.values, [0, 2, 4])

    def test_update_clocks_master_clock_warning(self, model):
        ds = xr.Dataset()
        ds = ds.xsimlab.update_clocks(model=model, clocks={"clock": [0, 1, 2]})
        assert ds.xsimlab.main_clock_dim == "clock"

        ds.clock.attrs[self._output_vars_key] = "profile__u"

        # assert that a warning is raised with correct use of update master clock
        with pytest.warns(
                FutureWarning,
                match=
                "master_clock is to be deprecated in favour of main_clock",
        ):
            ds = ds.xsimlab.update_clocks(
                model=model,
                clocks={"clock": [0, 1, 2]},
                master_clock={
                    "dim": "clock",
                    "units": "days since 1-1-1 0:0:0",
                    "calendar": "365_days",
                },
            )

        np.testing.assert_array_equal(ds.clock.values, [0, 1, 2])
        assert "units" in ds.clock.attrs
        assert "calendar" in ds.clock.attrs
        assert ds.clock.attrs[self._output_vars_key] == "profile__u"

        with pytest.warns(
                FutureWarning,
                match=
                "master_clock is to be deprecated in favour of main_clock",
        ):
            new_ds = ds.xsimlab.update_clocks(
                model=model,
                clocks={"clock2": [0, 0.5, 1, 1.5, 2]},
                master_clock="clock2",
            )
        assert new_ds.xsimlab.main_clock_dim == "clock2"

    def test_update_vars(self, model, in_dataset):
        ds = in_dataset.xsimlab.update_vars(
            model=model,
            input_vars={("roll", "shift"): 2},
            output_vars={("profile", "u"): "out"},
        )

        assert not ds["roll__shift"].equals(in_dataset["roll__shift"])
        assert not ds["out"].identical(in_dataset["out"])

    def test_update_vars_promote_to_coords(self, model, in_dataset):
        # It should be possible to update an input variable with a dimension
        # label that cooresponds to its name (turned into a coordinate). This
        # should not raise any merge conflict error
        ds = in_dataset.xsimlab.update_vars(
            model=model,
            input_vars={"roll__shift": ("roll__shift", [1, 2])},
        )

        assert "roll__shift" in ds.coords

    def test_reset_vars(self, model, in_dataset):
        # add new variable
        ds = xr.Dataset().xsimlab.reset_vars(model)
        assert ds["roll__shift"] == 2

        # overwrite existing variable
        reset_ds = in_dataset.xsimlab.reset_vars(model)
        assert reset_ds["roll__shift"] == 2

    def test_filter_vars(self, simple_model, in_dataset):
        in_dataset["not_a_xsimlab_model_input"] = 1

        filtered_ds = in_dataset.xsimlab.filter_vars(model=simple_model)

        assert "add__offset" not in filtered_ds
        assert "not_a_xsimlab_model_input" not in filtered_ds
        assert sorted(filtered_ds.xsimlab.clock_coords) == ["clock", "out"]
        assert filtered_ds.out.attrs[self._output_vars_key] == "roll__u_diff"

        # test unchanged attributes in original dataset
        assert in_dataset.out.attrs[
            self._output_vars_key] == "roll__u_diff,add__u_diff"
        assert in_dataset.attrs[self._output_vars_key] == "profile__u_opp"

    def test_set_output_vars(self, model):
        ds = xr.Dataset()
        ds["clock"] = (
            "clock",
            [0, 2, 4, 6, 8],
            {
                self._clock_key: 1,
                self._main_clock_key: 1
            },
        )
        ds["out"] = ("out", [0, 4, 8], {self._clock_key: 1})
        ds["not_a_clock"] = ("not_a_clock", [0, 1])

        with pytest.raises(KeyError, match=r".*not valid key.*"):
            ds.xsimlab._set_output_vars(model, {("invalid", "var"): None})

        ds.xsimlab._set_output_vars(model, {("profile", "u_opp"): None})
        assert ds.attrs[self._output_vars_key] == "profile__u_opp"

        ds.xsimlab._set_output_vars(model, {
            ("roll", "u_diff"): "out",
            ("add", "u_diff"): "out"
        })
        expected = "roll__u_diff,add__u_diff"
        assert ds["out"].attrs[self._output_vars_key] == expected

        with pytest.raises(ValueError, match=r".not a valid clock.*"):
            ds.xsimlab._set_output_vars(model,
                                        {("profile", "u"): "not_a_clock"})

        with pytest.warns(FutureWarning):
            ds.xsimlab._set_output_vars(model, {None: ("profile", "u_opp")})

        with pytest.warns(FutureWarning):
            ds.xsimlab._set_output_vars(model, {"out": ("profile", "u_opp")})

    @pytest.mark.parametrize(
        "field",
        [xs.any_object(), xs.group("g"),
         xs.group_dict("g")])
    def test_set_output_object_or_group_vars(self, field):
        @xs.process
        class P:
            var = field

        m = xs.Model({"p": P})
        ds = xr.Dataset()

        with pytest.raises(ValueError,
                           match=r"Object or group variables can't be set.*"):
            ds.xsimlab._set_output_vars(m, {("p", "var"): None})

    def test_output_vars(self, model):
        o_vars = {
            ("profile", "u_opp"): None,
            ("profile", "u"): "clock",
            ("roll", "u_diff"): "out",
            ("add", "u_diff"): "out",
        }

        ds = xs.create_setup(
            model=model,
            clocks={
                "clock": [0, 2, 4, 6, 8],
                "out": [0, 4, 8],
                # snapshot clock with no output variable
                "out2": [0, 8],
            },
            main_clock="clock",
            output_vars=o_vars,
        )

        assert ds.xsimlab.output_vars == o_vars

    def test_output_vars_by_clock(self, model):
        o_vars = {("roll", "u_diff"): "clock", ("add", "u_diff"): None}

        ds = xs.create_setup(
            model=model,
            clocks={"clock": [0, 2, 4, 6, 8]},
            output_vars=o_vars,
        )

        expected = {"clock": [("roll", "u_diff")], None: [("add", "u_diff")]}

        assert ds.xsimlab.output_vars_by_clock == expected

    def test_run(self, model, in_dataset, out_dataset, parallel, scheduler):
        @xs.process
        class ProfileFix(Profile):
            # limitation of using distributed for single-model parallelism
            # internal instance attributes created and used in multiple stage
            # methods are not supported.
            u_change = xs.any_object()

        m = model.update_processes({"profile": ProfileFix})

        out_ds = in_dataset.xsimlab.run(model=m,
                                        parallel=parallel,
                                        scheduler=scheduler)

        xr.testing.assert_equal(out_ds.load(), out_dataset)

    def test_run_safe_mode(self, model, in_dataset):
        # safe mode True: ensure model is cloned (empty state)
        _ = in_dataset.xsimlab.run(model=model, safe_mode=True)
        assert model.state == {}

        # safe mode False: model not cloned (non empty state)
        _ = in_dataset.xsimlab.run(model=model, safe_mode=False)
        assert model.state != {}

    def test_run_check_dims(self):
        @xs.process
        class P:
            var = xs.variable(dims=["x", ("x", "y")])

        m = xs.Model({"p": P})

        arr = np.array([[1, 2], [3, 4]])

        in_ds = xs.create_setup(
            model=m,
            clocks={"clock": [1, 2]},
            input_vars={"p__var": (("y", "x"), arr)},
            output_vars={"p__var": None},
        )

        out_ds = in_ds.xsimlab.run(model=m, check_dims=None)
        actual = out_ds.p__var.values
        np.testing.assert_array_equal(actual, arr)

        with pytest.raises(ValueError, match=r"Invalid dimension.*"):
            in_ds.xsimlab.run(model=m, check_dims="strict")

        out_ds = in_ds.xsimlab.run(model=m,
                                   check_dims="transpose",
                                   safe_mode=False)
        actual = out_ds.p__var.values
        np.testing.assert_array_equal(actual, arr)
        np.testing.assert_array_equal(m.p.var, arr.transpose())

        in_ds2 = in_ds.xsimlab.update_vars(model=m,
                                           output_vars={"p__var": "clock"})
        # TODO: fix update output vars time-independet -> dependent
        # currently need the workaround below
        in_ds2.attrs = {}

        out_ds = in_ds2.xsimlab.run(model=m, check_dims="transpose")
        actual = out_ds.p__var.isel(clock=-1).values
        np.testing.assert_array_equal(actual, arr)

    def test_run_validate(self, model, in_dataset):
        in_dataset["roll__shift"] = 2.5

        # no input validation -> raises within np.roll()
        with pytest.raises(TypeError,
                           match=r"slice indices must be integers.*"):
            in_dataset.xsimlab.run(model=model, validate=None)

        # input validation at initialization -> raises within attr.validate()
        with pytest.raises(TypeError, match=r".*'int'.*"):
            in_dataset.xsimlab.run(model=model, validate="inputs")

        in_dataset["roll__shift"] = ("clock", [1, 2.5, 1, 1, 1])

        # input validation at runtime -> raises within attr.validate()
        with pytest.raises(TypeError, match=r".*'int'.*"):
            in_dataset.xsimlab.run(model=model, validate="inputs")

        @xs.process
        class SetRollShift:
            shift = xs.foreign(Roll, "shift", intent="out")

            def initialize(self):
                self.shift = 2.5

        m = model.update_processes({"set_shift": SetRollShift})

        # no internal validation -> raises within np.roll()
        with pytest.raises(TypeError,
                           match=r"slice indices must be integers.*"):
            in_dataset.xsimlab.run(model=m, validate="inputs")

        # internal validation -> raises within attr.validate()
        with pytest.raises(TypeError, match=r".*'int'.*"):
            in_dataset.xsimlab.run(model=m, validate="all")

    @pytest.mark.parametrize(
        "dims,data,clock",
        [
            ("batch", [1, 2], None),
            (("batch", "clock"), [[1, 1, 1], [2, 2, 2]], "clock"),
            (("batch", "x"), [[1, 1], [2, 2]], None),
        ],
    )
    def test_run_batch_dim(self, dims, data, clock, parallel, scheduler):
        @xs.process
        class P:
            in_var = xs.variable(dims=[(), "x"])
            out_var = xs.variable(dims=[(), "x"], intent="out")
            idx_var = xs.index(dims="x")

            def initialize(self):
                self.idx_var = [0, 1]

            def run_step(self):
                self.out_var = self.in_var * 2

        m = xs.Model({"p": P})

        in_ds = xs.create_setup(
            model=m,
            clocks={"clock": [0, 1, 2]},
            input_vars={"p__in_var": (dims, data)},
            output_vars={"p__out_var": clock},
        )

        out_ds = in_ds.xsimlab.run(
            model=m,
            batch_dim="batch",
            parallel=parallel,
            scheduler=scheduler,
            store=zarr.TempStore(),
        )

        if clock is None:
            coords = {}
        else:
            coords = {"clock": in_ds["clock"]}

        expected = xr.DataArray(data, dims=dims, coords=coords) * 2
        xr.testing.assert_equal(out_ds["p__out_var"], expected)

    @pytest.mark.parametrize(
        "decoding,expected",
        [
            (None, [nan, nan]),  # mask_and_scale=True by default
            ({
                "mask_and_scale": False
            }, [-1, -1]),
        ],
    )
    def test_run_decoding(self, decoding, expected):
        @xs.process
        class P:
            var = xs.variable(dims="x",
                              intent="out",
                              encoding={"fill_value": -1})

            def initialize(self):
                self.var = [-1, -1]

        m = xs.Model({"p": P})

        in_ds = xs.create_setup(
            model=m,
            clocks={"clock": [0, 1]},
            output_vars={"p__var": None},
        )

        out_ds = in_ds.xsimlab.run(model=m, decoding=decoding)

        np.testing.assert_array_equal(out_ds.p__var, expected)