def test_pm_comparison_stack_dims_when_deterministic(PM_da_initialized_1d, comparison, metric): metric = get_metric_class(metric, PM_METRICS) comparison = get_comparison_class(comparison, PM_COMPARISONS) actual_f, actual_r = comparison.function(PM_da_initialized_1d, metric=metric) if not metric.probabilistic: assert "member" in actual_f.dims assert "member" in actual_r.dims else: assert "member" in actual_f.dims assert "member" not in actual_r.dims
def test_all(PM_da_initialized_1d, comparison, metric): metric = get_metric_class(metric, PM_METRICS) ds = PM_da_initialized_1d comparison = get_comparison_class(comparison, PM_COMPARISONS) forecast, obs = comparison.function(ds, metric=metric) assert not forecast.isnull().any() assert not obs.isnull().any() if not metric.probabilistic: # same dimensions for deterministic metrics assert forecast.dims == obs.dims else: if comparison.name in PROBABILISTIC_PM_COMPARISONS: # same but member dim for probabilistic assert set(forecast.dims) - set(["member"]) == set(obs.dims)
def test_get_comparison_class_fail(): """Test if passing something not in the dict raises the right error.""" with pytest.raises(KeyError) as excinfo: get_comparison_class("not_comparison", PM_COMPARISONS) assert "Specify comparison from" in str(excinfo.value)
def test_get_comparison_class(): """Test if passing in a string gets the right comparison function.""" actual = get_comparison_class("m2c", PM_COMPARISONS).name expected = __m2c.name assert actual == expected