Пример #1
0
def test_normal_comparison():
    # Compare a diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag = np.random.randn(3)**2
    var = np.diag(var_diag)
    dist1 = Normal(var, mean)
    dist2 = Normal(Diagonal(var_diag), mean)
    samples = dist1.sample(100)
    allclose(dist1.logpdf(samples), dist2.logpdf(samples), desc='logpdf')
    allclose(dist1.entropy(), dist2.entropy(), desc='entropy')
    allclose(dist1.kl(dist2), 0.)
    allclose(dist1.kl(dist1), 0.)
    allclose(dist2.kl(dist2), 0.)
    allclose(dist2.kl(dist1), 0.)
    assert dist1.w2(dist1) <= 1e-3
    assert dist1.w2(dist2) <= 1e-3
    assert dist2.w2(dist1) <= 1e-3
    assert dist2.w2(dist2) <= 1e-3

    # Check a uniformly diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag_scale = np.random.randn()**2
    var = np.eye(3) * var_diag_scale
    dist1 = Normal(var, mean)
    dist2 = Normal(UniformlyDiagonal(var_diag_scale, 3), mean)
    samples = dist1.sample(100)
    allclose(dist1.logpdf(samples), dist2.logpdf(samples), desc='logpdf')
    allclose(dist1.entropy(), dist2.entropy(), desc='entropy')
    allclose(dist1.kl(dist2), 0.)
    allclose(dist1.kl(dist1), 0.)
    allclose(dist2.kl(dist2), 0.)
    allclose(dist2.kl(dist1), 0.)
    assert dist1.w2(dist1) <= 1e-3
    assert dist1.w2(dist2) <= 1e-3
    assert dist2.w2(dist1) <= 1e-3
    assert dist2.w2(dist2) <= 1e-3
Пример #2
0
def test_normal_comparison():
    # Compare a diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag = np.random.randn(3)**2
    var = np.diag(var_diag)
    dist1 = Normal(var, mean)
    dist2 = Normal(Diagonal(var_diag), mean)
    samples = dist1.sample(100)
    yield ok, allclose(dist1.logpdf(samples), dist2.logpdf(samples)), 'logpdf'
    yield ok, allclose(dist1.entropy(), dist2.entropy()), 'entropy'
    yield ok, allclose(dist1.kl(dist2), 0.), 'kl 1'
    yield ok, allclose(dist1.kl(dist1), 0.), 'kl 2'
    yield ok, allclose(dist2.kl(dist2), 0.), 'kl 3'
    yield ok, allclose(dist2.kl(dist1), 0.), 'kl 4'
    yield le, dist1.w2(dist1), 5e-4, 'w2 1'
    yield le, dist1.w2(dist2), 5e-4, 'w2 2'
    yield le, dist2.w2(dist1), 5e-4, 'w2 3'
    yield le, dist2.w2(dist2), 5e-4, 'w2 4'

    # Check a uniformly diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag_scale = np.random.randn()**2
    var = np.eye(3) * var_diag_scale
    dist1 = Normal(var, mean)
    dist2 = Normal(UniformlyDiagonal(var_diag_scale, 3), mean)
    samples = dist1.sample(100)
    yield ok, allclose(dist1.logpdf(samples), dist2.logpdf(samples)), 'logpdf'
    yield ok, allclose(dist1.entropy(), dist2.entropy()), 'entropy'
    yield ok, allclose(dist1.kl(dist2), 0.), 'kl 1'
    yield ok, allclose(dist1.kl(dist1), 0.), 'kl 2'
    yield ok, allclose(dist2.kl(dist2), 0.), 'kl 3'
    yield ok, allclose(dist2.kl(dist1), 0.), 'kl 4'
    yield le, dist1.w2(dist1), 5e-4, 'w2 1'
    yield le, dist1.w2(dist2), 5e-4, 'w2 2'
    yield le, dist2.w2(dist1), 5e-4, 'w2 3'
    yield le, dist2.w2(dist2), 5e-4, 'w2 4'