def test_check_compute_fn(): def compute_fn(y_preds, y_targets): raise Exception em = _BaseRegressionEpoch(compute_fn, check_compute_fn=True) em.reset() output1 = (torch.rand(4, 1).float(), torch.randint(0, 2, size=(4, 1), dtype=torch.float32)) with pytest.warns( EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"): em.update(output1) em = _BaseRegressionEpoch(compute_fn, check_compute_fn=False) em.update(output1)
def test_base_regression_compute_fn(): # Wrong compute function with pytest.raises(TypeError): _BaseRegressionEpoch(12345)