def test_grid_ops(all_datasets): """ Check that we get the same answer using Axis or Grid objects """ ds, periodic, expected = all_datasets grid = Grid(ds, periodic=periodic) for axis_name in grid.axes.keys(): try: ax_periodic = axis_name in periodic except TypeError: ax_periodic = periodic axis = Axis(ds, axis_name, periodic=ax_periodic) bcs = [None] if ax_periodic else ['fill', 'extend'] for varname in ['data_c', 'data_g']: for boundary in bcs: da_interp = grid.interp(ds[varname], axis_name, boundary=boundary) da_interp_ax = axis.interp(ds[varname], boundary=boundary) assert da_interp.equals(da_interp_ax) da_diff = grid.diff(ds[varname], axis_name, boundary=boundary) da_diff_ax = axis.diff(ds[varname], boundary=boundary) assert da_diff.equals(da_diff_ax) if boundary is not None: da_cumsum = grid.cumsum(ds[varname], axis_name, boundary=boundary) da_cumsum_ax = axis.cumsum(ds[varname], boundary=boundary) assert da_cumsum.equals(da_cumsum_ax)
def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary, from_center): ds, periodic, expected = nonperiodic_1d axis = Axis(ds, 'X', periodic=periodic) dim_len_diff = len(ds.XG) - len(ds.XC) if from_center: to = (set(expected['axes']['X'].keys()) - {'center'}).pop() coord_to = 'XG' da = ds.data_c else: to = 'center' coord_to = 'XC' da = ds.data_g shift = expected.get('shift') or False data = da.data if ((dim_len_diff==1 and not from_center) or (dim_len_diff==-1 and from_center)): data_left = data[:-1] data_right = data[1:] elif ((dim_len_diff==1 and from_center) or (dim_len_diff==-1 and not from_center)): data_left = _pad_left(data, boundary) data_right = _pad_right(data, boundary) elif (shift and not from_center) or (not shift and from_center): data_left = _pad_left(data[:-1], boundary) data_right = data else: data_left = data data_right = _pad_right(data[1:], boundary) # interpolate data_interp_expected = xr.DataArray(0.5 * (data_left + data_right), dims=[coord_to], coords={coord_to: ds[coord_to]}) data_interp = axis.interp(da, to, boundary=boundary) print(data_interp_expected) print(data_interp) assert data_interp_expected.equals(data_interp) # check without "to" specified assert data_interp.equals(axis.interp(da, boundary=boundary)) # difference data_diff_expected = xr.DataArray(data_right - data_left, dims=[coord_to], coords={coord_to: ds[coord_to]}) data_diff = axis.diff(da, to, boundary=boundary) assert data_diff_expected.equals(data_diff) # check without "to" specified assert data_diff.equals(axis.diff(da, boundary=boundary))
def test_axis_neighbor_pairs_2d(periodic_2d, varname, axis_name, to, roll, roll_axis, swap_order): ds, periodic, expected = periodic_2d axis = Axis(ds, axis_name) data = ds[varname] data_left, data_right = axis._get_neighbor_data_pairs(data, to) if swap_order: data_left, data_right = data_right, data_left np.testing.assert_allclose(data_left, np.roll(data.data, roll, axis=roll_axis)) np.testing.assert_allclose(data_right, data.data)
def test_axis_wrap_and_replace_nonperiodic(nonperiodic_1d): ds, periodic, expected = nonperiodic_1d axis = Axis(ds, 'X') da_c = 0 * ds.XC + 1 da_g = 0 * ds.XG + 1 to = (set(expected['axes']['X'].keys()) - {'center'}).pop() da_g_test = axis._wrap_and_replace_coords(da_c, da_g.data, to) assert da_g.equals(da_g_test) da_c_test = axis._wrap_and_replace_coords(da_g, da_c.data, 'center') assert da_c.equals(da_c_test)
def test_axis_neighbor_pairs_nonperiodic_1d(nonperiodic_1d, boundary, from_center): ds, periodic, expected = nonperiodic_1d axis = Axis(ds, 'X', periodic=periodic) # detect whether this is an outer or inner case # outer --> dim_line_diff = 1 # inner --> dim_line_diff = -1 dim_len_diff = len(ds.XG) - len(ds.XC) if from_center: to = (set(expected['axes']['X'].keys()) - {'center'}).pop() da = ds.data_c else: to = 'center' da = ds.data_g shift = expected.get('shift') or False # need boundary condition for everything but outer to center if (boundary is None) and (dim_len_diff == 0 or (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center)): with pytest.raises(ValueError): data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) else: data_left, data_right = axis._get_neighbor_data_pairs(da, to, boundary=boundary) if (((dim_len_diff == 1) and not from_center) or ((dim_len_diff == -1) and from_center)): expected_left = da.data[:-1] expected_right = da.data[1:] elif (((dim_len_diff == 1) and from_center) or ((dim_len_diff == -1) and not from_center)): expected_left = _pad_left(da.data, boundary) expected_right = _pad_right(da.data, boundary) elif (shift and not from_center) or (not shift and from_center): expected_right = da.data expected_left = _pad_left(da.data, boundary)[:-1] else: expected_left = da.data expected_right = _pad_right(da.data, boundary)[1:] np.testing.assert_allclose(data_left, expected_left) np.testing.assert_allclose(data_right, expected_right)
def test_axis_cumsum(nonperiodic_1d, boundary): ds, periodic, expected = nonperiodic_1d axis = Axis(ds, 'X', periodic=periodic) axis_expected = expected['axes']['X'] cumsum_g = axis.cumsum(ds.data_g, to='center', boundary=boundary) assert cumsum_g.dims == ds.data_c.dims # check default "to" assert cumsum_g.equals(axis.cumsum(ds.data_g, boundary=boundary)) to = set(axis_expected).difference({'center'}).pop() cumsum_c = axis.cumsum(ds.data_c, to=to, boundary=boundary) assert cumsum_c.dims == ds.data_g.dims # check default "to" assert cumsum_c.equals(axis.cumsum(ds.data_c, boundary=boundary)) cumsum_c_raw = np.cumsum(ds.data_c.data) cumsum_g_raw = np.cumsum(ds.data_g.data) if to == 'right': np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw) fill_value = 0. if boundary=='fill' else cumsum_g_raw[0] np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw[:-1]])) elif to == 'left': np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw) fill_value = 0. if boundary=='fill' else cumsum_c_raw[0] np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw[:-1]])) elif to == 'inner': np.testing.assert_allclose(cumsum_c.data, cumsum_c_raw[:-1]) fill_value = 0. if boundary=='fill' else cumsum_g_raw[0] np.testing.assert_allclose(cumsum_g.data, np.hstack([fill_value, cumsum_g_raw])) elif to == 'outer': np.testing.assert_allclose(cumsum_g.data, cumsum_g_raw[:-1]) fill_value = 0. if boundary=='fill' else cumsum_c_raw[0] np.testing.assert_allclose(cumsum_c.data, np.hstack([fill_value, cumsum_c_raw]))
def test_axis_errors(): ds = datasets['1d_left'] ds_noattr = ds.copy() del ds_noattr.XC.attrs['axis'] with pytest.raises(ValueError, message="Couldn't find a center coordinate for axis X"): x_axis = Axis(ds_noattr, 'X', periodic=True) del ds_noattr.XG.attrs['axis'] with pytest.raises(ValueError, message="Couldn't find any coordinates for axis X"): x_axis = Axis(ds_noattr, 'X', periodic=True) ds_chopped = ds.copy() del ds_chopped['data_g'] ds_chopped['XG'] = ds_chopped['XG'][:-3] with pytest.raises(ValueError, message="Left coordinate XG has" "incompatible length 7 (axis_len=9)"): x_axis = Axis(ds_chopped, 'X', periodic=True) ds_chopped.XG.attrs['c_grid_axis_shift'] = -0.5 with pytest.raises(ValueError, message="Right coordinate XG has" "incompatible length 7 (axis_len=9)"): x_axis = Axis(ds_chopped, 'X', periodic=True) del ds_chopped.XG.attrs['c_grid_axis_shift'] with pytest.raises(ValueError, message="Coordinate XC has invalid or " "missing c_grid_axis_shift attribute `None`"): x_axis = Axis(ds_chopped, 'X', periodic=True) ax = Axis(ds, 'X', periodic=True) with pytest.raises(ValueError, message="Can't get neighbor pairs for" "the same position."): ax.interp(ds.data_c, 'center') with pytest.raises(KeyError, message="Position 'right' was not found in axis.coords."): ax.interp(ds.data_c, 'right') with pytest.raises(ValueError, message="`boundary=fill` is not allowed " "with periodic axis X."): ax.interp(ds.data_c, 'right', boundary='fill')
def _get_axes(ds): all_axes = {ds[c].attrs["axis"] for c in ds.dims if "axis" in ds[c].attrs} axis_objs = {ax: Axis(ds, ax) for ax in all_axes} return axis_objs
def test_axis_errors(): ds = datasets['1d_left'] ds_noattr = ds.copy() del ds_noattr.XC.attrs['axis'] with pytest.raises(ValueError, message="Couldn't find a center coordinate for axis X"): x_axis = Axis(ds_noattr, 'X', periodic=True) del ds_noattr.XG.attrs['axis'] with pytest.raises(ValueError, message="Couldn't find any coordinates for axis X"): x_axis = Axis(ds_noattr, 'X', periodic=True) ds_chopped = ds.copy() del ds_chopped['data_g'] ds_chopped['XG'] = ds_chopped['XG'][:-3] with pytest.raises(ValueError, message="Left coordinate XG has" "incompatible length 7 (axis_len=9)"): x_axis = Axis(ds_chopped, 'X', periodic=True) ds_chopped.XG.attrs['c_grid_axis_shift'] = -0.5 with pytest.raises(ValueError, message="Right coordinate XG has" "incompatible length 7 (axis_len=9)"): x_axis = Axis(ds_chopped, 'X', periodic=True) del ds_chopped.XG.attrs['c_grid_axis_shift'] with pytest.raises(ValueError, message="Coordinate XC has invalid or " "missing c_grid_axis_shift attribute `None`"): x_axis = Axis(ds_chopped, 'X', periodic=True) ax = Axis(ds, 'X', periodic=True) with pytest.raises(ValueError, message="Can't get neighbor pairs for" "the same position."): ax.interp(ds.data_c, 'center') with pytest.raises( KeyError, message="Position 'right' was not found in axis.coords."): ax.interp(ds.data_c, 'right') with pytest.raises(ValueError, message="`boundary=fill` is not allowed " "with periodic axis X."): ax.interp(ds.data_c, 'right', boundary='fill')
def test_axis_diff_and_interp_nonperiodic_2d( all_2d, boundary, axis_name, varname, this, to, ): ds, periodic, expected = all_2d try: ax_periodic = axis_name in periodic except TypeError: ax_periodic = periodic axis = Axis(ds, axis_name, periodic=ax_periodic) da = ds[varname] # everything is left shift data = ds[varname].data axis_num = da.get_axis_num(axis.coords[this].name) print(axis_num, ax_periodic) # lookups for numpy.pad numpy_pad_arg = {'extend': 'edge', 'fill': 'constant'} # args for numpy.pad pad_left = (1, 0) pad_right = (0, 1) pad_none = (0, 0) if this == 'center': if ax_periodic: data_left = np.roll(data, 1, axis=axis_num) data_right = data else: pad_width = [ pad_left if i == axis_num else pad_none for i in range(data.ndim) ] the_slice = [ slice(0, -1) if i == axis_num else slice(None) for i in range(data.ndim) ] data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] data_right = data elif this == 'left': if ax_periodic: data_left = data data_right = np.roll(data, -1, axis=axis_num) else: pad_width = [ pad_right if i == axis_num else pad_none for i in range(data.ndim) ] the_slice = [ slice(1, None) if i == axis_num else slice(None) for i in range(data.ndim) ] print(the_slice) data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] print(data_right.shape) data_left = data data_interp = 0.5 * (data_left + data_right) data_diff = data_right - data_left # determine new dims dims = list(da.dims) dims[axis_num] = axis.coords[to].name coords = {dim: ds[dim] for dim in dims} da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords) da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords) boundary_arg = boundary if not ax_periodic else None da_interp = axis.interp(da, to, boundary=boundary_arg) da_diff = axis.diff(da, to, boundary=boundary_arg) assert da_interp_expected.equals(da_interp) assert da_diff_expected.equals(da_diff)
def generate_axis(ds, axis, name, axis_dim, pos_from='center', pos_to='left', boundary_discontinuity=None, pad='auto', new_name=None, attrs_from_scratch=True): """ Creates c-grid dimensions (or coordinates) along an axis of Parameters ---------- ds : xarray.Dataset Dataset with gridinformation used to construct c-grid axis : str The appropriate xgcm axis. E.g. 'X' for longitudes. name : str The name of the variable in ds, providing the original grid. axis_dim : str The dimension of ds[name] corresponding to axis. If name itself is a dimension, this should be equal to name. pos_from : {'center','left','right'}, optional Position of the gridpoints given in 'ds'. pos_to : {'left','center','right'}, optional Position of the gridpoints to be generated. boundary_discontinuity : {None, float}, optional If specified, marks the value of discontinuity across boundary, e.g. 360 for global longitude values and 180 for global latitudes. pad : {'auto', None, float}, optional If specified, determines the padding to be applied across boundary. If float is specified, that value is used as padding. Auto attempts to pad linearly extrapolated values. Can be useful for e.g. depth coordinates (to reconstruct 0 depth). Can lead to unexpected values when coordinate is multidimensional. new_name : str, optional Name of the inferred grid variable. Defaults to name+'_'+pos_to' attrs_from_scratch : bool, optional Determines if the attributes are created from scratch. Should be enabled for dimensions and deactivated for multidimensional coordinates. These can only be calculated after the dims are created. """ if not isinstance(ds, xr.Dataset): raise ValueError("'ds' needs to be xarray.Dataset") if new_name is None: new_name = name + '_' + pos_to # Determine the relative position to interpolate to based on current and # desired position relative_pos_to = _position_to_relative(pos_from, pos_to) # This is bloated. We can probably retire the 'auto' logic in favor of # using 'boundary' and 'fill_value'. But first lets see if this all works. if (boundary_discontinuity is not None) and (pad is not None): raise ValueError('Coordinate cannot be wrapped and padded at the\ same time') elif (boundary_discontinuity is None) and (pad is None): raise ValueError('Either "boundary_discontinuity" or "pad" have \ to be specified') if pad is None: fill_value = 0.0 boundary = None periodic = True elif pad == 'auto': fill_value = 0.0 boundary = 'extrapolate' periodic = False else: fill_value = pad boundary = 'fill' periodic = False kwargs = dict( boundary_discontinuity=boundary_discontinuity, fill_value=fill_value, boundary=boundary, position_check=False, ) ds = ds.copy() # For a set of coordinates there are two fundamental cases. The coordinates # are a) one dimensional (dimensions) or 2) multidimensional. These are # separated by the keyword attrs_from_scratch. # These two cases are treated differently because for each dataset we need # to recreate all a) cases before we can proceed to 2), hence this is # really the 'raw' data processing step. If we have working one dimensional # coordinates (e.g. after we looped over the axes_dims_dict, we can use the # regular xgcm.Axis to interpolate multidimensional coordinates. # This assures that any changes to the Axis.interp method can directly # propagate to this module. if attrs_from_scratch: # Input coordinate has to be declared as center, # or xgcm.Axis throws error. Will be rewrapped below. ds[name] = _fill_attrs(ds[name], 'center', axis) ax = Axis(ds, axis, periodic=periodic) args = ds[name], raw_interp_function, relative_pos_to ds.coords[new_name] = ax._neighbor_binary_func_raw(*args, **kwargs) # Place the correct attributes ds[name] = _fill_attrs(ds[name], pos_from, axis) ds[new_name] = _fill_attrs(ds[new_name], pos_to, axis) else: kwargs.pop('position_check', None) ax = Axis(ds, axis, periodic=periodic) args = ds[name], pos_to ds.coords[new_name] = ax.interp(*args, **kwargs) return ds
def test_axis_errors(): ds = datasets["1d_left"] ds_noattr = ds.copy() del ds_noattr.XC.attrs["axis"] with pytest.raises(ValueError, match="Couldn't find a center coordinate for axis X"): x_axis = Axis(ds_noattr, "X", periodic=True) del ds_noattr.XG.attrs["axis"] with pytest.raises(ValueError, match="Couldn't find any coordinates for axis X"): x_axis = Axis(ds_noattr, "X", periodic=True) ds_chopped = ds.copy().isel(XG=slice(None, 3)) del ds_chopped["data_g"] with pytest.raises(ValueError, match="coordinate XG has incompatible length"): x_axis = Axis(ds_chopped, "X", periodic=True) ds_chopped.XG.attrs["c_grid_axis_shift"] = -0.5 with pytest.raises(ValueError, match="coordinate XG has incompatible length"): x_axis = Axis(ds_chopped, "X", periodic=True) del ds_chopped.XG.attrs["c_grid_axis_shift"] with pytest.raises( ValueError, match= "Found two coordinates without `c_grid_axis_shift` attribute for axis X", ): x_axis = Axis(ds_chopped, "X", periodic=True) ax = Axis(ds, "X", periodic=True) with pytest.raises( ValueError, match="Can't get neighbor pairs for the same position."): ax.interp(ds.data_c, "center") with pytest.raises(ValueError, match="This axis doesn't contain a `right` position"): ax.interp(ds.data_c, "right")
def generate_axis(ds, axis, name, axis_dim, pos_from='center', pos_to='left', boundary_discontinuity=None, pad='auto', new_name=None, attrs_from_scratch=True): """ Creates c-grid dimensions (or coordinates) along an axis of Parameters ---------- ds : xarray.Dataset Dataset with gridinformation used to construct c-grid axis : str The appropriate xgcm axis. E.g. 'X' for longitudes. name : str The name of the variable in ds, providing the original grid. axis_dim : str The dimension of ds[name] corresponding to axis. If name itself is a dimension, this should be equal to name. pos_from : {'center','left','right'}, optional Position of the gridpoints given in 'ds'. pos_to : {'left','center','right'}, optional Position of the gridpoints to be generated. boundary_discontinuity : {None, float}, optional If specified, marks the value of discontinuity across boundary, e.g. 360 for global longitude values and 180 for global latitudes. pad : {'auto', None, float}, optional If specified, determines the padding to be applied across boundary. If float is specified, that value is used as padding. Auto attempts to pad linearly extrapolated values. Can be useful for e.g. depth coordinates (to reconstruct 0 depth). Can lead to unexpected values when coordinate is multidimensional. new_name : str, optional Name of the inferred grid variable. Defaults to name+'_'+pos_to' attrs_from_scratch : bool, optional Determines if the attributes are created from scratch. Should be enabled for dimensions and deactivated for multidimensional coordinates. These can only be calculated after the dims are created. """ if not isinstance(ds, xr.Dataset): raise RuntimeError("'ds' needs to be xarray.Dataset") if new_name is None: new_name = name+'_'+pos_to # Determine the relative position to interpolate to based on current and # desired position relative_pos_to = _position_to_relative(pos_from, pos_to) if (boundary_discontinuity is not None) and (pad is not None): raise RuntimeError('Coordinate cannot be wrapped and padded at the\ same time') elif (boundary_discontinuity is None) and (pad is not None): if pad == 'auto': fill_value = _auto_pad(ds[name], axis_dim, relative_pos_to) else: fill_value = pad periodic = False boundary = 'fill' elif (boundary_discontinuity is not None) and (pad is None): periodic = True fill_value = 0.0 boundary = None else: raise RuntimeError('Either "boundary_discontinuity" or "pad" have \ to be specified') ds = ds.copy() # For a set of coordinates there are two fundamental cases. The coordinates # are a) one dimensional (dimensions) or 2) multidimensional. These are # separated by the keyword attrs_from_scratch. # These two cases are treated differently because for each dataset we need # to recreate all a) cases before we can proceed to 2), hence this is # really the 'raw' data processing step. If we have working one dimensional # coordinates (e.g. after we looped over the axes_dims_dict, we can use the # regular xgcm.Axis to interpolate multidimensional coordinates. # This assures that any changes to the Axis.interp method can directly # propagate to this module. if attrs_from_scratch: # Input coordinate has to be declared as center, # or xgcm.Axis throws error. Will be rewrapped below. ds[name] = _fill_attrs(ds[name], 'center', axis) ax = Axis(ds, axis, periodic=periodic) ds.coords[new_name] = \ ax._neighbor_binary_func_raw(ds[name], raw_interp_function, relative_pos_to, boundary=boundary, fill_value=fill_value, boundary_discontinuity=\ boundary_discontinuity) # Place the correct attributes ds[name] = _fill_attrs(ds[name], pos_from, axis) ds[new_name] = _fill_attrs(ds[new_name], pos_to, axis) else: ax = Axis(ds, axis, periodic=periodic) ds.coords[new_name] = ax.interp(ds[name], pos_to, boundary=boundary, fill_value=fill_value, boundary_discontinuity=\ boundary_discontinuity) return ds
def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name, varname, this, to,): ds, periodic, expected = all_2d try: ax_periodic = axis_name in periodic except TypeError: ax_periodic = periodic axis = Axis(ds, axis_name, periodic=ax_periodic) da = ds[varname] # everything is left shift data = ds[varname].data axis_num = da.get_axis_num(axis.coords[this].name) print(axis_num, ax_periodic) # lookups for numpy.pad numpy_pad_arg = {'extend': 'edge', 'fill': 'constant'} # args for numpy.pad pad_left = (1,0) pad_right = (0,1) pad_none = (0,0) if this=='center': if ax_periodic: data_left = np.roll(data, 1, axis=axis_num) data_right = data else: pad_width = [pad_left if i==axis_num else pad_none for i in range(data.ndim)] the_slice = [slice(0,-1) if i==axis_num else slice(None) for i in range(data.ndim)] data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] data_right = data elif this=='left': if ax_periodic: data_left = data data_right = np.roll(data, -1, axis=axis_num) else: pad_width = [pad_right if i==axis_num else pad_none for i in range(data.ndim)] the_slice = [slice(1,None) if i==axis_num else slice(None) for i in range(data.ndim)] print(the_slice) data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] print(data_right.shape) data_left = data data_interp = 0.5 * (data_left + data_right) data_diff = data_right - data_left # determine new dims dims = list(da.dims) dims[axis_num] = axis.coords[to].name coords = {dim: ds[dim] for dim in dims} da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords) da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords) boundary_arg = boundary if not ax_periodic else None da_interp = axis.interp(da, to, boundary=boundary_arg) da_diff = axis.diff(da, to, boundary=boundary_arg) assert da_interp_expected.equals(da_interp) assert da_diff_expected.equals(da_diff)
def test_axis_diff_and_interp_nonperiodic_1d(nonperiodic_1d, boundary, from_center): ds, periodic, expected = nonperiodic_1d axis = Axis(ds, "X", periodic=periodic) dim_len_diff = len(ds.XG) - len(ds.XC) if from_center: to = (set(expected["axes"]["X"].keys()) - {"center"}).pop() coord_to = "XG" da = ds.data_c else: to = "center" coord_to = "XC" da = ds.data_g shift = expected.get("shift") or False data = da.data if (dim_len_diff == 1 and not from_center) or (dim_len_diff == -1 and from_center): data_left = data[:-1] data_right = data[1:] elif (dim_len_diff == 1 and from_center) or (dim_len_diff == -1 and not from_center): data_left = _pad_left(data, boundary) data_right = _pad_right(data, boundary) elif (shift and not from_center) or (not shift and from_center): data_left = _pad_left(data[:-1], boundary) data_right = data else: data_left = data data_right = _pad_right(data[1:], boundary) # interpolate data_interp_expected = xr.DataArray(0.5 * (data_left + data_right), dims=[coord_to], coords={coord_to: ds[coord_to]}) data_interp = axis.interp(da, to, boundary=boundary) assert data_interp_expected.equals(data_interp) # check without "to" specified assert data_interp.equals(axis.interp(da, boundary=boundary)) # difference data_diff_expected = xr.DataArray(data_right - data_left, dims=[coord_to], coords={coord_to: ds[coord_to]}) data_diff = axis.diff(da, to, boundary=boundary) assert data_diff_expected.equals(data_diff) # check without "to" specified assert data_diff.equals(axis.diff(da, boundary=boundary)) # max data_max_expected = xr.DataArray( np.maximum(data_right, data_left), dims=[coord_to], coords={coord_to: ds[coord_to]}, ) data_max = axis.max(da, to, boundary=boundary) assert data_max_expected.equals(data_max) # check without "to" specified assert data_max.equals(axis.max(da, boundary=boundary)) # min data_min_expected = xr.DataArray( np.minimum(data_right, data_left), dims=[coord_to], coords={coord_to: ds[coord_to]}, ) data_min = axis.min(da, to, boundary=boundary) assert data_min_expected.equals(data_min) # check without "to" specified assert data_min.equals(axis.min(da, boundary=boundary))
def test_axis_diff_and_interp_nonperiodic_2d(all_2d, boundary, axis_name, varname, this, to): ds, periodic, _ = all_2d try: ax_periodic = axis_name in periodic except TypeError: ax_periodic = periodic boundary_arg = boundary if not ax_periodic else None axis = Axis(ds, axis_name, periodic=ax_periodic, boundary=boundary_arg) da = ds[varname] # everything is left shift data = ds[varname].data axis_num = da.get_axis_num(axis.coords[this]) # lookups for numpy.pad numpy_pad_arg = {"extend": "edge", "fill": "constant"} # args for numpy.pad pad_left = (1, 0) pad_right = (0, 1) pad_none = (0, 0) if this == "center": if ax_periodic: data_left = np.roll(data, 1, axis=axis_num) else: pad_width = [ pad_left if i == axis_num else pad_none for i in range(data.ndim) ] the_slice = tuple([ slice(0, -1) if i == axis_num else slice(None) for i in range(data.ndim) ]) data_left = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] data_right = data elif this == "left": if ax_periodic: data_left = data data_right = np.roll(data, -1, axis=axis_num) else: pad_width = [ pad_right if i == axis_num else pad_none for i in range(data.ndim) ] the_slice = tuple([ slice(1, None) if i == axis_num else slice(None) for i in range(data.ndim) ]) data_right = np.pad(data, pad_width, numpy_pad_arg[boundary])[the_slice] data_left = data data_interp = 0.5 * (data_left + data_right) data_diff = data_right - data_left # determine new dims dims = list(da.dims) dims[axis_num] = axis.coords[to] coords = {dim: ds[dim] for dim in dims} da_interp_expected = xr.DataArray(data_interp, dims=dims, coords=coords) da_diff_expected = xr.DataArray(data_diff, dims=dims, coords=coords) da_interp = axis.interp(da, to) da_diff = axis.diff(da, to) assert da_interp_expected.equals(da_interp) assert da_diff_expected.equals(da_diff) if boundary_arg is not None: if boundary == "extend": bad_boundary = "fill" elif boundary == "fill": bad_boundary = "extend" da_interp_wrong = axis.interp(da, to, boundary=bad_boundary) assert not da_interp_expected.equals(da_interp_wrong) da_diff_wrong = axis.diff(da, to, boundary=bad_boundary) assert not da_diff_expected.equals(da_diff_wrong)
def test_axis_errors(): ds = datasets["1d_left"] ds_noattr = ds.copy() del ds_noattr.XC.attrs["axis"] with pytest.raises(ValueError, message="Couldn't find a center coordinate for axis X"): x_axis = Axis(ds_noattr, "X", periodic=True) del ds_noattr.XG.attrs["axis"] with pytest.raises(ValueError, message="Couldn't find any coordinates for axis X"): x_axis = Axis(ds_noattr, "X", periodic=True) ds_chopped = ds.copy() del ds_chopped["data_g"] ds_chopped["XG"] = ds_chopped["XG"][:-3] with pytest.raises( ValueError, message="Left coordinate XG has" "incompatible length 7 (axis_len=9)", ): x_axis = Axis(ds_chopped, "X", periodic=True) ds_chopped.XG.attrs["c_grid_axis_shift"] = -0.5 with pytest.raises( ValueError, message="Right coordinate XG has" "incompatible length 7 (axis_len=9)", ): x_axis = Axis(ds_chopped, "X", periodic=True) del ds_chopped.XG.attrs["c_grid_axis_shift"] with pytest.raises( ValueError, message="Coordinate XC has invalid or " "missing c_grid_axis_shift attribute `None`", ): x_axis = Axis(ds_chopped, "X", periodic=True) ax = Axis(ds, "X", periodic=True) with pytest.raises(ValueError, message="Can't get neighbor pairs for" "the same position."): ax.interp(ds.data_c, "center") with pytest.raises(ValueError, message="This axis doesn't contain a `right` position"): ax.interp(ds.data_c, "right") with pytest.raises(ValueError, message="`boundary=fill` is not allowed " "with periodic axis X."): ax.interp(ds.data_c, "right", boundary="fill")