def test_intialise_resume(sampler): """Test the initialise method when being used after resuming. In this case the live points are not None """ sampler._flow_proposal = MagicMock() sampler._uninformed_proposal = MagicMock() sampler.populate_live_points = MagicMock() sampler._flow_proposal.initialised = False sampler._uninformed_proposal.initialised = False sampler.live_points = [0.0] sampler.iteration = 100 sampler.maximum_uninformed = 10 sampler.condition = 1.0 sampler.tolerance = 0.1 sampler.initialised = False NestedSampler.initialise(sampler) sampler._flow_proposal.initialise.assert_called_once() sampler._uninformed_proposal.initialise.assert_called_once() sampler.populate_live_points.assert_not_called() assert sampler.initialised is False assert sampler.proposal is sampler._flow_proposal sampler.proposal.configure_pool.assert_called_once()
def test_train_proposal(sampler): """Verify the proposal is trained""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.check_flow_model_reset = MagicMock() sampler.checkpoint = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 sampler.memory = False sampler.training_time = datetime.timedelta() sampler.training_iterations = [] sampler.live_points = np.arange(10) sampler.checkpoint_on_training = True sampler.block_iteration = 10 sampler.block_acceptance = 0.5 NestedSampler.train_proposal(sampler, force=True) sampler.check_flow_model_reset.assert_called_once() sampler.proposal.train.assert_called_once() sampler.checkpoint.assert_called_once_with(periodic=True) assert sampler.training_iterations == [100] assert sampler.training_time.total_seconds() > 0 assert sampler.completed_training is True assert sampler.block_iteration == 0 assert sampler.block_acceptance == 0
def test_consume_sample_reject(sampler, live_points): """Test the default behaviour of consume sample""" sampler.live_points = live_points reject_sample = parameters_to_live_point((-0.5,), ['x']) reject_sample['logL'] = -0.5 new_sample = parameters_to_live_point((0.5,), ['x']) new_sample['logL'] = 0.5 sampler.yield_sample = MagicMock() sampler.yield_sample.return_value = \ iter([(1, reject_sample), (1, new_sample)]) sampler.insert_live_point = MagicMock(return_value=0) sampler.check_state = MagicMock() NestedSampler.consume_sample(sampler) sampler.insert_live_point.assert_called_once_with(new_sample) sampler.check_state.assert_called_once() assert sampler.nested_samples == [live_points[0]] assert sampler.logLmin == 0.0 assert sampler.rejected == 1 assert sampler.accepted == 1 assert sampler.block_acceptance == 0.5 assert sampler.acceptance_history == [0.5] assert sampler.mean_block_acceptance == 0.5 assert sampler.insertion_indices == [0]
def test_checkpoint(sampler, periodic): """Test checkpointing method. Make sure a file is produced and that the sampling time is updated. Also checks to make sure that the iteration is recorded when periodic=False """ sampler.checkpoint_iterations = [10] sampler.iteration = 20 now = datetime.datetime.now() sampler.sampling_start_time = now sampler.sampling_time = datetime.timedelta() sampler.resume_file = 'test.pkl' with patch('nessai.nestedsampler.safe_file_dump') as sfd_mock: NestedSampler.checkpoint(sampler, periodic=periodic) sfd_mock.assert_called_once_with(sampler, sampler.resume_file, pickle, save_existing=True) assert sampler.sampling_start_time > now assert sampler.sampling_time.total_seconds() > 0. if periodic: assert sampler.checkpoint_iterations == [10] else: assert sampler.checkpoint_iterations == [10, 20]
def test_update_state_force(mock_plot, checkpointing, sampler): """Test the update that happens if force=True. Checks that plot_indices is not called even if plotting is enabled. """ sampler.iteration = 111 sampler.proposal._checked_population = True sampler.check_insertion_indices = MagicMock() sampler.plot = True sampler.plot_state = MagicMock() sampler.plot_trace = MagicMock() sampler.output = './' sampler.uninformed_sampling = False sampler.checkpointing = checkpointing NestedSampler.update_state(sampler, force=True) if checkpointing: sampler.checkpoint.assert_called_once_with(periodic=True) else: sampler.checkpoint.assert_not_called() assert not mock_plot.called assert not sampler.called sampler.plot_trace.assert_called_once() sampler.plot_state.assert_called_once_with(filename='.//state.png') assert sampler.max_likelihood == [100.0, 150.0] assert sampler.population_acceptance == [0.5] assert sampler.block_acceptance == 0.5 assert sampler.block_iteration == 5
def test_train_proposal_memory(sampler): """Verify the proposal is trained with memory""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.check_flow_model_reset = MagicMock() sampler.checkpoint = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 sampler.memory = 2 sampler.training_time = datetime.timedelta() sampler.training_iterations = [] sampler.nested_samples = np.arange(5) sampler.live_points = np.arange(5, 10) sampler.checkpoint_on_training = True sampler.block_iteration = 10 sampler.block_acceptance = 0.5 NestedSampler.train_proposal(sampler, force=True) sampler.check_flow_model_reset.assert_called_once() sampler.checkpoint.assert_called_once_with(periodic=True) sampler.proposal.train.assert_called_once() np.testing.assert_array_equal(sampler.proposal.train.call_args[0], np.array([[5, 6, 7, 8, 9, 3, 4]])) assert sampler.training_iterations == [100] assert sampler.training_time.total_seconds() > 0 assert sampler.completed_training is True assert sampler.block_iteration == 0 assert sampler.block_acceptance == 0
def test_log_likelihood(sampler): """Test the log-likelihood method. This method is unused in the sampler and there only for the user. """ sampler.model.log_likelihood = MagicMock() NestedSampler.log_likelihood(sampler, [0.1]) sampler.model.log_likelihood.assert_called_once_with([0.1])
def test_uninformed_analytic_priors(sampler): """ Test to check that the correct proposal method is used with analytic priors. """ NestedSampler.configure_uninformed_proposal(sampler, None, True, None, None) assert isinstance(sampler._uninformed_proposal, AnalyticProposal)
def test_train_proposal_not_training(sampler): """Verify the proposal is not trained it has not 'cooled down'""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 NestedSampler.train_proposal(sampler, force=False) sampler.proposal.train.assert_not_called()
def test_insertion_indices_p_none(mock_fn, sampler): """Test computing the distribution of insertion indices if p is None""" sampler.rolling_p = [] sampler.insertion_indices = \ np.random.randint(sampler.nlive, size=2 * sampler.nlive) NestedSampler.check_insertion_indices(sampler, rolling=True) assert len(sampler.rolling_p) == 0
def test_uninformed_threshold_default_(sampler, threshold): """ Test to check that the threshold is set to the same value if it is above or equal to 0.1 """ sampler.acceptance_threshold = threshold NestedSampler.configure_uninformed_proposal(sampler, None, False, None, None) assert sampler.uninformed_acceptance_threshold == threshold
def test_uninformed_threshold_default_below(sampler): """ Test to check that the threshold is set to 10 times the acceptance if it is below 0.1. """ sampler.acceptance_threshold = 0.05 NestedSampler.configure_uninformed_proposal(sampler, None, False, None, None) assert sampler.uninformed_acceptance_threshold == 0.5
def test_init(sampler, model): """Test the init method""" sampler.setup_output = MagicMock() sampler.configure_flow_reset = MagicMock() sampler.configure_flow_proposal = MagicMock() sampler.configure_uninformed_proposal = MagicMock NestedSampler.__init__(sampler, model, nlive=100, poolsize=100) assert sampler.initialised is False assert sampler.nlive == 100
def test_populate_live_points_nans(sampler): """Test popluting the live points with NaN values""" new_points = sampler.model.new_point(sampler.nlive + 1) new_points['logL'][4] = np.nan sampler.yield_sample = MagicMock( return_value=iter(zip(np.ones(sampler.nlive + 1), new_points))) NestedSampler.populate_live_points(sampler) assert len(sampler.live_points) == sampler.nlive assert not np.isnan(sampler.live_points['logL']).any()
def test_check_flow_model_reset_not_trained(sampler): """ Verify that the flow model is not reset if it has never been trained. """ sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.proposal.training_count = 0 NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_not_called()
def test_uninformed_proposal_class(sampler): """Test using a custom proposal class""" from nessai.proposal.base import Proposal class TestProposal(Proposal): def draw(self, point): pass NestedSampler.configure_uninformed_proposal(sampler, TestProposal, False, None, None) assert isinstance(sampler._uninformed_proposal, TestProposal)
def test_flow_class_not_subclass(sampler): """ Test to check an error is raised in the class does not inherit from FlowProposal """ class FakeProposal: pass with pytest.raises(RuntimeError) as excinfo: NestedSampler.configure_flow_proposal(sampler, FakeProposal, {}, False) assert 'inherits' in str(excinfo.value)
def test_check_flow_model_reset_weights(sampler, training_count): """Assert flow model only weights are reset""" sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = False sampler.reset_weights = 10 sampler.reset_permutations = 0 sampler.proposal.training_count = training_count NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_called_once_with(weights=True)
def test_uninformed_maximum(sampler, maximum, result): """ Test to check that the proposal is correctly configured depending on the maximum number of uninformed iterations. """ NestedSampler.configure_uninformed_proposal(sampler, None, False, maximum, None) assert sampler.maximum_uninformed == result if maximum is False: assert sampler.uninformed_sampling is False else: assert sampler.uninformed_sampling is True
def test_check_resume_no_indices(sampler): """Test check resume method""" sampler.uninformed_sampling = True sampler.resumed = True sampler._flow_proposal = MagicMock() sampler._flow_proposal.populated = False sampler._flow_proposal._resume_populated = True sampler._flow_proposal.indices = [] NestedSampler.check_resume(sampler) assert sampler.resumed is False assert sampler._flow_proposal.populated is False
def test_check_flow_model_reset_both(sampler, training_count): """Assert flow model only permutations are reset""" sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = False sampler.reset_weights = 10 sampler.reset_permutations = 10 sampler.proposal.training_count = training_count NestedSampler.check_flow_model_reset(sampler) calls = [call(weights=True), call(weights=False, permutations=True)] sampler.proposal.reset_model_weights.assert_has_calls(calls)
def test_check_resume(sampler): """Test check resume method""" sampler.uninformed_sampling = False sampler.check_proposal_switch = MagicMock() sampler.resumed = True sampler._flow_proposal = MagicMock() sampler._flow_proposal.populated = False sampler._flow_proposal._resume_populated = True sampler._flow_proposal.indices = [1, 2, 3] NestedSampler.check_resume(sampler) sampler.check_proposal_switch.assert_called_once_with(force=True) assert sampler.resumed is False assert sampler._flow_proposal.populated is True
def test_insertion_indices_save(mock_fn, mock_save, filename, sampler): """Test saving the insertion indices""" sampler.output = './' sampler.insertion_indices = \ np.random.randint(sampler.nlive, size=2 * sampler.nlive) NestedSampler.check_insertion_indices(sampler, rolling=False, filename=filename) if filename: mock_save.assert_called_once_with('./file.txt', sampler.insertion_indices, newline='\n', delimiter=' ')
def test_check_flow_model_reset_acceptance(sampler): """ Assert flow model is reset based on acceptance is reset_acceptance is True. """ sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = True sampler.mean_block_acceptance = 0.1 sampler.acceptance_threshold = 0.5 sampler.proposal.training_count = 1 NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_called_once_with( weights=True, permutations=True)
def test_setup_output_w_resume(sampler, tmpdir): """Test output configuration with a specified resume file""" p = tmpdir.mkdir('outputs') sampler.plot = False resume_file = \ NestedSampler.setup_output(sampler, f'{p}/tests', 'resume.pkl') assert resume_file == f'{p}/tests/resume.pkl'
def test_setup_output(sampler, tmpdir): """Test setting up the output directories""" p = tmpdir.mkdir('outputs') sampler.plot = False resume_file = NestedSampler.setup_output(sampler, f'{p}/tests') assert os.path.exists(f'{p}/tests') assert resume_file == f'{p}/tests/nested_sampler_resume.pkl'
def test_plot_state(sampler, tmpdir, filename, track_gradients): """Test making the state plot""" x = np.arange(10) sampler.min_likelihood = x sampler.max_likelihood = x sampler.iteration = 1003 sampler.training_iterations = [256, 711] sampler.train_on_empty = False sampler.population_iterations = [256, 500, 711, 800] sampler.population_acceptance = 4 * [0.5] sampler.population_radii = 4 * [1.] sampler.checkpoint_iterations = [600] sampler.likelihood_evaluations = x sampler.state = MagicMock() sampler.state.log_vols = np.linspace(0, -10, 1050) sampler.state.track_gradients = track_gradients sampler.state.gradients = np.arange(1050) sampler.logZ_history = x sampler.dZ_history = x sampler.mean_acceptance_history = x sampler.rolling_p = np.arange(4) if filename is not None: sampler.output = tmpdir.mkdir('test_plot_state') filename = os.path.join(sampler.output, filename) fig = NestedSampler.plot_state(sampler, filename) if filename is not None: assert os.path.exists(filename) else: assert fig is not None
def test_insertion_indices(mock_fn, rolling, sampler): """Test computing the distribution of insertion indices""" sampler.rolling_p = [] sampler.insertion_indices = \ np.random.randint(sampler.nlive, size=2 * sampler.nlive) NestedSampler.check_insertion_indices(sampler, rolling=rolling) if rolling: assert len(sampler.rolling_p) == 1 np.testing.assert_array_equal( mock_fn.call_args_list[0][0][0], sampler.insertion_indices[-sampler.nlive:]) else: mock_fn.assert_called_once_with(sampler.insertion_indices, sampler.nlive)
def test_proposal_no_switch(sampler): """Ensure proposal is not switched""" sampler.mean_acceptance = 0.5 sampler.uninformed_acceptance_threshold = 0.1 sampler.iteration = 10 sampler.maximum_uninformed = 100 assert NestedSampler.check_proposal_switch(sampler) is False
def test_get_state(sampler): """Test the getstate method used for pickling. It should remove the model. """ sampler.model = MagicMock() state = NestedSampler.__getstate__(sampler) assert 'model' not in state