def test_zip_1_gate(rate): # if gate is 1 ZIP is Delta(0) zip_ = ZeroInflatedPoisson(torch.ones(1), torch.tensor(rate)) delta = Delta(torch.zeros(1)) s = torch.tensor([0., 1.]) zip_prob = zip_.log_prob(s) delta_prob = delta.log_prob(s) assert_tensors_equal(zip_prob, delta_prob)
def test_zinb_1_gate(total_count, probs): # if gate is 1 ZINB is Delta(0) zinb_ = ZeroInflatedNegativeBinomial( torch.ones(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs) ) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) zinb_prob = zinb_.log_prob(s) delta_prob = delta.log_prob(s) assert_close(zinb_prob, delta_prob)
def test_zip_1_gate(rate): # if gate is 1 ZIP is Delta(0) zip1 = ZeroInflatedPoisson(torch.tensor(rate), gate=torch.ones(1)) zip2 = ZeroInflatedPoisson(torch.tensor(rate), gate_logits=torch.tensor(math.inf)) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) zip1_prob = zip1.log_prob(s) zip2_prob = zip2.log_prob(s) delta_prob = delta.log_prob(s) assert_close(zip1_prob, delta_prob) assert_close(zip2_prob, delta_prob)
def test_zinb_1_gate(total_count, probs): # if gate is 1 ZINB is Delta(0) zinb1 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate=torch.ones(1), probs=torch.tensor(probs), ) zinb2 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate_logits=torch.tensor(math.inf), probs=torch.tensor(probs), ) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) zinb1_prob = zinb1.log_prob(s) zinb2_prob = zinb2.log_prob(s) delta_prob = delta.log_prob(s) assert_close(zinb1_prob, delta_prob) assert_close(zinb2_prob, delta_prob)