Ejemplo n.º 1
0
def test_intialise_resume(sampler):
    """Test the initialise method when being used after resuming.

    In this case the live points are not None
    """
    sampler._flow_proposal = MagicMock()
    sampler._uninformed_proposal = MagicMock()
    sampler.populate_live_points = MagicMock()

    sampler._flow_proposal.initialised = False
    sampler._uninformed_proposal.initialised = False
    sampler.live_points = [0.0]
    sampler.iteration = 100
    sampler.maximum_uninformed = 10
    sampler.condition = 1.0
    sampler.tolerance = 0.1
    sampler.initialised = False

    NestedSampler.initialise(sampler)

    sampler._flow_proposal.initialise.assert_called_once()
    sampler._uninformed_proposal.initialise.assert_called_once()
    sampler.populate_live_points.assert_not_called()
    assert sampler.initialised is False
    assert sampler.proposal is sampler._flow_proposal
    sampler.proposal.configure_pool.assert_called_once()
Ejemplo n.º 2
0
def test_train_proposal(sampler):
    """Verify the proposal is trained"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.check_flow_model_reset = MagicMock()
    sampler.checkpoint = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    sampler.memory = False
    sampler.training_time = datetime.timedelta()
    sampler.training_iterations = []
    sampler.live_points = np.arange(10)
    sampler.checkpoint_on_training = True
    sampler.block_iteration = 10
    sampler.block_acceptance = 0.5

    NestedSampler.train_proposal(sampler, force=True)

    sampler.check_flow_model_reset.assert_called_once()
    sampler.proposal.train.assert_called_once()
    sampler.checkpoint.assert_called_once_with(periodic=True)

    assert sampler.training_iterations == [100]
    assert sampler.training_time.total_seconds() > 0
    assert sampler.completed_training is True
    assert sampler.block_iteration == 0
    assert sampler.block_acceptance == 0
Ejemplo n.º 3
0
def test_consume_sample_reject(sampler, live_points):
    """Test the default behaviour of consume sample"""
    sampler.live_points = live_points
    reject_sample = parameters_to_live_point((-0.5,), ['x'])
    reject_sample['logL'] = -0.5
    new_sample = parameters_to_live_point((0.5,), ['x'])
    new_sample['logL'] = 0.5
    sampler.yield_sample = MagicMock()
    sampler.yield_sample.return_value = \
        iter([(1, reject_sample), (1, new_sample)])

    sampler.insert_live_point = MagicMock(return_value=0)
    sampler.check_state = MagicMock()

    NestedSampler.consume_sample(sampler)

    sampler.insert_live_point.assert_called_once_with(new_sample)
    sampler.check_state.assert_called_once()

    assert sampler.nested_samples == [live_points[0]]
    assert sampler.logLmin == 0.0
    assert sampler.rejected == 1
    assert sampler.accepted == 1
    assert sampler.block_acceptance == 0.5
    assert sampler.acceptance_history == [0.5]
    assert sampler.mean_block_acceptance == 0.5
    assert sampler.insertion_indices == [0]
Ejemplo n.º 4
0
def test_checkpoint(sampler, periodic):
    """Test checkpointing method.

    Make sure a file is produced and that the sampling time is updated.
    Also checks to make sure that the iteration is recorded when periodic=False
    """
    sampler.checkpoint_iterations = [10]
    sampler.iteration = 20
    now = datetime.datetime.now()
    sampler.sampling_start_time = now
    sampler.sampling_time = datetime.timedelta()
    sampler.resume_file = 'test.pkl'

    with patch('nessai.nestedsampler.safe_file_dump') as sfd_mock:
        NestedSampler.checkpoint(sampler, periodic=periodic)

    sfd_mock.assert_called_once_with(sampler,
                                     sampler.resume_file,
                                     pickle,
                                     save_existing=True)

    assert sampler.sampling_start_time > now
    assert sampler.sampling_time.total_seconds() > 0.

    if periodic:
        assert sampler.checkpoint_iterations == [10]
    else:
        assert sampler.checkpoint_iterations == [10, 20]
Ejemplo n.º 5
0
def test_update_state_force(mock_plot, checkpointing, sampler):
    """Test the update that happens if force=True.

    Checks that plot_indices is not called even if plotting is enabled.
    """
    sampler.iteration = 111
    sampler.proposal._checked_population = True
    sampler.check_insertion_indices = MagicMock()
    sampler.plot = True
    sampler.plot_state = MagicMock()
    sampler.plot_trace = MagicMock()
    sampler.output = './'
    sampler.uninformed_sampling = False
    sampler.checkpointing = checkpointing

    NestedSampler.update_state(sampler, force=True)

    if checkpointing:
        sampler.checkpoint.assert_called_once_with(periodic=True)
    else:
        sampler.checkpoint.assert_not_called()
    assert not mock_plot.called
    assert not sampler.called
    sampler.plot_trace.assert_called_once()
    sampler.plot_state.assert_called_once_with(filename='.//state.png')

    assert sampler.max_likelihood == [100.0, 150.0]
    assert sampler.population_acceptance == [0.5]
    assert sampler.block_acceptance == 0.5
    assert sampler.block_iteration == 5
Ejemplo n.º 6
0
def test_train_proposal_memory(sampler):
    """Verify the proposal is trained with memory"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.check_flow_model_reset = MagicMock()
    sampler.checkpoint = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    sampler.memory = 2
    sampler.training_time = datetime.timedelta()
    sampler.training_iterations = []
    sampler.nested_samples = np.arange(5)
    sampler.live_points = np.arange(5, 10)
    sampler.checkpoint_on_training = True
    sampler.block_iteration = 10
    sampler.block_acceptance = 0.5

    NestedSampler.train_proposal(sampler, force=True)

    sampler.check_flow_model_reset.assert_called_once()
    sampler.checkpoint.assert_called_once_with(periodic=True)
    sampler.proposal.train.assert_called_once()

    np.testing.assert_array_equal(sampler.proposal.train.call_args[0],
                                  np.array([[5, 6, 7, 8, 9, 3, 4]]))

    assert sampler.training_iterations == [100]
    assert sampler.training_time.total_seconds() > 0
    assert sampler.completed_training is True
    assert sampler.block_iteration == 0
    assert sampler.block_acceptance == 0
Ejemplo n.º 7
0
def test_log_likelihood(sampler):
    """Test the log-likelihood method.

    This method is unused in the sampler and there only for the user.
    """
    sampler.model.log_likelihood = MagicMock()
    NestedSampler.log_likelihood(sampler, [0.1])
    sampler.model.log_likelihood.assert_called_once_with([0.1])
Ejemplo n.º 8
0
def test_uninformed_analytic_priors(sampler):
    """
    Test to check that the correct proposal method is used with analytic
    priors.
    """
    NestedSampler.configure_uninformed_proposal(sampler, None, True, None,
                                                None)
    assert isinstance(sampler._uninformed_proposal, AnalyticProposal)
Ejemplo n.º 9
0
def test_train_proposal_not_training(sampler):
    """Verify the proposal is not trained it has not 'cooled down'"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    NestedSampler.train_proposal(sampler, force=False)
    sampler.proposal.train.assert_not_called()
Ejemplo n.º 10
0
def test_insertion_indices_p_none(mock_fn, sampler):
    """Test computing the distribution of insertion indices if p is None"""
    sampler.rolling_p = []
    sampler.insertion_indices = \
        np.random.randint(sampler.nlive, size=2 * sampler.nlive)

    NestedSampler.check_insertion_indices(sampler, rolling=True)

    assert len(sampler.rolling_p) == 0
Ejemplo n.º 11
0
def test_uninformed_threshold_default_(sampler, threshold):
    """
    Test to check that the threshold is set to the same value if it is above
    or equal to 0.1
    """
    sampler.acceptance_threshold = threshold
    NestedSampler.configure_uninformed_proposal(sampler, None, False, None,
                                                None)
    assert sampler.uninformed_acceptance_threshold == threshold
Ejemplo n.º 12
0
def test_uninformed_threshold_default_below(sampler):
    """
    Test to check that the threshold is set to 10 times the acceptance
    if it is below 0.1.
    """
    sampler.acceptance_threshold = 0.05
    NestedSampler.configure_uninformed_proposal(sampler, None, False, None,
                                                None)
    assert sampler.uninformed_acceptance_threshold == 0.5
Ejemplo n.º 13
0
def test_init(sampler, model):
    """Test the init method"""
    sampler.setup_output = MagicMock()
    sampler.configure_flow_reset = MagicMock()
    sampler.configure_flow_proposal = MagicMock()
    sampler.configure_uninformed_proposal = MagicMock
    NestedSampler.__init__(sampler, model, nlive=100, poolsize=100)
    assert sampler.initialised is False
    assert sampler.nlive == 100
Ejemplo n.º 14
0
def test_populate_live_points_nans(sampler):
    """Test popluting the live points with NaN values"""
    new_points = sampler.model.new_point(sampler.nlive + 1)
    new_points['logL'][4] = np.nan
    sampler.yield_sample = MagicMock(
        return_value=iter(zip(np.ones(sampler.nlive + 1), new_points)))
    NestedSampler.populate_live_points(sampler)
    assert len(sampler.live_points) == sampler.nlive
    assert not np.isnan(sampler.live_points['logL']).any()
Ejemplo n.º 15
0
def test_check_flow_model_reset_not_trained(sampler):
    """
    Verify that the flow model is not reset if it has never been trained.
    """
    sampler.proposal = MagicMock()
    sampler.proposal.reset_model_weights = MagicMock()
    sampler.proposal.training_count = 0

    NestedSampler.check_flow_model_reset(sampler)

    sampler.proposal.reset_model_weights.assert_not_called()
Ejemplo n.º 16
0
def test_uninformed_proposal_class(sampler):
    """Test using a custom proposal class"""
    from nessai.proposal.base import Proposal

    class TestProposal(Proposal):
        def draw(self, point):
            pass

    NestedSampler.configure_uninformed_proposal(sampler, TestProposal, False,
                                                None, None)
    assert isinstance(sampler._uninformed_proposal, TestProposal)
Ejemplo n.º 17
0
def test_flow_class_not_subclass(sampler):
    """
    Test to check an error is raised in the class does not inherit from
    FlowProposal
    """
    class FakeProposal:
        pass

    with pytest.raises(RuntimeError) as excinfo:
        NestedSampler.configure_flow_proposal(sampler, FakeProposal, {}, False)
    assert 'inherits' in str(excinfo.value)
Ejemplo n.º 18
0
def test_check_flow_model_reset_weights(sampler, training_count):
    """Assert flow model only weights are reset"""
    sampler.proposal = MagicMock()
    sampler.proposal.reset_model_weights = MagicMock()
    sampler.reset_acceptance = False
    sampler.reset_weights = 10
    sampler.reset_permutations = 0
    sampler.proposal.training_count = training_count

    NestedSampler.check_flow_model_reset(sampler)

    sampler.proposal.reset_model_weights.assert_called_once_with(weights=True)
Ejemplo n.º 19
0
def test_uninformed_maximum(sampler, maximum, result):
    """
    Test to check that the proposal is correctly configured depending on
    the maximum number of uninformed iterations.
    """
    NestedSampler.configure_uninformed_proposal(sampler, None, False, maximum,
                                                None)
    assert sampler.maximum_uninformed == result

    if maximum is False:
        assert sampler.uninformed_sampling is False
    else:
        assert sampler.uninformed_sampling is True
Ejemplo n.º 20
0
def test_check_resume_no_indices(sampler):
    """Test check resume method"""
    sampler.uninformed_sampling = True
    sampler.resumed = True
    sampler._flow_proposal = MagicMock()
    sampler._flow_proposal.populated = False
    sampler._flow_proposal._resume_populated = True
    sampler._flow_proposal.indices = []

    NestedSampler.check_resume(sampler)

    assert sampler.resumed is False
    assert sampler._flow_proposal.populated is False
Ejemplo n.º 21
0
def test_check_flow_model_reset_both(sampler, training_count):
    """Assert flow model only permutations are reset"""
    sampler.proposal = MagicMock()
    sampler.proposal.reset_model_weights = MagicMock()
    sampler.reset_acceptance = False
    sampler.reset_weights = 10
    sampler.reset_permutations = 10
    sampler.proposal.training_count = training_count

    NestedSampler.check_flow_model_reset(sampler)

    calls = [call(weights=True), call(weights=False, permutations=True)]
    sampler.proposal.reset_model_weights.assert_has_calls(calls)
Ejemplo n.º 22
0
def test_check_resume(sampler):
    """Test check resume method"""
    sampler.uninformed_sampling = False
    sampler.check_proposal_switch = MagicMock()
    sampler.resumed = True
    sampler._flow_proposal = MagicMock()
    sampler._flow_proposal.populated = False
    sampler._flow_proposal._resume_populated = True
    sampler._flow_proposal.indices = [1, 2, 3]

    NestedSampler.check_resume(sampler)

    sampler.check_proposal_switch.assert_called_once_with(force=True)
    assert sampler.resumed is False
    assert sampler._flow_proposal.populated is True
Ejemplo n.º 23
0
def test_insertion_indices_save(mock_fn, mock_save, filename, sampler):
    """Test saving the insertion indices"""
    sampler.output = './'
    sampler.insertion_indices = \
        np.random.randint(sampler.nlive, size=2 * sampler.nlive)

    NestedSampler.check_insertion_indices(sampler,
                                          rolling=False,
                                          filename=filename)

    if filename:
        mock_save.assert_called_once_with('./file.txt',
                                          sampler.insertion_indices,
                                          newline='\n',
                                          delimiter=' ')
Ejemplo n.º 24
0
def test_check_flow_model_reset_acceptance(sampler):
    """
    Assert flow model is reset based on acceptance is reset_acceptance is True.
    """
    sampler.proposal = MagicMock()
    sampler.proposal.reset_model_weights = MagicMock()
    sampler.reset_acceptance = True
    sampler.mean_block_acceptance = 0.1
    sampler.acceptance_threshold = 0.5
    sampler.proposal.training_count = 1

    NestedSampler.check_flow_model_reset(sampler)

    sampler.proposal.reset_model_weights.assert_called_once_with(
        weights=True, permutations=True)
Ejemplo n.º 25
0
def test_setup_output_w_resume(sampler, tmpdir):
    """Test output configuration with a specified resume file"""
    p = tmpdir.mkdir('outputs')
    sampler.plot = False
    resume_file = \
        NestedSampler.setup_output(sampler, f'{p}/tests', 'resume.pkl')
    assert resume_file == f'{p}/tests/resume.pkl'
Ejemplo n.º 26
0
def test_setup_output(sampler, tmpdir):
    """Test setting up the output directories"""
    p = tmpdir.mkdir('outputs')
    sampler.plot = False
    resume_file = NestedSampler.setup_output(sampler, f'{p}/tests')
    assert os.path.exists(f'{p}/tests')
    assert resume_file == f'{p}/tests/nested_sampler_resume.pkl'
Ejemplo n.º 27
0
def test_plot_state(sampler, tmpdir, filename, track_gradients):
    """Test making the state plot"""
    x = np.arange(10)
    sampler.min_likelihood = x
    sampler.max_likelihood = x
    sampler.iteration = 1003
    sampler.training_iterations = [256, 711]
    sampler.train_on_empty = False
    sampler.population_iterations = [256, 500, 711, 800]
    sampler.population_acceptance = 4 * [0.5]
    sampler.population_radii = 4 * [1.]
    sampler.checkpoint_iterations = [600]
    sampler.likelihood_evaluations = x
    sampler.state = MagicMock()
    sampler.state.log_vols = np.linspace(0, -10, 1050)
    sampler.state.track_gradients = track_gradients
    sampler.state.gradients = np.arange(1050)
    sampler.logZ_history = x
    sampler.dZ_history = x
    sampler.mean_acceptance_history = x
    sampler.rolling_p = np.arange(4)

    if filename is not None:
        sampler.output = tmpdir.mkdir('test_plot_state')
        filename = os.path.join(sampler.output, filename)
    fig = NestedSampler.plot_state(sampler, filename)

    if filename is not None:
        assert os.path.exists(filename)
    else:
        assert fig is not None
Ejemplo n.º 28
0
def test_insertion_indices(mock_fn, rolling, sampler):
    """Test computing the distribution of insertion indices"""
    sampler.rolling_p = []
    sampler.insertion_indices = \
        np.random.randint(sampler.nlive, size=2 * sampler.nlive)

    NestedSampler.check_insertion_indices(sampler, rolling=rolling)

    if rolling:
        assert len(sampler.rolling_p) == 1
        np.testing.assert_array_equal(
            mock_fn.call_args_list[0][0][0],
            sampler.insertion_indices[-sampler.nlive:])
    else:
        mock_fn.assert_called_once_with(sampler.insertion_indices,
                                        sampler.nlive)
Ejemplo n.º 29
0
def test_proposal_no_switch(sampler):
    """Ensure proposal is not switched"""
    sampler.mean_acceptance = 0.5
    sampler.uninformed_acceptance_threshold = 0.1
    sampler.iteration = 10
    sampler.maximum_uninformed = 100
    assert NestedSampler.check_proposal_switch(sampler) is False
Ejemplo n.º 30
0
def test_get_state(sampler):
    """Test the getstate method used for pickling.

    It should remove the model.
    """
    sampler.model = MagicMock()
    state = NestedSampler.__getstate__(sampler)
    assert 'model' not in state