def test_nested_sampling_loop(sampler, config): """Test the main nested sampling loop. This is hard to test because of while loop. """ sampler.prior_sampling = False sampler.initialised = False sampler.condition = config.get('condition', 0.5) sampler.tolerance = config.get('tolerance', 0.1) sampler.max_iteration = config.get('max_iteration') sampler.iteration = config.get('iteration', 0) sampler.sampling_time = 0. sampler.training_time = 0. sampler.proposal_population_time = 0. sampler.likelihood_calls = 1 sampler.nested_samples = [1, 2] sampler.proposal = MagicMock() sampler.proposal.close_pool = MagicMock() sampler.proposal.pool = True sampler.proposal.logl_eval_time.total_seconds = MagicMock() sampler.finalised = False sampler.finalise = MagicMock() sampler.check_resume = MagicMock() sampler.check_insertion_indices = MagicMock() sampler.checkpoint = MagicMock() logZ, samples = NestedSampler.nested_sampling_loop(sampler) assert logZ == sampler.state.logZ assert samples.tolist() == [1, 2] sampler.initialise.assert_called_once_with(live_points=True) sampler.check_resume.assert_called_once() if config.get('call_while', True): if config.get('iteration', 0): sampler.update_state.call_count == 2 else: sampler.update_state.assert_called_once() sampler.consume_sample.assert_called_once() sampler.check_state.assert_called_once() else: sampler.check_state.assert_not_called() sampler.consume_sample.assert_not_called() sampler.update_state.assert_not_called() sampler.proposal.close_pool.assert_called_once() if config.get('call_finalise'): sampler.finalise.assert_called_once() else: sampler.finalise.assert_not_called() sampler.check_insertion_indices.assert_called_once_with(rolling=False) sampler.checkpoint.assert_called_once_with(periodic=True)
def test_nested_sampling_loop_prior_sampling(sampler): """Test the nested sampling loop for prior sampling""" sampler.initialised = False sampler.live_points = sampler.model.new_point(10) sampler.prior_sampling = True samples = NestedSampler.nested_sampling_loop(sampler) sampler.initialise.assert_called_once_with(live_points=True) np.testing.assert_array_equal(samples, sampler.live_points)