예제 #1
0
def test_fit_shapes():
    K = 5
    V = 3
    T = 10
    es = EventSegment(K, n_iter=2)
    sample_data = np.random.rand(V, T)
    es.fit(sample_data.T)

    assert es.segments_[0].shape == (T, K), "Segmentation from fit " \
                                            "has incorrect shape"
    assert np.isclose(np.sum(es.segments_[0], axis=1), np.ones(T)).all(), \
        "Segmentation from learn_events not correctly normalized"

    T2 = 15
    sample_data2 = np.random.rand(V, T2)
    test_segments, test_ll = es.find_events(sample_data2.T)

    assert test_segments.shape == (T2, K), "Segmentation from find_events " \
                                           "has incorrect shape"
    assert np.isclose(np.sum(test_segments, axis=1), np.ones(T2)).all(), \
        "Segmentation from find_events not correctly normalized"

    es_invalid = EventSegment(K)
    with pytest.raises(ValueError, message="T < K should cause error"):
        es_invalid.model_prior(K - 1)
    with pytest.raises(ValueError, message="#Events < K should cause error"):
        es_invalid.set_event_patterns(np.zeros((V, K - 1)))
예제 #2
0
def test_fit_shapes():
    K = 5
    V = 3
    T = 10
    es = EventSegment(K, n_iter=2)
    sample_data = np.random.rand(V, T)
    es.fit(sample_data.T)

    assert es.segments_[0].shape == (T, K), "Segmentation from fit " \
                                            "has incorrect shape"
    assert np.isclose(np.sum(es.segments_[0], axis=1), np.ones(T)).all(), \
        "Segmentation from learn_events not correctly normalized"

    T2 = 15
    sample_data2 = np.random.rand(V, T2)
    test_segments, test_ll = es.find_events(sample_data2.T)

    assert test_segments.shape == (T2, K), "Segmentation from find_events " \
                                           "has incorrect shape"
    assert np.isclose(np.sum(test_segments, axis=1), np.ones(T2)).all(), \
        "Segmentation from find_events not correctly normalized"

    es_invalid = EventSegment(K)
    with pytest.raises(ValueError, message="T < K should cause error"):
        es_invalid.model_prior(K-1)
    with pytest.raises(ValueError, message="#Events < K should cause error"):
        es_invalid.set_event_patterns(np.zeros((V, K-1)))
예제 #3
0
def test_chains():
    es = EventSegment(5, event_chains=np.array(['A', 'A', 'B', 'B', 'B']))

    es.set_event_patterns(np.array([[1, 1, 0, 0, 0], [0, 0, 1, 1, 1]]))
    sample_data = np.array([[0, 0, 0], [1, 1, 1]])
    seg = es.find_events(sample_data.T, 0.1)[0]

    ev = np.nonzero(seg > 0.99)[1]
    assert np.array_equal(ev, [2, 3, 4]),\
        "Failed to fit with multiple chains"
예제 #4
0
def test_chains():
    es = EventSegment(5, event_chains=np.array(['A', 'A', 'B', 'B', 'B']))

    es.set_event_patterns(np.array([[1, 1, 0, 0, 0],
                                    [0, 0, 1, 1, 1]]))
    sample_data = np.array([[0, 0, 0], [1, 1, 1]])
    seg = es.find_events(sample_data.T, 0.1)[0]

    ev = np.nonzero(seg > 0.99)[1]
    assert np.array_equal(ev, [2, 3, 4]),\
        "Failed to fit with multiple chains"
예제 #5
0
def test_sym():
    es = EventSegment(4)

    evpat = np.repeat(np.arange(10).reshape(-1, 1), 4, axis=1)
    es.set_event_patterns(evpat)

    D = np.repeat(np.arange(10).reshape(1, -1), 20, axis=0)
    ev = es.find_events(D, var=1)[0]

    # Check that events 1-4 and 2-3 are symmetric
    assert np.all(np.isclose(ev[:, :2], np.fliplr(np.flipud(ev[:, 2:])))),\
        "Fit with constant data is not symmetric"
예제 #6
0
def test_sym():
    es = EventSegment(4)

    evpat = np.repeat(np.arange(10).reshape(-1, 1), 4, axis=1)
    es.set_event_patterns(evpat)

    D = np.repeat(np.arange(10).reshape(1, -1), 20, axis=0)
    ev = es.find_events(D, var=1)[0]

    # Check that events 1-4 and 2-3 are symmetric
    assert np.all(np.isclose(ev[:, :2], np.fliplr(np.flipud(ev[:, 2:])))),\
        "Fit with constant data is not symmetric"
예제 #7
0
def test_chains():
    es = EventSegment(5, event_chains=np.array(['A', 'A', 'B', 'B', 'B']))
    sample_data = np.array([[0, 0, 0], [1, 1, 1]])

    with pytest.raises(RuntimeError):
        seg = es.fit(sample_data.T)[0]
        pytest.fail("Can't use fit() with event chains")

    es.set_event_patterns(np.array([[1, 1, 0, 0, 0], [0, 0, 1, 1, 1]]))
    seg = es.find_events(sample_data.T, 0.1)[0]

    ev = np.nonzero(seg > 0.99)[1]
    assert np.array_equal(ev, [2, 3, 4]),\
        "Failed to fit with multiple chains"
예제 #8
0
def test_event_transfer():
    es = EventSegment(2)
    sample_data = np.asarray([[1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 1]])

    with pytest.raises(NotFittedError, message="Should need to set variance"):
        seg = es.find_events(sample_data.T)[0]

    with pytest.raises(NotFittedError, message="Should need to set patterns"):
        seg = es.find_events(sample_data.T, np.asarray([1, 1]))[0]

    es.set_event_patterns(np.asarray([[1, 0], [0, 1]]))
    seg = es.find_events(sample_data.T, np.asarray([1, 1]))[0]

    events = np.argmax(seg, axis=1)
    assert np.array_equal(events, [0, 0, 0, 1, 1, 1, 1]),\
        "Failed to correctly transfer two events to new data"
예제 #9
0
def test_event_transfer():
    es = EventSegment(2)
    sample_data = np.asarray([[1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 1]])

    with pytest.raises(NotFittedError, message="Should need to set variance"):
        seg = es.find_events(sample_data.T)[0]

    with pytest.raises(NotFittedError, message="Should need to set patterns"):
        seg = es.find_events(sample_data.T, np.asarray([1, 1]))[0]

    es.set_event_patterns(np.asarray([[1, 0], [0, 1]]))
    seg = es.find_events(sample_data.T, np.asarray([1, 1]))[0]

    events = np.argmax(seg, axis=1)
    assert np.array_equal(events, [0, 0, 0, 1, 1, 1, 1]),\
        "Failed to correctly transfer two events to new data"
예제 #10
0
def test_sym_ll():
    ev = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2])
    random_state = np.random.RandomState(0)
    ev_pat = random_state.rand(3, 10)
    D_forward = np.zeros((len(ev), 10))
    for t in range(len(ev)):
        D_forward[t, :] = ev_pat[ev[t], :] + 0.1 * random_state.rand(10)
    D_backward = np.flip(D_forward, axis=0)

    hmm_forward = EventSegment(3)
    hmm_forward.set_event_patterns(ev_pat.T)
    _, ll_forward = hmm_forward.find_events(D_forward, var=1)

    hmm_backward = EventSegment(3)
    hmm_backward.set_event_patterns(np.flip(ev_pat.T, axis=1))
    _, ll_backward = hmm_backward.find_events(D_backward, var=1)

    assert (ll_forward == ll_backward),\
        "Log-likelihood not symmetric forward/backward"
#### 3.2 Define events from movie, find in recall

Alternatively, rather than finding events simultaneously in the movie and recall, we can first identify events using the movie data only, and then go looking for this fixed set of events in the recall. In this case we can use other methods, such as GSBS, to identify the events in the movie, and then use the HMM to transfer these events to the recall.

Below we do this for both the 40-event and 109-event segmentations from GSBS. Note that we need to specify a recall event variance for the HMM when applying it in this way - this value controls how confident the model is about whether it knows which event is being recalled given a spatial pattern. This parameter could be cross-validated relative to some other measure of performance, but here we just set the recall variance to be the same as the movie variance for each event.

# Use 40 GSBS events to set the HMM event patterns
n_events = 40
event_pat = GSBS_states.get_state_patterns(n_events).T
state_labels = GSBS_states.get_states(n_events)
movie_events40 = np.zeros((nTRs, np.max(state_labels)))
movie_events40[np.arange(nTRs), state_labels-1] = 1

HMM = EventSegment(n_events)
HMM.set_event_patterns(event_pat)
ev_var = HMM.calc_weighted_event_var(movie_group, movie_events40, HMM.event_pat_)
recall_events40 = HMM.find_events(recall, var = ev_var)[0]


# Use 109 GSBS events to set the HMM event patterns
n_events = GSBS_states.nstates
event_pat = GSBS_states.state_patterns.T
state_labels = GSBS_states.get_states()
movie_events109 = np.zeros((nTRs, np.max(state_labels)))
movie_events109[np.arange(nTRs), state_labels-1] = 1

HMM = EventSegment(n_events)
HMM.set_event_patterns(event_pat)
ev_var = HMM.calc_weighted_event_var(movie_group, movie_events109, HMM.event_pat_)
recall_events109 = HMM.find_events(recall, var = ev_var)[0]