示例#1
0
def select_times(recording, subset, random_only=True, dual_only=True):
    '''
    Parameters
    ----------
    recording : nems.recording.Recording
        The recording object.
    subset : Nx2 array
        Epochs representing the selected subset (e.g., from an est/val split).
    random_only : bool
        If True, return only the repeating portion of the subset
    dual_only : bool
        If True, return only the dual stream portion of the subset
    '''
    epochs = recording['stim'].epochs

    m_dual = epochs['name'] == 'dual'
    m_repeating = epochs['name'] == 'repeating'
    m_trial = epochs['name'] == 'TRIAL'

    dual_epochs = epochs.loc[m_dual, ['start', 'end']].values
    repeating_epochs = epochs.loc[m_repeating, ['start', 'end']].values
    trial_epochs = epochs.loc[m_trial, ['start', 'end']].values

    if random_only:
        subset = epoch_difference(subset, repeating_epochs)

    if dual_only:
        subset = epoch_intersection(subset, dual_epochs)

    return recording.select_times(subset)
示例#2
0
def test_difference(epoch_a, epoch_b):
    expected = np.array([
        [0, 50],
        [77, 77],
        [85, 90],
        [95, 100],
    ])
    result = epoch_difference(epoch_a, epoch_b)
    assert np.all(result == expected)
示例#3
0
def test_difference_float(epoch_a, epoch_b):
    expected = np.array([
        [0, 50],
        [77, 77],
        [85, 90],
        [95, 100],
        [140, 150],
    ]) / 10
    result = epoch_difference(epoch_a / 10, epoch_b / 10)
    assert np.all(result == expected)
示例#4
0
def get_est_val_times(recording, balance_phase=False):
    rng = np.random.RandomState(0)

    epochs = recording.epochs
    est_times, val_times = get_est_val_times_by_sequence(recording, rng)

    if balance_phase:
        target_times = select_balanced_targets(epochs, rng)

        m = epochs['name'].str.contains('^repeating$')
        repeating_times = epochs.loc[m, ['start', 'end']].values

        # Remove the repeating phase from the dataset
        est_times = epoch_difference(est_times, repeating_times)
        # Now, add back in selected targets from repeating phase
        est_times = epoch_union(est_times, target_times)

        # Remove the repeating phase from the dataset
        val_times = epoch_difference(val_times, repeating_times)
        # Now, add back in selected targets from repeating phase
        val_times = epoch_union(val_times, target_times)

    return est_times, val_times
示例#5
0
def test_empty_difference():
    a = b = np.array([[0, 50], [50, 100]])

    with pytest.warns(RuntimeWarning):
        result = epoch_difference(a, b)
示例#6
0
def test_empty_difference():
    a = b = np.array([[0, 50], [50, 100]])

    with pytest.raises(RuntimeWarning,
                       message="Expected RuntimeWarning for size 0"):
        result = epoch_difference(a, b)