def test_token_cooccurrence_vectorizer_orientation():
    vectorizer = TokenCooccurrenceVectorizer(window_radius=1,
                                             window_orientation="directional")
    result = vectorizer.fit_transform(text_token_data)
    assert result.shape == (4, 8)
    # Check the pok preceded by wer value is 1
    row = vectorizer.token_label_dictionary_["pok"]
    col = vectorizer.column_label_dictionary_["pre_wer"]
    assert result[row, col] == 1
    result_before = TokenCooccurrenceVectorizer(
        window_orientation="before").fit_transform(text_token_data)
    result_after = TokenCooccurrenceVectorizer(
        window_orientation="after").fit_transform(text_token_data)
    assert np.all(
        result_after.toarray() == (result_before.transpose()).toarray())
    result_symmetric = TokenCooccurrenceVectorizer(
        window_orientation="symmetric").fit_transform(text_token_data)
    assert np.all(result_symmetric.toarray() == (result_before +
                                                 result_after).toarray())
Beispiel #2
0
def test_token_cooccurrence_vectorizer_orientation():
    vectorizer = TokenCooccurrenceVectorizer(
        window_radii=1, window_orientations="directional", normalize_windows=False
    )
    result = vectorizer.fit_transform(text_token_data)
    assert result.shape == (4, 8)
    # Check the pok preceded by wer value is 1
    row = vectorizer.token_label_dictionary_["pok"]
    col = vectorizer.column_label_dictionary_["pre_0_wer"]
    assert result[row, col] == 1
    result_before = TokenCooccurrenceVectorizer(
        window_radii=1, window_orientations="before", normalize_windows=False
    ).fit_transform(text_token_data)
    result_after = TokenCooccurrenceVectorizer(
        window_radii=1, window_orientations="after", normalize_windows=False
    ).fit_transform(text_token_data)
    assert np.all(result_after.toarray() == (result_before.transpose()).toarray())
    assert np.all(
        result.toarray() == np.hstack([result_before.toarray(), result_after.toarray()])
    )