Exemple #1
0
def test_dump_params():
    """
    Make sure that the parameter dumping function used in evotuning
    conserves all parameter shapes correctly.
    """
    params = load_params()
    dump_params(params, "tmp")
    dumped_params = load_params("tmp/iter_0")
    rmtree("tmp")
    validate_params(dumped_params)
Exemple #2
0
def test_dump_params(model):
    """
    Make sure that the parameter dumping function used in evotuning
    conserves all parameter shapes correctly.
    """
    init_fun, apply_fun = model
    _, params = init_fun(PRNGKey(42), input_shape=(-1, 26))
    dump_params(params, "tmp")
    with open("tmp/iter_0/model_weights.pkl", "rb") as f:
        dumped_params = pkl.load(f)
    rmtree("tmp")
    validate_params(model_func=apply_fun, params=dumped_params)
Exemple #3
0
def main(
    kind: DataType,
    num_sequences: Optional[int] = None,
    length: Optional[int] = None,
    validation_fraction: float = None,
    backend: str = "gpu",
    n_epochs: int = 20,
    learning_rate_power: int = -3,
    mlstm_size: int = 256,
    batch_size: int = 100,
    batch_method: BatchMethod = "random",
):
    if kind == "enzymes":
        seqs = load_enzymes()
    if kind == "random":
        seqs = load_random(num_sequences, length)
    shuffle(seqs)
    break_point = int(len(seqs) * validation_fraction)
    sequences = seqs[0:break_point]
    holdout_sequences = seqs[break_point:]

    logger = logging.getLogger("fitting.py")
    logger.setLevel(logging.DEBUG)
    logger.info(f"There are {len(sequences)} sequences.")

    LEARN_RATE = 10**learning_rate_power
    PROJECT_NAME = "temp"

    evotuned_params = fit(
        mlstm_size=mlstm_size,
        rng=PRNGKey(42),
        params=None,
        sequences=sequences,
        n_epochs=n_epochs,
        step_size=LEARN_RATE,
        holdout_seqs=holdout_sequences,
        batch_size=batch_size,
        batch_method=batch_method,
        proj_name=PROJECT_NAME,
        epochs_per_print=1,
        backend=
        backend,  # default is "cpu", can be "gpu" if you have JAX-GPU installed.
    )
    dump_params(evotuned_params, PROJECT_NAME)
    print("Evotuning done! Find output weights in", PROJECT_NAME)
Exemple #4
0
seqs = seqs[:50]
shuffle(seqs)
break_point = int(len(seqs) * 0.7)
sequences = seqs[0:break_point]
holdout_sequences = seqs[break_point:]

logger = logging.getLogger("fitting.py")
logger.setLevel(logging.DEBUG)
logger.info(f"There are {len(sequences)} sequences.")

N_EPOCHS = 20
LEARN_RATE = 1e-5
PROJECT_NAME = "temp"

params = load_random_evotuning_params()

evotuned_params = fit(
    params=params,
    sequences=sequences,
    n_epochs=N_EPOCHS,
    step_size=LEARN_RATE,
    holdout_seqs=holdout_sequences,
    batch_method="random",
    proj_name=PROJECT_NAME,
    epochs_per_print=1,
    backend="cpu",  # default is "cpu"
)

dump_params(evotuned_params, PROJECT_NAME)
print("Evotuning done! Find output weights in", PROJECT_NAME)
Exemple #5
0
    "ALAVA",
    "LIMED",
    "HAST",
    "HASVALTA",
] * 5
PROJECT_NAME = "evotuning_temp"

init_fun, apply_fun = mlstm64()

# The input_shape is always going to be (-1, 26),
# because that is the number of unique AA, one-hot encoded.
_, inital_params = init_fun(PRNGKey(42), input_shape=(-1, 26))

# 1. Evotuning with Optuna
n_epochs_config = {"low": 1, "high": 1}
lr_config = {"low": 1e-5, "high": 1e-3}
study, evotuned_params = evotune(
    sequences=sequences,
    model_func=apply_fun,
    params=inital_params,
    out_dom_seqs=holdout_sequences,
    n_trials=2,
    n_splits=2,
    n_epochs_config=n_epochs_config,
    learning_rate_config=lr_config,
)

dump_params(evotuned_params, Path(PROJECT_NAME))
print("Evotuning done! Find output weights in", PROJECT_NAME)
print(study.trials_dataframe())