def test_dump_params(): """ Make sure that the parameter dumping function used in evotuning conserves all parameter shapes correctly. """ params = load_params() dump_params(params, "tmp") dumped_params = load_params("tmp/iter_0") rmtree("tmp") validate_params(dumped_params)
def test_dump_params(model): """ Make sure that the parameter dumping function used in evotuning conserves all parameter shapes correctly. """ init_fun, apply_fun = model _, params = init_fun(PRNGKey(42), input_shape=(-1, 26)) dump_params(params, "tmp") with open("tmp/iter_0/model_weights.pkl", "rb") as f: dumped_params = pkl.load(f) rmtree("tmp") validate_params(model_func=apply_fun, params=dumped_params)
def main( kind: DataType, num_sequences: Optional[int] = None, length: Optional[int] = None, validation_fraction: float = None, backend: str = "gpu", n_epochs: int = 20, learning_rate_power: int = -3, mlstm_size: int = 256, batch_size: int = 100, batch_method: BatchMethod = "random", ): if kind == "enzymes": seqs = load_enzymes() if kind == "random": seqs = load_random(num_sequences, length) shuffle(seqs) break_point = int(len(seqs) * validation_fraction) sequences = seqs[0:break_point] holdout_sequences = seqs[break_point:] logger = logging.getLogger("fitting.py") logger.setLevel(logging.DEBUG) logger.info(f"There are {len(sequences)} sequences.") LEARN_RATE = 10**learning_rate_power PROJECT_NAME = "temp" evotuned_params = fit( mlstm_size=mlstm_size, rng=PRNGKey(42), params=None, sequences=sequences, n_epochs=n_epochs, step_size=LEARN_RATE, holdout_seqs=holdout_sequences, batch_size=batch_size, batch_method=batch_method, proj_name=PROJECT_NAME, epochs_per_print=1, backend= backend, # default is "cpu", can be "gpu" if you have JAX-GPU installed. ) dump_params(evotuned_params, PROJECT_NAME) print("Evotuning done! Find output weights in", PROJECT_NAME)
seqs = seqs[:50] shuffle(seqs) break_point = int(len(seqs) * 0.7) sequences = seqs[0:break_point] holdout_sequences = seqs[break_point:] logger = logging.getLogger("fitting.py") logger.setLevel(logging.DEBUG) logger.info(f"There are {len(sequences)} sequences.") N_EPOCHS = 20 LEARN_RATE = 1e-5 PROJECT_NAME = "temp" params = load_random_evotuning_params() evotuned_params = fit( params=params, sequences=sequences, n_epochs=N_EPOCHS, step_size=LEARN_RATE, holdout_seqs=holdout_sequences, batch_method="random", proj_name=PROJECT_NAME, epochs_per_print=1, backend="cpu", # default is "cpu" ) dump_params(evotuned_params, PROJECT_NAME) print("Evotuning done! Find output weights in", PROJECT_NAME)
"ALAVA", "LIMED", "HAST", "HASVALTA", ] * 5 PROJECT_NAME = "evotuning_temp" init_fun, apply_fun = mlstm64() # The input_shape is always going to be (-1, 26), # because that is the number of unique AA, one-hot encoded. _, inital_params = init_fun(PRNGKey(42), input_shape=(-1, 26)) # 1. Evotuning with Optuna n_epochs_config = {"low": 1, "high": 1} lr_config = {"low": 1e-5, "high": 1e-3} study, evotuned_params = evotune( sequences=sequences, model_func=apply_fun, params=inital_params, out_dom_seqs=holdout_sequences, n_trials=2, n_splits=2, n_epochs_config=n_epochs_config, learning_rate_config=lr_config, ) dump_params(evotuned_params, Path(PROJECT_NAME)) print("Evotuning done! Find output weights in", PROJECT_NAME) print(study.trials_dataframe())