def embed_from_tfrecord(tfrecord_file, model: str, load_from=None, use_global=False, deserialization_func=deserialize_fasta_sequence): sess = tf.Session() embedding_model = ModelBuilder.build_model(model) primary = tf.placeholder(tf.int32, [None, None]) protein_length = tf.placeholder(tf.int32, [None]) output = embedding_model({'primary': primary, 'protein_length': protein_length}) sess.run(tf.global_variables_initializer()) if load_from is not None: embedding_model.load_weights(load_from, by_name=True) data = tf.data.TFRecordDataset(tfrecord_file).map(deserialization_func) data = data.batch(1) iterator = data.make_one_shot_iterator() batch = iterator.get_next() output = embedding_model(batch) embeddings = [] with suppress(tf.errors.OutOfRangeError): if use_global: out = output['global_emb'] else: out = output['encoder_output'] while True: encoder_output_batch = sess.run(out) for encoder_output in encoder_output_batch: embeddings.append(encoder_output) return embeddings
def embed_from_fasta(fasta_file, model: str, load_from=None, use_global=False, vocab=PFAM_VOCAB): sess = tf.Session() embedding_model = ModelBuilder.build_model(model) primary = tf.placeholder(tf.int32, [None, None]) protein_length = tf.placeholder(tf.int32, [None]) output = embedding_model({'primary': primary, 'protein_length': protein_length}) sess.run(tf.global_variables_initializer()) if load_from is not None: embedding_model.load_weights(load_from, by_name=True) embeddings = [] for record in SeqIO.parse(fasta_file, 'fasta'): int_sequence = np.array([vocab[aa] for aa in record.seq], ndmin=2) if use_global: out = output['global_emb'] else: out = output['encoder_output'] encoder_output = sess.run(out,#output['encoder_output'], feed_dict={primary: int_sequence, protein_length: [int_sequence.shape[1]]}) embeddings.append(encoder_output) return embeddings
def eval(_run, _config, tasks: Union[str, List[str]], model: str): assert _config['load_task_from'] is not None outdir = _run.observers[0].basedir atexit.register(cleanup_folders, outdir, debug=True) sess = setup_tensorflow() if isinstance(tasks, str): tasks = [tasks] embedding_model = ModelBuilder.build_model(model) task_list = TaskBuilder.build_tasks(tasks) task_model = TaskBuilder.build_task_model( embedding_model, task_list, _config['freeze_embedding_weights']) experiment = ProteinExperiment(task_model, task_list) if not _config['datafile']: _, valid_data = get_data(task_list, embedding_model.get_optimal_batch_sizes()) else: datafile = _config['datafile'] if ',' not in _config[ 'datafile'] else _config['datafile'].split(',') valid_data = task_list[0].get_test_data( embedding_model.get_optimal_batch_sizes(), datafile) test_graph = rk.train.TestGraph.from_experiment(experiment, valid_data) sess.run(tf.global_variables_initializer()) print('Model Parameters: {}'.format(embedding_model.count_params())) print('Loading task weights from {}'.format(_config['load_task_from'])) rk.utils.load_distributed(experiment.distribution_strategy, task_model, _config['load_task_from']) task_dir = os.path.dirname(_config['load_task_from']) outfile = os.path.join(task_dir, 'outputs.pkl') print('Saving outputs to {}'.format(outfile)) test_metrics = test_graph.run_epoch(save_outputs=outfile) print(test_metrics.get_average()) consolidate_data(outfile, include_hidden=True)
def main(_run, _config, tasks: Union[str, List[str]], model: str): outdir = _run.observers[0].basedir atexit.register(cleanup_folders, outdir) sess = setup_tensorflow() if isinstance(tasks, str): tasks = [tasks] embedding_model = ModelBuilder.build_model(model) task_list = TaskBuilder.build_tasks(tasks) task_model = TaskBuilder.build_task_model( embedding_model, task_list, _config['freeze_embedding_weights']) experiment = ProteinExperiment(task_model, task_list) bounds, batch_sizes = embedding_model.get_optimal_batch_sizes() batch_sizes = np.asarray(batch_sizes / len(tasks), np.int32) batch_sizes[batch_sizes <= 0] = 1 train_data, valid_data = get_data(task_list, (bounds, batch_sizes)) if _config['steps_per_epoch'] != -1: train_data = train_data.repeat() train_graph = rk.train.TrainGraph.from_experiment(experiment, train_data) test_graph = rk.train.TestGraph.from_experiment(experiment, valid_data) sess.run(tf.global_variables_initializer()) print('Model Parameters: {}'.format(embedding_model.count_params())) if _config['load_from'] is not None: print('Loading weights from {}'.format(_config['load_from'])) rk.utils.load_distributed(experiment.distribution_strategy, embedding_model, _config['load_from']) if _config['load_task_from'] is not None: print('Loading task weights from {}'.format(_config['load_task_from'])) rk.utils.load_distributed(experiment.distribution_strategy, task_model, _config['load_task_from']) evaluator = MetricEvaluator(task_list[0].key_metric) train_graph.initialize() for epoch in range(_config['num_epochs']): train_metrics = train_graph.run_for_n_steps(_config['steps_per_epoch'], epoch_num=epoch) outfile = os.path.join( outdir, 'outputs.pkl') if _config['save_outputs'] else None test_metrics = test_graph.run_epoch(epoch_num=epoch, save_outputs=outfile) if all( isinstance(task, AbstractLanguageModelingTask) for task in task_list): with experiment.distribution_strategy.scope(): embedding_model.save_weights('{}/epoch_{}.h5'.format( outdir, epoch), overwrite=True) evaluator.check_and_log_metric(train_metrics, test_metrics) for name, value in train_metrics.items(): _run.log_scalar('train.{}'.format(name), value) for name, value in test_metrics.items(): _run.log_scalar('valid.{}'.format(name), value) _run.log_scalar('runtime', round(train_metrics.runtime + test_metrics.runtime)) if evaluator.was_improvement: with experiment.distribution_strategy.scope(): embedding_model.save_weights('{}/best_weights.h5'.format( outdir, overwrite=True)) task_model.save_weights('{}/task_weights.h5'.format( outdir, overwrite=True)) else: if evaluator.n_epochs_no_improvement >= _config['patience']: print( "Early stopping because no improvement in validation loss " "for {} epochs\n".format(_config['patience'])) break
def run_embed(datafile: str, model_name: str, load_from: Optional[str] = None, task_name: Optional[str] = None): datapath = Path(datafile) if not datapath.exists(): raise FileNotFoundError(datapath) elif datapath.suffix not in ['.fasta', '.tfrecord', '.tfrecords']: raise Exception( f"Unknown file type: {datapath.suffix}, must be .fasta or .tfrecord" ) load_path: Optional[Path] = None if load_from is not None: load_path = Path(load_from) if not load_path.exists(): raise FileNotFoundError(load_path) import tensorflow as tf import tensorflow.keras.backend as K import numpy as np from tape.models import ModelBuilder sess = tf.InteractiveSession() K.set_learning_phase(0) embedding_model = ModelBuilder.build_model(model_name) if datapath.suffix == '.fasta': from Bio import SeqIO from tape.data_utils import PFAM_VOCAB primary = tf.placeholder(tf.int32, [None, None]) protein_length = tf.placeholder(tf.int32, [None]) output = embedding_model({ 'primary': primary, 'protein_length': protein_length }) sess.run(tf.global_variables_initializer()) if load_path is not None: embedding_model.load_weights(str(load_path)) embeddings = [] for record in SeqIO.parse(str(datapath), 'fasta'): int_sequence = np.array([PFAM_VOCAB[aa] for aa in record.seq], ndmin=2) encoder_output = sess.run(output['encoder_output'], feed_dict={ primary: int_sequence, protein_length: [int_sequence.shape[1]] }) embeddings.append(encoder_output) else: import contextlib if task_name is not None: from tape.tasks import TaskBuilder task = TaskBuilder.build_task(task_name) deserialization_func = task.deserialization_func else: from tape.data_utils import deserialize_fasta_sequence deserialization_func = deserialize_fasta_sequence data = tf.data.TFRecordDataset(str(datapath)).map(deserialization_func) data = data.batch(1) iterator = data.make_one_shot_iterator() batch = iterator.get_next() output = embedding_model(batch) if load_path is not None: embedding_model.load_weights(str(load_path)) embeddings = [] with contextlib.suppress(tf.errors.OutOfRangeError): while True: output_batch = sess.run(output['encoder_output']) for encoder_output in output_batch: embeddings.append(encoder_output) return embeddings