def test_make_samples(): samples = make_samples(0.5, Matrix([ [0.1, 0.7], [0.9, 0.3], ])) expected = { (None, 0): 0.45, # 0.1 0.9 (1, 0): 0.8, # 0.7 0.9 (1, None): 0.35, # 0.7 0.3 } for sample, weight in samples: assert almost_equal(weight, expected[sample])
def test_make_samples_2x3(): samples = make_samples(0.5, Matrix([ [0.0, 0.9], [0.3, 0.3], [0.6, 0.9], ])) # [None, 1], # [None, None], # [0, 1], expected_samples = [ ((None, None, 0), 0.2), # 0.0 0.3 0.6 ((None, None, 1), 0.3), # 0.0 0.3 0.9 ((None, None, 0), 0.2), # 0.0 0.3 0.6 ((None, None, 1), 0.3), # 0.0 0.3 0.9 ((1, None, 0), 0.5), # 0.9 0.3 0.6 ((1, None, 1), 0.6), # 0.9 0.3 0.9 ((1, None, 0), 0.5), # 0.9 0.3 0.6 ((1, None, 1), 0.6), # 0.9 0.3 0.9 ((1, None, None), 0.3), # 0.9 None None ] for sample, weight in samples: found = False for expected_sample, expected_weight in list(expected_samples): if sample == expected_sample and abs(weight - expected_weight) < 0.0001: expected_samples.remove((expected_sample, expected_weight)) found = True break if not found: assert False, 'unexpected sample %s with weight %s' % (sample, weight)