def test_zinb_1_gate(total_count, probs): # if gate is 1 ZINB is Delta(0) zinb_ = ZeroInflatedNegativeBinomial( torch.ones(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs) ) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) zinb_prob = zinb_.log_prob(s) delta_prob = delta.log_prob(s) assert_close(zinb_prob, delta_prob)
def test_zinb_0_gate(total_count, probs): # if gate is 0 ZINB is NegativeBinomial zinb_ = ZeroInflatedNegativeBinomial( torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs) ) neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs)) s = neg_bin.sample((20,)) zinb_prob = zinb_.log_prob(s) neg_bin_prob = neg_bin.log_prob(s) assert_close(zinb_prob, neg_bin_prob)
def test_zinb_mean_variance(gate, total_count, logits): num_samples = 1000000 zinb_ = ZeroInflatedNegativeBinomial( torch.tensor(gate), total_count=torch.tensor(total_count), logits=torch.tensor(logits), ) s = zinb_.sample((num_samples,)) expected_mean = zinb_.mean estimated_mean = s.mean() expected_std = zinb_.stddev estimated_std = s.std() assert_close(expected_mean, estimated_mean, atol=1e-01) assert_close(expected_std, estimated_std, atol=1e-1)
def test_zinb_1_gate(total_count, probs): # if gate is 1 ZINB is Delta(0) zinb1 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate=torch.ones(1), probs=torch.tensor(probs), ) zinb2 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate_logits=torch.tensor(math.inf), probs=torch.tensor(probs), ) delta = Delta(torch.zeros(1)) s = torch.tensor([0.0, 1.0]) zinb1_prob = zinb1.log_prob(s) zinb2_prob = zinb2.log_prob(s) delta_prob = delta.log_prob(s) assert_close(zinb1_prob, delta_prob) assert_close(zinb2_prob, delta_prob)
def test_zinb_0_gate(total_count, probs): # if gate is 0 ZINB is NegativeBinomial zinb1 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate=torch.zeros(1), probs=torch.tensor(probs), ) zinb2 = ZeroInflatedNegativeBinomial( total_count=torch.tensor(total_count), gate_logits=torch.tensor(-99.9), probs=torch.tensor(probs), ) neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs)) s = neg_bin.sample((20, )) zinb1_prob = zinb1.log_prob(s) zinb2_prob = zinb2.log_prob(s) neg_bin_prob = neg_bin.log_prob(s) assert_close(zinb1_prob, neg_bin_prob) assert_close(zinb2_prob, neg_bin_prob)