def test_reset_integration(tmpdir, model): """Test reset method iteration with other methods""" proposal = FlowProposal(model, poolsize=10) output = str(tmpdir.mkdir('reset_integration')) proposal = FlowProposal(model, output=output, plot=False, poolsize=10, latent_prior='truncated_gaussian', constant_volume_mode=False) modified_proposal = FlowProposal(model, output=output, plot=False, poolsize=10, latent_prior='truncated_gaussian', constant_volume_mode=False) proposal.initialise() modified_proposal.initialise() modified_proposal.populate(model.new_point(), r=1.0) modified_proposal.reset() d1 = proposal.__getstate__() d2 = modified_proposal.__getstate__() for d in [d1, d2]: del d['_min'] del d['_max'] del d['rescale'] del d['inverse_rescale'] assert d1 == d2
def test_get_state(proposal, populated): """Test the get state method used for pickling the proposal. Tests cases where the proposal is and isn't populated. """ proposal.populated = populated proposal.indices = [1, 2] proposal._reparameterisation = MagicMock() proposal.model = MagicMock() proposal._flow_config = {} proposal.pool = MagicMock() proposal.initialised = True proposal.flow = MagicMock() proposal.flow.weights_file = 'file' state = FlowProposal.__getstate__(proposal) assert state['resume_populated'] is populated assert state['pool'] is None assert state['initialised'] is False assert state['weights_file'] == 'file' assert '_reparameterisation' not in state assert 'model' not in state assert 'flow' not in state assert '_flow_config' not in state
def test_test_draw(tmpdir, model, rescale): """Verify that the `test_draw` method works. This method checks that samples can be drawn from the flow and then resets the flows. This test makes sure the flow is correctly reset. """ output = tmpdir.mkdir('test') fp = FlowProposal(model, output=output, poolsize=100, rescale_parameters=rescale) fp.initialise() # Call these since they are worked out the first time they're called fp.x_dtype, fp.x_prime_dtype orig_state = fp.__getstate__() t = TestCase() t.maxDiff = None t.assertDictEqual(fp.__getstate__(), orig_state) fp.test_draw() t.assertDictEqual(fp.__getstate__(), orig_state)