def test_fit_single_compare(): scipy_dist = Logistic.from_samples_scipy(onp.array([0.1, 0.2])) jax_dist = Logistic.from_samples(np.array([0.1, 0.2])) assert scipy_dist.loc == pytest.approx(float(jax_dist.loc), abs=0.1) assert scipy_dist.scale == pytest.approx(float(jax_dist.scale), abs=0.1)
def test_fit_single_jax(): dist = Logistic.from_samples(np.array([0.1, 0.2])) assert dist.loc == pytest.approx(0.15, abs=0.02)