def test_check_training_not_completed_training(sampler): """ Assert the flow is forced to train if training did not complete when the sampler was checkpointed. """ sampler.completed_training = False train, force = NestedSampler.check_training(sampler) assert train is True assert force is True
def test_check_training_train_on_empty(sampler): """ Assert the flow is forced to train if training the pool is empty and `train_on_empty` is true but the proposal was not in the process of popluating. """ sampler.completed_training = True sampler.train_on_empty = True sampler.proposal = MagicMock() sampler.proposal.populated = False sampler.proposal.populating = False train, force = NestedSampler.check_training(sampler) assert train is True assert force is True
def test_check_training_acceptance(sampler): """ Assert that training will be true but not forced if the acceptance threshold is met and retraining on acceptance is enabled. """ sampler.completed_training = True sampler.train_on_empty = True sampler.proposal = MagicMock() sampler.proposal.populated = True sampler.proposal.populating = False sampler.acceptance_threshold = 0.1 sampler.mean_block_acceptance = 0.01 sampler.retrain_acceptance = True train, force = NestedSampler.check_training(sampler) assert train is True assert force is False
def test_check_training_false(sampler, config): """ Test a range of different scenarios that should all not start training. """ sampler.completed_training = True sampler.train_on_empty = config.get('train_on_empty', False) sampler.proposal = MagicMock() sampler.proposal.populated = config.get('populated', False) sampler.proposal.populating = config.get('populating', False) sampler.acceptance_threshold = config.get('acceptance_threshold', 0.1) sampler.mean_block_acceptance = config.get('mean_acceptance', 0.2) sampler.retrain_acceptance = config.get('retrain_acceptance', False) sampler.iteration = config.get('iteration', 3000) sampler.last_updated = config.get('last_updated', 2500) sampler.training_frequency = config.get('training_frequency', 1000) train, force = NestedSampler.check_training(sampler) assert train is False assert force is False
def test_check_training_iteration(sampler): """ Assert that training will be true but not forced if a training iteration is reached (n iterations have passed since last updated). """ sampler.completed_training = True sampler.train_on_empty = True sampler.proposal = MagicMock() sampler.proposal.populated = True sampler.proposal.populating = False sampler.acceptance_threshold = 0.1 sampler.mean_block_acceptance = 0.2 sampler.retrain_acceptance = False sampler.iteration = 3521 sampler.last_updated = 2521 sampler.training_frequency = 1000 train, force = NestedSampler.check_training(sampler) assert train is True assert force is False