def test_distribution_collection_serialization(): dists = DistributionCollection([('norm', Normal(5, 2), 2), ('bern', Bernoulli(), 5)]) json_str = dists.to_json() dists_loaded = load_from_json(json_str) assert dists_loaded.dtype == dists.dtype assert dists_loaded.norm_dtype == dists.norm_dtype n = 50 arr = dists.sample(n) arr_norm = dists.normalize(arr) arr_norm_loaded = dists_loaded.normalize(arr) for name in dists.names: np.testing.assert_allclose(arr_norm[name], arr_norm_loaded[name]) n = 10000 arr = dists.sample(n) arr_loaded = dists_loaded.sample(n) for name in dists.names: assert abs(arr[name].mean() - arr_loaded[name].mean()) <= 0.10 assert abs(arr[name].std() - arr_loaded[name].std()) <= 0.10 assert dists_loaded.names == dists.names assert dists_loaded.dtype == dists.dtype assert dists_loaded.norm_dtype == dists.norm_dtype assert dists_loaded.distributions == dists.distributions assert dists_loaded.normalizations == dists.normalizations assert dists_loaded.nb_elems == dists.nb_elems
def check_serialization_distribution(dist, n=100000): json_str = dist.to_json() loaded_dist = load_from_json(json_str) assert dist == loaded_dist arr = dist.sample((n,)) arr_loaded = loaded_dist.sample((n,)) assert abs(arr.mean() - arr_loaded.mean()) <= 0.10 assert abs(arr.std() - arr_loaded.std()) <= 0.10
def check_serialization_normalization(norm, n=100): json_str = norm.to_json() loaded_norm = load_from_json(json_str) assert norm == loaded_norm arr = np.random.normal(0, 1, (n,)) np.testing.assert_allclose(norm.normalize(arr), loaded_norm.normalize(arr))