def test_reset_integration(tmpdir, model):
    """Test reset method iteration with other methods"""
    proposal = FlowProposal(model, poolsize=10)
    output = str(tmpdir.mkdir('reset_integration'))
    proposal = FlowProposal(model,
                            output=output,
                            plot=False,
                            poolsize=10,
                            latent_prior='truncated_gaussian',
                            constant_volume_mode=False)

    modified_proposal = FlowProposal(model,
                                     output=output,
                                     plot=False,
                                     poolsize=10,
                                     latent_prior='truncated_gaussian',
                                     constant_volume_mode=False)
    proposal.initialise()
    modified_proposal.initialise()

    modified_proposal.populate(model.new_point(), r=1.0)
    modified_proposal.reset()

    d1 = proposal.__getstate__()
    d2 = modified_proposal.__getstate__()
    for d in [d1, d2]:
        del d['_min']
        del d['_max']
        del d['rescale']
        del d['inverse_rescale']

    assert d1 == d2
def test_get_state(proposal, populated):
    """Test the get state method used for pickling the proposal.

    Tests cases where the proposal is and isn't populated.
    """

    proposal.populated = populated
    proposal.indices = [1, 2]
    proposal._reparameterisation = MagicMock()
    proposal.model = MagicMock()
    proposal._flow_config = {}
    proposal.pool = MagicMock()
    proposal.initialised = True
    proposal.flow = MagicMock()
    proposal.flow.weights_file = 'file'

    state = FlowProposal.__getstate__(proposal)

    assert state['resume_populated'] is populated
    assert state['pool'] is None
    assert state['initialised'] is False
    assert state['weights_file'] == 'file'
    assert '_reparameterisation' not in state
    assert 'model' not in state
    assert 'flow' not in state
    assert '_flow_config' not in state
def test_test_draw(tmpdir, model, rescale):
    """Verify that the `test_draw` method works.

    This method checks that samples can be drawn from the flow and then
    resets the flows. This test makes sure the flow is correctly reset.
    """
    output = tmpdir.mkdir('test')
    fp = FlowProposal(model,
                      output=output,
                      poolsize=100,
                      rescale_parameters=rescale)
    fp.initialise()
    # Call these since they are worked out the first time they're called
    fp.x_dtype, fp.x_prime_dtype
    orig_state = fp.__getstate__()

    t = TestCase()
    t.maxDiff = None
    t.assertDictEqual(fp.__getstate__(), orig_state)
    fp.test_draw()
    t.assertDictEqual(fp.__getstate__(), orig_state)