def test_reliability(o, f_prob, dim, obj):
    """Test that reliability object can be generated"""
    if "ds" in obj:
        name = "var"
        o = o.to_dataset(name=name)
        f_prob = f_prob.to_dataset(name=name)
    if "chunked" in obj:
        o = o.chunk()
        f_prob = f_prob.chunk()
    if dim == []:
        with pytest.raises(ValueError):
            reliability(o > 0.5, (f_prob > 0.5).mean("member"), dim)
    else:
        reliability(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
def test_reliability_api_and_inputs(o, f_prob, dim, chunk_bool, input_type):
    """Test that reliability keeps chunking and input types."""
    o, f_prob = modify_inputs(o, f_prob, input_type, chunk_bool)
    if dim == []:
        with pytest.raises(ValueError):
            reliability(o > 0.5, (f_prob > 0.5).mean("member"), dim)
    else:
        actual = reliability(o > 0.5, (f_prob > 0.5).mean("member"), dim=dim)
        # test that returns chunks
        assert_chunk(actual, chunk_bool)
        # test that attributes are kept
        # assert_keep_attrs(actual, o, keep_attrs) # TODO: implement
        # test that input types equal output types
        assign_type_input_output(actual, o)
def test_reliability_perfect_values(o):
    """Test values for perfect forecast"""
    f_prob = xr.concat(10 * [o], dim="member")
    actual = reliability(o > 0.5, (f_prob > 0.5).mean("member"))
    expected_true_samples = (o > 0.5).sum()
    expected_false_samples = (o <= 0.5).sum()
    assert np.allclose(actual[0], 0)
    assert np.allclose(actual[-1], 1)
    assert np.allclose(actual["samples"][0], expected_false_samples)
    assert np.allclose(actual["samples"][-1], expected_true_samples)
    assert np.allclose(actual["samples"].sum(), o.size)
def test_reliability_accessor(o, f_prob, threshold, outer_bool):
    actual = reliability(o > threshold, (f_prob > threshold).mean("member"))

    ds = xr.Dataset()
    ds["o"] = o > threshold
    ds["f_prob"] = (f_prob > threshold).mean("member")
    if outer_bool:
        ds = ds.drop_vars("f_prob")
        expected = ds.xs.reliability("o", (f_prob > threshold).mean("member"))
    else:
        expected = ds.xs.reliability("o", "f_prob")
    assert_allclose(actual, expected)
def test_reliability_values(o, f_prob):
    """Test 1D reliability values against sklearn calibration_curve"""
    for lon in f_prob.lon:
        for lat in f_prob.lat:
            o_1d = o.sel(lon=lon, lat=lat) > 0.5
            f_1d = (f_prob.sel(lon=lon, lat=lat) > 0.5).mean("member")
            actual = reliability(o_1d, f_1d)
            expected, _ = calibration_curve(o_1d,
                                            f_1d,
                                            normalize=False,
                                            n_bins=5,
                                            strategy="uniform")
            npt.assert_allclose(actual.where(actual.notnull(), drop=True),
                                expected)
            npt.assert_allclose(actual["samples"].sum(), o_1d.size)
def test_reliability_values(o, f_prob):
    """Test 1D reliability values against sklearn calibration_curve"""
    for lon in f_prob.lon:
        for lat in f_prob.lat:
            o_1d = o.sel(lon=lon, lat=lat) > 0.5
            f_1d = (f_prob.sel(lon=lon, lat=lat) > 0.5).mean("member")
            # scipy bins are only left-edge inclusive and 1e-8 is added to the last bin, whereas
            # xhistogram the rightmost edge of xhistogram bins is included - mimic scipy behaviour
            actual = reliability(
                o_1d, f_1d, probability_bin_edges=np.linspace(0, 1 + 1e-8, 6)
            )
            expected, _ = calibration_curve(
                o_1d, f_1d, normalize=False, n_bins=5, strategy="uniform"
            )
            npt.assert_allclose(actual.where(actual.notnull(), drop=True), expected)
            npt.assert_allclose(actual["samples"].sum(), o_1d.size)
def test_reliability_dask(o_dask, f_prob_dask):
    """Test that reliability returns dask array if provided dask array"""
    actual = reliability(o_dask > 0.5, (f_prob_dask > 0.5).mean("member"))
    assert actual.chunks is not None