Ejemplo n.º 1
0
    def test_stack_config_idx(self):
        """Test if config_idx is updated correctly after stack."""
        stimulus_set = np.array((
            (0, 1, 2, 3, -1, -1, -1, -1, -1),
            (9, 12, 7, 1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 7, -1, -1, -1, -1),
            (3, 4, 2, 6, 7, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))

        # Create first set of original trials.
        n_select = np.array((1, 1, 1, 1, 1))
        trials_0 = trials.RankObservations(stimulus_set, n_select=n_select)
        desired_config_idx = np.array((0, 0, 1, 1, 2))
        np.testing.assert_array_equal(trials_0.config_idx, desired_config_idx)

        # Create second set of original trials, with non-overlapping
        # configuration.
        n_select = np.array((2, 2, 2, 2, 2))
        trials_1 = trials.RankObservations(stimulus_set, n_select=n_select)
        desired_config_idx = np.array((0, 0, 1, 1, 2))
        np.testing.assert_array_equal(trials_1.config_idx, desired_config_idx)

        # Stack trials
        trials_stack = trials.stack((trials_0, trials_1))
        desired_config_idx = np.array((0, 0, 1, 1, 2, 3, 3, 4, 4, 5))
        np.testing.assert_array_equal(
            trials_stack.config_idx, desired_config_idx)
Ejemplo n.º 2
0
    def test_invalid_stimulus_set(self):
        """Test handling of invalid `stimulus_set` argument."""
        # Non-integer input.
        stimulus_set = np.array((
            (0., 1, 2, -1, -1, -1, -1, -1, -1),
            (9, 12, 7, -1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 7, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))
        with pytest.raises(Exception) as e_info:
            obs = trials.RankObservations(stimulus_set)

        # Contains integers below -1.
        stimulus_set = np.array((
            (0, 1, -2, -1, -1, -1, -1, -1, -1),
            (9, 12, 7, -1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 7, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))
        with pytest.raises(Exception) as e_info:
            obs = trials.RankObservations(stimulus_set)

        # Does not contain enough references for each trial.
        stimulus_set = np.array((
            (0, 1, 2, -1, -1, -1, -1, -1, -1),
            (9, 12, 7, -1, -1, -1, -1, -1, -1),
            (3, 4, -1, -1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))
        with pytest.raises(Exception) as e_info:
            obs = trials.RankObservations(stimulus_set)
Ejemplo n.º 3
0
def setup_obs_1():
    """
    """
    stimulus_set = np.array(((0, 1, 2, -1, -1, -1, -1, -1, -1),
                            (9, 12, 7, -1, -1, -1, -1, -1, -1),
                            (3, 4, 5, 6, 7, -1, -1, -1, -1),
                            (3, 4, 5, 6, 13, 14, 15, 16, 17)), dtype=np.int32)
    n_trial = 4
    n_select = np.array((1, 1, 1, 2), dtype=np.int32)
    n_reference = np.array((2, 2, 4, 8), dtype=np.int32)
    is_ranked = np.array((True, True, True, True))
    group_id = np.array((0, 0, 1, 1), dtype=np.int32)

    configurations = pd.DataFrame(
        {
            'n_reference': np.array([2, 4, 8], dtype=np.int32),
            'n_select': np.array([1, 1, 2], dtype=np.int32),
            'is_ranked': [True, True, True],
            'group_id': np.array([0, 1, 1], dtype=np.int32),
            'session_id': np.array([0, 0, 0], dtype=np.int32),
            'n_outcome': np.array([2, 4, 56], dtype=np.int32)
        },
        index=[0, 2, 3])
    configuration_id = np.array((0, 0, 1, 2), dtype=np.int32)

    obs = trials.RankObservations(
        stimulus_set, n_select=n_select, group_id=group_id)
    return {
        'n_trial': n_trial, 'stimulus_set': stimulus_set,
        'n_reference': n_reference, 'n_select': n_select,
        'is_ranked': is_ranked, 'group_id': group_id, 'obs': obs,
        'configurations': configurations,
        'configuration_id': configuration_id
        }
Ejemplo n.º 4
0
    def test_invalid_group_id(self):
        """Test handling of invalid `group_id` argument."""
        stimulus_set = np.array((
            (0, 1, 2, -1, -1, -1, -1, -1, -1),
            (9, 12, 7, -1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 7, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))

        # Mismatch in number of trials
        group_id = np.array((0, 0, 1))
        with pytest.raises(Exception) as e_info:
            obs = trials.RankObservations(stimulus_set, group_id=group_id)

        # Below support.
        group_id = np.array((0, -1, 1, 0))
        with pytest.raises(Exception) as e_info:
            obs = trials.RankObservations(stimulus_set, group_id=group_id)
Ejemplo n.º 5
0
    def test_subset_config_idx(self):
        """Test if config_idx is updated correctly after subset."""
        stimulus_set = np.array((
            (0, 1, 2, -1, -1, -1, -1, -1, -1),
            (9, 12, 7, -1, -1, -1, -1, -1, -1),
            (3, 4, 5, 6, 7, -1, -1, -1, -1),
            (3, 4, 2, 6, 7, -1, -1, -1, -1),
            (3, 4, 5, 6, 13, 14, 15, 16, 17)))

        # Create original trials.
        n_select = np.array((1, 1, 1, 1, 2))
        obs = trials.RankObservations(stimulus_set, n_select=n_select)
        desired_config_idx = np.array((0, 0, 1, 1, 2))
        np.testing.assert_array_equal(obs.config_idx, desired_config_idx)
        # Grab subset and check that config_idx is updated to start at 0.
        trials_subset = obs.subset(np.array((2, 3, 4)))
        desired_config_idx = np.array((0, 0, 1))
        np.testing.assert_array_equal(
            trials_subset.config_idx, desired_config_idx)