예제 #1
0
def main(
    kind: DataType,
    num_sequences: Optional[int] = None,
    length: Optional[int] = None,
    validation_fraction: float = None,
    backend: str = "gpu",
    n_epochs: int = 20,
    learning_rate_power: int = -3,
    mlstm_size: int = 256,
    batch_size: int = 100,
    batch_method: BatchMethod = "random",
):
    if kind == "enzymes":
        seqs = load_enzymes()
    if kind == "random":
        seqs = load_random(num_sequences, length)
    shuffle(seqs)
    break_point = int(len(seqs) * validation_fraction)
    sequences = seqs[0:break_point]
    holdout_sequences = seqs[break_point:]

    logger = logging.getLogger("fitting.py")
    logger.setLevel(logging.DEBUG)
    logger.info(f"There are {len(sequences)} sequences.")

    LEARN_RATE = 10**learning_rate_power
    PROJECT_NAME = "temp"

    evotuned_params = fit(
        mlstm_size=mlstm_size,
        rng=PRNGKey(42),
        params=None,
        sequences=sequences,
        n_epochs=n_epochs,
        step_size=LEARN_RATE,
        holdout_seqs=holdout_sequences,
        batch_size=batch_size,
        batch_method=batch_method,
        proj_name=PROJECT_NAME,
        epochs_per_print=1,
        backend=
        backend,  # default is "cpu", can be "gpu" if you have JAX-GPU installed.
    )
    dump_params(evotuned_params, PROJECT_NAME)
    print("Evotuning done! Find output weights in", PROJECT_NAME)
예제 #2
0
seqs = seqs[:50]
shuffle(seqs)
break_point = int(len(seqs) * 0.7)
sequences = seqs[0:break_point]
holdout_sequences = seqs[break_point:]

logger = logging.getLogger("fitting.py")
logger.setLevel(logging.DEBUG)
logger.info(f"There are {len(sequences)} sequences.")

N_EPOCHS = 20
LEARN_RATE = 1e-5
PROJECT_NAME = "temp"

params = load_random_evotuning_params()

evotuned_params = fit(
    params=params,
    sequences=sequences,
    n_epochs=N_EPOCHS,
    step_size=LEARN_RATE,
    holdout_seqs=holdout_sequences,
    batch_method="random",
    proj_name=PROJECT_NAME,
    epochs_per_print=1,
    backend="cpu",  # default is "cpu"
)

dump_params(evotuned_params, PROJECT_NAME)
print("Evotuning done! Find output weights in", PROJECT_NAME)
예제 #3
0
    out_dom_seqs=holdout_sequences,
    n_trials=2,
    n_splits=2,
    n_epochs_config=n_epochs_config,
    learning_rate_config=lr_config,
    epochs_per_print=1,
)

dump_params(evotuned_params, PROJECT_NAME)
print("Evotuning done! Find output weights in", PROJECT_NAME)
print(study.trials_dataframe())

## 2. Evotuning without Optuna

N_EPOCHS = 3
LEARN_RATE = 1e-4
PROJECT_NAME = "temp"

evotuned_params = fit(
    params=None,
    sequences=sequences,
    n_epochs=N_EPOCHS,
    step_size=LEARN_RATE,
    holdout_seqs=holdout_sequences,
    proj_name=PROJECT_NAME,
    epochs_per_print=1,
)

dump_params(evotuned_params, PROJECT_NAME)
print("Evotuning done! Find output weights in", PROJECT_NAME)
예제 #4
0
    for record in SeqIO.parse(f, "fasta"):
        seqs.append(str(record.seq))

seqs = seqs[:50]
shuffle(seqs)
break_point = int(len(seqs) * 0.7)
sequences = seqs[0:break_point]
holdout_sequences = seqs[break_point:]

logger = logging.getLogger("fitting.py")
logger.setLevel(logging.DEBUG)
logger.info(f"There are {len(sequences)} sequences.")

N_EPOCHS = 20
LEARN_RATE = 1e-5
PROJECT_NAME = "temp"

evotuned_params = fit(
    params=None,
    sequences=sequences,
    n_epochs=N_EPOCHS,
    step_size=LEARN_RATE,
    holdout_seqs=holdout_sequences,
    batch_method="random",
    proj_name=PROJECT_NAME,
    steps_per_print=None,
)

dump_params(evotuned_params, PROJECT_NAME)
print("Evotuning done! Find output weights in", PROJECT_NAME)
예제 #5
0
# Set aside sequences for determining overfitting
if args.validation:
    # Split sequences into training set and validation set
    break_point = int(len(sequences) * (1 - args.validation))
    train_sequences = sequences[0:break_point]
    holdout_sequences = sequences[break_point:]
    # Save validation sequences
    with open("validation_sequences.txt", "w") as holdout_file:
        for seq in holdout_sequences:
            holdout_file.write(seq + "\n")
else:
    train_sequences = sequences
    holdout_sequences = None

# Perform evotuning
evotuned_params = ju.fit(
    params=None,
    sequences=train_sequences,
    n_epochs=args.epochs,
    batch_size=args.batch,
    step_size=args.step,
    holdout_seqs=holdout_sequences,
    batch_method=args.method,
    proj_name=args.outdir,
    epochs_per_print=args.dumps,
    backend=args.cpu,
)

ju.utils.dump_params(evotuned_params, args.outdir, step='final')