def test_train_proposal(sampler): """Verify the proposal is trained""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.check_flow_model_reset = MagicMock() sampler.checkpoint = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 sampler.memory = False sampler.training_time = datetime.timedelta() sampler.training_iterations = [] sampler.live_points = np.arange(10) sampler.checkpoint_on_training = True sampler.block_iteration = 10 sampler.block_acceptance = 0.5 NestedSampler.train_proposal(sampler, force=True) sampler.check_flow_model_reset.assert_called_once() sampler.proposal.train.assert_called_once() sampler.checkpoint.assert_called_once_with(periodic=True) assert sampler.training_iterations == [100] assert sampler.training_time.total_seconds() > 0 assert sampler.completed_training is True assert sampler.block_iteration == 0 assert sampler.block_acceptance == 0
def test_train_proposal_memory(sampler): """Verify the proposal is trained with memory""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.check_flow_model_reset = MagicMock() sampler.checkpoint = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 sampler.memory = 2 sampler.training_time = datetime.timedelta() sampler.training_iterations = [] sampler.nested_samples = np.arange(5) sampler.live_points = np.arange(5, 10) sampler.checkpoint_on_training = True sampler.block_iteration = 10 sampler.block_acceptance = 0.5 NestedSampler.train_proposal(sampler, force=True) sampler.check_flow_model_reset.assert_called_once() sampler.checkpoint.assert_called_once_with(periodic=True) sampler.proposal.train.assert_called_once() np.testing.assert_array_equal(sampler.proposal.train.call_args[0], np.array([[5, 6, 7, 8, 9, 3, 4]])) assert sampler.training_iterations == [100] assert sampler.training_time.total_seconds() > 0 assert sampler.completed_training is True assert sampler.block_iteration == 0 assert sampler.block_acceptance == 0
def test_train_proposal_not_training(sampler): """Verify the proposal is not trained it has not 'cooled down'""" sampler.proposal = MagicMock() sampler.proposal.train = MagicMock() sampler.iteration = 100 sampler.last_updated = 90 sampler.cooldown = 20 NestedSampler.train_proposal(sampler, force=False) sampler.proposal.train.assert_not_called()