コード例 #1
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_grid_ops(all_datasets):
    """
    Check that we get the same answer using Axis or Grid objects
    """
    ds, periodic, expected = all_datasets
    grid = Grid(ds, periodic=periodic)

    for axis_name in grid.axes.keys():
        try:
            ax_periodic = axis_name in periodic
        except TypeError:
            ax_periodic = periodic
        axis = Axis(ds, axis_name, periodic=ax_periodic)

        bcs = [None] if ax_periodic else ['fill', 'extend']
        for varname in ['data_c', 'data_g']:
            for boundary in bcs:
                da_interp = grid.interp(ds[varname], axis_name,
                    boundary=boundary)
                da_interp_ax = axis.interp(ds[varname], boundary=boundary)
                assert da_interp.equals(da_interp_ax)
                da_diff = grid.diff(ds[varname], axis_name,
                    boundary=boundary)
                da_diff_ax = axis.diff(ds[varname], boundary=boundary)
                assert da_diff.equals(da_diff_ax)
                if boundary is not None:
                    da_cumsum = grid.cumsum(ds[varname], axis_name,
                        boundary=boundary)
                    da_cumsum_ax = axis.cumsum(ds[varname], boundary=boundary)
                    assert da_cumsum.equals(da_cumsum_ax)
コード例 #2
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary, from_center):
    ds, periodic, expected = nonperiodic_1d
    axis = Axis(ds, 'X', periodic=periodic)

    dim_len_diff = len(ds.XG) - len(ds.XC)

    if from_center:
        to = (set(expected['axes']['X'].keys()) - {'center'}).pop()
        coord_to = 'XG'
        da = ds.data_c
    else:
        to = 'center'
        coord_to = 'XC'
        da = ds.data_g

    shift = expected.get('shift') or False

    data = da.data
    if ((dim_len_diff==1 and not from_center) or
        (dim_len_diff==-1 and from_center)):
        data_left = data[:-1]
        data_right = data[1:]
    elif ((dim_len_diff==1 and from_center) or
          (dim_len_diff==-1 and not from_center)):
        data_left = _pad_left(data, boundary)
        data_right = _pad_right(data, boundary)
    elif (shift and not from_center) or (not shift and from_center):
        data_left = _pad_left(data[:-1], boundary)
        data_right = data
    else:
        data_left = data
        data_right = _pad_right(data[1:], boundary)

    # interpolate
    data_interp_expected = xr.DataArray(0.5 * (data_left + data_right),
                                        dims=[coord_to],
                                        coords={coord_to: ds[coord_to]})
    data_interp = axis.interp(da, to, boundary=boundary)
    print(data_interp_expected)
    print(data_interp)
    assert data_interp_expected.equals(data_interp)
    # check without "to" specified
    assert data_interp.equals(axis.interp(da, boundary=boundary))

    # difference
    data_diff_expected = xr.DataArray(data_right - data_left,
                                      dims=[coord_to],
                                      coords={coord_to: ds[coord_to]})
    data_diff = axis.diff(da, to, boundary=boundary)
    assert data_diff_expected.equals(data_diff)
    # check without "to" specified
    assert data_diff.equals(axis.diff(da, boundary=boundary))
コード例 #3
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_neighbor_pairs_2d(periodic_2d, varname, axis_name, to, roll,
                                roll_axis, swap_order):
    ds, periodic, expected = periodic_2d

    axis = Axis(ds, axis_name)

    data = ds[varname]
    data_left, data_right = axis._get_neighbor_data_pairs(data, to)
    if swap_order:
        data_left, data_right = data_right, data_left
    np.testing.assert_allclose(data_left, np.roll(data.data,
                                                  roll, axis=roll_axis))
    np.testing.assert_allclose(data_right, data.data)
コード例 #4
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_wrap_and_replace_nonperiodic(nonperiodic_1d):
    ds, periodic, expected = nonperiodic_1d
    axis = Axis(ds, 'X')

    da_c = 0 * ds.XC + 1
    da_g = 0 * ds.XG + 1

    to = (set(expected['axes']['X'].keys()) - {'center'}).pop()

    da_g_test = axis._wrap_and_replace_coords(da_c, da_g.data, to)
    assert da_g.equals(da_g_test)

    da_c_test = axis._wrap_and_replace_coords(da_g, da_c.data, 'center')
    assert da_c.equals(da_c_test)
コード例 #5
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_neighbor_pairs_nonperiodic_1d(nonperiodic_1d, boundary, from_center):
    ds, periodic, expected = nonperiodic_1d
    axis = Axis(ds, 'X', periodic=periodic)

    # detect whether this is an outer or inner case
    # outer --> dim_line_diff = 1
    # inner --> dim_line_diff = -1
    dim_len_diff = len(ds.XG) - len(ds.XC)

    if from_center:
        to = (set(expected['axes']['X'].keys()) - {'center'}).pop()
        da = ds.data_c
    else:
        to = 'center'
        da = ds.data_g

    shift = expected.get('shift') or False

    # need boundary condition for everything but outer to center
    if (boundary is None) and (dim_len_diff == 0 or
        (dim_len_diff == 1 and from_center) or
        (dim_len_diff == -1 and not from_center)):
        with pytest.raises(ValueError):
            data_left, data_right = axis._get_neighbor_data_pairs(da, to,
                                                boundary=boundary)
    else:
        data_left, data_right = axis._get_neighbor_data_pairs(da, to,
                                                boundary=boundary)
        if (((dim_len_diff == 1) and not from_center) or
            ((dim_len_diff == -1) and from_center)):
            expected_left = da.data[:-1]
            expected_right = da.data[1:]
        elif (((dim_len_diff == 1) and from_center) or
              ((dim_len_diff == -1) and not from_center)):
            expected_left = _pad_left(da.data, boundary)
            expected_right = _pad_right(da.data, boundary)
        elif (shift and not from_center) or (not shift and from_center):
            expected_right = da.data
            expected_left = _pad_left(da.data, boundary)[:-1]
        else:
            expected_left = da.data
            expected_right = _pad_right(da.data, boundary)[1:]

        np.testing.assert_allclose(data_left, expected_left)
        np.testing.assert_allclose(data_right, expected_right)
コード例 #6
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_cumsum(nonperiodic_1d, boundary):
    ds, periodic, expected = nonperiodic_1d
    axis = Axis(ds, 'X', periodic=periodic)

    axis_expected = expected['axes']['X']

    cumsum_g = axis.cumsum(ds.data_g, to='center', boundary=boundary)
    assert cumsum_g.dims == ds.data_c.dims
    # check default "to"
    assert cumsum_g.equals(axis.cumsum(ds.data_g, boundary=boundary))

    to = set(axis_expected).difference({'center'}).pop()
    cumsum_c = axis.cumsum(ds.data_c, to=to, boundary=boundary)
    assert cumsum_c.dims == ds.data_g.dims
    # check default "to"
    assert cumsum_c.equals(axis.cumsum(ds.data_c, boundary=boundary))

    cumsum_c_raw = np.cumsum(ds.data_c.data)
    cumsum_g_raw = np.cumsum(ds.data_g.data)

    if to == 'right':
        np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw)
        fill_value = 0. if boundary=='fill' else cumsum_g_raw[0]
        np.testing.assert_allclose(cumsum_g.data,
            np.hstack([fill_value, cumsum_g_raw[:-1]]))
    elif to == 'left':
        np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw)
        fill_value = 0. if boundary=='fill' else cumsum_c_raw[0]
        np.testing.assert_allclose(cumsum_c.data,
            np.hstack([fill_value, cumsum_c_raw[:-1]]))
    elif to == 'inner':
        np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw[:-1])
        fill_value = 0. if boundary=='fill' else cumsum_g_raw[0]
        np.testing.assert_allclose(cumsum_g.data,
            np.hstack([fill_value, cumsum_g_raw]))
    elif to == 'outer':
        np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw[:-1])
        fill_value = 0. if boundary=='fill' else cumsum_c_raw[0]
        np.testing.assert_allclose(cumsum_c.data,
            np.hstack([fill_value, cumsum_c_raw]))
コード例 #7
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_errors():
    ds = datasets['1d_left']

    ds_noattr = ds.copy()
    del ds_noattr.XC.attrs['axis']
    with pytest.raises(ValueError,
                       message="Couldn't find a center coordinate for axis X"):
        x_axis = Axis(ds_noattr, 'X', periodic=True)

    del ds_noattr.XG.attrs['axis']
    with pytest.raises(ValueError,
                       message="Couldn't find any coordinates for axis X"):
        x_axis = Axis(ds_noattr, 'X', periodic=True)

    ds_chopped = ds.copy()
    del ds_chopped['data_g']
    ds_chopped['XG'] = ds_chopped['XG'][:-3]
    with pytest.raises(ValueError, message="Left coordinate XG has"
                                    "incompatible length 7 (axis_len=9)"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    ds_chopped.XG.attrs['c_grid_axis_shift'] = -0.5
    with pytest.raises(ValueError, message="Right coordinate XG has"
                                    "incompatible length 7 (axis_len=9)"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    del ds_chopped.XG.attrs['c_grid_axis_shift']
    with pytest.raises(ValueError, message="Coordinate XC has invalid or "
                                "missing c_grid_axis_shift attribute `None`"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    ax = Axis(ds, 'X', periodic=True)

    with pytest.raises(ValueError, message="Can't get neighbor pairs for"
                                   "the same position."):
        ax.interp(ds.data_c, 'center')

    with pytest.raises(KeyError,
                    message="Position 'right' was not found in axis.coords."):
        ax.interp(ds.data_c, 'right')

    with pytest.raises(ValueError, message="`boundary=fill` is not allowed "
                                    "with periodic axis X."):
        ax.interp(ds.data_c, 'right', boundary='fill')
コード例 #8
0
ファイル: test_grid.py プロジェクト: xgcm/xgcm
def _get_axes(ds):
    all_axes = {ds[c].attrs["axis"] for c in ds.dims if "axis" in ds[c].attrs}
    axis_objs = {ax: Axis(ds, ax) for ax in all_axes}
    return axis_objs
コード例 #9
0
ファイル: test_grid.py プロジェクト: roxyboy/xgcm
def test_axis_errors():
    ds = datasets['1d_left']

    ds_noattr = ds.copy()
    del ds_noattr.XC.attrs['axis']
    with pytest.raises(ValueError,
                       message="Couldn't find a center coordinate for axis X"):
        x_axis = Axis(ds_noattr, 'X', periodic=True)

    del ds_noattr.XG.attrs['axis']
    with pytest.raises(ValueError,
                       message="Couldn't find any coordinates for axis X"):
        x_axis = Axis(ds_noattr, 'X', periodic=True)

    ds_chopped = ds.copy()
    del ds_chopped['data_g']
    ds_chopped['XG'] = ds_chopped['XG'][:-3]
    with pytest.raises(ValueError,
                       message="Left coordinate XG has"
                       "incompatible length 7 (axis_len=9)"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    ds_chopped.XG.attrs['c_grid_axis_shift'] = -0.5
    with pytest.raises(ValueError,
                       message="Right coordinate XG has"
                       "incompatible length 7 (axis_len=9)"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    del ds_chopped.XG.attrs['c_grid_axis_shift']
    with pytest.raises(ValueError,
                       message="Coordinate XC has invalid or "
                       "missing c_grid_axis_shift attribute `None`"):
        x_axis = Axis(ds_chopped, 'X', periodic=True)

    ax = Axis(ds, 'X', periodic=True)

    with pytest.raises(ValueError,
                       message="Can't get neighbor pairs for"
                       "the same position."):
        ax.interp(ds.data_c, 'center')

    with pytest.raises(
            KeyError,
            message="Position 'right' was not found in axis.coords."):
        ax.interp(ds.data_c, 'right')

    with pytest.raises(ValueError,
                       message="`boundary=fill` is not allowed "
                       "with periodic axis X."):
        ax.interp(ds.data_c, 'right', boundary='fill')
コード例 #10
0
ファイル: test_grid.py プロジェクト: roxyboy/xgcm
def test_axis_diff_and_interp_nonperiodic_2d(
    all_2d,
    boundary,
    axis_name,
    varname,
    this,
    to,
):
    ds, periodic, expected = all_2d

    try:
        ax_periodic = axis_name in periodic
    except TypeError:
        ax_periodic = periodic

    axis = Axis(ds, axis_name, periodic=ax_periodic)
    da = ds[varname]

    # everything is left shift
    data = ds[varname].data

    axis_num = da.get_axis_num(axis.coords[this].name)
    print(axis_num, ax_periodic)

    # lookups for numpy.pad
    numpy_pad_arg = {'extend': 'edge', 'fill': 'constant'}
    # args for numpy.pad
    pad_left = (1, 0)
    pad_right = (0, 1)
    pad_none = (0, 0)

    if this == 'center':
        if ax_periodic:
            data_left = np.roll(data, 1, axis=axis_num)
            data_right = data
        else:
            pad_width = [
                pad_left if i == axis_num else pad_none
                for i in range(data.ndim)
            ]
            the_slice = [
                slice(0, -1) if i == axis_num else slice(None)
                for i in range(data.ndim)
            ]
            data_left = np.pad(data, pad_width,
                               numpy_pad_arg[boundary])[the_slice]
            data_right = data
    elif this == 'left':
        if ax_periodic:
            data_left = data
            data_right = np.roll(data, -1, axis=axis_num)
        else:
            pad_width = [
                pad_right if i == axis_num else pad_none
                for i in range(data.ndim)
            ]
            the_slice = [
                slice(1, None) if i == axis_num else slice(None)
                for i in range(data.ndim)
            ]
            print(the_slice)
            data_right = np.pad(data, pad_width,
                                numpy_pad_arg[boundary])[the_slice]
            print(data_right.shape)
            data_left = data

    data_interp = 0.5 * (data_left + data_right)
    data_diff = data_right - data_left

    # determine new dims
    dims = list(da.dims)
    dims[axis_num] = axis.coords[to].name
    coords = {dim: ds[dim] for dim in dims}

    da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords)
    da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords)

    boundary_arg = boundary if not ax_periodic else None
    da_interp = axis.interp(da, to, boundary=boundary_arg)
    da_diff = axis.diff(da, to, boundary=boundary_arg)

    assert da_interp_expected.equals(da_interp)
    assert da_diff_expected.equals(da_diff)
コード例 #11
0
def generate_axis(ds,
                  axis,
                  name,
                  axis_dim,
                  pos_from='center',
                  pos_to='left',
                  boundary_discontinuity=None,
                  pad='auto',
                  new_name=None,
                  attrs_from_scratch=True):
    """
    Creates c-grid dimensions (or coordinates) along an axis of

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset with gridinformation used to construct c-grid
    axis : str
        The appropriate xgcm axis. E.g. 'X' for longitudes.
    name : str
        The name of the variable in ds, providing the original grid.
    axis_dim : str
        The dimension of ds[name] corresponding to axis. If name itself is a
        dimension, this should be equal to name.
    pos_from : {'center','left','right'}, optional
        Position of the gridpoints given in 'ds'.
    pos_to : {'left','center','right'}, optional
        Position of the gridpoints to be generated.
    boundary_discontinuity : {None, float}, optional
        If specified, marks the value of discontinuity across boundary, e.g.
        360 for global longitude values and 180 for global latitudes.
    pad : {'auto', None, float}, optional
        If specified, determines the padding to be applied across boundary.
        If float is specified, that value is used as padding. Auto attempts to
        pad linearly extrapolated values. Can be useful for e.g. depth
        coordinates (to reconstruct 0 depth). Can lead to unexpected values
        when coordinate is multidimensional.
    new_name : str, optional
        Name of the inferred grid variable. Defaults to name+'_'+pos_to'
    attrs_from_scratch : bool, optional
        Determines if the attributes are created from scratch. Should be
        enabled for dimensions and deactivated for multidimensional
        coordinates. These can only be calculated after the dims are created.
    """
    if not isinstance(ds, xr.Dataset):
        raise ValueError("'ds' needs to be xarray.Dataset")

    if new_name is None:
        new_name = name + '_' + pos_to

    # Determine the relative position to interpolate to based on current and
    # desired position

    relative_pos_to = _position_to_relative(pos_from, pos_to)

    # This is bloated. We can probably retire the 'auto' logic in favor of
    # using 'boundary' and 'fill_value'. But first lets see if this all works.

    if (boundary_discontinuity is not None) and (pad is not None):
        raise ValueError('Coordinate cannot be wrapped and padded at the\
                            same time')
    elif (boundary_discontinuity is None) and (pad is None):
        raise ValueError('Either "boundary_discontinuity" or "pad" have \
                            to be specified')

    if pad is None:
        fill_value = 0.0
        boundary = None
        periodic = True
    elif pad == 'auto':
        fill_value = 0.0
        boundary = 'extrapolate'
        periodic = False
    else:
        fill_value = pad
        boundary = 'fill'
        periodic = False

    kwargs = dict(
        boundary_discontinuity=boundary_discontinuity,
        fill_value=fill_value,
        boundary=boundary,
        position_check=False,
    )

    ds = ds.copy()

    # For a set of coordinates there are two fundamental cases. The coordinates
    # are a) one dimensional (dimensions) or 2) multidimensional. These are
    # separated by the keyword attrs_from_scratch.
    # These two cases are treated differently because for each dataset we need
    # to recreate all a) cases before we can proceed to 2), hence this is
    # really the 'raw' data processing step. If we have working one dimensional
    # coordinates (e.g. after we looped over the axes_dims_dict, we can use the
    # regular xgcm.Axis to interpolate multidimensional coordinates.
    # This assures that any changes to the Axis.interp method can directly
    # propagate to this module.

    if attrs_from_scratch:
        # Input coordinate has to be declared as center,
        # or xgcm.Axis throws error. Will be rewrapped below.
        ds[name] = _fill_attrs(ds[name], 'center', axis)

        ax = Axis(ds, axis, periodic=periodic)
        args = ds[name], raw_interp_function, relative_pos_to
        ds.coords[new_name] = ax._neighbor_binary_func_raw(*args, **kwargs)

        # Place the correct attributes
        ds[name] = _fill_attrs(ds[name], pos_from, axis)
        ds[new_name] = _fill_attrs(ds[new_name], pos_to, axis)
    else:
        kwargs.pop('position_check', None)
        ax = Axis(ds, axis, periodic=periodic)
        args = ds[name], pos_to
        ds.coords[new_name] = ax.interp(*args, **kwargs)
    return ds
コード例 #12
0
def test_axis_errors():
    ds = datasets["1d_left"]

    ds_noattr = ds.copy()
    del ds_noattr.XC.attrs["axis"]
    with pytest.raises(ValueError,
                       match="Couldn't find a center coordinate for axis X"):
        x_axis = Axis(ds_noattr, "X", periodic=True)

    del ds_noattr.XG.attrs["axis"]
    with pytest.raises(ValueError,
                       match="Couldn't find any coordinates for axis X"):
        x_axis = Axis(ds_noattr, "X", periodic=True)

    ds_chopped = ds.copy().isel(XG=slice(None, 3))
    del ds_chopped["data_g"]
    with pytest.raises(ValueError,
                       match="coordinate XG has incompatible length"):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    ds_chopped.XG.attrs["c_grid_axis_shift"] = -0.5
    with pytest.raises(ValueError,
                       match="coordinate XG has incompatible length"):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    del ds_chopped.XG.attrs["c_grid_axis_shift"]
    with pytest.raises(
            ValueError,
            match=
            "Found two coordinates without `c_grid_axis_shift` attribute for axis X",
    ):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    ax = Axis(ds, "X", periodic=True)

    with pytest.raises(
            ValueError,
            match="Can't get neighbor pairs for the same position."):
        ax.interp(ds.data_c, "center")

    with pytest.raises(ValueError,
                       match="This axis doesn't contain a `right` position"):
        ax.interp(ds.data_c, "right")
コード例 #13
0
ファイル: autogenerate.py プロジェクト: jamesp/xgcm
def generate_axis(ds,
                  axis,
                  name,
                  axis_dim,
                  pos_from='center',
                  pos_to='left',
                  boundary_discontinuity=None,
                  pad='auto',
                  new_name=None,
                  attrs_from_scratch=True):
    """
    Creates c-grid dimensions (or coordinates) along an axis of

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset with gridinformation used to construct c-grid
    axis : str
        The appropriate xgcm axis. E.g. 'X' for longitudes.
    name : str
        The name of the variable in ds, providing the original grid.
    axis_dim : str
        The dimension of ds[name] corresponding to axis. If name itself is a
        dimension, this should be equal to name.
    pos_from : {'center','left','right'}, optional
        Position of the gridpoints given in 'ds'.
    pos_to : {'left','center','right'}, optional
        Position of the gridpoints to be generated.
    boundary_discontinuity : {None, float}, optional
        If specified, marks the value of discontinuity across boundary, e.g.
        360 for global longitude values and 180 for global latitudes.
    pad : {'auto', None, float}, optional
        If specified, determines the padding to be applied across boundary.
        If float is specified, that value is used as padding. Auto attempts to
        pad linearly extrapolated values. Can be useful for e.g. depth
        coordinates (to reconstruct 0 depth). Can lead to unexpected values
        when coordinate is multidimensional.
    new_name : str, optional
        Name of the inferred grid variable. Defaults to name+'_'+pos_to'
    attrs_from_scratch : bool, optional
        Determines if the attributes are created from scratch. Should be
        enabled for dimensions and deactivated for multidimensional
        coordinates. These can only be calculated after the dims are created.
    """
    if not isinstance(ds, xr.Dataset):
        raise RuntimeError("'ds' needs to be xarray.Dataset")

    if new_name is None:
        new_name = name+'_'+pos_to

    # Determine the relative position to interpolate to based on current and
    # desired position
    relative_pos_to = _position_to_relative(pos_from, pos_to)

    if (boundary_discontinuity is not None) and (pad is not None):
        raise RuntimeError('Coordinate cannot be wrapped and padded at the\
                            same time')
    elif (boundary_discontinuity is None) and (pad is not None):
        if pad == 'auto':
            fill_value = _auto_pad(ds[name], axis_dim, relative_pos_to)
        else:
            fill_value = pad
        periodic = False
        boundary = 'fill'
    elif (boundary_discontinuity is not None) and (pad is None):
        periodic = True
        fill_value = 0.0
        boundary = None
    else:
        raise RuntimeError('Either "boundary_discontinuity" or "pad" have \
                            to be specified')

    ds = ds.copy()

    # For a set of coordinates there are two fundamental cases. The coordinates
    # are a) one dimensional (dimensions) or 2) multidimensional. These are
    # separated by the keyword attrs_from_scratch.
    # These two cases are treated differently because for each dataset we need
    # to recreate all a) cases before we can proceed to 2), hence this is
    # really the 'raw' data processing step. If we have working one dimensional
    # coordinates (e.g. after we looped over the axes_dims_dict, we can use the
    # regular xgcm.Axis to interpolate multidimensional coordinates.
    # This assures that any changes to the Axis.interp method can directly
    # propagate to this module.

    if attrs_from_scratch:
        # Input coordinate has to be declared as center,
        # or xgcm.Axis throws error. Will be rewrapped below.
        ds[name] = _fill_attrs(ds[name], 'center', axis)

        ax = Axis(ds, axis, periodic=periodic)
        ds.coords[new_name] = \
            ax._neighbor_binary_func_raw(ds[name],
                                         raw_interp_function,
                                         relative_pos_to,
                                         boundary=boundary,
                                         fill_value=fill_value,
                                         boundary_discontinuity=\
                                         boundary_discontinuity)

        # Place the correct attributes
        ds[name] = _fill_attrs(ds[name], pos_from, axis)
        ds[new_name] = _fill_attrs(ds[new_name], pos_to, axis)
    else:
        ax = Axis(ds, axis, periodic=periodic)
        ds.coords[new_name] = ax.interp(ds[name], pos_to, boundary=boundary,
                                        fill_value=fill_value,
                                        boundary_discontinuity=\
                                        boundary_discontinuity)
    return ds
コード例 #14
0
ファイル: test_grid.py プロジェクト: jamesp/xgcm
def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name,
                                             varname, this, to,):
    ds, periodic, expected = all_2d

    try:
        ax_periodic = axis_name in periodic
    except TypeError:
        ax_periodic = periodic

    axis = Axis(ds, axis_name, periodic=ax_periodic)
    da = ds[varname]

    # everything is left shift
    data = ds[varname].data

    axis_num = da.get_axis_num(axis.coords[this].name)
    print(axis_num, ax_periodic)

    # lookups for numpy.pad
    numpy_pad_arg = {'extend': 'edge', 'fill': 'constant'}
    # args for numpy.pad
    pad_left = (1,0)
    pad_right = (0,1)
    pad_none = (0,0)

    if this=='center':
        if ax_periodic:
            data_left = np.roll(data, 1, axis=axis_num)
            data_right = data
        else:
            pad_width = [pad_left if i==axis_num else pad_none
                         for i in range(data.ndim)]
            the_slice = [slice(0,-1) if i==axis_num else slice(None)
                         for i in range(data.ndim)]
            data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice]
            data_right = data
    elif this=='left':
        if ax_periodic:
            data_left = data
            data_right = np.roll(data, -1, axis=axis_num)
        else:
            pad_width = [pad_right if i==axis_num else pad_none
                         for i in range(data.ndim)]
            the_slice = [slice(1,None) if i==axis_num else slice(None)
                         for i in range(data.ndim)]
            print(the_slice)
            data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice]
            print(data_right.shape)
            data_left = data

    data_interp = 0.5 * (data_left + data_right)
    data_diff = data_right - data_left

    # determine new dims
    dims = list(da.dims)
    dims[axis_num] = axis.coords[to].name
    coords = {dim: ds[dim] for dim in dims}

    da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords)
    da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords)

    boundary_arg = boundary if not ax_periodic else None
    da_interp = axis.interp(da, to, boundary=boundary_arg)
    da_diff = axis.diff(da, to, boundary=boundary_arg)

    assert da_interp_expected.equals(da_interp)
    assert da_diff_expected.equals(da_diff)
コード例 #15
0
ファイル: test_grid.py プロジェクト: xgcm/xgcm
def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary,
                                             from_center):
    ds, periodic, expected = nonperiodic_1d
    axis = Axis(ds, "X", periodic=periodic)

    dim_len_diff = len(ds.XG) - len(ds.XC)

    if from_center:
        to = (set(expected["axes"]["X"].keys()) - {"center"}).pop()
        coord_to = "XG"
        da = ds.data_c
    else:
        to = "center"
        coord_to = "XC"
        da = ds.data_g

    shift = expected.get("shift") or False

    data = da.data
    if (dim_len_diff == 1 and not from_center) or (dim_len_diff == -1
                                                   and from_center):
        data_left = data[:-1]
        data_right = data[1:]
    elif (dim_len_diff == 1 and from_center) or (dim_len_diff == -1
                                                 and not from_center):
        data_left = _pad_left(data, boundary)
        data_right = _pad_right(data, boundary)
    elif (shift and not from_center) or (not shift and from_center):
        data_left = _pad_left(data[:-1], boundary)
        data_right = data
    else:
        data_left = data
        data_right = _pad_right(data[1:], boundary)

    # interpolate
    data_interp_expected = xr.DataArray(0.5 * (data_left + data_right),
                                        dims=[coord_to],
                                        coords={coord_to: ds[coord_to]})
    data_interp = axis.interp(da, to, boundary=boundary)
    assert data_interp_expected.equals(data_interp)
    # check without "to" specified
    assert data_interp.equals(axis.interp(da, boundary=boundary))

    # difference
    data_diff_expected = xr.DataArray(data_right - data_left,
                                      dims=[coord_to],
                                      coords={coord_to: ds[coord_to]})
    data_diff = axis.diff(da, to, boundary=boundary)
    assert data_diff_expected.equals(data_diff)
    # check without "to" specified
    assert data_diff.equals(axis.diff(da, boundary=boundary))

    # max
    data_max_expected = xr.DataArray(
        np.maximum(data_right, data_left),
        dims=[coord_to],
        coords={coord_to: ds[coord_to]},
    )
    data_max = axis.max(da, to, boundary=boundary)
    assert data_max_expected.equals(data_max)
    # check without "to" specified
    assert data_max.equals(axis.max(da, boundary=boundary))

    # min
    data_min_expected = xr.DataArray(
        np.minimum(data_right, data_left),
        dims=[coord_to],
        coords={coord_to: ds[coord_to]},
    )
    data_min = axis.min(da, to, boundary=boundary)
    assert data_min_expected.equals(data_min)
    # check without "to" specified
    assert data_min.equals(axis.min(da, boundary=boundary))
コード例 #16
0
ファイル: test_grid.py プロジェクト: xgcm/xgcm
def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name,
                                             varname, this, to):
    ds, periodic, _ = all_2d

    try:
        ax_periodic = axis_name in periodic
    except TypeError:
        ax_periodic = periodic

    boundary_arg = boundary if not ax_periodic else None
    axis = Axis(ds, axis_name, periodic=ax_periodic, boundary=boundary_arg)
    da = ds[varname]

    # everything is left shift
    data = ds[varname].data

    axis_num = da.get_axis_num(axis.coords[this])

    # lookups for numpy.pad
    numpy_pad_arg = {"extend": "edge", "fill": "constant"}
    # args for numpy.pad
    pad_left = (1, 0)
    pad_right = (0, 1)
    pad_none = (0, 0)

    if this == "center":
        if ax_periodic:
            data_left = np.roll(data, 1, axis=axis_num)
        else:
            pad_width = [
                pad_left if i == axis_num else pad_none
                for i in range(data.ndim)
            ]
            the_slice = tuple([
                slice(0, -1) if i == axis_num else slice(None)
                for i in range(data.ndim)
            ])
            data_left = np.pad(data, pad_width,
                               numpy_pad_arg[boundary])[the_slice]
        data_right = data
    elif this == "left":
        if ax_periodic:
            data_left = data
            data_right = np.roll(data, -1, axis=axis_num)
        else:
            pad_width = [
                pad_right if i == axis_num else pad_none
                for i in range(data.ndim)
            ]
            the_slice = tuple([
                slice(1, None) if i == axis_num else slice(None)
                for i in range(data.ndim)
            ])
            data_right = np.pad(data, pad_width,
                                numpy_pad_arg[boundary])[the_slice]
            data_left = data

    data_interp = 0.5 * (data_left + data_right)
    data_diff = data_right - data_left

    # determine new dims
    dims = list(da.dims)
    dims[axis_num] = axis.coords[to]
    coords = {dim: ds[dim] for dim in dims}

    da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords)
    da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords)

    da_interp = axis.interp(da, to)
    da_diff = axis.diff(da, to)

    assert da_interp_expected.equals(da_interp)
    assert da_diff_expected.equals(da_diff)

    if boundary_arg is not None:
        if boundary == "extend":
            bad_boundary = "fill"
        elif boundary == "fill":
            bad_boundary = "extend"

        da_interp_wrong = axis.interp(da, to, boundary=bad_boundary)
        assert not da_interp_expected.equals(da_interp_wrong)
        da_diff_wrong = axis.diff(da, to, boundary=bad_boundary)
        assert not da_diff_expected.equals(da_diff_wrong)
コード例 #17
0
def test_axis_errors():
    ds = datasets["1d_left"]

    ds_noattr = ds.copy()
    del ds_noattr.XC.attrs["axis"]
    with pytest.raises(ValueError,
                       message="Couldn't find a center coordinate for axis X"):
        x_axis = Axis(ds_noattr, "X", periodic=True)

    del ds_noattr.XG.attrs["axis"]
    with pytest.raises(ValueError,
                       message="Couldn't find any coordinates for axis X"):
        x_axis = Axis(ds_noattr, "X", periodic=True)

    ds_chopped = ds.copy()
    del ds_chopped["data_g"]
    ds_chopped["XG"] = ds_chopped["XG"][:-3]
    with pytest.raises(
            ValueError,
            message="Left coordinate XG has"
            "incompatible length 7 (axis_len=9)",
    ):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    ds_chopped.XG.attrs["c_grid_axis_shift"] = -0.5
    with pytest.raises(
            ValueError,
            message="Right coordinate XG has"
            "incompatible length 7 (axis_len=9)",
    ):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    del ds_chopped.XG.attrs["c_grid_axis_shift"]
    with pytest.raises(
            ValueError,
            message="Coordinate XC has invalid or "
            "missing c_grid_axis_shift attribute `None`",
    ):
        x_axis = Axis(ds_chopped, "X", periodic=True)

    ax = Axis(ds, "X", periodic=True)

    with pytest.raises(ValueError,
                       message="Can't get neighbor pairs for"
                       "the same position."):
        ax.interp(ds.data_c, "center")

    with pytest.raises(ValueError,
                       message="This axis doesn't contain a `right` position"):
        ax.interp(ds.data_c, "right")

    with pytest.raises(ValueError,
                       message="`boundary=fill` is not allowed "
                       "with periodic axis X."):
        ax.interp(ds.data_c, "right", boundary="fill")