def test_check_flow_model_reset_not_trained(sampler): """ Verify that the flow model is not reset if it has never been trained. """ sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.proposal.training_count = 0 NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_not_called()
def test_check_flow_model_reset_weights(sampler, training_count): """Assert flow model only weights are reset""" sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = False sampler.reset_weights = 10 sampler.reset_permutations = 0 sampler.proposal.training_count = training_count NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_called_once_with(weights=True)
def test_check_flow_model_reset_both(sampler, training_count): """Assert flow model only permutations are reset""" sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = False sampler.reset_weights = 10 sampler.reset_permutations = 10 sampler.proposal.training_count = training_count NestedSampler.check_flow_model_reset(sampler) calls = [call(weights=True), call(weights=False, permutations=True)] sampler.proposal.reset_model_weights.assert_has_calls(calls)
def test_check_flow_model_reset_acceptance(sampler): """ Assert flow model is reset based on acceptance is reset_acceptance is True. """ sampler.proposal = MagicMock() sampler.proposal.reset_model_weights = MagicMock() sampler.reset_acceptance = True sampler.mean_block_acceptance = 0.1 sampler.acceptance_threshold = 0.5 sampler.proposal.training_count = 1 NestedSampler.check_flow_model_reset(sampler) sampler.proposal.reset_model_weights.assert_called_once_with( weights=True, permutations=True)