Exemplo n.º 1
0
def test_reconstruct():
    """Test Tracker objects reconstruction"""

    # test with one TrackerUD
    with ExperimentController(*std_args, **std_kwargs) as ec:
        tr = TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3)
        while not tr.stopped:
            tr.respond(np.random.rand() < tr.x_current)

    tracker = reconstruct_tracker(ec.data_fname)[0]
    assert_true(tracker.stopped)
    tracker.x_current

    # test with one TrackerBinom
    with ExperimentController(*std_args, **std_kwargs) as ec:
        tr = TrackerBinom(ec, .05, .5, 10)
        while not tr.stopped:
            tr.respond(True)

    tracker = reconstruct_tracker(ec.data_fname)[0]
    assert_true(tracker.stopped)
    tracker.x_current

    # tracker not stopped
    with ExperimentController(*std_args, **std_kwargs) as ec:
        tr = TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3)
        tr.respond(np.random.rand() < tr.x_current)
        assert_true(not tr.stopped)
    assert_raises(ValueError, reconstruct_tracker, ec.data_fname)

    # test with dealer
    with ExperimentController(*std_args, **std_kwargs) as ec:
        tr = [TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3) for _ in range(3)]
        td = TrackerDealer(ec, tr)

        for _, x_current in td:
            td.respond(np.random.rand() < x_current)

    dealer = reconstruct_dealer(ec.data_fname)[0]
    assert_true(all(td._x_history == dealer._x_history))
    assert_true(all(td._tracker_history == dealer._tracker_history))
    assert_true(all(td._response_history == dealer._response_history))
    assert_true(td.shape == dealer.shape)
    assert_true(td.trackers.shape == dealer.trackers.shape)

    # no tracker/dealer in file
    with ExperimentController(*std_args, **std_kwargs) as ec:
        ec.identify_trial(ec_id='one', ttl_id=[0])
        ec.start_stimulus()
        ec.write_data_line('misc', 'trial one')
        ec.stop()
        ec.trial_ok()
        ec.write_data_line('misc', 'end')

    assert_raises(ValueError, reconstruct_tracker, ec.data_fname)
    assert_raises(ValueError, reconstruct_dealer, ec.data_fname)
Exemplo n.º 2
0
def test_tracker_ud():
    """Test TrackerUD"""
    import matplotlib.pyplot as plt
    tr = TrackerUD(callback, 3, 1, 1, 1, np.inf, 10, 1)
    with ExperimentController('test', **std_kwargs) as ec:
        tr = TrackerUD(ec, 3, 1, 1, 1, np.inf, 10, 1)
    tr = TrackerUD(None, 3, 1, 1, 1, np.inf, 10, 1)
    rand = np.random.RandomState(0)
    while not tr.stopped:
        tr.respond(rand.rand() < tr.x_current)

    tr = TrackerUD(None, 3, 1, 1, 1, np.inf, 10, 1)
    tr.threshold()
    rand = np.random.RandomState(0)
    while not tr.stopped:
        tr.respond(rand.rand() < tr.x_current)
    # test responding after stopped
    assert_raises(RuntimeError, tr.respond, 0)

    # all the properties better work
    tr.up
    tr.down
    tr.step_size_up
    tr.step_size_down
    tr.stop_reversals
    tr.stop_trials
    tr.start_value
    tr.x_min
    tr.x_max
    tr.stopped
    tr.x
    tr.responses
    tr.n_trials
    tr.n_reversals
    tr.reversals
    tr.reversal_inds
    fig, ax, lines = tr.plot()
    tr.plot_thresh(ax=ax)
    tr.plot_thresh()
    plt.close(fig)
    ax = plt.axes()
    fig, ax, lines = tr.plot(ax)
    plt.close(fig)
    tr.threshold()
    tr.check_valid(2)

    # bad callback type
    assert_raises(TypeError, TrackerUD, 'foo', 3, 1, 1, 1, 10, np.inf, 1)

    # test dynamic step size and error conditions
    tr = TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1,
                   change_indices=[2])
    tr.respond(True)

    with warnings.catch_warnings(record=True) as w:
        tr = TrackerUD(None, 1, 1, 0.75, 0.75, np.inf, 8, 1,
                       x_min=0, x_max=2)
        responses = [True, True, True, False, False, False, False, True, False]
        for r in responses:  # run long enough to encounter change_indices
            tr.respond(r)
        assert_equal(len(w), 1)
    assert(tr.check_valid(1))  # make sure checking validity is good
    assert(not tr.check_valid(3))
    assert_raises(ValueError, tr.threshold, 1)
    tr.threshold(3)

    # run tests with ignore too--should generate warnings, but no error
    with warnings.catch_warnings(record=True) as w:
        tr = TrackerUD(None, 1, 1, 0.75, 0.25, np.inf, 7, 1,
                       x_min=0, x_max=2, repeat_limit='ignore')
        responses = [False, True, False, False, True, True, False, True]
        for r in responses:  # run long enough to encounter change_indices
            tr.respond(r)
        assert_equal(len(w), 1)
    tr.threshold(0)

    # bad stop_trials
    assert_raises(ValueError, TrackerUD, None, 3, 1, 1, 1, 10, 'foo', 1)

    # bad stop_reversals
    assert_raises(ValueError, TrackerUD, None, 3, 1, 1, 1, 'foo', 10, 1)

    # change_indices too long
    assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10,
                  np.inf, 1, change_indices=[1, 2])
    # step_size_up length mismatch
    assert_raises(ValueError, TrackerUD, None, 3, 1, [1], [1, 0.5], 10,
                  np.inf, 1, change_indices=[2])
    # step_size_down length mismatch
    assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1], 10,
                  np.inf, 1, change_indices=[2])
    # bad change_rule
    assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10,
                  np.inf, 1, change_indices=[2], change_rule='foo')
    # no change_indices (i.e. change_indices=None)
    assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10,
                  np.inf, 1)

    # start_value scalar type checking
    assert_raises(TypeError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10,
                  np.inf, [9, 5], change_indices=[2])
    assert_raises(TypeError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10,
                  np.inf, None, change_indices=[2])

    # test with multiple change_indices
    tr = TrackerUD(None, 3, 1, [3, 2, 1], [3, 2, 1], 10, np.inf, 1,
                   change_indices=[2, 4], change_rule='reversals')
max_lag = 2
pace_rule = 'reversals'
rng_dealer = np.random.RandomState(4)  # random seed for selecting trial type

##############################################################################
# Initializing and Running Trackers
# ---------------------------------
# The two trackers in this example use all of the same parameters except for
# the start value and then are passed into the dealer. After the dealer is
# created, the type of trial with the start value above or below the true
# threshold (returned as an index) and trial level for that trial can be
# acquired.

# initialize two tracker objects--one for each start value
tr_ud = [
    TrackerUD(callback, up, down, step_size_up, step_size_down, stop_reversals,
              stop_trials, sv, change_indices, change_rule, x_min, x_max)
    for sv in start_value
]

# initialize TrackerDealer object
td = TrackerDealer(callback, tr_ud, max_lag, pace_rule, rng_dealer)

# Initialize human state
rng_human = np.random.RandomState(1)  # random seed for modeled subject

for _, level in td:
    # Get information of which trial type is next and what the level is at
    # that time from TrackerDealer
    td.respond(rng_human.rand() < sigmoid(
        level - true_thresh, lower=chance, slope=slope))