Ejemplo n.º 1
0
def test_train_proposal(sampler):
    """Verify the proposal is trained"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.check_flow_model_reset = MagicMock()
    sampler.checkpoint = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    sampler.memory = False
    sampler.training_time = datetime.timedelta()
    sampler.training_iterations = []
    sampler.live_points = np.arange(10)
    sampler.checkpoint_on_training = True
    sampler.block_iteration = 10
    sampler.block_acceptance = 0.5

    NestedSampler.train_proposal(sampler, force=True)

    sampler.check_flow_model_reset.assert_called_once()
    sampler.proposal.train.assert_called_once()
    sampler.checkpoint.assert_called_once_with(periodic=True)

    assert sampler.training_iterations == [100]
    assert sampler.training_time.total_seconds() > 0
    assert sampler.completed_training is True
    assert sampler.block_iteration == 0
    assert sampler.block_acceptance == 0
Ejemplo n.º 2
0
def test_train_proposal_memory(sampler):
    """Verify the proposal is trained with memory"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.check_flow_model_reset = MagicMock()
    sampler.checkpoint = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    sampler.memory = 2
    sampler.training_time = datetime.timedelta()
    sampler.training_iterations = []
    sampler.nested_samples = np.arange(5)
    sampler.live_points = np.arange(5, 10)
    sampler.checkpoint_on_training = True
    sampler.block_iteration = 10
    sampler.block_acceptance = 0.5

    NestedSampler.train_proposal(sampler, force=True)

    sampler.check_flow_model_reset.assert_called_once()
    sampler.checkpoint.assert_called_once_with(periodic=True)
    sampler.proposal.train.assert_called_once()

    np.testing.assert_array_equal(sampler.proposal.train.call_args[0],
                                  np.array([[5, 6, 7, 8, 9, 3, 4]]))

    assert sampler.training_iterations == [100]
    assert sampler.training_time.total_seconds() > 0
    assert sampler.completed_training is True
    assert sampler.block_iteration == 0
    assert sampler.block_acceptance == 0
Ejemplo n.º 3
0
def test_train_proposal_not_training(sampler):
    """Verify the proposal is not trained it has not 'cooled down'"""
    sampler.proposal = MagicMock()
    sampler.proposal.train = MagicMock()
    sampler.iteration = 100
    sampler.last_updated = 90
    sampler.cooldown = 20
    NestedSampler.train_proposal(sampler, force=False)
    sampler.proposal.train.assert_not_called()