def test_mLSTM1900_batch(): """ Given one fake embedded sequence, ensure that we get out _an_ output from mLSTM1900. """ emb = load_embedding_1900() x = get_embedding("TEST", emb) params = load_params_1900() h_final, c_final, h = mLSTM1900_batch(params, x) assert h.shape == (x.shape[0], 1900)
def test_mLSTM(data): length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding() x = get_embedding(sequence, embedding) output_dim = 256 init_fun, apply_fun = mLSTM(output_dim=output_dim) output_shape, params = init_fun(rng, (-1, 10)) _, _, outputs = apply_fun(params=params, inputs=x) assert output_shape == (-1, output_dim) assert outputs.shape == (length + 1, output_dim)
def test_mLSTM1900(data): params = load_params_1900() length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding_1900() x = get_embedding(sequence, embedding) init_fun, apply_fun = mLSTM1900(output_dim=1900) output_shape, params = init_fun(rng, (length, 10)) h_final, c_final, outputs = apply_fun(params=params, inputs=x) assert output_shape == (length, 1900) validate_mLSTM1900_params(params) assert outputs.shape == (length + 1, 1900)
def test_mLSTM1900_AvgHidden(data): params = load_params_1900() length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding_1900() x = get_embedding(sequence, embedding) init_fun, apply_fun = stax.serial( mLSTM1900(output_dim=1900), mLSTM1900_AvgHidden(output_dim=1900), ) output_shape, params = init_fun(rng, (length, 10)) h_avg = apply_fun(params=params, inputs=x) assert output_shape == (1900,) validate_mLSTM1900_params(params[0]) assert params[1] == () assert h_avg.shape == (1900,)
def test_mLSTMFusion(data): length = data.draw(st.integers(min_value=1, max_value=10)) sequence = data.draw( st.text( alphabet="MRHKDESTNQCUGPAVIFYWLOXZBJ", min_size=length, max_size=length, ), ) embedding = load_embedding() x = get_embedding(sequence, embedding) output_dim = 256 init_fun, apply_fun = stax.serial( mLSTM(output_dim=output_dim), mLSTMFusion(), ) output_shape, params = init_fun(rng, (length, 10)) h_avg = apply_fun(params=params, inputs=x) assert output_shape == (output_dim * 3,) validate_mLSTM_params(params[0], output_dim) assert params[1] == () assert h_avg.shape == (output_dim * 3,)