Exemple #1
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = ZeroInflatedPoisson(torch.zeros(1), torch.tensor(rate))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_tensors_equal(zip_prob, pois_prob)
Exemple #2
0
def test_zip_0_gate(rate):
    # if gate is 0 ZIP is Poisson
    zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.zeros(1))
    zip2 = ZeroInflatedPoisson(torch.tensor(rate),
                               gate_logits=torch.tensor(-99.9))
    pois = Poisson(torch.tensor(rate))
    s = pois.sample((20, ))
    zip1_prob = zip1.log_prob(s)
    zip2_prob = zip2.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_close(zip1_prob, pois_prob)
    assert_close(zip2_prob, pois_prob)