def test_log_prob_scale(num_leaves, num_steps, batch_shape, sample_shape): rate = torch.randn(batch_shape).exp() leaf_times_1 = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps d1 = CoalescentTimes(leaf_times_1) coal_times_1 = d1.sample(sample_shape) log_prob_1 = d1.log_prob(coal_times_1) leaf_times_2 = leaf_times_1 / rate.unsqueeze(-1) coal_times_2 = coal_times_1 / rate.unsqueeze(-1) d2 = CoalescentTimes(leaf_times_2, rate) log_prob_2 = d2.log_prob(coal_times_2) log_abs_det_jacobian = -coal_times_2.size(-1) * rate.log() assert_close(log_prob_1 - log_abs_det_jacobian, log_prob_2)
def test_simple_smoke(num_leaves, num_steps, batch_shape, sample_shape): leaf_times = torch.rand(batch_shape + (num_leaves, )).pow(0.5) * num_steps d = CoalescentTimes(leaf_times) coal_times = d.sample(sample_shape) assert coal_times.shape == sample_shape + batch_shape + (num_leaves - 1, ) actual = d.log_prob(coal_times) assert actual.shape == sample_shape + batch_shape
def test_log_prob_unit_rate(num_leaves, num_steps, batch_shape, sample_shape): leaf_times = torch.rand(batch_shape + (num_leaves, )).pow(0.5) * num_steps d1 = CoalescentTimes(leaf_times) rate_grid = torch.ones(batch_shape + (num_steps, )) d2 = CoalescentTimesWithRate(leaf_times, rate_grid) coal_times = d1.sample(sample_shape) assert_close(d1.log_prob(coal_times), d2.log_prob(coal_times))
def test_log_prob_constant_rate_2(num_leaves, num_steps, batch_shape, sample_shape): rate = torch.randn(batch_shape).exp() rate_grid = rate.unsqueeze(-1).expand(batch_shape + (num_steps,)) leaf_times = torch.rand(batch_shape + (num_leaves,)).pow(0.5) * num_steps d1 = CoalescentTimes(leaf_times, rate) coal_times = d1.sample(sample_shape) log_prob_1 = d1.log_prob(coal_times) d2 = CoalescentTimesWithRate(leaf_times, rate_grid) log_prob_2 = d2.log_prob(coal_times) assert_close(log_prob_1, log_prob_2)