def test_check_state_train(sampler, force, train): """Test the behaviour of check_state with force=False and `check_training` returns True, False, True, True, False, True, or False, False. Force is used for the return values for `check_training` """ sampler.uninformed_sampling = False sampler.check_proposal_switch = MagicMock() sampler.check_training = MagicMock(return_value=(train, force)) sampler.train_proposal = MagicMock() NestedSampler.check_state(sampler, force=False) if train or force: sampler.check_training.assert_called_once_with() sampler.train_proposal.assert_called_once_with(force=force) else: sampler.check_training.assert_called_once_with() sampler.train_proposal.assert_not_called()
def test_check_state_force(sampler, switch, uninformed): """Test the behaviour of check_state with force=True. Training should always start irrespective of other checks and with force=True unless uninformed sampling is being used and the switch=False. """ sampler.uninformed_sampling = uninformed sampler.check_proposal_switch = MagicMock(return_value=switch) sampler.check_training = MagicMock() sampler.train_proposal = MagicMock() NestedSampler.check_state(sampler, force=True) if uninformed and not switch: sampler.train_proposal.assert_not_called() else: sampler.train_proposal.assert_called_once_with(force=True) sampler.check_training.assert_not_called()