Beispiel #1
0
def test_generate_grid_ds():
    # simple case...just the dims
    axis_dims = {'X': 'lon', 'Y': 'lat', 'Z': 'z'}
    axis_coords = {'X': 'llon', 'Y': 'llat', 'Z': 'zz'}
    ds_old = ds_original.copy()
    ds_new = generate_grid_ds(ds_old,
                              axis_dims,
                              boundary_discontinuity={
                                  'lon': 360,
                                  'lat': 180
                              },
                              pad={'z': 'auto'})
    assert_equal(ds_new,
                 ds_out_left.drop(['llon_left', 'llat_left', 'zz_left']))
    # TODO why are they not identical ? assert identical fails
    ds_new = generate_grid_ds(ds_original,
                              axis_dims,
                              axis_coords,
                              boundary_discontinuity={
                                  'lon': 360,
                                  'lat': 180,
                                  'llon': 360,
                                  'llat': 180
                              },
                              pad={
                                  'z': 'auto',
                                  'zz': 'auto'
                              })
    assert_equal(ds_new, ds_out_left)
Beispiel #2
0
def recreate_grid_simple(ds, lon_name="lon", lat_name="lat"):
    ds_full = generate_grid_ds(ds, {
        "X": "x",
        "Y": "y"
    },
                               position=("center", "right"))
    grid = Grid(ds_full, periodic=["X"])
    ds_full = recreate_metrics(ds, grid)
    return ds_full
def load_gos_data(gos_filenames):
    #import xgcm
    from xgcm import Grid
    from xgcm.autogenerate import generate_grid_ds
    # ====== load in all .nc files and combine into one xarray dataset
    gos_map = xr.open_mfdataset(gos_filenames)
    gos_map = gos_map.rename({'latitude': 'lat'}).rename({'longitude': 'lon'})
    gos_select = gos_map  #.sel(time='2016-11-19',lon=slice(10,16),lat=slice(-28,-24))
    #gos_map.ugos
    #dx = gos_map.lon.diff('lon')
    #gos_map['rel_vort'] = gos_map.vgos.diff('lon')/gos_map.lon.diff('lon')

    #gos_select = gos_map #gos_map.sel(time='2016-11-19',lon=slice(10,16),lat=slice(-28,-24))
    # create grid for interpolation, differencing
    #grid = xgcm.Grid(gos_select)
    # for Satellite data:
    # https://xgcm.readthedocs.io/en/latest/autogenerate_examples.html
    ds_full = generate_grid_ds(gos_select, {'X': 'lon', 'Y': 'lat'})
    ds_full.vgos

    grid = Grid(ds_full, periodic=['X'])

    # compute the difference (in degrees) along the longitude and latitude for both the cell center and the cell face
    # need to specify the boundary_discontinutity in order to avoid the introduction of artefacts at the boundary
    dlong = grid.diff(ds_full.lon, 'X', boundary_discontinuity=360)
    dlonc = grid.diff(ds_full.lon_left, 'X', boundary_discontinuity=360)
    #dlonc_wo_discontinuity = grid.diff(ds_full.lon_left, 'X')
    dlatg = grid.diff(ds_full.lat, 'Y', boundary='fill', fill_value=np.nan)
    dlatc = grid.diff(ds_full.lat_left,
                      'Y',
                      boundary='fill',
                      fill_value=np.nan)

    # converted into approximate cartesian distances on a globe.
    ds_full.coords['dxg'], ds_full.coords['dyg'] = dll_dist(
        dlong, dlatg, ds_full.lon, ds_full.lat)
    ds_full.coords['dxc'], ds_full.coords['dyc'] = dll_dist(
        dlonc, dlatc, ds_full.lon, ds_full.lat)

    # Relative vorticity: ζ = ∂ v/∂ x – ∂ u/∂ y
    ds_full['dv_dx'] = grid.diff(ds_full.vgos, 'X') / ds_full.dxg
    ds_full['du_dy'] = grid.diff(
        ds_full.ugos, 'Y', boundary='fill', fill_value=np.nan) / ds_full.dyg
    dv_dx = grid.interp(ds_full['dv_dx'],
                        'Y',
                        boundary='fill',
                        fill_value=np.nan)  # get dv_dx and du_dy on same grid
    du_dy = grid.interp(ds_full['du_dy'],
                        'X',
                        boundary='fill',
                        fill_value=np.nan)
    ds_full['Rel_Vort'] = dv_dx - du_dy

    # Vorticity Rossby Number = ζ / f
    ds_full['Ro'] = ds_full.Rel_Vort / coriolis(ds_full.Rel_Vort.lat_left)

    return ds_full
Beispiel #4
0
def test_generate_grid_ds():
    # This needs more cases
    axis_dims = {"X": "lon", "Y": "lat", "Z": "z"}
    axis_coords = {"X": "llon", "Y": "llat", "Z": "zz"}
    position = ("center", "outer")
    boundary_discontinuity = {"lon": 360, "llon": 360, "lat": 180, "llat": 180}
    pad = {"z": "auto", "zz": "auto"}
    ds = generate_grid_ds(ds_original_left, axis_dims, axis_coords, position,
                          boundary_discontinuity, pad)
    assert_equal(ds, ds_out_outer)
Beispiel #5
0
def test_generate_grid_ds():
    # This needs more cases
    axis_dims = {'X': 'lon', 'Y': 'lat', 'Z': 'z'}
    axis_coords = {'X': 'llon', 'Y': 'llat', 'Z': 'zz'}
    position = ('center', 'outer')
    boundary_discontinuity = {'lon': 360, 'llon': 360, 'lat': 180, 'llat': 180}
    pad = {'z': 'auto', 'zz': 'auto'}
    ds = generate_grid_ds(ds_original_left, axis_dims, axis_coords, position,
                          boundary_discontinuity, pad)
    assert_equal(ds, ds_out_outer)
Beispiel #6
0
def create_full_grid(base_ds, grid_dict=None):
    """Generate a full xgcm-compatible dataset from a reference datasets `base_ds`.
    This dataset should be representing a tracer fields, e.g. the cell center.

    Parameters
    ----------
    base_ds : xr.Dataset
        The reference ('base') datasets, assumed to be at the tracer position/cell center
    grid_dict : dict, optional
        Dictionary with info about the grid staggering.
        Must be encoded using the base_ds attrs (e.g. {'model_name':{'axis_shift':{'X':'left',...}}}).
        If deactivated (default), will load from the internal database for CMIP6 models, by default None

    Returns
    -------
    xr.Dataset
        xgcm compatible dataset
    """

    # load dict with grid shift info for each axis
    if grid_dict is None:
        ff = open(grid_spec, "r")
        grid_dict = yaml.safe_load(ff)
        ff.close()

    source_id = base_ds.attrs["source_id"]
    grid_label = base_ds.attrs["grid_label"]

    # if source_id not in dict, and grid label is gn, warn and ask to submit an issue
    try:
        axis_shift = grid_dict[source_id][grid_label]["axis_shift"]
    except KeyError:
        warnings.warn(
            f"Could not find the source_id/grid_label ({source_id}/{grid_label}) combo in `grid_dict`, returning `None`. Please submit an issue to github: https://github.com/jbusecke/cmip6_preprocessing/issues"
        )
        return None

    position = {k: ("center", axis_shift[k]) for k in axis_shift.keys()}

    axis_dict = {"X": "x", "Y": "y"}

    ds_grid = generate_grid_ds(base_ds,
                               axis_dict,
                               position=position,
                               boundary_discontinuity={"X": 360})

    # TODO: man parse lev and lev_bounds as center and outer dims.
    # I should also be able to do this with `generate_grid_ds`, but here we
    # have the `lev_bounds` with most models, so that is probably more reliable.
    # cheapest solution right now
    if "lev" in ds_grid.dims:
        ds_grid["lev"].attrs["axis"] = "Z"

    return ds_grid
Beispiel #7
0
def test_generate_grid_ds():
    # simple case...just the dims
    axis_dims = {'X': 'lon', 'Y': 'lat', 'Z': 'z'}
    axis_coords = {'X': 'llon', 'Y': 'llat', 'Z': 'zz'}
    ds_old = ds_original.copy()
    ds_new = generate_grid_ds(ds_old, axis_dims,
                              boundary_discontinuity={'lon': 360, 'lat': 180},
                              pad={'z': 'auto'})
    assert_equal(ds_new, ds_out_left.drop(['llon_left',
                                           'llat_left',
                                           'zz_left']))
    # TODO why are they not identical ? assert identical fails
    ds_new = generate_grid_ds(ds_original,
                              axis_dims,
                              axis_coords,
                              boundary_discontinuity={'lon': 360,
                                                      'lat': 180,
                                                      'llon': 360,
                                                      'llat': 180},
                              pad={'z': 'auto', 'zz': 'auto'})
    assert_equal(ds_new, ds_out_left)
def recreate_full_grid_old(ds,
                           lon_bound_name='lon_bounds',
                           lat_bound_name='lat_bounds'):
    ds_full = generate_grid_ds(ds, {
        'X': 'x',
        'Y': 'y'
    },
                               position=('center', 'right'))

    grid = Grid(ds_full)

    # Derive distances at u point
    dlon = 0  # ill interpolate that later
    dlat = ds[lat_bound_name].isel(vertex=1) - ds[lat_bound_name].isel(
        vertex=2)
    # interpolate the centered position
    lat = (ds[lat_bound_name].isel(vertex=1) +
           ds[lat_bound_name].isel(vertex=2)) / 2
    lon = (ds[lon_bound_name].isel(vertex=1) +
           ds[lon_bound_name].isel(vertex=2)) / 2
    _, dy = dll_dist(dlon, dlat, lon, lat)
    # strip coords and rename dims
    ds_full.coords['dye'] = (['y', 'x_right'], dy.data)

    # Derive distances at v point
    dlon = ds[lon_bound_name].isel(vertex=3) - ds[lon_bound_name].isel(
        vertex=2)

    print(dlon)
    # special check for dlon
    temp = dlon.load().data
    temp[temp < 0] = temp[temp < 0] + 360.0
    dlon.data = temp

    dlat = 0  # ill interpolate that later
    # interpolate the centered position
    lat = (ds[lat_bound_name].isel(vertex=3) +
           ds[lat_bound_name].isel(vertex=2)) / 2
    lon = (ds[lon_bound_name].isel(vertex=3) +
           ds[lon_bound_name].isel(vertex=2)) / 2
    dx, _ = dll_dist(dlon, dlat, lon, lat)
    # strip coords and rename dims
    ds_full.coords['dxn'] = (['y_right', 'x'], dx.data)

    # interpolate the missing metrics
    ds_full.coords['dxt'] = grid.interp(ds_full.coords['dxn'], 'Y')
    ds_full.coords['dxne'] = grid.interp(ds_full.coords['dxn'], 'X')

    ds_full.coords['dyt'] = grid.interp(ds_full.coords['dye'], 'X')
    ds_full.coords['dyne'] = grid.interp(ds_full.coords['dye'], 'Y')

    return ds_full
def recreate_full_grid(ds, lon_name="lon", lat_name="lat"):
    ds_full = generate_grid_ds(ds, {
        "X": "x",
        "Y": "y"
    },
                               position=("center", "right"))
    grid = Grid(ds_full, periodic=['X'])

    # infer dx at eastern bound from tracer points
    lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs(
        ds_full.lon.load(), "right")
    lat0 = lat1 = ds_full.lat.load().data
    dx = distance(lon0, lat0, lon1, lat1)
    ds_full.coords["dxe"] = xr.DataArray(dx,
                                         coords=grid.interp(ds_full.lon,
                                                            "X").coords)

    # infer dy at northern bound from tracer points
    lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs(
        ds_full.lat.load(), "right", boundary="extrapolate")

    lon0 = lon1 = ds_full.lon.load().data
    dy = distance(lon0, lat0, lon1, lat1)
    ds_full.coords["dyn"] = xr.DataArray(dy,
                                         coords=grid.interp(
                                             ds_full.lat,
                                             "Y",
                                             boundary="extrapolate").coords)

    # now simply interpolate all the other metrics
    ds_full.coords['dxt'] = grid.interp(ds_full.coords['dxe'], 'X')
    ds_full.coords['dxne'] = grid.interp(ds_full.coords['dxe'],
                                         'Y',
                                         boundary="extrapolate")
    ds_full.coords['dxn'] = grid.interp(ds_full.coords['dxt'],
                                        'Y',
                                        boundary="extrapolate")

    ds_full.coords['dyt'] = grid.interp(ds_full.coords['dyn'],
                                        'Y',
                                        boundary="extrapolate")
    ds_full.coords['dyne'] = grid.interp(ds_full.coords['dyn'], 'X')
    ds_full.coords['dye'] = grid.interp(ds_full.coords['dyt'], 'X')

    ds_full.coords['area_t'] = ds_full.coords['dxt'] * ds_full.coords['dyt']
    ds_full.coords['area_e'] = ds_full.coords['dxe'] * ds_full.coords['dye']
    ds_full.coords['area_ne'] = ds_full.coords['dxne'] * ds_full.coords['dyne']
    ds_full.coords['area_n'] = ds_full.coords['dxn'] * ds_full.coords['dyn']

    # should i return the coords to dask?
    return ds_full
def merge_variables_on_staggered_grid(data_dict,
                                      modelname,
                                      tracer_ref='thetao',
                                      u_ref='uo',
                                      v_ref='vo',
                                      plot=True,
                                      verbose=False):
    """Parses datavariables according to their staggered grid position.
    Should also work for gr variables, which are assumed to be on an A-grid."""

    # extract reference dataarrays (those need to be in there)
    tracer = data_dict[tracer_ref]
    u = data_dict[u_ref]
    v = data_dict[v_ref]

    # determine grid type
    if tracer.lon.equals(u.lon):
        grid_type = 'A'
    else:
        if u.lon.equals(v.lon):
            grid_type = 'B'
        else:
            grid_type = 'C'

    print('Grid Type: %s detected' % grid_type)

    if grid_type == 'A':
        # this should also work with interpolated and obs datasets
        # Just merge everything together
        ds_combined = xr.merge([v for v in data_dict.values()])
        ds_full = generate_grid_ds(ds_combined, {
            "X": "x",
            "Y": "y"
        },
                                   position=("center", "left"))
    else:
        # now determine the axis shift
        lon = {}
        lat = {}

        lon['tracer'] = tracer.lon
        lon['u'] = u.lon
        lon['v'] = v.lon

        lat['tracer'] = tracer.lat
        lat['u'] = u.lat
        lat['v'] = v.lat

        # vizualize the position
        if plot:
            ref_idx = 3
            plt.figure()
            for vi, var in enumerate(['tracer', 'u', 'v']):
                plt.plot(lon[var].isel(x=ref_idx, y=ref_idx),
                         lat[var].isel(x=ref_idx, y=ref_idx),
                         marker='*',
                         markersize=25)
                plt.text(lon[var].isel(x=ref_idx, y=ref_idx),
                         lat[var].isel(x=ref_idx, y=ref_idx),
                         var,
                         ha='center',
                         va='center')
            plt.title('Staggered Grid visualizaton')
            plt.show()

        if verbose:
            print('Determine grid shift')

        lon_diff = lon['tracer'] - lon['u']
        # elinate large values due to boundry disc
        lon_diff = lon_diff.where(abs(lon_diff) < 180)

        lat_diff = lat['tracer'] - lat['v']

        position = dict()
        for axis, diff in zip(['X', 'Y'], [lon_diff, lat_diff]):
            if np.sign(diff.mean().load()) < 0:
                position[axis] = ('center', 'left')
            else:
                position[axis] = ('center', 'right')

        if verbose:
            print('Regenerate grid')

        ds_full = generate_grid_ds(tracer, {
            "X": "x",
            "Y": "y"
        },
                                   position=position)

        # now sort all other variables in accordingly
        def rename(da,
                   da_tracer,
                   da_u,
                   da_v,
                   grid_type,
                   position,
                   verbose=False):
            # check with which variable the lon and lat agree
            rename_dict = {
                'B': {
                    'u': {
                        'x': 'x_' + position['X'][1],
                        'y': 'y_' + position['Y'][1],
                        'lon': 'lon_ne',
                        'lat': 'lat_ne',
                        'vertices_latitude': 'vertices_latitude_ne',
                        'vertices_longitude': 'vertices_longitude_ne',
                    },
                    'v': {
                        'x': 'x_' + position['X'][1],
                        'y': 'y_' + position['Y'][1],
                        'lon': 'lon_ne',
                        'lat': 'lat_ne',
                        'vertices_latitude': 'vertices_latitude_ne',
                        'vertices_longitude': 'vertices_longitude_ne',
                    },
                },
                'C': {
                    'u': {
                        'x': 'x_' + position['X'][1],
                        'lon': 'lon_e',
                        'lat': 'lat_e',
                        'vertices_latitude': 'vertices_latitude_e',
                        'vertices_longitude': 'vertices_longitude_e',
                    },
                    'v': {
                        'y': 'y_' + position['Y'][1],
                        'lon': 'lon_n',
                        'lat': 'lat_n',
                        'vertices_latitude': 'vertices_latitude_n',
                        'vertices_longitude': 'vertices_longitude_n',
                    },
                },
            }

            loc = []

            for data, name in zip([da_tracer, da_u, da_v],
                                  ['tracer', 'u', 'v']):
                if da.lon.equals(data.lon):
                    loc.append(name)
            if len(loc) != 1:
                if grid_type == 'B' and set(loc) == set(['u', 'v']):
                    loc = ['u']
                else:
                    raise RuntimeError('somthing went wrong')

            loc = loc[0]
            if loc != 'tracer':
                re_dict = {
                    k: v
                    for k, v in rename_dict[grid_type][loc].items()
                    if k in da.variables
                }
                da = da.rename(re_dict)
            return da

        if verbose:
            print('Renaming and Merging')
        for k, da in data_dict.items():
            #             print('Merging: %s' %k)
            da_renamed = rename(da,
                                tracer,
                                u,
                                v,
                                grid_type,
                                position,
                                verbose=verbose)
            #           # parse all the coordinate values from the reconstructed dataset to the new dataarray
            # or the merging will create intermediate steps
            for co in da_renamed.coords:
                if co in ds_full.coords:
                    da_renamed.coords[co] = ds_full.coords[co]

            if verbose:
                print('Merge')

            # place all new data_variables into the dataset
            for dvar in da_renamed.data_vars:
                if dvar not in ds_full.data_vars:
                    ds_full[dvar] = da_renamed[dvar]


#             ds_full = xr.merge([ds_full, da_renamed])
    return ds_full
Beispiel #11
0
def merge_variables_on_staggered_grid(
    data_dict,
    modelname,
    tracer_ref="thetao",
    u_ref="uo",
    v_ref="vo",
    plot=True,
    verbose=False,
):
    """Parses datavariables according to their staggered grid position.
    Should also work for gr variables, which are assumed to be on an A-grid."""

    if any([not a in data_dict.keys() for a in [tracer_ref, u_ref, v_ref]]):
        print(
            "NON-REFERENCE MODE. This should just be used for a bunch of variables on the same grid"
        )
        grid_type = "A"

    else:

        # extract reference dataarrays (those need to be in there)
        tracer = data_dict[tracer_ref]
        u = data_dict[u_ref]
        v = data_dict[v_ref]

        # determine grid type
        if tracer.lon.equals(u.lon):
            grid_type = "A"
        else:
            if u.lon.equals(v.lon):
                grid_type = "B"
            else:
                grid_type = "C"

        print("Grid Type: %s detected" % grid_type)

    if grid_type == "A":
        # this should also work with interpolated and obs datasets
        # Just merge everything together
        ds_combined = xr.merge([v for v in data_dict.values()])
        ds_full = generate_grid_ds(ds_combined, {
            "X": "x",
            "Y": "y"
        },
                                   position=("center", "right"))
    else:
        # now determine the axis shift
        lon = {}
        lat = {}

        lon["tracer"] = tracer.lon
        lon["u"] = u.lon
        lon["v"] = v.lon

        lat["tracer"] = tracer.lat
        lat["u"] = u.lat
        lat["v"] = v.lat

        # vizualize the position
        if plot:
            ref_idx = 3
            plt.figure()
            for vi, var in enumerate(["tracer", "u", "v"]):
                plt.plot(
                    lon[var].isel(x=ref_idx, y=ref_idx),
                    lat[var].isel(x=ref_idx, y=ref_idx),
                    marker="*",
                    markersize=25,
                )
                plt.text(
                    lon[var].isel(x=ref_idx, y=ref_idx),
                    lat[var].isel(x=ref_idx, y=ref_idx),
                    var,
                    ha="center",
                    va="center",
                )
            plt.title("Staggered Grid visualizaton")
            plt.show()

        if verbose:
            print("Determine grid shift")

        lon_diff = lon["tracer"] - lon["u"]
        # elinate large values due to boundry disc
        lon_diff = lon_diff.where(abs(lon_diff) < 180)

        lat_diff = lat["tracer"] - lat["v"]

        position = dict()
        for axis, diff in zip(["X", "Y"], [lon_diff, lat_diff]):
            if np.sign(diff.mean().load()) < 0:
                position[axis] = ("center", "left")
            else:
                position[axis] = ("center", "right")

        if verbose:
            print("Regenerate grid")

        ds_full = generate_grid_ds(tracer, {
            "X": "x",
            "Y": "y"
        },
                                   position=position)

        if verbose:
            print("Renaming and Merging")
        for k, da in data_dict.items():
            #             print('Merging: %s' %k)
            da_renamed = rename(da,
                                tracer,
                                u,
                                v,
                                grid_type,
                                position,
                                verbose=verbose)
            # parse all the coordinate values from the reconstructed
            # dataset to the new dataarray
            # or the merging will create intermediate steps
            for co in da_renamed.coords:
                if co in ds_full.coords:
                    da_renamed.coords[co] = ds_full.coords[co]

            if verbose:
                print("Merge")

            # place all new data_variables into the dataset
            for dvar in da_renamed.data_vars:
                if dvar not in ds_full.data_vars:
                    ds_full[dvar] = da_renamed[dvar]
    #             ds_full = xr.merge([ds_full, da_renamed])
    return ds_full
Beispiel #12
0
def add_latlon_metrics(dset, dims=None):
    """
    Infer 2D metrics (latitude/longitude) from gridded data file.

    Parameters
    ----------
    dset : xarray.Dataset
        A dataset open from a file
    dims : dict
        Dimension pair in a dict, e.g., {'lat':'latitude', 'lon':'longitude'}

    Return
    -------
    dset : xarray.Dataset
        Input dataset with appropriated metrics added
    grid : xgcm.Grid
        The grid with appropriated metrics
    """
    lon, lat = None, None

    if dims is None:
        for dim in dimXList:
            if dim in dset or dim in dset.coords:
                lon = dim
                break

        for dim in dimYList:
            if dim in dset or dim in dset.coords:
                lat = dim
                break

        if lon is None or lat is None:
            raise Exception('unknown dimension names in dset, should be in ' +
                            str(dimXList + dimYList))
    else:
        lon, lat = dims['lon'], dims['lat']

    ds = generate_grid_ds(dset, {'X': lon, 'Y': lat})

    coords = ds.coords

    if __is_periodic(coords[lon], 360.0):
        periodic = 'X'
    else:
        periodic = []

    grid = Grid(ds, periodic=periodic)

    na = np.nan

    if 'X' in periodic:
        dlonG = grid.diff(ds[lon], 'X', boundary_discontinuity=360)
        dlonC = grid.diff(ds[lon + '_left'], 'X', boundary_discontinuity=360)
    else:
        dlonG = grid.diff(ds[lon], 'X', boundary='fill', fill_value=na)
        dlonC = grid.diff(ds[lon + '_left'],
                          'X',
                          boundary='fill',
                          fill_value=na)

    dlatG = grid.diff(ds[lat], 'Y', boundary='fill', fill_value=na)
    dlatC = grid.diff(ds[lat + '_left'], 'Y', boundary='fill', fill_value=na)

    coords['dxG'], coords['dyG'] = __dll_dist(dlonG, dlatG, ds[lon], ds[lat])
    coords['dxC'], coords['dyC'] = __dll_dist(dlonC, dlatC, ds[lon], ds[lat])
    coords['rAc'] = ds['dyC'] * ds['dxC']

    metrics = {
        ('X', ): ['dxG', 'dxC'],  # X distances
        ('Y', ): ['dyG', 'dyC'],  # Y distances
        ('X', 'Y'): ['rAc']
    }

    grid._assign_metrics(metrics)

    return ds, grid
Beispiel #13
0
def test_recreate_metrics(xshift, yshift, z_axis):

    # reconstruct all the metrics by hand and compare to inferred output

    # * For now this is a regular lon lat grid. Might need to add some tests for more complex grids.
    # Then again. This will not do a great job for those....

    # create test dataset
    ds = _test_data(z_axis=z_axis)

    # TODO: generalize so this also works with e.g. zonal average sections (which dont have a X axis)
    coord_dict = {"X": "x", "Y": "y"}
    if z_axis:
        coord_dict["Z"] = "lev"

    ds_full = generate_grid_ds(
        ds, coord_dict, position={"X": ("center", xshift), "Y": ("center", yshift)},
    )

    grid = Grid(ds_full)

    ds_metrics, metrics_dict = recreate_metrics(ds_full, grid)

    if z_axis:
        # Check that the bound values are intact (previously those got alterd due to unexpected behaviour of .assign_coords())
        assert "bnds" in ds_metrics.lev_bounds.dims

    # compute the more complex metrics (I could wrap this into a function I guess?)
    lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs(ds.lon.load(), xshift)
    lat0, lat1 = grid.axes["X"]._get_neighbor_data_pairs(ds.lat.load(), xshift)
    dx_gx_expected = distance(lon0, lat0, lon1, lat1)

    lon0, lon1 = grid.axes["Y"]._get_neighbor_data_pairs(ds.lon.load(), yshift)
    lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs(ds.lat.load(), yshift)
    dy_gy_expected = distance(lon0, lat0, lon1, lat1)

    # corner metrics
    # dx
    if yshift == "left":
        # dx
        lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs(
            _interp_vertex_to_bounds(ds_metrics.lon_verticies, "y").isel(bnds=0),
            xshift,
        )
        lat0, lat1 = grid.axes["X"]._get_neighbor_data_pairs(
            ds_metrics.lat_bounds.isel(bnds=0), xshift
        )
    elif yshift == "right":
        lon0, lon1 = grid.axes["X"]._get_neighbor_data_pairs(
            _interp_vertex_to_bounds(ds_metrics.lon_verticies, "y").isel(bnds=1),
            xshift,
        )
        lat0, lat1 = grid.axes["X"]._get_neighbor_data_pairs(
            ds_metrics.lat_bounds.isel(bnds=1), xshift
        )
    dx_gxgy_expected = distance(lon0, lat0, lon1, lat1)

    # dy
    if xshift == "left":
        # dx
        lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs(
            _interp_vertex_to_bounds(ds_metrics.lat_verticies, "x").isel(bnds=0),
            yshift,
        )
        lon0, lon1 = grid.axes["Y"]._get_neighbor_data_pairs(
            ds_metrics.lon_bounds.isel(bnds=0), yshift
        )
    elif xshift == "right":
        lat0, lat1 = grid.axes["Y"]._get_neighbor_data_pairs(
            _interp_vertex_to_bounds(ds_metrics.lat_verticies, "x").isel(bnds=1),
            yshift,
        )
        lon0, lon1 = grid.axes["Y"]._get_neighbor_data_pairs(
            ds_metrics.lon_bounds.isel(bnds=1), yshift
        )
    dy_gxgy_expected = distance(lon0, lat0, lon1, lat1)

    if xshift == "left":
        vertex_points = [0, 1]
    else:
        vertex_points = [2, 3]
    lon0, lon1 = (
        ds_metrics.lon_verticies.isel(vertex=vertex_points[0]),
        ds_metrics.lon_verticies.isel(vertex=vertex_points[1]),
    )
    lat0, lat1 = (
        ds_metrics.lat_verticies.isel(vertex=vertex_points[0]),
        ds_metrics.lat_verticies.isel(vertex=vertex_points[1]),
    )
    dy_gx_expected = distance(lon0, lat0, lon1, lat1)

    if yshift == "left":
        vertex_points = [0, 3]
    else:
        vertex_points = [1, 2]
    lon0, lon1 = (
        ds_metrics.lon_verticies.isel(vertex=vertex_points[0]),
        ds_metrics.lon_verticies.isel(vertex=vertex_points[1]),
    )
    lat0, lat1 = (
        ds_metrics.lat_verticies.isel(vertex=vertex_points[0]),
        ds_metrics.lat_verticies.isel(vertex=vertex_points[1]),
    )
    dx_gy_expected = distance(lon0, lat0, lon1, lat1)

    if z_axis:
        dz_t_expected = ds.lev_bounds.diff("bnds").squeeze().data
    else:
        dz_t_expected = None

    for var, expected in [
        ("dz_t", dz_t_expected),
        (
            "dx_t",
            distance(
                ds_metrics.lon_bounds.isel(bnds=0).data,
                ds_metrics.lat.data,
                ds_metrics.lon_bounds.isel(bnds=1).data,
                ds_metrics.lat.data,
            ),
        ),
        (
            "dy_t",
            distance(
                ds_metrics.lon.data,
                ds_metrics.lat_bounds.isel(bnds=0).data,
                ds_metrics.lon.data,
                ds_metrics.lat_bounds.isel(bnds=1).data,
            ),
        ),
        ("dx_gx", dx_gx_expected),
        ("dy_gy", dy_gy_expected),
        ("dy_gx", dy_gx_expected),
        ("dx_gy", dx_gy_expected),
        ("dy_gxgy", dy_gxgy_expected),
        ("dx_gxgy", dx_gxgy_expected),
    ]:
        if expected is not None:
            print(var)
            control = ds_metrics[var].data
            if expected.shape != control.shape:
                control = control.T
            np.testing.assert_allclose(control, expected)

    if z_axis:
        assert set(["X", "Y", "Z"]).issubset(set(metrics_dict.keys()))
    else:
        assert set(["X", "Y"]).issubset(set(metrics_dict.keys()))
        assert not "Z" in list(metrics_dict.keys())
Beispiel #14
0
def load_wrf(date_in, wrf_nc_dir):
    '''Load wrf data

    Because loading large files could cost much time,
    please use "frames_per_outfile=1" in "namelist.input"
    to generate wrfout* files.

    '''

    # get all wrf output files in the same day
    f_wrf_pattern = os.path.join(wrf_nc_dir,
                                 date_in.strftime('wrfout_*_%Y-%m-%d_*'))
    wrf_list = glob.glob(f_wrf_pattern)

    # omit the directory and get "yyyy-mm-dd_hh:mm:ss"
    wrf_dates = [
        datetime.strptime(
            ntpath.basename(name).split('_', maxsplit=2)[-1],
            '%Y-%m-%d_%H:%M:%S') for name in wrf_list
    ]

    # get the index of the closest datetime
    wrf_index = wrf_dates.index(min(wrf_dates, key=lambda d: abs(d - date_in)))
    wrf_file = wrf_list[wrf_index]

    # read selected wrf data
    logging.info(' ' * 4 + f'Reading {wrf_file} ...')
    wrf = xr.open_dataset(wrf_file)
    wrf.attrs['wrfchem_filename'] = os.path.basename(wrf_file)

    # generate lon_bounds and lat_bounds
    wrf = wrf.rename({'XLONG': 'lon', 'XLAT': 'lat'}).isel(Time=0)
    wrf = generate_grid_ds(wrf, {
        'X': 'west_east',
        'Y': 'south_north'
    },
                           position=('center', 'outer'))

    # convert from potential temperature to absolute temperature (K)
    # http://mailman.ucar.edu/pipermail/wrf-users/2010/001896.html
    # http://mailman.ucar.edu/pipermail/wrf-users/2013/003117.html
    wrf['T'] = (wrf['T'] + 300) * ((wrf['P'] + wrf['PB']) / 1e5)**0.2865

    # generate grid and bounds
    grid_ds = Grid(wrf, periodic=False)
    bnd = 'extrapolate'
    wrf.coords['lon_b'] = grid_ds.interp(grid_ds.interp(wrf['lon'],
                                                        'X',
                                                        boundary=bnd,
                                                        fill_value=np.nan),
                                         'Y',
                                         boundary=bnd,
                                         fill_value=np.nan)
    wrf.coords['lat_b'] = grid_ds.interp(grid_ds.interp(wrf['lat'],
                                                        'X',
                                                        boundary=bnd,
                                                        fill_value=np.nan),
                                         'Y',
                                         boundary=bnd,
                                         fill_value=np.nan)
    logging.info(' ' * 8 + 'Finish reading')

    return wrf_file, wrf