Ejemplo n.º 1
0
 def __init__(self, getter):
     """
     Args:
         getter: the fv3gfs object or a mock of it.
     """
     self._getter = getter
     self._mapper = DerivedMapping(
         FV3StateMapper(getter, alternate_keys=None))
Ejemplo n.º 2
0
def test_DerivedMapping_dataset():
    derived_state = DerivedMapping(ds)
    keys = ["T", "q", "cos_zenith_angle"]
    ds_derived_state = derived_state.dataset(keys)
    assert isinstance(ds_derived_state, xr.Dataset)
    for existing_var in ["T", "q"]:
        np.testing.assert_array_almost_equal(
            ds_derived_state[existing_var], ds[existing_var]
        )
Ejemplo n.º 3
0
def add_derived_data(variables: Sequence[str], ds: xr.Dataset) -> xr.Dataset:
    """
    Overlay the DerivedMapping and grab a dataset of specified variables

    Args:
        variables: All variables (derived and non-derived) to include in the
            dataset.
    """
    derived_mapping = DerivedMapping(ds)
    return derived_mapping.dataset(variables)
Ejemplo n.º 4
0
def test_find_all_required_inputs(dependency_map, derived_vars, reqs):
    for var, dependencies in dependency_map.items():

        @DerivedMapping.register(var, required_inputs=dependencies)
        def var(self):
            return None

    required_inputs = DerivedMapping.find_all_required_inputs(derived_vars)
    assert set(required_inputs) == set(reqs)
    assert len(required_inputs) == len(reqs)
Ejemplo n.º 5
0
def test_net_downward_shortwave_sfc_flux_derived():
    ds = xr.Dataset(
        {
            "surface_diffused_shortwave_albedo": xr.DataArray(
                [0, 0.5, 1.0], dims=["x"]
            ),
            "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface": (
                xr.DataArray([1.0, 1.0, 1.0], dims=["x"])
            ),
        }
    )
    derived_state = DerivedMapping(ds)
    derived_net_sw = derived_state["net_shortwave_sfc_flux_derived"]
    np.testing.assert_array_almost_equal(derived_net_sw, [1.0, 0.5, 0.0])
Ejemplo n.º 6
0
def test_downward_shortwave_sfc_flux_via_transmissivity():
    ds = xr.Dataset(
        {
            "total_sky_downward_shortwave_flux_at_top_of_atmosphere": xr.DataArray(
                [2.0, 1.0, 3.0], dims=["x"]
            ),
            "shortwave_transmissivity_of_atmospheric_column": (
                xr.DataArray([0.5, 0.75, 1.0], dims=["x"])
            ),
        }
    )
    derived_state = DerivedMapping(ds)
    derived_downward_shortwave = derived_state[
        "downward_shortwave_sfc_flux_via_transmissivity"
    ]
    np.testing.assert_array_almost_equal(derived_downward_shortwave, [1.0, 0.75, 3.0])
Ejemplo n.º 7
0
def test_horizontal_wind_tendency_parallel_to_horizontal_wind(
    dQu, dQv, eastward, northward, projection
):
    data = xr.Dataset(
        {
            "dQu": xr.DataArray([dQu], dims=["x"]),
            "dQv": xr.DataArray([dQv], dims=["x"]),
            "eastward_wind": xr.DataArray([eastward], dims=["x"]),
            "northward_wind": xr.DataArray([northward], dims=["x"]),
        }
    )
    derived_mapping = DerivedMapping(data)
    assert pytest.approx(
        derived_mapping[
            "horizontal_wind_tendency_parallel_to_horizontal_wind"
        ].values.item(),
        projection,
    )
Ejemplo n.º 8
0
class DerivedFV3State(MutableMapping):
    """A uniform mapping-like interface to the FV3GFS model state
    
    This class wraps the fv3gfs getters with the FV3StateMapper, that always returns
    DataArray and has time as an attribute (since this isn't a DataArray).
    
    This encapsulates from the details of Quantity
    
    """
    def __init__(self, getter):
        """
        Args:
            getter: the fv3gfs object or a mock of it.
        """
        self._getter = getter
        self._mapper = DerivedMapping(
            FV3StateMapper(getter, alternate_keys=None))

    @property
    def time(self) -> cftime.DatetimeJulian:
        state_time = self._getter.get_state(["time"])["time"]
        return round_time(cftime.DatetimeJulian(*state_time.timetuple()))

    def __getitem__(self, key: Hashable) -> xr.DataArray:
        return self._mapper[key]

    def __setitem__(self, key: str, value: xr.DataArray):
        state_update = _cast_single_to_double({key: value})
        try:
            self._getter.set_state_mass_conserving(
                _data_arrays_to_quantities(state_update))
        except ValueError as e:
            raise KeyError(e)

    def keys(self):
        return self._mapper.keys()

    def update_mass_conserving(
        self,
        items: State,
    ):
        """Update state from another mapping

        This may be faster than setting each item individually. Same as dict.update.
        
        All states except for pressure thicknesses are set in a mass-conserving fashion.
        """
        items_with_attrs = _cast_single_to_double(
            self._assign_attrs_from_mapper(items))

        if DELP in items_with_attrs:
            self._getter.set_state(
                _data_arrays_to_quantities({DELP: items_with_attrs[DELP]}))

        not_pressure = dissoc(items_with_attrs, DELP)
        try:
            self._getter.set_state_mass_conserving(
                _data_arrays_to_quantities(not_pressure))
        except ValueError as e:
            raise KeyError(e)

    def _assign_attrs_from_mapper(self, dst: State) -> State:
        updated = {}
        for name in dst:
            updated[name] = dst[name].assign_attrs(self._mapper[name].attrs)
        return updated

    def __delitem__(self, key: str):
        raise NotImplementedError()

    def __iter__(self):
        return iter(self.keys())

    def __len__(self):
        return len(self.keys())
Ejemplo n.º 9
0
def test_DerivedMapping():
    derived_state = DerivedMapping(ds)
    assert isinstance(derived_state["T"], xr.DataArray)
Ejemplo n.º 10
0
def test_wind_tendency_nonderived():
    # dQu/dQv already exist in data
    derived_mapping = DerivedMapping(ds)
    dQu = derived_mapping["dQu"]
    np.testing.assert_array_almost_equal(dQu, 1.0)
Ejemplo n.º 11
0
def test_rotated_winds(variable):
    data = _dataset_with_d_grid_winds_and_tendencies()
    derived_mapping = DerivedMapping(data)
    np.testing.assert_array_almost_equal(0.0, derived_mapping[variable])
Ejemplo n.º 12
0
def test_is_sea_ice():
    derived_state = DerivedMapping(ds_sfc)
    np.testing.assert_array_almost_equal(derived_state["is_sea_ice"], [0.0, 0.0, 1.0])
Ejemplo n.º 13
0
def test_DerivedMapping_unregistered():
    derived_state = DerivedMapping(ds)
    with pytest.raises(KeyError):
        derived_state["latent_heat_flux"]
Ejemplo n.º 14
0
def test_DerivedMapping__data_arrays():
    derived_state = DerivedMapping(ds)
    keys = ["T", "q"]
    data_arrays = derived_state._data_arrays(keys)
    assert isinstance(data_arrays, Mapping)
    assert set(keys) == set(data_arrays.keys())
Ejemplo n.º 15
0
def test_DerivedMapping_cos_zenith():
    derived_state = DerivedMapping(ds)
    output = derived_state["cos_zenith_angle"]
    assert isinstance(output, xr.DataArray)