Пример #1
0
def test_construct_tripolar_grid_ncar(attr_fmt):
    result = construct_tripolar_grid(attr_fmt=attr_fmt,
                                     add_attrs=True,
                                     retain_coords=True)
    assert sorted(list(result.coords)) == ["nlat", "nlon"]
    assert sorted(list(result.dims)) == ["nlat", "nlon"]
    varlist = ["depth", "lat", "lon", "mask", "nlat", "nlon", "wet"]
    assert sorted(list(result.variables)) == varlist
    assert result.nlat.sum() == 666
    assert result.nlon.sum() == 2628
    assert np.allclose(result.depth.to_masked_array(),
                       pytest.tripolar_t.depth.to_masked_array())
    assert np.allclose(result.mask.to_masked_array(),
                       pytest.tripolar_t.mask.to_masked_array())
    assert np.allclose(result.wet.to_masked_array(),
                       pytest.tripolar_t.wet.to_masked_array())
Пример #2
0
def test_construct_tripolar_grid_c(retain_coords):
    result = construct_tripolar_grid(point_type="c",
                                     retain_coords=retain_coords)
    assert isinstance(result, xr.Dataset)
    assert sorted(list(result.coords)) == ["xq", "yq"]
    assert sorted(list(result.dims)) == ["xq", "yq"]
    varlist = (["geolat_c", "geolon_c", "mask", "wet_c", "xq", "yq"]
               if retain_coords else ["mask", "xq", "yq"])
    assert sorted(list(result.variables)) == varlist
    assert result.mask.sum() == 1426.0
    assert result.xq.sum() == -7300.0
    assert np.allclose(result.yq.sum(), 1.42108547e-14)
    assert result.yq.min() == -90.0
    if retain_coords:
        assert np.allclose(result.geolon_c.sum(), -270100.0)
        assert np.allclose(result.geolat_c.sum(), -2184.3828)
        assert result.wet_c.sum() == 1426.0
Пример #3
0
def test_construct_tripolar_grid_v(retain_coords):
    result = construct_tripolar_grid(point_type="v",
                                     retain_coords=retain_coords)
    assert isinstance(result, xr.Dataset)
    assert sorted(list(result.coords)) == ["xh", "yq"]
    assert sorted(list(result.dims)) == ["xh", "yq"]
    varlist = (["geolat_v", "geolon_v", "mask", "wet_v", "xh", "yq"]
               if retain_coords else ["mask", "xh", "yq"])
    assert sorted(list(result.variables)) == varlist
    assert result.mask.sum() == 1512.0
    assert result.xh.sum() == -7200.0
    assert np.allclose(result.yq.sum(), 1.42108547e-14)
    assert result.yq.min() == -90.0
    if retain_coords:
        assert np.allclose(result.geolon_v.sum(), -266400.0)
        assert np.allclose(result.geolat_v.sum(), -2106.3906)
        assert result.wet_v.sum() == 1512.0
Пример #4
0
def test_construct_tripolar_grid_u(retain_coords):
    result = construct_tripolar_grid(point_type="u",
                                     retain_coords=retain_coords)
    assert isinstance(result, xr.Dataset)
    assert sorted(list(result.coords)) == ["xq", "yh"]
    assert sorted(list(result.dims)) == ["xq", "yh"]
    varlist = (["geolat_u", "geolon_u", "mask", "wet_u", "xq", "yh"]
               if retain_coords else ["mask", "xq", "yh"])
    assert sorted(list(result.variables)) == varlist
    assert result.mask.sum() == 1561.0
    assert result.xq.sum() == -7300.0
    assert np.allclose(result.yh.sum(), 1.42108547e-14)
    assert result.yh.min() == -87.5
    if retain_coords:
        assert np.allclose(result.geolon_u.sum(), -262800.0)
        assert np.allclose(result.geolat_u.sum(), -1744.9609)
        assert result.wet_u.sum() == 1561.0
Пример #5
0
def test_construct_tripolar_grid_t(retain_coords):
    result = construct_tripolar_grid(retain_coords=retain_coords)
    assert isinstance(result, xr.Dataset)
    assert sorted(list(result.coords)) == ["xh", "yh"]
    assert sorted(list(result.dims)) == ["xh", "yh"]
    varlist = (["depth", "geolat", "geolon", "mask", "wet", "xh", "yh"]
               if retain_coords else ["depth", "mask", "xh", "yh"])
    assert sorted(list(result.variables)) == varlist
    assert result.depth.sum() == 5558217.0
    assert result.mask.sum() == 1640.0
    assert result.xh.sum() == -7200.0
    assert np.allclose(result.yh.sum(), 1.42108547e-14)
    assert result.yh.min() == -87.5
    if retain_coords:
        assert np.allclose(result.geolon.sum(), -259200.0)
        assert np.allclose(result.geolat.sum(), -1679.41633401)
        assert result.wet.sum() == 1640.0
        pytest.tripolar_t = result
Пример #6
0
def generate_synthetic_dataset(
    dlon,
    dlat,
    startyear,
    nyears,
    varname,
    timeres="mon",
    attrs=None,
    fmt="ncar",
    coords=None,
    generator="normal",
    generator_kwargs=None,
    stats=None,
    static=False,
    data=None,
    grid="standard",
):
    """Generates xarray dataset of syntheic data in NCAR format

    Parameters
    ----------
    dlon : float, optional
        Grid spacing in the x-dimension (longitude)
    dlat : float, optional
        Grid spacing in the y-dimension (latitude)
    startyear : int
        Start year for requested time axis
    nyears : int
        Number of years in requested time axis
    varname : str
        Variable name in output dataset
    attrs : dict, optional
        Variable attributes, by default None
    attrs : dict, optional
        Variable attributes, by default None
    attrs : dict, optional
        Variable attributes, by default None
    stats : tuple or list of tuples
        Array statistics in the format of [(mean,stddev)]
    static : bool
        Flag denoting if variable is static
    grid : str
        Type of output grid, either "standard" or "tripolar",
        by default "standard"

    Returns
    -------
    xarray.Dataset
        Dataset of synthetic data
    """

    attrs = {} if attrs is None else attrs

    # some logical control flags
    do_bounds = True if fmt == "cmip" else False

    # Step 1: set up the horizontal grid
    if grid == "tripolar":
        dset = construct_tripolar_grid(attr_fmt=fmt,
                                       retain_coords=True,
                                       add_attrs=True)
        xyshape = dset["mask"].shape
        latvar = "nlat" if "nlat" in list(dset.variables) else "yh"
        lonvar = "nlon" if "nlon" in list(dset.variables) else "xh"
        lat = dset[latvar]
        lon = dset[lonvar]
    else:
        dset = construct_rect_grid(dlon,
                                   dlat,
                                   add_attrs=True,
                                   attr_fmt=fmt,
                                   bounds=do_bounds)
        lat = dset.lat
        lon = dset.lon
        xyshape = (len(dset["lat"]), len(dset["lon"]))

    # Step 2: set up the time axis
    if static is False:
        if timeres == "mon":
            ds_time = generate_monthly_time_axis(startyear,
                                                 nyears,
                                                 timefmt=fmt)
        elif timeres == "day":
            ds_time = generate_daily_time_axis(startyear, nyears, timefmt=fmt)
        elif timeres == "3hr":
            ds_time = generate_hourly_time_axis(startyear,
                                                nyears,
                                                3,
                                                timefmt=fmt)
        elif timeres == "1hr":
            ds_time = generate_hourly_time_axis(startyear,
                                                nyears,
                                                1,
                                                timefmt=fmt)
        else:
            print(timeres)
            raise ValueError("Unknown time resolution requested")

        dset = ds_time.merge(dset)
        time = dset["time"]
        ntimes = len(time)
    else:
        ntimes = 1

    # Step 3: generate the vertical coordinate
    if stats is not None:
        stats = [stats] if not isinstance(stats, list) else stats
        if len(stats) > 1:
            if fmt == "ncar":
                dset = dset.merge(ncar_hybrid_coord())
                lev = dset.lev
            elif fmt == "gfdl":
                if len(stats) == 19:
                    dset = dset.merge(gfdl_plev19_vertical_coord())
                    lev = dset.plev19
                else:
                    dset = dset.merge(gfdl_vertical_coord())
                    lev = dset.pfull
            elif fmt == "cmip":
                if grid == "tripolar":
                    dset = dset.merge(mom6_z_coord())
                    lev = dset.lev
                else:
                    dset = dset.merge(cmip_vertical_coord())
                    lev = dset.plev
                assert len(stats) == len(
                    lev
                ), f" Length of stats {data.shape[1]} must match number of levels {len(lev)}."

    # Step 4: define the synthetic data generator kernel
    generator_kwargs = {} if generator_kwargs is None else generator_kwargs
    if stats is not None:
        generator_kwargs["stats"] = stats

    assert generator in list(
        generators.__dict__.keys()), f"Unknown generator method: {generator}"
    generator = generators.__dict__[generator]

    # Step 5: generate the synthetic data array
    data = (generators.generate_random_array(xyshape,
                                             ntimes,
                                             generator=generator,
                                             generator_kwargs=generator_kwargs)
            if data is None else data)
    data = data.squeeze()

    # Step 6: convert to Xarray DataArray by assigning coords
    mask = dset["mask"].values if "mask" in dset.variables else 1.0
    data = np.array(data * mask, dtype=np.float32)

    if static is True:
        if len(data.shape) == 4:
            assert data.shape[1] == len(
                lev
            ), f" Length of stats {data.shape[1]} must match number of levels {len(lev)}."
            dset[varname] = xr.DataArray(data,
                                         coords=(lev, lat, lon),
                                         attrs=attrs)
        else:
            dset[varname] = xr.DataArray(data, coords=(lat, lon), attrs=attrs)
    else:
        if len(data.shape) == 4:
            #print(varname)
            assert data.shape[1] == len(
                lev
            ), f" Length of stats {data.shape[1]} must match number of levels {len(lev)}."
            dset[varname] = xr.DataArray(data,
                                         coords=(time, lev, lat, lon),
                                         attrs=attrs)
        else:
            dset[varname] = xr.DataArray(data,
                                         coords=(time, lat, lon),
                                         attrs=attrs)
        dset.set_coords(("lat", "lon"))

    if coords is not None:
        dset[coords["name"]] = xr.DataArray(coords["value"],
                                            attrs=coords["atts"])
        dset[varname].attrs = {
            **dset[varname].attrs, "coordinates": coords["name"]
        }

    dset.attrs["convention"] = fmt

    if fmt == "cmip":
        if "bnds" in dset.variables:
            dset["bnds"].attrs = {"long_name": "vertex number"}
        cmip_global_atts = [
            "external_variables",
            "history",
            "table_id",
            "activity_id",
            "branch_method",
            "branch_time_in_child",
            "branch_time_in_parent",
            "comment",
            "contact",
            "Conventions",
            "creation_date",
            "data_specs_version",
            "experiment",
            "experiment_id",
            "forcing_index",
            "frequency",
            "further_info_url",
            "grid",
            "grid_label",
            "initialization_index",
            "institution",
            "institution_id",
            "license",
            "mip_era",
            "nominal_resolution",
            "parent_activity_id",
            "parent_experiment_id",
            "parent_mip_era",
            "parent_source_id",
            "parent_time_units",
            "parent_variant_label",
            "physics_index",
            "product",
            "realization_index",
            "realm",
            "source",
            "source_id",
            "source_type",
            "sub_experiment",
            "sub_experiment_id",
            "title",
            "tracking_id",
            "variable_id",
            "variant_info",
            "references",
            "variant_label",
        ]

        cmip_global_atts = {x: "" for x in cmip_global_atts}
        dset.attrs = {**dset.attrs, **cmip_global_atts}

    # remove unused fields
    if grid == "tripolar":
        dset = dset.drop_vars(["mask", "wet", "depth"])

    return dset