def test_discrimination_sum(o, f_prob, dim, chunk_bool, input_type):
    """Test that the probabilities sum to 1"""
    o, f_prob = modify_inputs(o, f_prob, input_type, chunk_bool)
    if dim == []:
        with pytest.raises(ValueError):
            discrimination(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
    else:
        disc = discrimination(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
        # test that input types equal output types
        assign_type_input_output(disc, o)
        if "Dataset" in input_type:
            disc = disc[list(o.data_vars)[0]]
        # dont understand the error message here, but it appeared
        with suppress_warnings("invalid value encountered in true_divide"):
            hist_event_sum = (
                disc.sel(event=True).sum("forecast_probability", skipna=False).values
            )
            hist_no_event_sum = (
                disc.sel(event=False).sum("forecast_probability", skipna=False).values
            )
        # Note, xarray's assert_allclose is already imported but won't compare to scalar
        assert np.allclose(hist_event_sum[~np.isnan(hist_event_sum)], 1)
        assert np.allclose(hist_no_event_sum[~np.isnan(hist_no_event_sum)], 1)

        # test that returns chunks
        assert_chunk(disc, chunk_bool)
def test_discrimination_perfect_values(o):
    """Test values for perfect forecast"""
    f = xr.concat(10 * [o], dim="member")
    disc = discrimination(o > 0.5, (f > 0.5).mean("member"))
    assert np.allclose(disc.sel(event=True)[-1], 1)
    assert np.allclose(disc.sel(event=True)[:-1], 0)
    assert np.allclose(disc.sel(event=False)[0], 1)
    assert np.allclose(disc.sel(event=False)[1:], 0)
def test_discrimination_accessor(o, f_prob, threshold, outer_bool):
    actual = discrimination(o > threshold, (f_prob > threshold).mean("member"))
    ds = xr.Dataset()
    ds["o"] = o > threshold
    ds["f_prob"] = (f_prob > threshold).mean("member")
    if outer_bool:
        ds = ds.drop_vars("f_prob")
        expected = ds.xs.discrimination("o",
                                        (f_prob > threshold).mean("member"))
    else:

        expected = ds.xs.discrimination("o", "f_prob")
    assert_allclose(actual, expected)
def test_discrimination_sum(o, f_prob, dim, obj):
    """Test that the probabilities sum to 1"""
    if "ds" in obj:
        name = "var"
        o = o.to_dataset(name=name)
        f_prob = f_prob.to_dataset(name=name)
    if "chunked" in obj:
        o = o.chunk()
        f_prob = f_prob.chunk()
    if dim == []:
        with pytest.raises(ValueError):
            discrimination(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
    else:
        disc = discrimination(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
        if "ds" in obj:
            disc = disc[name]
        hist_event_sum = (disc.sel(event=True).sum("forecast_probability",
                                                   skipna=False).values)
        hist_no_event_sum = (disc.sel(event=False).sum("forecast_probability",
                                                       skipna=False).values)
        # Note, xarray's assert_allclose is already imported but won't compare to scalar
        assert np.allclose(hist_event_sum[~np.isnan(hist_event_sum)], 1)
        assert np.allclose(hist_no_event_sum[~np.isnan(hist_no_event_sum)], 1)
def test_discrimination_dask(o_dask, f_prob_dask):
    """Test that discrimination returns dask array if provided dask array"""
    disc = discrimination(o_dask > 0.5, (f_prob_dask > 0.5).mean("member"))
    assert disc.chunks is not None