コード例 #1
0
def test_visualization_plot_precip_field(
    source,
    type,
    bbox,
    colorscale,
    probthr,
    title,
    colorbar,
    axis,
):

    if type == "intensity":

        field, metadata = get_precipitation_fields(0, 0, True, True, None,
                                                   source)
        field = field.squeeze()
        field, metadata = conversion.to_rainrate(field, metadata)

    elif type == "depth":

        field, metadata = get_precipitation_fields(0, 0, True, True, None,
                                                   source)
        field = field.squeeze()
        field, metadata = conversion.to_raindepth(field, metadata)

    elif type == "prob":

        field, metadata = get_precipitation_fields(0, 10, True, True, None,
                                                   source)
        field, metadata = conversion.to_rainrate(field, metadata)
        field = ensemblestats.excprob(field, probthr)

    ax = plot_precip_field(
        field,
        type=type,
        bbox=bbox,
        geodata=None,
        colorscale=colorscale,
        probthr=probthr,
        units=metadata["unit"],
        title=title,
        colorbar=colorbar,
        axis=axis,
    )
コード例 #2
0
def test_anvil_rainrate(n_cascade_levels, ar_order, ar_window_radius,
                        timesteps, min_csi):
    """Tests ANVIL nowcast using rain rate precipitation fields."""
    # inputs
    precip_input, metadata = get_precipitation_fields(
        num_prev_files=4,
        num_next_files=0,
        return_raw=False,
        metadata=True,
        upscale=2000,
    )
    precip_input = precip_input.filled()

    precip_obs = get_precipitation_fields(num_prev_files=0,
                                          num_next_files=3,
                                          return_raw=False,
                                          upscale=2000)[1:, :, :]
    precip_obs = precip_obs.filled()

    pytest.importorskip("cv2")
    oflow_method = motion.get_method("LK")
    retrieved_motion = oflow_method(precip_input)

    nowcast_method = nowcasts.get_method("anvil")

    precip_forecast = nowcast_method(
        precip_input[-(ar_order + 2):],
        retrieved_motion,
        timesteps=timesteps,
        rainrate=None,  # no R(VIL) conversion is done
        n_cascade_levels=n_cascade_levels,
        ar_order=ar_order,
        ar_window_radius=ar_window_radius,
    )

    assert precip_forecast.ndim == 3
    assert precip_forecast.shape[0] == (timesteps if isinstance(
        timesteps, int) else len(timesteps))

    result = verification.det_cat_fct(precip_forecast[-1],
                                      precip_obs[-1],
                                      thr=0.1,
                                      scores="CSI")["CSI"]
    assert result > min_csi, f"CSI={result:.2f}, required > {min_csi:.2f}"
コード例 #3
0
def test_sseps(n_ens_members, n_cascade_levels, ar_order, mask_method,
               probmatching_method, win_size, max_crps):
    """Tests SSEPS nowcast."""
    # inputs
    precip_input, metadata = get_precipitation_fields(num_prev_files=2,
                                                      num_next_files=0,
                                                      return_raw=False,
                                                      metadata=True,
                                                      upscale=2000)
    precip_input = precip_input.filled()

    precip_obs = get_precipitation_fields(num_prev_files=0,
                                          num_next_files=3,
                                          return_raw=False,
                                          upscale=2000)[1:, :, :]
    precip_obs = precip_obs.filled()

    # Retrieve motion field
    pytest.importorskip("cv2")
    oflow_method = motion.get_method("LK")
    retrieved_motion = oflow_method(precip_input)

    # Run nowcast
    nowcast_method = nowcasts.get_method("sseps")

    precip_forecast = nowcast_method(
        precip_input,
        metadata,
        retrieved_motion,
        win_size=win_size,
        n_timesteps=3,
        n_ens_members=n_ens_members,
        n_cascade_levels=n_cascade_levels,
        ar_order=ar_order,
        seed=42,
        mask_method=mask_method,
        probmatching_method=probmatching_method,
    )

    # result
    crps = verification.probscores.CRPS(precip_forecast[-1], precip_obs[-1])
    print(f"got CRPS={crps:.1f}, required < {max_crps:.1f}")
    assert crps < max_crps
コード例 #4
0
 def test_sal_translation(self, converter, thr_factor):
     precip, metadata = get_precipitation_fields(
         num_prev_files=0, log_transform=False, metadata=True
     )
     precip, metadata = converter(precip.filled(np.nan), metadata)
     precip_translated = np.roll(precip, 10, axis=0)
     result = sal(precip, precip_translated, thr_factor)
     assert np.allclose(result[0], 0)
     assert np.allclose(result[1], 0)
     assert not np.allclose(result[2], 0)
コード例 #5
0
 def test_sal_same_image(self, converter, thr_factor):
     """Test the SAL verification method."""
     precip, metadata = get_precipitation_fields(
         num_prev_files=0, log_transform=False, metadata=True
     )
     precip, metadata = converter(precip.filled(np.nan), metadata)
     result = sal(precip, precip, thr_factor)
     assert isinstance(result, tuple)
     assert len(result) == 3
     assert np.allclose(result, [0, 0, 0])
コード例 #6
0
def test_visualization_plot_precip_field(source, plot_type, bbox, colorscale,
                                         probthr, title, colorbar, axis):
    if plot_type == "intensity":

        field, metadata = get_precipitation_fields(0, 0, True, True, None,
                                                   source)
        field = field.squeeze()
        field, metadata = conversion.to_rainrate(field, metadata)

    elif plot_type == "depth":

        field, metadata = get_precipitation_fields(0, 0, True, True, None,
                                                   source)
        field = field.squeeze()
        field, metadata = conversion.to_raindepth(field, metadata)

    elif plot_type == "prob":

        field, metadata = get_precipitation_fields(0, 10, True, True, None,
                                                   source)
        field, metadata = conversion.to_rainrate(field, metadata)
        field = ensemblestats.excprob(field, probthr)

    field_orig = field.copy()
    ax = plot_precip_field(
        field.copy(),
        ptype=plot_type,
        bbox=bbox,
        geodata=None,
        colorscale=colorscale,
        probthr=probthr,
        units=metadata["unit"],
        title=title,
        colorbar=colorbar,
        axis=axis,
    )

    # Check that plot_precip_field does not modify the input data
    field_orig = np.ma.masked_invalid(field_orig)
    field_orig.data[field_orig.mask] = -100
    field = np.ma.masked_invalid(field)
    field.data[field.mask] = -100
    assert np.array_equal(field_orig.data, field.data)
コード例 #7
0
def test_feature_tstorm_detection(source, output_feat, dry_input,
                                  max_num_features):

    pytest.importorskip("skimage")
    pytest.importorskip("pandas")

    if not dry_input:
        input, metadata = get_precipitation_fields(0, 0, True, True, None,
                                                   source)
        input = input.squeeze()
        input, __ = to_reflectivity(input, metadata)
    else:
        input = np.zeros((50, 50))

    time = "000"
    output = detection(input,
                       time=time,
                       output_feat=output_feat,
                       max_num_features=max_num_features)

    if output_feat:
        assert isinstance(output, np.ndarray)
        assert output.ndim == 2
        assert output.shape[1] == 2
        if max_num_features is not None:
            assert output.shape[0] <= max_num_features
    else:
        assert isinstance(output, tuple)
        assert len(output) == 2
        assert isinstance(output[0], DataFrame)
        assert isinstance(output[1], np.ndarray)
        if max_num_features is not None:
            assert output[0].shape[0] <= max_num_features
        assert output[0].shape[1] == 9
        assert list(output[0].columns) == [
            "ID",
            "time",
            "x",
            "y",
            "cen_x",
            "cen_y",
            "max_ref",
            "cont",
            "area",
        ]
        assert (output[0].time == time).all()
        assert output[1].ndim == 2
        assert output[1].shape == input.shape
        if not dry_input:
            assert output[0].shape[0] > 0
            assert sorted(list(output[0].ID)) == sorted(
                list(np.unique(output[1]))[1:])
        else:
            assert output[0].shape[0] == 0
            assert output[1].sum() == 0
コード例 #8
0
ファイル: test_exporters.py プロジェクト: pySTEPS/pysteps
def test_io_export_netcdf_one_member_one_time_step():
    """
    Test the export netcdf.
    Also, test that the exported file can be read by the importer.
    """

    pytest.importorskip("pyproj")

    precip, metadata = get_precipitation_fields(return_raw=True,
                                                metadata=True,
                                                source="fmi")
    precip = precip.squeeze()

    invalid_mask = get_invalid_mask(precip)

    # save it back to disk
    with tempfile.TemporaryDirectory() as outpath:
        outfnprefix = "test_netcdf_out"
        file_path = os.path.join(outpath, outfnprefix + ".nc")
        startdate = metadata["timestamps"][0]
        timestep = metadata["accutime"]
        n_timesteps = 1
        shape = precip.shape
        exporter = initialize_forecast_exporter_netcdf(
            outpath,
            outfnprefix,
            startdate,
            timestep,
            n_timesteps,
            shape,
            metadata,
            n_ens_members=1,
        )
        export_forecast_dataset(precip[np.newaxis, :], exporter)
        close_forecast_files(exporter)

        # assert if netcdf file was saved and file size is not zero
        assert os.path.exists(file_path) and os.path.getsize(file_path) > 0

        # Test that the file can be read by the nowcast_importer
        output_file_path = os.path.join(outpath, f"{outfnprefix}.nc")

        precip_new, _ = import_netcdf_pysteps(output_file_path)

        assert_array_almost_equal(precip, precip_new.data)
        assert precip_new.dtype == "single"

        precip_new, _ = import_netcdf_pysteps(output_file_path, dtype="double")
        assert_array_almost_equal(precip, precip_new.data)
        assert precip_new.dtype == "double"

        precip_new, _ = import_netcdf_pysteps(output_file_path, fillna=-1000)
        new_invalid_mask = precip_new == -1000
        assert (new_invalid_mask == invalid_mask).all()
コード例 #9
0
def test_feature(method, max_num_features):
    input_field, _ = get_precipitation_fields(0, 0, True, True, None, "mch")

    detector = feature.get_method(method)

    kwargs = {"max_num_features": max_num_features}
    output = detector(input_field.squeeze(), **kwargs)

    assert isinstance(output, np.ndarray)
    assert output.ndim == 2
    assert output.shape[0] > 0
    if max_num_features is not None:
        assert output.shape[0] <= max_num_features
    assert output.shape[1] == 2
コード例 #10
0
 def test_sal_zeros(self, converter, thr_factor):
     """Test the SAL verification method."""
     precip, metadata = get_precipitation_fields(
         num_prev_files=0, log_transform=False, metadata=True
     )
     precip, metadata = converter(precip.filled(np.nan), metadata)
     result = sal(precip * 0, precip * 0, thr_factor)
     assert np.isnan(result).all()
     result = sal(precip * 0, precip, thr_factor)
     assert result[:2] == (-2, -2)
     assert np.isnan(result[2])
     result = sal(precip, precip * 0, thr_factor)
     assert result[:2] == (2, 2)
     assert np.isnan(result[2])
コード例 #11
0
ファイル: test_plt_cartopy.py プロジェクト: AFansGH/pysteps
def test_visualization_plot_precip_field(source, map, drawlonlatlines, lw):

    field, metadata = get_precipitation_fields(0, 0, True, True, None, source)
    field = field.squeeze()
    field, __ = to_rainrate(field, metadata)

    ax = plot_precip_field(
        field,
        type="intensity",
        geodata=metadata,
        map=map,
        drawlonlatlines=drawlonlatlines,
        lw=lw,
    )
コード例 #12
0
ファイル: test_plt_cartopy.py プロジェクト: wolfidan/pysteps
def test_visualization_plot_precip_field(source, map_kwargs, pass_geodata):

    field, metadata = get_precipitation_fields(0, 0, True, True, None, source)
    field = field.squeeze()
    field, __ = to_rainrate(field, metadata)

    if not pass_geodata:
        metadata = None

    ax = plot_precip_field(
        field,
        type="intensity",
        geodata=metadata,
        map_kwargs=map_kwargs,
    )
コード例 #13
0
def test_visualization_motionfields_quiver(
    source,
    axis,
    step,
    quiver_kwargs,
    map_kwargs,
    upscale,
    pass_geodata,
):

    if source is not None:
        fields, geodata = get_precipitation_fields(0, 2, False, True, upscale,
                                                   source)
        if not pass_geodata:
            geodata = None
        ax = plot_precip_field(fields[-1], geodata=geodata)
        oflow_method = motion.get_method("LK")
        UV = oflow_method(fields)

    else:
        shape = (100, 100)
        geodata = None
        ax = None
        u = np.ones(shape[1]) * shape[0]
        v = np.arange(0, shape[0])
        U, V = np.meshgrid(u, v)
        UV = np.concatenate([U[None, :], V[None, :]])

    UV_orig = UV.copy()
    __ = quiver(
        UV,
        ax,
        geodata,
        axis,
        step,
        quiver_kwargs,
        map_kwargs=map_kwargs,
    )

    # Check that quiver does not modify the input data
    assert np.array_equal(UV, UV_orig)
コード例 #14
0
def test_visualization_motionfields_streamplot(
    source, axis, streamplot_kwargs, map_kwargs, upscale, pass_geodata
):

    if source is not None:
        fields, geodata = get_precipitation_fields(0, 2, False, True, upscale, source)
        if not pass_geodata:
            pass_geodata = None
        ax = plot_precip_field(fields[-1], geodata=geodata)
        oflow_method = motion.get_method("LK")
        UV = oflow_method(fields)

    else:
        shape = (100, 100)
        geodata = None
        ax = None
        u = np.ones(shape[1]) * shape[0]
        v = np.arange(0, shape[0])
        U, V = np.meshgrid(u, v)
        UV = np.concatenate([U[None, :], V[None, :]])

    __ = streamplot(UV, ax, geodata, axis, streamplot_kwargs, map_kwargs=map_kwargs,)
コード例 #15
0
def test_visualization_motionfields_quiver(
    source,
    map,
    drawlonlatlines,
    lw,
    axis,
    step,
    quiver_kwargs,
    upscale,
):

    if map == "cartopy":
        pytest.importorskip("cartopy")
    elif map == "basemap":
        pytest.importorskip("basemap")

    if source is not None:
        fields, geodata = get_precipitation_fields(0, 2, False, True, upscale,
                                                   source)
        ax = plot_precip_field(
            fields[-1],
            map=map,
            geodata=geodata,
        )
        oflow_method = motion.get_method("LK")
        UV = oflow_method(fields)

    else:
        shape = (100, 100)
        geodata = None
        ax = None
        u = np.ones(shape[1]) * shape[0]
        v = np.arange(0, shape[0])
        U, V = np.meshgrid(u, v)
        UV = np.concatenate([U[None, :], V[None, :]])

    __ = quiver(UV, ax, map, geodata, drawlonlatlines, lw, axis, step,
                quiver_kwargs)
コード例 #16
0
def test_tracking_tdating_dating(source, dry_input):

    pytest.importorskip("skimage")
    pytest.importorskip("pandas")

    if not dry_input:
        input, metadata = get_precipitation_fields(0, 2, True, True, 4000,
                                                   source)
        input, __ = to_reflectivity(input, metadata)
    else:
        input = np.zeros((3, 50, 50))
        metadata = {"timestamps": ["00", "01", "02"]}

    timelist = metadata["timestamps"]

    output = dating(input, timelist, mintrack=1)

    # Check output format
    assert isinstance(output, tuple)
    assert len(output) == 3
    assert isinstance(output[0], list)
    assert isinstance(output[1], list)
    assert isinstance(output[2], list)
    assert len(output[1]) == input.shape[0]
    assert len(output[2]) == input.shape[0]
    assert isinstance(output[1][0], DataFrame)
    assert isinstance(output[2][0], np.ndarray)
    assert output[1][0].shape[1] == 8
    assert output[2][0].shape == input.shape[1:]
    if not dry_input:
        assert len(output[0]) > 0
        assert isinstance(output[0][0], DataFrame)
        assert output[0][0].shape[1] == 8
    else:
        assert len(output[0]) == 0
        assert output[1][0].shape[0] == 0
        assert output[2][0].sum() == 0
コード例 #17
0
This tests check that the retrieved motion fields are within reasonable values.
Also, they will fail if any modification on the code decrease the quality of
the retrieval.
"""

import numpy as np
import pytest
from scipy.ndimage import uniform_filter

import pysteps as stp
from pysteps import motion
from pysteps.motion.vet import morph
from pysteps.tests.helpers import get_precipitation_fields, smart_assert

reference_field = get_precipitation_fields(num_prev_files=0)


def _create_motion_field(input_precip, motion_type):
    """
    Create idealized motion fields to be applied to the reference image.

    Parameters
    ----------

    input_precip: numpy array (lat, lon)

    motion_type : str
        The supported motion fields are:

            - linear_x: (u=2, v=0)
コード例 #18
0
# -*- coding: utf-8 -*-

import pytest

from pysteps import downscaling
from pysteps.tests.helpers import get_precipitation_fields
from pysteps.utils import aggregate_fields_space, square_domain

# load and preprocess input field
precip, metadata = get_precipitation_fields(
    num_prev_files=0,
    num_next_files=0,
    return_raw=False,
    metadata=True,
)
precip = precip.filled()
precip, metadata = square_domain(precip, metadata, "crop")

rainfarm_arg_names = (
    "alpha",
    "ds_factor",
    "threshold",
    "return_alpha",
)

rainfarm_arg_values = [
    (1.0, 1, 0, False),
    (1, 2, 0, False),
    (1, 4, 0, False),
]
コード例 #19
0
def test_steps_callback(tmp_path):
    """Test STEPS callback functionality to export the output as a netcdf."""
    n_ens_members = 2
    n_timesteps = 3

    precip_input, metadata = get_precipitation_fields(
        num_prev_files=2,
        num_next_files=0,
        return_raw=False,
        metadata=True,
        upscale=2000,
    )
    precip_input = precip_input.filled()
    field_shape = (precip_input.shape[1], precip_input.shape[2])
    startdate = metadata["timestamps"][-1]
    timestep = metadata["accutime"]

    motion_field = np.zeros((2, *field_shape))

    exporter = io.initialize_forecast_exporter_netcdf(
        outpath=tmp_path.as_posix(),
        outfnprefix="test_steps",
        startdate=startdate,
        timestep=timestep,
        n_timesteps=n_timesteps,
        shape=field_shape,
        n_ens_members=n_ens_members,
        metadata=metadata,
        incremental="timestep",
    )

    def callback(array):
        return io.export_forecast_dataset(array, exporter)

    precip_output = nowcasts.get_method("steps")(
        precip_input,
        motion_field,
        timesteps=n_timesteps,
        R_thr=metadata["threshold"],
        kmperpixel=2.0,
        timestep=timestep,
        seed=42,
        n_ens_members=n_ens_members,
        vel_pert_method=None,
        callback=callback,
        return_output=True,
    )
    io.close_forecast_files(exporter)

    # assert that netcdf exists and its size is not zero
    tmp_file = os.path.join(tmp_path, "test_steps.nc")
    assert os.path.exists(tmp_file) and os.path.getsize(tmp_file) > 0

    # assert that the file can be read by the nowcast importer
    precip_netcdf, metadata_netcdf = io.import_netcdf_pysteps(tmp_file, dtype="float64")

    # assert that the dimensionality of the array is as expected
    assert precip_netcdf.ndim == 4, "Wrong number of dimensions"
    assert precip_netcdf.shape[0] == n_ens_members, "Wrong ensemble size"
    assert precip_netcdf.shape[1] == n_timesteps, "Wrong number of lead times"
    assert precip_netcdf.shape[2:] == field_shape, "Wrong field shape"

    # assert that the saved output is the same as the original output
    assert np.allclose(
        precip_netcdf, precip_output, equal_nan=True
    ), "Wrong output values"

    # assert that leadtimes and timestamps are as expected
    td = timedelta(minutes=timestep)
    leadtimes = [(i + 1) * timestep for i in range(n_timesteps)]
    timestamps = [startdate + (i + 1) * td for i in range(n_timesteps)]
    assert (metadata_netcdf["leadtimes"] == leadtimes).all(), "Wrong leadtimes"
    assert (metadata_netcdf["timestamps"] == timestamps).all(), "Wrong timestamps"
コード例 #20
0
def test_lk(
    lk_kwargs,
    fd_method,
    dense,
    nr_std_outlier,
    k_outlier,
    size_opening,
    decl_scale,
    verbose,
):
    """Tests Lucas-Kanade optical flow."""

    pytest.importorskip("cv2")
    if fd_method == "blob":
        pytest.importorskip("skimage")
    if fd_method == "tstorm":
        pytest.importorskip("skimage")
        pytest.importorskip("pandas")

    # inputs
    precip, metadata = get_precipitation_fields(
        num_prev_files=2,
        num_next_files=0,
        return_raw=False,
        metadata=True,
        upscale=2000,
    )
    precip = precip.filled()

    # Retrieve motion field
    oflow_method = motion.get_method("LK")
    output = oflow_method(
        precip,
        lk_kwargs=lk_kwargs,
        fd_method=fd_method,
        dense=dense,
        nr_std_outlier=nr_std_outlier,
        k_outlier=k_outlier,
        size_opening=size_opening,
        decl_scale=decl_scale,
        verbose=verbose,
    )

    # Check format of ouput
    if dense:
        assert isinstance(output, np.ndarray)
        assert output.ndim == 3
        assert output.shape[0] == 2
        assert output.shape[1:] == precip[0].shape
        if nr_std_outlier == 0:
            assert output.sum() == 0
    else:
        assert isinstance(output, tuple)
        assert len(output) == 2
        assert isinstance(output[0], np.ndarray)
        assert isinstance(output[1], np.ndarray)
        assert output[0].ndim == 2
        assert output[1].ndim == 2
        assert output[0].shape[1] == 2
        assert output[1].shape[1] == 2
        assert output[0].shape[0] == output[1].shape[0]
        if nr_std_outlier == 0:
            assert output[0].shape[0] == 0
            assert output[1].shape[0] == 0
コード例 #21
0
# -*- coding: utf-8 -*-

import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal

from pysteps.tests.helpers import get_precipitation_fields
from pysteps.verification import ensscores

precip = get_precipitation_fields(num_next_files=10, return_raw=True)
np.random.seed(42)

# rankhist
test_data = [
    (precip[:10], precip[-1], None, True, 11),
    (precip[:10], precip[-1], None, False, 11),
]


@pytest.mark.parametrize("X_f, X_o, X_min, normalize, expected", test_data)
def test_rankhist_size(X_f, X_o, X_min, normalize, expected):
    """Test the rankhist."""
    assert_array_almost_equal(
        ensscores.rankhist(X_f, X_o, X_min, normalize).size, expected
    )


# ensemble_skill
test_data = [
    (
        precip[:10],
コード例 #22
0
# -*- coding: utf-8 -*-

import pytest
from numpy.testing import assert_array_almost_equal

from pysteps.tests.helpers import get_precipitation_fields
from pysteps.verification import spatialscores

try:
    import pywt

    PYWT_IMPORTED = True
except ImportError:
    PYWT_IMPORTED = False

R = get_precipitation_fields(num_prev_files=1, return_raw=True)
test_data = [(R[0], R[1], "FSS", [1], [10], None, 0.85161531)]
if PYWT_IMPORTED:
    test_data.append((R[0], R[1], "BMSE", [1], None, "Haar", 0.99989651))


@pytest.mark.parametrize("X_f, X_o, name, thrs, scales, wavelet, expected",
                         test_data)
def test_intensity_scale(X_f, X_o, name, thrs, scales, wavelet, expected):
    """Test the intensity_scale."""
    assert_array_almost_equal(
        spatialscores.intensity_scale(X_f, X_o, name, thrs, scales,
                                      wavelet)[0][0],
        expected,
    )
コード例 #23
0
ファイル: test_plt_animate.py プロジェクト: pySTEPS/pysteps
# -*- coding: utf-8 -*-

import os

import numpy as np
import pytest
from unittest.mock import patch

from pysteps.tests.helpers import get_precipitation_fields
from pysteps.visualization.animations import animate

PRECIP, METADATA = get_precipitation_fields(
    num_prev_files=2,
    num_next_files=0,
    return_raw=True,
    metadata=True,
    upscale=2000,
)

VALID_ARGS = (
    ([PRECIP], {}),
    ([PRECIP], {
        "title": "title"
    }),
    ([PRECIP], {
        "timestamps_obs": METADATA["timestamps"]
    }),
    ([PRECIP], {
        "geodata": METADATA,
        "map_kwargs": {
            "plot_map": None
コード例 #24
0
import numpy as np

from pysteps.noise import fftgenerators
from pysteps.tests.helpers import get_precipitation_fields

PRECIP = get_precipitation_fields(
    num_prev_files=0,
    num_next_files=0,
    return_raw=False,
    metadata=False,
    upscale=2000,
)
PRECIP = PRECIP.filled()


def test_noise_param_2d_fft_filter():

    fft_filter = fftgenerators.initialize_param_2d_fft_filter(PRECIP)

    assert isinstance(fft_filter, dict)
    assert all([
        key in fft_filter for key in ["field", "input_shape", "model", "pars"]
    ])

    out = fftgenerators.generate_noise_2d_fft_filter(fft_filter)

    assert isinstance(out, np.ndarray)
    assert out.shape == PRECIP.shape


def test_noise_nonparam_2d_fft_filter():