Ejemplo n.º 1
0
def test_mlstm64():
    """Test forward pass of pre-built mlstm64 model"""
    init_fun, model_fun = mlstm64()
    _, params = init_fun(PRNGKey(42), input_shape=(-1, 26))

    oh = seq_to_oh("HASTA")
    out = model_fun(params, oh)

    assert out.shape == (7, 25)
Ejemplo n.º 2
0
def model():
    """Return mlstm with randomly initialized parameters."""
    init_fun, apply_fun = mlstm64()
    _, params = init_fun(rng=PRNGKey(0), input_shape=(-1, 26))
    return apply_fun, params
Ejemplo n.º 3
0
from jax_unirep.evotuning_models import mlstm64
from jax_unirep.utils import dump_params

# Test sequences:
sequences = ["HASTA", "VISTA", "ALAVA", "LIMED", "HAST", "HAS", "HASVASTA"] * 5
holdout_sequences = [
    "HASTA",
    "VISTA",
    "ALAVA",
    "LIMED",
    "HAST",
    "HASVALTA",
] * 5
PROJECT_NAME = "evotuning_temp"

init_fun, apply_fun = mlstm64()

# The input_shape is always going to be (-1, 26),
# because that is the number of unique AA, one-hot encoded.
_, inital_params = init_fun(PRNGKey(42), input_shape=(-1, 26))

# 1. Evotuning with Optuna
n_epochs_config = {"low": 1, "high": 1}
lr_config = {"low": 1e-5, "high": 1e-3}
study, evotuned_params = evotune(
    sequences=sequences,
    model_func=apply_fun,
    params=inital_params,
    out_dom_seqs=holdout_sequences,
    n_trials=2,
    n_splits=2,
Ejemplo n.º 4
0
def model():
    """Dummy mLSTM64 model."""
    init_fun, apply_fun = mlstm64()
    return init_fun, apply_fun