コード例 #1
0
ファイル: test_mixture.py プロジェクト: ibab/carl
def test_mixture_api():
    # Check basic API
    p1 = Normal(mu=0.0, sigma=T.constant(1.0))
    p2 = Normal(mu=1.0, sigma=2.0)
    m = Mixture(components=[p1, p2], weights=[0.25])

    assert len(m.components) == 2
    assert len(m.weights) == 2

    assert len(m.parameters_) == 4
    assert len(m.constants_) == 1
    assert len(m.observeds_) == 0

    assert p1.mu in m.parameters_
    assert p1.sigma in m.constants_
    assert p2.mu in m.parameters_
    assert p2.sigma in m.parameters_
    assert m.X == p1.X
    assert m.X == p2.X
    assert m.ndim == p1.ndim
    assert m.ndim == p2.ndim

    m = Mixture(components=[p1, p2])
    w = m.compute_weights()
    assert_array_equal(w, [0.5, 0.5])

    y = T.dscalar(name="y")
    w1 = T.constant(0.25)
    w2 = y * 2
    m = Mixture(components=[p1, p2], weights=[w1, w2])
    assert y in m.observeds_

    # Check errors
    assert_raises(ValueError, Mixture,
                  components=[p1, p1, p1], weights=[1.0])
コード例 #2
0
ファイル: test_mixture.py プロジェクト: glouppe/carl
def test_mixture_api():
    # Check basic API
    p1 = Normal(mu=0.0, sigma=T.constant(1.0))
    p2 = Normal(mu=1.0, sigma=2.0)
    m = Mixture(components=[p1, p2], weights=[0.25])

    assert len(m.components) == 2
    assert len(m.weights) == 2

    assert len(m.parameters_) == 4
    assert len(m.constants_) == 1
    assert len(m.observeds_) == 0

    assert p1.mu in m.parameters_
    assert p1.sigma in m.constants_
    assert p2.mu in m.parameters_
    assert p2.sigma in m.parameters_
    assert m.X == p1.X
    assert m.X == p2.X
    assert m.ndim == p1.ndim
    assert m.ndim == p2.ndim

    m = Mixture(components=[p1, p2])
    w = m.compute_weights()
    assert_array_equal(w, [0.5, 0.5])

    y = T.dscalar(name="y")
    w1 = T.constant(0.25)
    w2 = y * 2
    m = Mixture(components=[p1, p2], weights=[w1, w2])
    assert y in m.observeds_

    # Check errors
    assert_raises(ValueError, Mixture, components=[p1, p1, p1], weights=[1.0])