def test_distribution_collection_serialization():
    dists = DistributionCollection([('norm', Normal(5, 2), 2), ('bern', Bernoulli(), 5)])
    json_str = dists.to_json()
    dists_loaded = load_from_json(json_str)
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    n = 50
    arr = dists.sample(n)
    arr_norm = dists.normalize(arr)
    arr_norm_loaded = dists_loaded.normalize(arr)
    for name in dists.names:
        np.testing.assert_allclose(arr_norm[name], arr_norm_loaded[name])

    n = 10000
    arr = dists.sample(n)
    arr_loaded = dists_loaded.sample(n)
    for name in dists.names:
        assert abs(arr[name].mean() - arr_loaded[name].mean()) <= 0.10
        assert abs(arr[name].std() - arr_loaded[name].std()) <= 0.10

    assert dists_loaded.names == dists.names
    assert dists_loaded.dtype == dists.dtype
    assert dists_loaded.norm_dtype == dists.norm_dtype
    assert dists_loaded.distributions == dists.distributions
    assert dists_loaded.normalizations == dists.normalizations
    assert dists_loaded.nb_elems == dists.nb_elems
def check_serialization_distribution(dist, n=100000):
    json_str = dist.to_json()
    loaded_dist = load_from_json(json_str)
    assert dist == loaded_dist

    arr = dist.sample((n,))
    arr_loaded = loaded_dist.sample((n,))
    assert abs(arr.mean() - arr_loaded.mean()) <= 0.10
    assert abs(arr.std() - arr_loaded.std()) <= 0.10
def check_serialization_normalization(norm, n=100):
    json_str = norm.to_json()
    loaded_norm = load_from_json(json_str)
    assert norm == loaded_norm
    arr = np.random.normal(0, 1, (n,))
    np.testing.assert_allclose(norm.normalize(arr), loaded_norm.normalize(arr))