Ejemplo n.º 1
0
def test_regrid():

    # use conservative regridding as an example,
    # since it is the most well-tested studied one in papers

    # TODO: possible to break this long test into smaller tests?
    # not easy due to strong dependencies.

    grid_in = esmf_grid(lon_in.T, lat_in.T)
    grid_out = esmf_grid(lon_out.T, lat_out.T)

    # no corner info yet, should not be able to use conservative
    with pytest.raises(ValueError):
        esmf_regrid_build(grid_in, grid_out, 'conservative')

    # now add corners
    add_corner(grid_in, lon_b_in.T, lat_b_in.T)
    add_corner(grid_out, lon_b_out.T, lat_b_out.T)

    # also write to file for scipy regridding
    filename = 'test_weights.nc'
    if os.path.exists(filename):
        os.remove(filename)
    regrid = esmf_regrid_build(grid_in,
                               grid_out,
                               'conservative',
                               filename=filename)
    assert regrid.regrid_method is ESMF.RegridMethod.CONSERVE

    # apply regridding using ESMPy's native method
    data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T

    rel_err = (data_out_esmpy - data_ref) / data_ref  # relative error
    assert np.max(np.abs(rel_err)) < 0.05

    # apply regridding using scipy
    weights = read_weights(filename, lon_in.size, lon_out.size)
    shape_in = lon_in.shape
    shape_out = lon_out.shape
    data_out_scipy = apply_weights(weights, data_in, shape_in, shape_out)

    # must be exactly the same as esmpy's result!
    # TODO: this fails once but I cannot replicate it.
    # Maybe assert_equal is too strict for scipy vs esmpy comparision
    assert_equal(data_out_scipy, data_out_esmpy)

    # finally, test broadcasting with scipy
    # TODO: need to test broadcasting with ESMPy backend?
    # We only use Scipy in frontend, and ESMPy is just for backend benchmark
    # However, it is useful to compare performance and show scipy is 3x faster
    data4D_out = apply_weights(weights, data4D_in, shape_in, shape_out)

    # data over broadcasting dimensions should agree
    assert_almost_equal(data4D_in.mean(axis=(2, 3)),
                        data4D_out.mean(axis=(2, 3)),
                        decimal=10)

    # clean-up
    esmf_regrid_finalize(regrid)
    os.remove(filename)
Ejemplo n.º 2
0
def test_read_weights(tmp_path):
    fn = tmp_path / "weights.nc"

    grid_in = esmf_grid(lon_in.T, lat_in.T)
    grid_out = esmf_grid(lon_out.T, lat_out.T)

    regrid_memory = esmf_regrid_build(grid_in, grid_out, method='bilinear')
    esmf_regrid_build(grid_in, grid_out, method='bilinear', filename=str(fn))

    w = regrid_memory.get_weights_dict(deep_copy=True)
    sm = read_weights(w, lon_in.size, lon_out.size)

    # Test Path and string to netCDF file against weights dictionary
    np.testing.assert_array_equal(read_weights(fn, lon_in.size, lon_out.size).todense(), sm.todense())
    np.testing.assert_array_equal(read_weights(str(fn), lon_in.size, lon_out.size).todense(), sm.todense())

    # Test xr.Dataset
    np.testing.assert_array_equal(read_weights(xr.open_dataset(fn), lon_in.size, lon_out.size).todense(), sm.todense())

    # Test COO matrix
    np.testing.assert_array_equal(read_weights(sm, lon_in.size, lon_out.size).todense(), sm.todense())

    # Test failures
    with pytest.raises(IOError):
        read_weights(tmp_path / "wrong_file.nc", lon_in.size, lon_out.size)

    with pytest.raises(ValueError):
        read_weights({}, lon_in.size, lon_out.size)

    with pytest.raises(ValueError):
        ds = xr.open_dataset(fn)
        read_weights(ds.drop_vars("col"), lon_in.size, lon_out.size)
Ejemplo n.º 3
0
def test_esmf_extrapolation():

    grid_in = esmf_grid(lon_in.T, lat_in.T)
    grid_out = esmf_grid(lon_out.T, lat_out.T)

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
    data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
    # without extrapolation, the first and last lines/columns = 0
    assert data_out_esmpy[0, 0] == 0

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear',
                               extrap_method='inverse_dist',
                               extrap_num_src_pnts=3,
                               extrap_dist_exponent=1)
    data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T
    # the 3 closest points in data_in are 2.010, 2.005, and 1.992. The result should be roughly equal to 2.0
    assert np.round(data_out_esmpy[0, 0], 1) == 2.0
Ejemplo n.º 4
0
def test_esmf_locstream():
    lon = np.arange(5)
    lat = np.arange(5)

    ls = LocStream.from_xarray(lon, lat)
    assert isinstance(ls, ESMF.LocStream)

    lon2d, lat2d = np.meshgrid(lon, lat)
    with pytest.raises(ValueError):
        ls = LocStream.from_xarray(lon2d, lat2d)
    with pytest.raises(ValueError):
        ls = LocStream.from_xarray(lon, lat2d)
    with pytest.raises(ValueError):
        ls = LocStream.from_xarray(lon2d, lat)

    grid_in = Grid.from_xarray(lon_in.T, lat_in.T, periodic=True)
    esmf_regrid_build(grid_in, ls, 'bilinear')
    esmf_regrid_build(ls, grid_in, 'nearest_s2d')
Ejemplo n.º 5
0
def test_to_netcdf(tmp_path):
    from xesmf.backend import Grid, esmf_regrid_build

    # Let the frontend write the weights to disk
    xfn = tmp_path / 'ESMF_weights.nc'
    method = 'bilinear'
    regridder = xe.Regridder(ds_in, ds_out, method)
    regridder.to_netcdf(filename=xfn)

    grid_in = Grid.from_xarray(ds_in['lon'].values.T, ds_in['lat'].values.T)
    grid_out = Grid.from_xarray(ds_out['lon'].values.T, ds_out['lat'].values.T)

    # Let the ESMPy backend write the weights to disk
    efn = tmp_path / 'weights.nc'
    esmf_regrid_build(grid_in, grid_out, method=method, filename=str(efn))

    x = xr.open_dataset(xfn)
    e = xr.open_dataset(efn)
    xr.testing.assert_identical(x, e)
Ejemplo n.º 6
0
    def _write_weight_file(self):
        if os.path.exists(self.filename):
            if self.reuse_weights:
                return  # do not compute it again, just read it
            else:
                os.remove(self.filename)

        regrid = esmf_regrid_build(self._grid_in,
                                   self._grid_out,
                                   self.method,
                                   filename=self.filename)
        esmf_regrid_finalize(regrid)  # only need weights, not regrid object
Ejemplo n.º 7
0
def test_esmf_build_bilinear():

    grid_in = esmf_grid(lon_in.T, lat_in.T)
    grid_out = esmf_grid(lon_out.T, lat_out.T)

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
    assert regrid.unmapped_action == 1
    assert regrid.regrid_method == 0

    # they should share the same memory
    regrid.srcfield.grid is grid_in
    regrid.dstfield.grid is grid_out

    esmf_regrid_finalize(regrid)
Ejemplo n.º 8
0
def test_esmf_build_bilinear():

    grid_in = Grid.from_xarray(lon_in.T, lat_in.T)
    grid_out = Grid.from_xarray(lon_out.T, lat_out.T)

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
    assert regrid.unmapped_action is ESMF.UnmappedAction.IGNORE
    assert regrid.regrid_method is ESMF.RegridMethod.BILINEAR

    # they should share the same memory
    regrid.srcfield.grid is grid_in
    regrid.dstfield.grid is grid_out

    esmf_regrid_finalize(regrid)
Ejemplo n.º 9
0
def test_regrid_periodic_correct():

    # only need to specific periodic for input grid
    grid_in = Grid.from_xarray(lon_in.T, lat_in.T, periodic=True)
    grid_out = Grid.from_xarray(lon_out.T, lat_out.T)

    assert grid_in.num_peri_dims == 1
    assert grid_in.periodic_dim == 0  # the first axis, longitude

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
    data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T

    rel_err = (data_out_esmpy - data_ref) / data_ref  # relative error
    assert np.max(np.abs(rel_err)) < 0.065
    # clean-up
    esmf_regrid_finalize(regrid)
Ejemplo n.º 10
0
def test_regrid_periodic_wrong():

    # not using periodic grid
    grid_in = esmf_grid(lon_in.T, lat_in.T)
    grid_out = esmf_grid(lon_out.T, lat_out.T)

    assert grid_in.num_peri_dims == 0
    assert grid_in.periodic_dim is None

    regrid = esmf_regrid_build(grid_in, grid_out, 'bilinear')
    data_out_esmpy = esmf_regrid_apply(regrid, data_in.T).T

    rel_err = (data_out_esmpy - data_ref)/data_ref  # relative error
    assert np.max(np.abs(rel_err)) == 1.0  # some data will be missing

    # clean-up
    esmf_regrid_finalize(regrid)