Example #1
0
def _create_grid(dataset, coords, periodic, face_connections):
    """
    Create xgcm grid by adding comodo attributes to the
    dimensions of the dataset.

    Parameters
    ----------
    dataset: xarray.Dataset
    coords: dict
        E.g., {'Y': {Y: None, Yp1: 0.5}}
    periodic: list
        List of periodic axes.
    face_connections: dict
        dictionary specifying grid topology

    Returns
    -------
    grid: xgcm.Grid
    """
    # Clean up comodo (currently force user to specify axis using set_coords).
    for dim in dataset.dims:
        dataset[dim].attrs.pop("axis", None)
        dataset[dim].attrs.pop("c_grid_axis_shift", None)

    # Add comodo attributes.
    # TODO: it is possible to pass grid dict in xgcm.
    #       Should we implement it?
    warn_dims = []
    if coords:
        for axis in coords:
            for dim in coords[axis]:
                if dim not in dataset.dims:
                    warn_dims = warn_dims + [dim]
                else:
                    shift = coords[axis][dim]
                    dataset[dim].attrs["axis"] = axis
                    if shift:
                        dataset[dim].attrs["c_grid_axis_shift"] = str(shift)
    if len(warn_dims) != 0:
        warnings.warn(
            "{} are not dimensions"
            " and are not added"
            " to the grid object.".format(warn_dims),
            stacklevel=2,
        )
    # Create grid
    if face_connections is None:
        grid = xgcm.Grid(dataset, periodic=periodic)
    elif type(face_connections) is str:
        grid = xgcm.Grid(dataset, periodic=periodic)
    else:
        grid = xgcm.Grid(dataset,
                         periodic=periodic,
                         face_connections=face_connections)
    if len(grid.axes) == 0:
        grid = None

    return grid
Example #2
0
def test_get_dx_dims(datagrid_w_attrs, test_coord):
    data = datagrid_w_attrs
    grid = xgcm.Grid(data)
    dx = get_dx(grid, data[test_coord], 'X')
    dy = get_dx(grid, data[test_coord], 'Y')
    assert dx.dims == data[test_coord].dims
    assert dy.dims == data[test_coord].dims
Example #3
0
def _create_grid(dataset, coords, periodic):

    # Clean up comodo (currently force user to specify axis using set_coords).
    for dim in dataset.dims:
        dataset[dim].attrs.pop('axis', None)
        dataset[dim].attrs.pop('c_grid_axis_shift', None)

    # Add comodo attributes.
    # We won't need this step in the future because future versions of xgcm will allow to pass coords in Grid.
    warn_dims = []
    if coords:
        for axis in coords:
            for dim in coords[axis]:
                shift = coords[axis][dim]
                if dim in dataset.dims:
                    dataset[dim].attrs['axis'] = axis
                    if shift:
                        dataset[dim].attrs['c_grid_axis_shift'] = str(shift)
                else:
                    warn_dims = warn_dims + [dim]
    if len(warn_dims) != 0:
        _warnings.warn(
            '{} are not dimensions of the dataset and will be omitted'.format(
                warn_dims),
            stacklevel=2)

    # Create grid
    grid = _xgcm.Grid(dataset, periodic=periodic)
    if len(grid.axes) == 0:
        grid = None

    return grid
Example #4
0
def grid_from_id(g):
    """
    XGCM representation of Aus400 grid 'g'
    """

    res = g[:5]
    subt = g[-1]

    ds = xarray.open_dataset(root / "grids" / f"{g}.nc")

    coords = {
        "X": {
            "center": "longitude"
        },
        "Y": {
            "center": "latitude"
        },
    }

    if subt == "u":
        coords["X"] = {"left": "longitude"}
    elif subt == "v":
        coords["Y"] = {"left": "latitude"}

    return xgcm.Grid(ds, coords)
Example #5
0
    def __init__(self, dset, grid=None):
        """
        Construct a Budget instance using a Dataset
        
        Parameters
        ----------
        dset : xarray.Dataset
            a given Dataset containing MITgcm output diagnostics
        grid : xgcm.Grid
            a given grid that accounted for grid metrics
        
        Return
        ----------
        terms : xarray.Dataset
            A Dataset containing all budget terms
        """
        self.grid = xgcm.Grid(dset) if grid is None else grid
        self.coords = dset.coords.to_dataset().reset_coords()
        self.dset = dset
        self.terms = None

        # self.dset   = dset.reset_coords(drop=True)
        # self.volume = dset.drF * dset.hFacC * dset.rA

        self.BCx = 'periodic' if self.grid.axes['X']._periodic else 'fill'
        self.BCy = 'periodic' if self.grid.axes['Y']._periodic else 'fill'
Example #6
0
def calculate_speed(ds):
    """Calculate speed on the central (T) grid.

    First, interpolate U and V to the central grid, then square, add, and take
    root.

    Parameters
    ----------
    ds : xarray dataset
        A grid-aware dataset as produced by `xorca.lib.preprocess_orca`.

    Returns
    -------
    speed : xarray data array
        A grid-aware data array with the speed in `[m/s]`.

    """
    grid = xgcm.Grid(ds, periodic=["Y", "X"])

    U_cc = grid.interp(ds.vozocrtx, "X", to="center")
    V_cc = grid.interp(ds.vomecrty, "Y", to="center")

    speed = (U_cc**2 + V_cc**2)**0.5

    return speed
Example #7
0
def test_T_1_auto_get_scale_factor_to():
    domcfg_fr = open_domcfg_fr()
    nemo_ds = xr.open_dataset("data/xnemogcm.nemo.nc")
    nemo_ds.load()
    domcfg_to = open_domcfg_to()

    grid_fr = xgcm.Grid(domcfg_fr, periodic=False)
    grid_to = xgcm.Grid(domcfg_to, periodic=False, metrics=_metrics)

    v_fr = (nemo_ds["thetao"] * 0 + 1) * domcfg_fr.tmask
    v_to = remap_vertical(
        v_fr, grid_fr, grid_to, axis="Z", scale_factor_fr=domcfg_fr.e3t_0
    )
    _assert_same_integrated_value(
        v_fr, v_to, e3_fr=domcfg_fr.e3t_0, e3_to=domcfg_to.e3t_0
    )
Example #8
0
def compute_missing_metrics(ds,
                            all_scale_factors=all_scale_factors,
                            time_varying=True,
                            periodic=False):
    """
    Add all possible scale factors to the dataset.

    For the moment, e3t at least (or e3t_0) needs to be present in the dataset.
    May have some boundary issues.
    Will add the metrics to the given dataset. To avoid this, use a ds.copy()

    Parameters
    ----------
    ds : xarray.Dataset
        dataset containing the scale factors. Must be xgcm compatible (e.g. opened with xnemogcm)
    all_scale_factors : list
        list of the scale factors to compute (nothing is done for the scale factors
        already present in *ds*)
        Must be a sublist of: ['e3t', 'e3u', 'e3v', 'e3f', 'e3w', 'e3uw', 'e3vw', 'e3fw']
    time_varying : bool
        Whether to use the time varying scale factors (True) of the constant ones (False, 'e3x_0')

    Returns
    -------
    the new dataset with the scale factors added
    """
    try:
        import xgcm
    except ModuleNotFoundError as e:
        raise ModuleNotFoundError(
            "xgcm is not installed, you need xgcm for this function")
    from warnings import warn

    warn(
        "This function is in pre-phase. Do not expect a high precision, but a good estimate. Some boundary issues may arise."
    )

    grid = xgcm.Grid(ds, periodic=False)

    if not time_varying:
        all_scale_factors = [i + "_0" for i in all_scale_factors]

    for i in all_scale_factors:
        if i not in ds.variables:
            if time_varying:
                vertex = dep_graph[i]
            else:
                vertex = dep_graph[i[:-2]]
            for e3 in vertex.keys():
                if time_varying:
                    e3_nme = e3
                else:
                    e3_nme = e3 + "_0"
                if e3_nme in ds.variables:
                    # we stop at the first one matching
                    ds[i] = grid.interp(ds[e3_nme],
                                        vertex[e3],
                                        boundary="extend")
    return ds
Example #9
0
def compute_geostrophic_velocities(ds, lat, lon, day_offset, days, zF, α, β, g,
                                   f):
    logging.info(
        f"Computing geostrophic velocities at ({lat}°N, {lon}°E) for {days} days..."
    )

    # Reverse z index so we calculate cumulative integrals bottom up
    ds = ds.reindex(Z=ds.Z[::-1], Zl=ds.Zl[::-1])

    # Only pull out the data we need as time has chunk size 1.
    time_slice = slice(day_offset, day_offset + days + 1)

    U = ds.UVEL.isel(time=time_slice)
    V = ds.VVEL.isel(time=time_slice)
    Θ = ds.THETA.isel(time=time_slice)
    S = ds.SALT.isel(time=time_slice)

    # Set up grid metric
    # See: https://xgcm.readthedocs.io/en/latest/grid_metrics.html#Using-metrics-with-xgcm
    ds["drW"] = ds.hFacW * ds.drF  # vertical cell size at u point
    ds["drS"] = ds.hFacS * ds.drF  # vertical cell size at v point
    ds["drC"] = ds.hFacC * ds.drF  # vertical cell size at tracer point

    metrics = {
        ('X', ): ['dxC', 'dxG'],  # X distances
        ('Y', ): ['dyC', 'dyG'],  # Y distances
        ('Z', ): ['drW', 'drS', 'drC'],  # Z distances
        ('X', 'Y'): ['rA', 'rA', 'rAs', 'rAw']  # Areas
    }

    # xgcm grid for calculating derivatives and interpolating
    # Not sure why it's periodic in Y but copied it from the xgcm SOSE example:
    # https://pangeo.io/use_cases/physical-oceanography/SOSE.html#create-xgcm-grid
    grid = xgcm.Grid(ds, metrics=metrics, periodic=('X', 'Y'))

    # Vertical integrals from z'=-Lz to z'=z (cumulative integrals)
    Σdz_dΘdx = grid.cumint(grid.derivative(Θ, 'X'), 'Z', boundary="extend")
    Σdz_dΘdy = grid.cumint(grid.derivative(Θ, 'Y'), 'Z', boundary="extend")
    Σdz_dSdx = grid.cumint(grid.derivative(S, 'X'), 'Z', boundary="extend")
    Σdz_dSdy = grid.cumint(grid.derivative(S, 'Y'), 'Z', boundary="extend")

    # Assuming linear equation of state
    Σdz_dBdx = g * (α * Σdz_dΘdx - β * Σdz_dSdx)
    Σdz_dBdy = g * (α * Σdz_dΘdy - β * Σdz_dSdy)

    # Velocities at depth
    z_bottom = ds.Z.values[0]
    U_d = U.sel(XG=lon, YC=lat, Z=z_bottom, method="nearest")
    V_d = V.sel(XC=lon, YG=lat, Z=z_bottom, method="nearest")

    with ProgressBar():
        U_geo = (U_d - 1 / f * Σdz_dBdy).sel(XC=lon, YG=lat,
                                             method="nearest").values
        V_geo = (V_d + 1 / f * Σdz_dBdx).sel(XG=lon, YC=lat,
                                             method="nearest").values

    return U_geo, V_geo
Example #10
0
def test_T_0_same_fr_and_to():
    domcfg_fr = open_domcfg_fr()

    nemo_ds = xr.open_dataset("data/xnemogcm.nemo.nc")
    nemo_ds.load()
    domcfg_to = domcfg_fr

    grid_fr = xgcm.Grid(domcfg_fr, periodic=False)
    grid_to = xgcm.Grid(domcfg_to, periodic=False)

    v_fr = nemo_ds["thetao"] * 0 * domcfg_fr.tmask
    v_to = remap_vertical(
        v_fr,
        grid_fr,
        grid_to,
        axis="Z",
        scale_factor_fr=domcfg_fr.e3t_0,
        scale_factor_to=domcfg_to.e3t_0,
    )
    _assert_same_domcfg(v_fr, v_to)
Example #11
0
def create_grid(domcfg):
    """Create a xgcm grid based on the domcfg"""
    grid = xgcm.Grid(
        domcfg,
        periodic=False,
        metrics={
            ("X", ): ["e1t", "e1u"],
            ("Y", ): ["e2t", "e2v"],
            ("Z", ): ["e3t_0", "e3w_0"],
        },
    )
    return grid
Example #12
0
def test_W_0_same_fr_and_to():
    domcfg_fr = open_domcfg_fr()

    nemo_ds = xr.open_dataset("data/xnemogcm.nemo.nc")
    nemo_ds.load()
    domcfg_to = domcfg_fr

    grid_fr = xgcm.Grid(domcfg_fr, periodic=False)
    grid_to = xgcm.Grid(domcfg_to, periodic=False)

    v_fr = nemo_ds["woce"] * 0
    try:
        v_to = remap_vertical(
            v_fr,
            grid_fr,
            grid_to,
            axis="Z",
            scale_factor_fr=domcfg_fr.e3t_0,
            scale_factor_to=domcfg_to.e3t_0,
        )
    except NotImplementedError:
        return 0
    _assert_same_domcfg(v_fr, v_to)
Example #13
0
def test_U():
    domcfg_fr = open_domcfg_fr()
    nemo_ds = xr.open_dataset("data/xnemogcm.nemo.nc")
    nemo_ds.load()
    domcfg_to = open_domcfg_to()

    grid_fr = xgcm.Grid(domcfg_fr, periodic=False)
    grid_to = xgcm.Grid(domcfg_to, periodic=False)

    v_fr = nemo_ds["uo"]
    v_to = remap_vertical(
        v_fr,
        grid_fr,
        grid_to,
        axis="Z",
        scale_factor_fr=domcfg_fr.e3u_0,
        scale_factor_to=domcfg_to.e3u_0,
        z_fr=grid_fr.interp(domcfg_fr.gdepw_0, "X", boundary="extend"),
        z_to=domcfg_to.gdepw_0.isel({"x_c": 1, "y_c": 1}).drop_vars(["x_c", "y_c"]),
    )
    _assert_same_integrated_value(
        v_fr, v_to, e3_fr=domcfg_fr.e3u_0, e3_to=domcfg_to.e3u_0
    )
Example #14
0
 def __init__(self, dset):
     """
     Construct a class instance using a Dataset
     
     Parameters
     ----------
     dset : xarray.Dataset
         a given Dataset containing MITgcm output diagnostics
     """
     self.grid = xgcm.Grid(dset)
     self.coords = dset.coords.to_dataset().reset_coords()
     self.dset = dset.reset_coords(drop=True)
     self.volume = dset.drF * dset.hFacC * dset.rA
     self.terms = None
Example #15
0
def get_grid(ds, coords=None, metrics=None, topology="PPN", **kwargs):
    """ Gets xgcm grid for ds """
    import xgcm as xg

    if coords is None:
        coords = get_coords(ds, topology=topology)
    if metrics is None:
        metrics = get_metrics(ds, topology=topology)

    periodic = [dim for (dim, top) in zip("xyz", topology) if top in "PF"]
    return xg.Grid(ds,
                   coords=coords,
                   metrics=metrics,
                   periodic=periodic,
                   **kwargs)
Example #16
0
def create_fv3_grid(
    ds: xr.Dataset,
    x_center: str = constants.COORD_X_CENTER,
    x_outer: str = constants.COORD_X_OUTER,
    y_center: str = constants.COORD_Y_CENTER,
    y_outer: str = constants.COORD_Y_OUTER,
) -> xgcm.Grid:
    """Create an XGCM_ grid from a dataset of FV3 tile data


    Args:
        ds: dataset with a valid tiles dimension. The tile dimension must have a
            corresponding coordinate. To avoid xgcm bugs, this coordinate should start
            with 0. You can make it like this::

                ds = ds.assign_coords(tile=np.arange(6))

        x_center (optional): the dimension name for the x edges
        x_outer (optional): the dimension name for the x edges
        y_center (optional): the dimension name for the y edges
        y_outer (optional): the dimension name for the y edges

    Returns:
        an xgcm grid object. This object can be used to interpolate and differentiate 
        cubed sphere data, please see the XGCM_ documentation for more information.

    Notes:
        See this notebook_ for usage.


    .. _XGCM: https://xgcm.readthedocs.io/en/latest/
    .. _notebook: https://github.com/VulcanClimateModeling/explore/blob/master/noahb/2019-12-06-XGCM.ipynb # noqa

    """

    _validate_tile_coord(ds)

    coords = {
        "x": {
            "center": x_center,
            "outer": x_outer
        },
        "y": {
            "center": y_center,
            "outer": y_outer
        },
    }
    return xgcm.Grid(ds, coords=coords, face_connections=FV3_FACE_CONNECTIONS)
Example #17
0
def calculate_moc(ds, region=""):
    """Calculate the MOC.

    Parameters
    ----------
    ds : xarray dataset
        A grid-aware dataset as produced by `xorca.lib.preprocess_orca`.
    region : str
        A region string.  Examples: `"atl"`, `"pac"`, `"ind"`.
        Defaults to `""`.

    Returns
    -------
    moc : xarray data array
        A grid-aware data array with the moc for the specified region.  The
        data array will have a coordinate called `"lat_moc{region}"` which is
        the weighted horizontal and vertical avarage of the latitude of all
        latitudes for the given point on the y-axis.

    """
    grid = xgcm.Grid(ds, periodic=["Y", "X"])

    vmaskname = "vmask" + region
    mocname = "moc" + region
    latname = "lat_moc" + region

    weights = ds[vmaskname] * ds.e3v * ds.e1v

    Ve3 = weights * ds.vomecrty

    # calculate indefinite vertical integral of V from bottom to top, then
    # integrate zonally, convert to [Sv], and rename to region
    moc = grid.cumsum(Ve3, "Z", to="left", boundary="fill") - Ve3.sum("z_c")
    moc = moc.sum("x_c")
    moc /= 1.0e6
    moc = moc.rename(mocname)

    # calculate the weighted zonal and vertical mean of latitude
    lat_moc = ((weights * ds.llat_rc).sum(dim=["z_c", "x_c"]) /
               (weights).sum(dim=["z_c", "x_c"]))
    moc.coords[latname] = ([
        "y_r",
    ], lat_moc.data)

    # also copy the relevant depth-coordinates
    moc.coords["depth_l"] = ds.coords["depth_l"]

    return moc
Example #18
0
 def __init__(self, dset, grid=None):
     '''
     Construct a Dynamics instance using a Dataset
     
     Parameters
     ----------
     dset : xarray.Dataset
         a given Dataset containing MITgcm output diagnostics
     grid : xgcm.Grid
         a given grid that accounted for grid metrics
     '''
     self.grid   = xgcm.Grid(dset) if grid is None else grid
     self.coords = dset.coords.to_dataset().reset_coords()
     self.dset   = dset.reset_coords(drop=True)
     self.volume = dset.drF * dset.hFacC * dset.rA
     self.terms  = None
Example #19
0
def get_grid(ds: xr.core.dataset.Dataset) -> xgcm.Grid:
    grid = xgcm.Grid(ds,
                     periodic=['X'],
                     coords={
                         'X': {
                             'center': 'xt_ocean',
                             'left': 'xt_ocean_left'
                         },
                         'Y': {
                             'center': 'yt_ocean',
                             'left': 'yt_ocean_left'
                         },
                         'T': {
                             'center': 'time'
                         }
                     })
    return grid
Example #20
0
def grid(ds):
    """
    XGCM grid of Aus400 variable 'ds'
    """

    coords = {
        "X": {
            "center": "latitude"
        },
        "Y": {
            "center": "latitude"
        },
    }

    if "model_level_number" in ds.coords:
        coords["Z"] = {"center": "model_level_number"}

    return xgcm.Grid(ds, coords=coords, periodic=False)
Example #21
0
def complete_dataset(ds):
    grid = xgcm.Grid(ds, periodic=["Y", "X"])

    ds.coords['tarea'] = ds.e1t * ds.e2t
    ds.coords['uarea'] = ds.e1u * ds.e2u
    ds.coords['varea'] = ds.e1v * ds.e2v
    ds.coords['farea'] = ds.e1f * ds.e2f
    if 'thkcello' in ds.variables.keys():
        ds.coords['e3t'] = ds.thkcello
        ds.coords['e3u'] = grid.interp(ds.thkcello, "X", boundary="fill")
        ds.coords['e3v'] = grid.interp(ds.thkcello, "Y", boundary="fill")
        ds.coords['e3w'] = grid.interp(ds.thkcello, "Z", boundary="fill")
    ds.coords['tvol'] = ds.tarea * ds.e3t
    ds.coords['uvol'] = ds.uarea * ds.e3u
    ds.coords['vvol'] = ds.varea * ds.e3v
    ds.coords['wvol'] = ds.tarea * ds.e3w

    return ds
Example #22
0
def rotate_llc_to_geo(u, v, grid, face_connections, boundary='extend'):
    """ interp velocities to cell center and rotate
    to geographical axes
    u : data array for zonal velocity
    v : data array for meridional velocity
    grid : model grid, must contain CS and SN
    """

    xgrid = xgcm.Grid(grid, face_connections=face_connections)
    uv_center = xgrid.interp_2d_vector({'X': u, 'Y': v}, boundary=boundary)

    u_geo = uv_center['X'] * grid['CS'] - uv_center['Y'] * grid['SN']
    v_geo = uv_center['Y'] * grid['CS'] + uv_center['X'] * grid['SN']

    # this is a wrong but works
    #u_geo = u.rename({'i_g': 'i'}) * grid['CS'] - v.rename({'j_g': 'j'}) * grid['SN']
    #v_geo = v.rename({'j_g': 'j'}) * grid['CS'] + u.rename({'i_g': 'i'}) * grid['SN']

    return u_geo, v_geo
Example #23
0
def calculate_psi(ds):
    """Calculate the barotropic stream function.

    Parameters
    ----------
    ds : xarray dataset
        A grid-aware dataset as produced by `xorca.lib.preprocess_orca`.

    Returns
    -------
    psi : xarray data array
        A grid-aware data array with the barotropic stream function in `[Sv]`.

    """
    grid = xgcm.Grid(ds, periodic=["Y", "X"])

    U_bt = (ds.vozocrtx * ds.e3u).sum("z_c")

    psi = grid.cumsum(-U_bt * ds.e2u, "Y") / 1.0e6
    psi -= psi.isel(y_r=-1, x_r=-1)  # normalize upper right corner
    psi = psi.rename("psi")

    return psi
Example #24
0
def test_shift_position_to_T():

    #ds = open_nemo_and_domain_cfg(datadir='data')
    domcfg = xr.open_dataset("data/xnemogcm.domcfg_to.nc")
    nemo_ds = xr.open_dataset("data/xnemogcm.nemo.nc")
    grid = xgcm.Grid(domcfg, metrics=_metrics, periodic=False)
    grid_ops = Grid_ops(grid)

    u_fr = nemo_ds.uo
    v_fr = nemo_ds.vo
    w_fr = nemo_ds.woce

    u_3d_fr = [u_fr, v_fr, w_fr]

    #Test single variables
    u_to = grid_ops._shift_position(u_fr, output_position='T')
    v_to = grid_ops._shift_position(v_fr, output_position='T')
    w_to = grid_ops._shift_position(w_fr, output_position='T')

    u_3d_to = grid_ops._shift_position(u_3d_fr, output_position='T')

    #grid_ops._matching_pos([u_to,v_to,w_to,u_3d_to],'T')
    _assert_same_position(grid_ops, [u_to, v_to, w_to], 'T')
Example #25
0
def loadgrid(fname='grid.glob.nc', basin_masks=True, chunking=None):
    """ loadgrid(fname,sizearr,prec) reads a netcdf grid file and returns it as a
        xarray, with a few additional items.
        
        fname is the file name,
    """
    grd = xr.open_dataset(fname, chunks=chunking)

    if "T" in grd.coords:
        grd = grd.squeeze('T')

    # Preserve these arrays
    grd['lonc'] = grd.XC
    grd['latc'] = grd.YC
    grd['lonu'] = grd.dxC * 0 + grd.XG[0, :]
    grd['latu'] = grd.dxC * 0 + grd.YC[:, 0]
    grd['lonv'] = grd.dyC * 0 + grd.XC[0, :]
    grd['latv'] = grd.dyC * 0 + grd.YG[:, 0]
    grd['lonz'] = grd.dxV * 0 + grd.XG[0, :]
    grd['latz'] = grd.dxV * 0 + grd.YG[:, 0]

    grd['cmask'] = grd.HFacC.where(grd.HFacC > grd.HFacC.min())
    grd['umask'] = grd.HFacW.where(grd.HFacW > grd.HFacW.min())
    grd['vmask'] = grd.HFacS.where(grd.HFacS > grd.HFacS.min())
    grd['depth'] = (grd.R_low * grd.cmask)

    grd['dzC'] = grd.HFacC * grd.drF  #vertical cell size at tracer point
    grd['dzW'] = grd.HFacW * grd.drF  #vertical cell size at u point
    grd['dzS'] = grd.HFacS * grd.drF  #vertical cell size at v point

    # Reshape axes to have the same dimensions
    grd['cvol'] = (grd.HFacC * grd.rA *
                   grd.drF).where(grd.HFacC >= grd.HFacC.min())
    grd['uvol'] = (grd.HFacW * grd.rAw *
                   grd.drF).where(grd.HFacW >= grd.HFacW.min())
    grd['vvol'] = (grd.HFacS * grd.rAs *
                   grd.drF).where(grd.HFacS >= grd.HFacS.min())

    if basin_masks:
        # Get basin masks
        atlantic_mask, pacific_mask, indian_mask, so_mask, arctic_mask = oceanmasks(
            grd.lonc.transpose('X', 'Y').data,
            grd.latc.transpose('X', 'Y').data,
            grd.cmask.transpose('X', 'Y', 'Z').data)
        grd['cmask_atlantic'] = xr.DataArray(
            atlantic_mask,
            coords=[grd.X.data, grd.Y.data, grd.Z.data],
            dims=['X', 'Y', 'Z'])
        grd['cmask_pacific'] = xr.DataArray(
            pacific_mask,
            coords=[grd.X.data, grd.Y.data, grd.Z.data],
            dims=['X', 'Y', 'Z'])
        grd['cmask_indian'] = xr.DataArray(
            indian_mask,
            coords=[grd.X.data, grd.Y.data, grd.Z.data],
            dims=['X', 'Y', 'Z'])
        grd['cmask_so'] = xr.DataArray(
            so_mask,
            coords=[grd.X.data, grd.Y.data, grd.Z.data],
            dims=['X', 'Y', 'Z'])
        grd['cmask_arctic'] = xr.DataArray(
            arctic_mask,
            coords=[grd.X.data, grd.Y.data, grd.Z.data],
            dims=['X', 'Y', 'Z'])
        grd['cmask_nh'] = grd.cmask.where(grd.coords['Y'] > 0)
        grd['cmask_sh'] = grd.cmask.where(grd.coords['Y'] <= 0)

        atlantic_mask, pacific_mask, indian_mask, so_mask, arctic_mask = oceanmasks(
            grd.lonu.transpose('Xp1', 'Y').data,
            grd.latu.transpose('Xp1', 'Y').data,
            grd.umask.transpose('Xp1', 'Y', 'Z').data)
        grd['umask_atlantic'] = xr.DataArray(
            atlantic_mask,
            coords=[grd.Xp1.data, grd.Y.data, grd.Z.data],
            dims=['Xp1', 'Y', 'Z'])
        grd['umask_pacific'] = xr.DataArray(
            pacific_mask,
            coords=[grd.Xp1.data, grd.Y.data, grd.Z.data],
            dims=['Xp1', 'Y', 'Z'])
        grd['umask_indian'] = xr.DataArray(
            indian_mask,
            coords=[grd.Xp1.data, grd.Y.data, grd.Z.data],
            dims=['Xp1', 'Y', 'Z'])
        grd['umask_so'] = xr.DataArray(
            so_mask,
            coords=[grd.Xp1.data, grd.Y.data, grd.Z.data],
            dims=['Xp1', 'Y', 'Z'])
        grd['umask_arctic'] = xr.DataArray(
            arctic_mask,
            coords=[grd.Xp1.data, grd.Y.data, grd.Z.data],
            dims=['Xp1', 'Y', 'Z'])
        grd['umask_nh'] = grd.umask.where(grd.coords['Y'] > 0)
        grd['umask_sh'] = grd.umask.where(grd.coords['Y'] <= 0)

        atlantic_mask, pacific_mask, indian_mask, so_mask, arctic_mask = oceanmasks(
            grd.lonv.transpose('X', 'Yp1').data,
            grd.latv.transpose('X', 'Yp1').data,
            grd.vmask.transpose('X', 'Yp1', 'Z').data)
        grd['vmask_atlantic'] = xr.DataArray(
            atlantic_mask,
            coords=[grd.X.data, grd.Yp1.data, grd.Z.data],
            dims=['X', 'Yp1', 'Z'])
        grd['vmask_pacific'] = xr.DataArray(
            pacific_mask,
            coords=[grd.X.data, grd.Yp1.data, grd.Z.data],
            dims=['X', 'Yp1', 'Z'])
        grd['vmask_indian'] = xr.DataArray(
            indian_mask,
            coords=[grd.X.data, grd.Yp1.data, grd.Z.data],
            dims=['X', 'Yp1', 'Z'])
        grd['vmask_so'] = xr.DataArray(
            so_mask,
            coords=[grd.X.data, grd.Yp1.data, grd.Z.data],
            dims=['X', 'Yp1', 'Z'])
        grd['vmask_arctic'] = xr.DataArray(
            arctic_mask,
            coords=[grd.X.data, grd.Yp1.data, grd.Z.data],
            dims=['X', 'Yp1', 'Z'])
        grd['vmask_nh'] = grd.vmask.where(grd.coords['Yp1'] > 0)
        grd['vmask_sh'] = grd.vmask.where(grd.coords['Yp1'] <= 0)

    grd.close()

    # These variable conflict with future axis names
    grd = grd.drop(['XC', 'YC', 'XG', 'YG'])

    # Attempt to conform axes to conventions
    grd = conform_axes(grd)

    # generate XGCM grid, with metrics for grid aware calculations
    # Have to make sure the metrics are properly masked
    # issue for area, but not volume...
    grd['rA'] = grd.rA * grd.HFacC.isel(ZC=0)
    grd['rAs'] = grd.rAs * grd.HFacS.isel(ZC=0)
    grd['rAw'] = grd.rAw * grd.HFacW.isel(ZC=0)
    # This is dodgy, but not sure what else to do...
    grd['rAz'] = grd.rAz * (grd.HFacW * grd.HFacS).isel(ZC=0)

    metrics = {
        ('X', ): ['dxC', 'dxG'],  # X distances
        ('Y', ): ['dyC', 'dyG'],  # Y distances
        ('Z', ): ['dzW', 'dzS', 'dzC'],  # Z distances
        ('X', 'Y'): ['rA', 'rAz', 'rAs', 'rAw']  # Areas
    }
    xgrd = xgcm.Grid(grd, periodic=['X', 'Y'], metrics=metrics)

    return grd, xgrd
Example #26
0
def get_llc_grid(ds,domain='global'):
    """
    Define xgcm Grid object for the LLC grid
    See example usage in the xgcm documentation:
    https://xgcm.readthedocs.io/en/latest/example_eccov4.html#Spatially-Integrated-Heat-Content-Anomaly

    Parameters
    ----------
    ds : xarray Dataset
        formed from LLC90 grid, must have the basic coordinates:
        i,j,i_g,j_g,k,k_l,k_u,k_p1

    Returns
    -------
    grid : xgcm Grid object
        defines horizontal connections between LLC tiles

    """

    if 'domain' in ds.attrs:
        domain = ds.attrs['domain']

    if domain == 'global':
        # Establish grid topology
        tile_connections = {'tile':  {
                0: {'X': ((12, 'Y', False), (3, 'X', False)),
                    'Y': (None, (1, 'Y', False))},
                1: {'X': ((11, 'Y', False), (4, 'X', False)),
                    'Y': ((0, 'Y', False), (2, 'Y', False))},
                2: {'X': ((10, 'Y', False), (5, 'X', False)),
                    'Y': ((1, 'Y', False), (6, 'X', False))},
                3: {'X': ((0, 'X', False), (9, 'Y', False)),
                    'Y': (None, (4, 'Y', False))},
                4: {'X': ((1, 'X', False), (8, 'Y', False)),
                    'Y': ((3, 'Y', False), (5, 'Y', False))},
                5: {'X': ((2, 'X', False), (7, 'Y', False)),
                    'Y': ((4, 'Y', False), (6, 'Y', False))},
                6: {'X': ((2, 'Y', False), (7, 'X', False)),
                    'Y': ((5, 'Y', False), (10, 'X', False))},
                7: {'X': ((6, 'X', False), (8, 'X', False)),
                    'Y': ((5, 'X', False), (10, 'Y', False))},
                8: {'X': ((7, 'X', False), (9, 'X', False)),
                    'Y': ((4, 'X', False), (11, 'Y', False))},
                9: {'X': ((8, 'X', False), None),
                    'Y': ((3, 'X', False), (12, 'Y', False))},
                10: {'X': ((6, 'Y', False), (11, 'X', False)),
                     'Y': ((7, 'Y', False), (2, 'X', False))},
                11: {'X': ((10, 'X', False), (12, 'X', False)),
                     'Y': ((8, 'Y', False), (1, 'X', False))},
                12: {'X': ((11, 'X', False), None),
                     'Y': ((9, 'Y', False), (0, 'X', False))}
        }}

        grid = xgcm.Grid(ds,
                periodic=False,
                face_connections=tile_connections
        )
    elif domain == 'aste':
        tile_connections = {'tile':{
                    0:{'X':((5,'Y',False),None),
                       'Y':(None,(1,'Y',False))},
                    1:{'X':((4,'Y',False),None),
                       'Y':((0,'Y',False),(2,'X',False))},
                    2:{'X':((1,'Y',False),(3,'X',False)),
                       'Y':(None,(4,'X',False))},
                    3:{'X':((2,'X',False),None),
                       'Y':(None,None)},
                    4:{'X':((2,'Y',False),(5,'X',False)),
                       'Y':(None,(1,'X',False))},
                    5:{'X':((4,'X',False),None),
                       'Y':(None,(0,'X',False))}
                   }}
        grid = xgcm.Grid(ds,periodic=False,face_connections=tile_connections)
    else:
        raise TypeError(f'Domain {domain} not recognized')


    return grid
Example #27
0
def rebuild_grid(grid,
                 x_index_name='i',
                 y_index_name='j',
                 x_name='X',
                 y_name='Y',
                 g_index_suffix='_g',
                 g_suffix='G',
                 c_suffix='C',
                 g_shift=-0.5,
                 x_wrap=360,
                 y_wrap=180,
                 ll_dist=True):
    """rebuild a xgcm compatible grid from scratch
    """
    grid.coords[x_index_name + g_index_suffix] = xr.DataArray(
        grid.coords[x_index_name].data,
        coords={
            x_index_name + g_index_suffix: ([
                x_index_name + g_index_suffix,
            ], grid.coords[x_index_name].data)
        },
        dims=[
            x_index_name + g_index_suffix,
        ])

    grid.coords[y_index_name + g_index_suffix] = xr.DataArray(
        grid.coords[y_index_name].data,
        coords={
            y_index_name + g_index_suffix: ([
                y_index_name + g_index_suffix,
            ], grid.coords[y_index_name].data)
        },
        dims=[
            y_index_name + g_index_suffix,
        ])

    # assign xgcm compatible attributes
    grid[x_index_name].attrs = {
        'axis': 'X',
        'standard_name': 'x_grid_index',
        'long_name': 'x-dimension of the grid'
    }
    grid[y_index_name].attrs = {
        'axis': 'Y',
        'standard_name': 'y_grid_index',
        'long_name': 'y-dimension of the grid'
    }
    grid[x_index_name + g_index_suffix].attrs = \
        {'axis': 'X',
         'standard_name': 'x_grid_index_at_u_location',
         'long_name': 'x-dimension of the grid',
         'c_grid_axis_shift': g_shift}
    grid[y_index_name + g_index_suffix].attrs = \
        {'axis': 'Y',
         'standard_name': 'y_grid_index_at_v_location',
         'long_name': 'y-dimension of the grid',
         'c_grid_axis_shift': g_shift}
    xgrid = xgcm.Grid(grid)

    # #Construct the grid coordinates
    tempa = grid.coords[x_name + g_suffix] = \
        wrap_func(xgrid, grid.coords[x_name + c_suffix],
                  'X',
                  x_wrap,
                  func='interp',
                  idx=0)
    tempb = grid.coords[y_name + g_suffix] = \
        wrap_func(xgrid, grid.coords[y_name + c_suffix],
                  'Y',
                  y_wrap,
                  func='interp',
                  idx=0)

    grid.coords[x_name + g_suffix] = xgrid.interp(tempa, 'Y')
    grid.coords[y_name + g_suffix] = xgrid.interp(tempb, 'X')

    ##############
    # cell lengths
    ##############

    #
    grid.coords['dx' + c_suffix] = wrap_func(xgrid,
                                             grid.coords[x_name + c_suffix],
                                             'X',
                                             x_wrap,
                                             idx=0)
    grid.coords['dy' + c_suffix] = wrap_func(xgrid,
                                             grid.coords[y_name + c_suffix],
                                             'Y',
                                             y_wrap,
                                             idx=0)

    grid.coords['dx' + g_suffix] = wrap_func(xgrid,
                                             grid.coords[x_name + g_suffix],
                                             'X',
                                             x_wrap,
                                             idx=-1)
    grid.coords['dy' + g_suffix] = wrap_func(xgrid,
                                             grid.coords[y_name + g_suffix],
                                             'Y',
                                             y_wrap,
                                             idx=-1)

    if ll_dist:
        grid.coords['dx' + c_suffix], grid.coords['dy' + c_suffix] = dll_dist(
            grid.coords['dx' + c_suffix], grid.coords['dy' + c_suffix],
            grid.coords[x_name + c_suffix], grid.coords[y_name + c_suffix])

        grid.coords['dx' + g_suffix], grid.coords['dy' + g_suffix] = dll_dist(
            grid.coords['dx' + g_suffix], grid.coords['dy' + g_suffix],
            grid.coords[x_name + g_suffix], grid.coords[y_name + g_suffix])
    return grid
Example #28
0
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import xgcm
import cartopy.crs as ccrs

from xmitgcm import open_mdsdataset
from matplotlib.mlab import bivariate_normal
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

dir0 = '/homedata/bderembl/runmit/test_southatlgyre3'

ds0 = open_mdsdataset(dir0,iters='all',prefix=['U','V'])

grid = xgcm.Grid(ds0)
print(grid)

Vorticity = (-grid.diff(ds0.U.where(ds0.hFacW>0)*ds0.dxC, 'Y') + grid.diff(ds0.V.where(ds0.hFacS>0)*ds0.dyC, 'X'))/ds0.rAz
print('Vorticity')

i = 0
nz = 0


while (i < 150) :
    i=i+1
    print(i)
    plt.figure(1)
    ax = plt.subplot(projection=ccrs.PlateCarree());
    Vorticity[i,nz,:,:].plot.pcolormesh('XG','YG', ax=ax,vmin=-0.00020,vmax=0.00020,cmap='ocean')
    plt.title('Case 4 : Vorticity')
Example #29
0
def roms_dataset(ds, Vtransform=None, add_verts=False, proj=None):
    """Modify Dataset to be aware of ROMS coordinates, with matching xgcm grid object.

    Inputs
    ------
    ds: Dataset
        xarray Dataset with model output
    Vtransform: int, optional
        Vertical transform for ROMS model. Should be either 1 or 2 and only needs
        to be input if not available in ds.
    add_verts: boolean, optional
        Add 'verts' horizontal grid to ds if True. This requires a cartopy projection
        to be input too.
    proj: cartopy crs projection, optional
        Should match geographic area of model domain. Required if `add_verts=True`,
        otherwise not used. Example:
        >>> proj = cartopy.crs.LambertConformal(central_longitude=-98, central_latitude=30)

    Returns
    -------
    ds: Dataset
        Same dataset as input, but with dimensions renamed to be consistent with `xgcm` and
        with vertical coordinates and metrics added.
    grid: xgcm grid object
        Includes ROMS metrics so can be used for xgcm grid operations, which mostly have
        been wrapped into xroms.

    Notes
    -----
    Note that this could be very slow if dask is not on.

    This does not need to be run by the user if `xroms` functions `open_netcdf` or
    `open_zarr` are used for reading in model output, since run in those functions.

    This also uses `cf-xarray` to manage dimensions of variables.

    Example usage
    -------------
    >>> ds, grid = xroms.roms_dataset(ds)
    """

    if add_verts:
        assert proj is not None, 'To add "verts" grid, input projection "proj".'

    rename = {}
    if "eta_u" in ds.dims:
        rename["eta_u"] = "eta_rho"
    if "xi_v" in ds.dims:
        rename["xi_v"] = "xi_rho"
    if "xi_psi" in ds.dims:
        rename["xi_psi"] = "xi_u"
    if "eta_psi" in ds.dims:
        rename["eta_psi"] = "eta_v"
    ds = ds.rename(rename)

    #     ds = ds.rename({'eta_u': 'eta_rho', 'xi_v': 'xi_rho', 'xi_psi': 'xi_u', 'eta_psi': 'eta_v'})

    # make sure psi grid in coords
    ds = ds.assign_coords({"lon_psi": ds.lon_psi, "lat_psi": ds.lat_psi})

    # modify attributes for using cf-xarray
    tdims = [dim for dim in ds.dims if dim[:3] == "xi_"]
    for dim in tdims:
        ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "X"})
    tdims = [dim for dim in ds.dims if dim[:4] == "eta_"]
    for dim in tdims:
        ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "Y"})
    ds.ocean_time.attrs["axis"] = "T"
    ds.ocean_time.attrs["standard_name"] = "time"
    tcoords = [coord for coord in ds.coords if coord[:2] == "s_"]
    for coord in tcoords:
        ds[coord].attrs["axis"] = "Z"
    # make sure lon/lat have standard names
    tcoords = [coord for coord in ds.coords if coord[:4] == "lon_"]
    for coord in tcoords:
        ds[coord].attrs["standard_name"] = "longitude"
    tcoords = [coord for coord in ds.coords if coord[:4] == "lat_"]
    for coord in tcoords:
        ds[coord].attrs["standard_name"] = "latitude"

    coords = {
        "X": {"center": "xi_rho", "inner": "xi_u"},
        "Y": {"center": "eta_rho", "inner": "eta_v"},
        "Z": {"center": "s_rho", "outer": "s_w"},
    }

    grid = xgcm.Grid(ds, coords=coords, periodic=[])

    if "Vtransform" in ds.variables.keys():
        Vtransform = ds.Vtransform

    assert (
        Vtransform is not None
    ), "Need a Vtransform of 1 or 2, either in the Dataset or input to the function."

    if Vtransform == 1:
        Zo_rho = ds.hc * (ds.s_rho - ds.Cs_r) + ds.Cs_r * ds.h
        z_rho = Zo_rho + ds.zeta * (1 + Zo_rho / ds.h)
        Zo_w = ds.hc * (ds.s_w - ds.Cs_w) + ds.Cs_w * ds.h
        z_w = Zo_w + ds.zeta * (1 + Zo_w / ds.h)
        # also include z coordinates with mean sea level (constant over time)
        z_rho0 = Zo_rho
        z_w0 = Zo_w
    elif Vtransform == 2:
        Zo_rho = (ds.hc * ds.s_rho + ds.Cs_r * ds.h) / (ds.hc + ds.h)
        z_rho = ds.zeta + (ds.zeta + ds.h) * Zo_rho
        Zo_w = (ds.hc * ds.s_w + ds.Cs_w * ds.h) / (ds.hc + ds.h)
        z_w = ds.zeta + (ds.zeta + ds.h) * Zo_w
        # also include z coordinates with mean sea level (constant over time)
        z_rho0 = ds.h * Zo_rho
        z_w0 = ds.h * Zo_w

    ds.coords["z_w"] = z_w.cf.transpose(
        *[dim for dim in ["T", "Z", "Y", "X"] if dim in z_w.cf.get_valid_keys()]
    )
    #     ds.coords['z_w'] = z_w.transpose('ocean_time', 's_w', 'eta_rho', 'xi_rho', transpose_coords=False)
    ds.coords["z_w_u"] = grid.interp(ds.z_w, "X")
    ds.coords["z_w_v"] = grid.interp(ds.z_w, "Y")
    ds.coords["z_w_psi"] = grid.interp(ds.z_w_u, "Y")

    ds.coords["z_rho"] = z_rho.cf.transpose(
        *[dim for dim in ["T", "Z", "Y", "X"] if dim in z_rho.cf.get_valid_keys()]
    )
    #     ds.coords['z_rho'] = z_rho.transpose('ocean_time', 's_rho', 'eta_rho', 'xi_rho', transpose_coords=False)
    ds.coords["z_rho_u"] = grid.interp(ds.z_rho, "X")
    ds.coords["z_rho_v"] = grid.interp(ds.z_rho, "Y")
    ds.coords["z_rho_psi"] = grid.interp(ds.z_rho_u, "Y")
    # also include z coordinates with mean sea level (constant over time)
    ds.coords["z_rho0"] = z_rho0.cf.transpose(
        *[dim for dim in ["T", "Z", "Y", "X"] if dim in z_rho0.cf.get_valid_keys()]
    )
    #     ds.coords['z_rho0'] = z_rho0.transpose('s_rho', 'eta_rho', 'xi_rho', transpose_coords=False)
    ds.coords["z_rho_u0"] = grid.interp(ds.z_rho0, "X")
    ds.coords["z_rho_v0"] = grid.interp(ds.z_rho0, "Y")
    ds.coords["z_rho_psi0"] = grid.interp(ds.z_rho_u0, "Y")
    ds.coords["z_w0"] = z_w0.cf.transpose(
        *[dim for dim in ["T", "Z", "Y", "X"] if dim in z_w0.cf.get_valid_keys()]
    )
    #     ds.coords['z_w0'] = z_w0.transpose('s_w', 'eta_rho', 'xi_rho', transpose_coords=False)
    ds.coords["z_w_u0"] = grid.interp(ds.z_w0, "X")
    ds.coords["z_w_v0"] = grid.interp(ds.z_w0, "Y")
    ds.coords["z_w_psi0"] = grid.interp(ds.z_w_u0, "Y")

    # add vert grid, esp for plotting pcolormesh
    if add_verts:
        import pygridgen

        pc = cartopy.crs.PlateCarree()
        # project points for this calculation
        xr, yr = proj.transform_points(pc, ds.lon_rho.values, ds.lat_rho.values)[
            ..., :2
        ].T
        xr = xr.T
        yr = yr.T
        # calculate vert locations
        xv, yv = pygridgen.grid.rho_to_vert(xr, yr, ds.pm, ds.pn, ds.angle)
        # project back
        lon_vert, lat_vert = pc.transform_points(proj, xv, yv)[..., :2].T
        lon_vert = lon_vert.T
        lat_vert = lat_vert.T
        # add new coords to ds
        ds.coords["lon_vert"] = (("eta_vert", "xi_vert"), lon_vert)
        ds.coords["lat_vert"] = (("eta_vert", "xi_vert"), lat_vert)

    ds["pm_v"] = grid.interp(ds.pm, "Y")
    ds["pn_u"] = grid.interp(ds.pn, "X")
    ds["pm_u"] = grid.interp(ds.pm, "X")
    ds["pn_v"] = grid.interp(ds.pn, "Y")
    ds["pm_psi"] = grid.interp(
        grid.interp(ds.pm, "Y"), "X"
    )  # at psi points (eta_v, xi_u)
    ds["pn_psi"] = grid.interp(
        grid.interp(ds.pn, "X"), "Y"
    )  # at psi points (eta_v, xi_u)

    ds["dx"] = 1 / ds.pm
    ds["dx_u"] = 1 / ds.pm_u
    ds["dx_v"] = 1 / ds.pm_v
    ds["dx_psi"] = 1 / ds.pm_psi

    ds["dy"] = 1 / ds.pn
    ds["dy_u"] = 1 / ds.pn_u
    ds["dy_v"] = 1 / ds.pn_v
    ds["dy_psi"] = 1 / ds.pn_psi

    ds["dz"] = grid.diff(ds.z_w, "Z")
    ds["dz_w"] = grid.diff(ds.z_rho, "Z", boundary="fill")
    ds["dz_u"] = grid.interp(ds.dz, "X")
    ds["dz_w_u"] = grid.interp(ds.dz_w, "X")
    ds["dz_v"] = grid.interp(ds.dz, "Y")
    ds["dz_w_v"] = grid.interp(ds.dz_w, "Y")
    ds["dz_psi"] = grid.interp(ds.dz_v, "X")
    ds["dz_w_psi"] = grid.interp(ds.dz_w_v, "X")

    # also include z coordinates with mean sea level (constant over time)
    ds["dz0"] = grid.diff(ds.z_w0, "Z")
    ds["dz_w0"] = grid.diff(ds.z_rho0, "Z", boundary="fill")
    ds["dz_u0"] = grid.interp(ds.dz0, "X")
    ds["dz_w_u0"] = grid.interp(ds.dz_w0, "X")
    ds["dz_v0"] = grid.interp(ds.dz0, "Y")
    ds["dz_w_v0"] = grid.interp(ds.dz_w0, "Y")
    ds["dz_psi0"] = grid.interp(ds.dz_v0, "X")
    ds["dz_w_psi0"] = grid.interp(ds.dz_w_v0, "X")

    # grid areas
    ds["dA"] = ds.dx * ds.dy
    ds["dA_u"] = ds.dx_u * ds.dy_u
    ds["dA_v"] = ds.dx_v * ds.dy_v
    ds["dA_psi"] = ds.dx_psi * ds.dy_psi

    # volume
    ds["dV"] = ds.dz * ds.dx * ds.dy  # rho vertical, rho horizontal
    ds["dV_w"] = ds.dz_w * ds.dx * ds.dy  # w vertical, rho horizontal
    ds["dV_u"] = ds.dz_u * ds.dx_u * ds.dy_u  # rho vertical, u horizontal
    ds["dV_w_u"] = ds.dz_w_u * ds.dx_u * ds.dy_u  # w vertical, u horizontal
    ds["dV_v"] = ds.dz_v * ds.dx_v * ds.dy_v  # rho vertical, v horizontal
    ds["dV_w_v"] = ds.dz_w_v * ds.dx_v * ds.dy_v  # w vertical, v horizontal
    ds["dV_psi"] = ds.dz_psi * ds.dx_psi * ds.dy_psi  # rho vertical, psi horizontal
    ds["dV_w_psi"] = ds.dz_w_psi * ds.dx_psi * ds.dy_psi  # w vertical, psi horizontal

    if "rho0" not in ds:
        ds["rho0"] = 1025  # kg/m^3

    # cf-xarray
    # areas
    #     ds.coords["cell_area"] = ds['dA']
    #     ds.coords["cell_area_u"] = ds['dA_u']
    #     ds.coords["cell_area_v"] = ds['dA_v']
    #     ds.coords["cell_area_psi"] = ds['dA_psi']
    #     # and set proper attributes
    #     ds.temp.attrs["cell_measures"] = "area: cell_area, volume: cell_volume"
    #     ds.salt.attrs["cell_measures"] = "area: cell_area"
    #     ds.u.attrs["cell_measures"] = "area: cell_area_u"
    #     ds.v.attrs["cell_measures"] = "area: cell_area_v"
    #     # volumes
    #     ds.coords["cell_volume"] = ds['dV']
    # #     ds.temp.attrs["cell_measures"] = "volume: cell_volume"

    #     ds['temp'].attrs['cell_measures'] = 'area: cell_area'
    #     tcoords = [coord for coord in ds.variables if coord[:2] == 'dA']
    #     for coord in tcoords:
    #         ds[coord].attrs['cell_measures'] = 'area: cell_area'
    #     # add coordinates attributes for variables
    if "positive" in ds.s_rho.attrs:
        ds.s_rho.attrs.pop("positive")
    if "positive" in ds.s_w.attrs:
        ds.s_w.attrs.pop("positive")
    #     ds['z_rho'].attrs['positive'] = 'up'
    tcoords = [coord for coord in ds.coords if coord[:2] == "z_" and "0" not in coord]
    for coord in tcoords:
        ds[coord].attrs["positive"] = "up"
    #         ds[dim] = (dim, np.arange(ds.sizes[dim]), {'axis': 'Y'})
    #     ds['z_rho'].attrs['vertical'] = 'depth'
    #     ds['temp'].attrs['coordinates'] = 'lon_rho lat_rho z_rho ocean_time'
    #     [del ds[var].encoding['coordinates'] for var in ds.variables if 'coordinates' in ds[var].encoding]
    for var in ds.variables:
        if "coordinates" in ds[var].encoding:
            del ds[var].encoding["coordinates"]

    metrics = {
        ("X",): ["dx", "dx_u", "dx_v", "dx_psi"],  # X distances
        ("Y",): ["dy", "dy_u", "dy_v", "dy_psi"],  # Y distances
        ("Z",): [
            "dz",
            "dz_u",
            "dz_v",
            "dz_w",
            "dz_w_u",
            "dz_w_v",
            "dz_psi",
            "dz_w_psi",
        ],  # Z distances
        ("X", "Y"): ["dA"],  # Areas
    }
    grid = xgcm.Grid(ds, coords=coords, metrics=metrics, periodic=[])

    #     ds.attrs['grid'] = grid  # causes recursion error
    # also put grid into every variable with at least 2D
    for var in ds.variables:
        if ds[var].ndim > 1:
            ds[var].attrs["grid"] = grid

    return ds, grid
Example #30
0
def calculate_AMOC_sigma_z(domain, ds, fn=None):
    """ calculate the AMOC in depth and density space """
    assert domain in ['ocn', 'ocn_low']
    for q in ['PD', 'VVEL', 'DXT', 'DYT', 'DXU', 'DYU', 'REGION_MASK']:
        assert q in ds

    (grid, ds_) = pop_tools.to_xgcm_grid_dataset(ds)
    ds_['DZU'] = xr_DZ_xgcm(domain=domain, grid='U')

    metrics = {
        ('X'): ['DXT', 'DXU'],  # X distances
        ('Y'): ['DYT', 'DYU'],  # Y distances
        ('Z'): ['DZU'],  # Z distances
    }
    coords = {
        'X': {
            'center': 'nlon_t',
            'right': 'nlon_u'
        },
        'Y': {
            'center': 'nlat_t',
            'right': 'nlat_u'
        },
        'Z': {
            'center': 'z_t',
            'left': 'z_w_top',
            'right': 'z_w_bot'
        }
    }
    grid = xgcm.Grid(ds_, metrics=metrics, coords=coords)

    print('merged annual datasets do not convert to U/T-lat/lons')
    if 'nlat' in ds_.VVEL.dims:
        rn = {'nlat': 'nlat_u', 'nlon': 'nlon_u'}
        ac = {'nlat_u': ds_.nlat_u, 'nlon_u': ds_.nlon_u}
        ds_['VVEL'] = ds_.VVEL.rename(rn).assign_coords()
    if 'nlat' in ds_.PD.dims:
        rn = {'nlat': 'nlat_t', 'nlon': 'nlon_t'}
        ac = {'nlat_t': ds_.nlat_t, 'nlon_t': ds_.nlon_t}
        ds_['PD'] = ds_.PD.rename(rn).assign_coords(ac)

    print('interpolating density to UU point')
    ds_['PD'] = grid.interp(grid.interp(ds_['PD'], 'X'), 'Y')

    print('interpolating REGION_MASK to UU point')
    fn_MASK = f'{path_prace}/MOC/AMOC_MASK_uu_{domain}.nc'
    if os.path.exists(fn_MASK):
        AMOC_MASK_uu = xr.open_dataarray(fn_MASK)
    else:
        MASK_uu = grid.interp(grid.interp(ds_.REGION_MASK, 'Y'), 'X')
        AMOC_MASK_uu = xr.DataArray(np.in1d(
            MASK_uu, [-12, 6, 7, 8, 9, 11, 12]).reshape(MASK_uu.shape),
                                    dims=MASK_uu.dims,
                                    coords=MASK_uu.coords)
        AMOC_MASK_uu.to_netcdf(fn_MASK)

    print('AMOC(y,z);  [cm^3/s] -> [Sv]')
    AMOC_yz = (grid.integrate(
        grid.cumint(ds_.VVEL.where(AMOC_MASK_uu), 'Z', boundary='fill'), 'X') /
               1e12)
    #     AMOC_yz = (ds_.VVEL*ds_.DZU*ds_.DXU).where(AMOC_MASK_uu).sum('nlon_u').cumsum('z_t')/1e12
    AMOC_yz = AMOC_yz.rename({'z_w_top': 'z_t'}).assign_coords({'z_t': ds.z_t})
    AMOC_yz.name = 'AMOC(y,z)'

    print('AMOC(sigma_0,z);  [cm^3/s] -> [Sv]')
    if int(ds_.PD.isel(z_t=0).mean().values) == 0:
        PD, PDbins = ds_.PD * 1000, np.arange(-10, 7, .05)
    if int(ds_.PD.isel(z_t=0).mean().values) == 1:
        PD, PDbins = (ds_.PD - 1) * 1000, np.arange(5, 33, .05)

    print('histogram')
    weights = ds_.VVEL.where(AMOC_MASK_uu) * ds_.DZU * ds_.DXU / 1e12
    #     ds_.PD.isel(z_t=0).plot()
    AMOC_sz = histogram(PD, bins=[PDbins], dim=['z_t'],
                        weights=weights).sum('nlon_u',
                                             skipna=True).cumsum('PD_bin').T
    AMOC_sz.name = 'AMOC(y,PD)'

    # output to file
    if fn is not None: xr.merge([AMOC_yz, AMOC_sz]).to_netcdf(fn)
    return AMOC_yz, AMOC_sz