def test_rank_hist_tied():
    """Test that rank_histogram handles tied ranks."""
    a = xr.DataArray(np.zeros(100), dims="a")
    b = xr.DataArray(np.zeros((100, 10)), dims=["a", "member"])
    rh = rank_histogram(a, b)
    assert rh.min() > 3
    assert rh.max() < 30
def test_rank_histogram_sum(o, f_prob, dim, obj):
    """Test that the number of samples in the rank histogram is correct"""
    if "ds" in obj:
        name = "var"
        o = o.to_dataset(name=name)
        f_prob = f_prob.to_dataset(name=name)
    if "chunked" in obj:
        o = o.chunk()
        f_prob = f_prob.chunk()
    if dim == []:
        with pytest.raises(ValueError):
            rank_histogram(o, f_prob, dim=dim)
    else:
        rank_hist = rank_histogram(o, f_prob, dim=dim)
        if "ds" in obj:
            rank_hist = rank_hist[name]
            o = o[name]
        assert_allclose(rank_hist.sum(), o.count())
def test_rank_histogram_sum(o, f_prob, dim, chunk_bool, input_type):
    """Test that the number of samples in the rank histogram is correct"""
    o, f_prob = modify_inputs(o, f_prob, input_type, chunk_bool)
    if dim == []:
        with pytest.raises(ValueError):
            rank_histogram(o, f_prob, dim=dim)
    else:
        rank_hist = rank_histogram(o, f_prob, dim=dim)
        if "Dataset" in input_type:
            rank_hist = rank_hist[list(o.data_vars)[0]]
            o = o[list(o.data_vars)[0]]
            assert_allclose(rank_hist.sum(), o.count())
        assert_allclose(rank_hist.sum(), o.count())
        # test that returns chunks
        assert_chunk(rank_hist, chunk_bool)
        # test that attributes are kept # TODO: add
        # assert_keep_attrs(rank_hist, o, keep_attrs)
        # test that input types equal output types
        assign_type_input_output(rank_hist, o)
def test_rank_histogram_accessor(o, f_prob, outer_bool):
    actual = rank_histogram(o, f_prob)
    ds = xr.Dataset()
    ds["o"] = o
    ds["f_prob"] = f_prob
    if outer_bool:
        ds = ds.drop_vars("f_prob")
        expected = ds.xs.rank_histogram("o", f_prob)
    else:
        expected = ds.xs.rank_histogram("o", "f_prob")
    assert_allclose(actual, expected)
def test_rank_histogram_values(o, f_prob):
    """Test values in extreme cases that observations \
        all smaller/larger than forecasts"""
    assert rank_histogram((f_prob.min() - 1) + 0 * o, f_prob)[0] == o.size
    assert rank_histogram((f_prob.max() + 1) + 0 * o, f_prob)[-1] == o.size
def test_rank_histogram_dask(o_dask, f_prob_dask):
    """Test that rank_histogram returns dask array if provided dask array"""
    actual = rank_histogram(o_dask, f_prob_dask)
    assert actual.chunks is not None
def test_rank_histogram_values(o, f_prob):
    """Test values in extreme cases that observations \
        all smaller/larger than forecasts"""
    assert rank_histogram(o - 10, f_prob)[0] == o.size
    assert rank_histogram(o + 10, f_prob)[-1] == o.size