def test_reconstruct(): """Test Tracker objects reconstruction""" # test with one TrackerUD with ExperimentController(*std_args, **std_kwargs) as ec: tr = TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3) while not tr.stopped: tr.respond(np.random.rand() < tr.x_current) tracker = reconstruct_tracker(ec.data_fname)[0] assert_true(tracker.stopped) tracker.x_current # test with one TrackerBinom with ExperimentController(*std_args, **std_kwargs) as ec: tr = TrackerBinom(ec, .05, .5, 10) while not tr.stopped: tr.respond(True) tracker = reconstruct_tracker(ec.data_fname)[0] assert_true(tracker.stopped) tracker.x_current # tracker not stopped with ExperimentController(*std_args, **std_kwargs) as ec: tr = TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3) tr.respond(np.random.rand() < tr.x_current) assert_true(not tr.stopped) assert_raises(ValueError, reconstruct_tracker, ec.data_fname) # test with dealer with ExperimentController(*std_args, **std_kwargs) as ec: tr = [TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3) for _ in range(3)] td = TrackerDealer(ec, tr) for _, x_current in td: td.respond(np.random.rand() < x_current) dealer = reconstruct_dealer(ec.data_fname)[0] assert_true(all(td._x_history == dealer._x_history)) assert_true(all(td._tracker_history == dealer._tracker_history)) assert_true(all(td._response_history == dealer._response_history)) assert_true(td.shape == dealer.shape) assert_true(td.trackers.shape == dealer.trackers.shape) # no tracker/dealer in file with ExperimentController(*std_args, **std_kwargs) as ec: ec.identify_trial(ec_id='one', ttl_id=[0]) ec.start_stimulus() ec.write_data_line('misc', 'trial one') ec.stop() ec.trial_ok() ec.write_data_line('misc', 'end') assert_raises(ValueError, reconstruct_tracker, ec.data_fname) assert_raises(ValueError, reconstruct_dealer, ec.data_fname)
def test_tracker_ud(): """Test TrackerUD""" import matplotlib.pyplot as plt tr = TrackerUD(callback, 3, 1, 1, 1, np.inf, 10, 1) with ExperimentController('test', **std_kwargs) as ec: tr = TrackerUD(ec, 3, 1, 1, 1, np.inf, 10, 1) tr = TrackerUD(None, 3, 1, 1, 1, np.inf, 10, 1) rand = np.random.RandomState(0) while not tr.stopped: tr.respond(rand.rand() < tr.x_current) tr = TrackerUD(None, 3, 1, 1, 1, np.inf, 10, 1) tr.threshold() rand = np.random.RandomState(0) while not tr.stopped: tr.respond(rand.rand() < tr.x_current) # test responding after stopped assert_raises(RuntimeError, tr.respond, 0) # all the properties better work tr.up tr.down tr.step_size_up tr.step_size_down tr.stop_reversals tr.stop_trials tr.start_value tr.x_min tr.x_max tr.stopped tr.x tr.responses tr.n_trials tr.n_reversals tr.reversals tr.reversal_inds fig, ax, lines = tr.plot() tr.plot_thresh(ax=ax) tr.plot_thresh() plt.close(fig) ax = plt.axes() fig, ax, lines = tr.plot(ax) plt.close(fig) tr.threshold() tr.check_valid(2) # bad callback type assert_raises(TypeError, TrackerUD, 'foo', 3, 1, 1, 1, 10, np.inf, 1) # test dynamic step size and error conditions tr = TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, change_indices=[2]) tr.respond(True) with warnings.catch_warnings(record=True) as w: tr = TrackerUD(None, 1, 1, 0.75, 0.75, np.inf, 8, 1, x_min=0, x_max=2) responses = [True, True, True, False, False, False, False, True, False] for r in responses: # run long enough to encounter change_indices tr.respond(r) assert_equal(len(w), 1) assert(tr.check_valid(1)) # make sure checking validity is good assert(not tr.check_valid(3)) assert_raises(ValueError, tr.threshold, 1) tr.threshold(3) # run tests with ignore too--should generate warnings, but no error with warnings.catch_warnings(record=True) as w: tr = TrackerUD(None, 1, 1, 0.75, 0.25, np.inf, 7, 1, x_min=0, x_max=2, repeat_limit='ignore') responses = [False, True, False, False, True, True, False, True] for r in responses: # run long enough to encounter change_indices tr.respond(r) assert_equal(len(w), 1) tr.threshold(0) # bad stop_trials assert_raises(ValueError, TrackerUD, None, 3, 1, 1, 1, 10, 'foo', 1) # bad stop_reversals assert_raises(ValueError, TrackerUD, None, 3, 1, 1, 1, 'foo', 10, 1) # change_indices too long assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, change_indices=[1, 2]) # step_size_up length mismatch assert_raises(ValueError, TrackerUD, None, 3, 1, [1], [1, 0.5], 10, np.inf, 1, change_indices=[2]) # step_size_down length mismatch assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1], 10, np.inf, 1, change_indices=[2]) # bad change_rule assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, change_indices=[2], change_rule='foo') # no change_indices (i.e. change_indices=None) assert_raises(ValueError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1) # start_value scalar type checking assert_raises(TypeError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, [9, 5], change_indices=[2]) assert_raises(TypeError, TrackerUD, None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, None, change_indices=[2]) # test with multiple change_indices tr = TrackerUD(None, 3, 1, [3, 2, 1], [3, 2, 1], 10, np.inf, 1, change_indices=[2, 4], change_rule='reversals')
max_lag = 2 pace_rule = 'reversals' rng_dealer = np.random.RandomState(4) # random seed for selecting trial type ############################################################################## # Initializing and Running Trackers # --------------------------------- # The two trackers in this example use all of the same parameters except for # the start value and then are passed into the dealer. After the dealer is # created, the type of trial with the start value above or below the true # threshold (returned as an index) and trial level for that trial can be # acquired. # initialize two tracker objects--one for each start value tr_ud = [ TrackerUD(callback, up, down, step_size_up, step_size_down, stop_reversals, stop_trials, sv, change_indices, change_rule, x_min, x_max) for sv in start_value ] # initialize TrackerDealer object td = TrackerDealer(callback, tr_ud, max_lag, pace_rule, rng_dealer) # Initialize human state rng_human = np.random.RandomState(1) # random seed for modeled subject for _, level in td: # Get information of which trial type is next and what the level is at # that time from TrackerDealer td.respond(rng_human.rand() < sigmoid( level - true_thresh, lower=chance, slope=slope))