def test_distribution_collection_normalization(): dists = DistributionCollection([('const', Constant(5), 2), ('bern', Bernoulli(), 5)]) bs = 10 arr = dists.sample(bs) norm_arr = dists.normalize(arr) denorm_arr = dists.denormalize(norm_arr) assert (arr == denorm_arr).all()