예제 #1
0
def test_mLSTM1900_batch():
    """
    Given one fake embedded sequence,
    ensure that we get out _an_ output from mLSTM1900.
    """
    emb = load_embedding_1900()
    x = get_embedding("TEST", emb)

    params = load_params_1900()

    h_final, c_final, h = mLSTM1900_batch(params, x)
    assert h.shape == (x.shape[0], 1900)
예제 #2
0
def test_mLSTM(data):
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding()
    x = get_embedding(sequence, embedding)
    output_dim = 256
    init_fun, apply_fun = mLSTM(output_dim=output_dim)
    output_shape, params = init_fun(rng, (-1, 10))
    _, _, outputs = apply_fun(params=params, inputs=x)
    assert output_shape == (-1, output_dim)
    assert outputs.shape == (length + 1, output_dim)
예제 #3
0
def test_mLSTM1900(data):
    params = load_params_1900()
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding_1900()
    x = get_embedding(sequence, embedding)
    init_fun, apply_fun = mLSTM1900(output_dim=1900)
    output_shape, params = init_fun(rng, (length, 10))
    h_final, c_final, outputs = apply_fun(params=params, inputs=x)
    assert output_shape == (length, 1900)
    validate_mLSTM1900_params(params)
    assert outputs.shape == (length + 1, 1900)
예제 #4
0
def test_mLSTM1900_AvgHidden(data):
    params = load_params_1900()
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding_1900()
    x = get_embedding(sequence, embedding)
    init_fun, apply_fun = stax.serial(
        mLSTM1900(output_dim=1900), mLSTM1900_AvgHidden(output_dim=1900),
    )
    output_shape, params = init_fun(rng, (length, 10))
    h_avg = apply_fun(params=params, inputs=x)
    assert output_shape == (1900,)
    validate_mLSTM1900_params(params[0])
    assert params[1] == ()
    assert h_avg.shape == (1900,)
예제 #5
0
def test_mLSTMFusion(data):
    length = data.draw(st.integers(min_value=1, max_value=10))
    sequence = data.draw(
        st.text(
            alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ",
            min_size=length,
            max_size=length,
        ),
    )
    embedding = load_embedding()
    x = get_embedding(sequence, embedding)
    output_dim = 256
    init_fun, apply_fun = stax.serial(
        mLSTM(output_dim=output_dim),
        mLSTMFusion(),
    )
    output_shape, params = init_fun(rng, (length, 10))
    h_avg = apply_fun(params=params, inputs=x)
    assert output_shape == (output_dim * 3,)
    validate_mLSTM_params(params[0], output_dim)
    assert params[1] == ()
    assert h_avg.shape == (output_dim * 3,)