def get_task_parameters(self) -> dict: """Get task parameters from pkl file. Returns ------- dict A dictionary containing parameters used to define the task runtime behavior. """ data = self._behavior_stimulus_file() return get_task_parameters(data)
def test_get_task_parameters(data, expected): actual = get_task_parameters(data) for k, v in actual.items(): # Special nan checking since pytest doesn't do it well try: if np.isnan(v): assert np.isnan(expected[k]) else: assert expected[k] == v except (TypeError, ValueError): assert expected[k] == v actual_keys = list(actual.keys()) actual_keys.sort() expected_keys = list(expected.keys()) expected_keys.sort() assert actual_keys == expected_keys
def test_get_task_parameters_flash_duration_exception(): """ Test that, when 'images' or 'grating' not present in 'stimuli', get_task_parameters throws the correct exception """ input_data = { "items": { "behavior": { "config": { "DoC": { "blank_duration_range": (0.5, 0.6), "response_window": [0.15, 0.75], "change_time_dist": "geometric", "auto_reward_volume": 0.002 }, "reward": { "reward_volume": 0.007, }, "behavior": { "task_id": "DoC", }, }, "params": { "stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05 }, "stimuli": { "junk": { "draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0] } }, } } } with pytest.raises(RuntimeError) as error: _ = get_task_parameters(input_data) shld_be = "'images' and/or 'grating' not a valid key" assert shld_be in error.value.args[0]
def test_get_task_parameters_task_id_exception(): """ Test that, when task_id has an unexpected value, get_task_parameters throws the correct exception """ input_data = { "items": { "behavior": { "config": { "DoC": { "blank_duration_range": (0.5, 0.6), "response_window": [0.15, 0.75], "change_time_dist": "geometric", "auto_reward_volume": 0.002 }, "reward": { "reward_volume": 0.007, }, "behavior": { "task_id": "junk", }, }, "params": { "stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05 }, "stimuli": { "images": { "draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0] } }, } } } with pytest.raises(RuntimeError) as error: _ = get_task_parameters(input_data) assert "does not know how to parse 'task_id'" in error.value.args[0]