def test_mixed_2(histogram): conditions = ( HistogramCondition(histogram["xs"], histogram["densities"]), IntervalCondition(p=0.4, max=1), IntervalCondition(p=0.45, max=1.2), IntervalCondition(p=0.48, max=1.3), IntervalCondition(p=0.5, max=2), IntervalCondition(p=0.7, max=2.2), IntervalCondition(p=0.9, max=2.3), ) dist = LogisticMixture.from_conditions(conditions, num_components=3, verbose=True) assert dist.pdf1(-5) == pytest.approx(0, abs=0.1) assert dist.pdf1(6) == pytest.approx(0, abs=0.1) my_cache = {} my_cache[conditions] = 3 conditions_2 = ( HistogramCondition(histogram["xs"], histogram["densities"]), IntervalCondition(p=0.4, max=1), IntervalCondition(p=0.45, max=1.2), IntervalCondition(p=0.48, max=1.3), IntervalCondition(p=0.5, max=2), IntervalCondition(p=0.7, max=2.2), IntervalCondition(p=0.9, max=2.3), ) assert hash(conditions) == hash(conditions_2) assert my_cache[conditions_2] == 3
def test_normalization_histogram_condition(histogram): original = HistogramCondition(histogram["xs"], histogram["densities"]) normalized_denormalized = original.normalize(10, 1000).denormalize(10, 1000) for (density, norm_denorm_density) in zip(histogram["densities"], normalized_denormalized.densities): assert density == pytest.approx( norm_denorm_density, rel=0.001, ) for (x, norm_denorm_x) in zip(histogram["xs"], normalized_denormalized.xs): assert x == pytest.approx( norm_denorm_x, rel=0.001, ) # half-assed test that xs and densities are at least # getting transformed in the right direction normalized = original.normalize(1, 4) for idx, (normalized_x, normalized_density) in enumerate( zip(normalized.xs, normalized.densities)): orig_x = histogram["xs"][idx] orig_density = histogram["densities"][idx] assert orig_x > normalized_x assert orig_density < normalized_density
def test_mixture_from_histogram(histogram): conditions = [HistogramCondition(histogram["xs"], histogram["densities"])] mixture = LogisticMixture.from_conditions(conditions, num_components=3, verbose=True) for (x, density) in zip(histogram["xs"], histogram["densities"]): assert mixture.pdf1(x) == pytest.approx(density, abs=0.2)
def compile_mixture_loss_functions(num_bins: int = 201): print("Compiling mixture loss functions") target_dist = HistogramDist(np.array([-float(num_bins)] * num_bins)) condition = HistogramCondition(*target_dist.to_arrays()) LogisticMixture.from_conditions([condition], num_components=3)