コード例 #1
0
def embed_from_tfrecord(tfrecord_file,
                        model: str,
                        load_from=None,
                        use_global=False,
                        deserialization_func=deserialize_fasta_sequence):
    sess = tf.Session()
    embedding_model = ModelBuilder.build_model(model)

    primary = tf.placeholder(tf.int32, [None, None])
    protein_length = tf.placeholder(tf.int32, [None])
    output = embedding_model({'primary': primary, 'protein_length': protein_length})

    sess.run(tf.global_variables_initializer())
    if load_from is not None:
        embedding_model.load_weights(load_from, by_name=True)

    data = tf.data.TFRecordDataset(tfrecord_file).map(deserialization_func)
    data = data.batch(1)
    iterator = data.make_one_shot_iterator()
    batch = iterator.get_next()
    output = embedding_model(batch)
    embeddings = []
    with suppress(tf.errors.OutOfRangeError):
        if use_global:
            out = output['global_emb']
        else:
            out = output['encoder_output']
        while True:
            encoder_output_batch = sess.run(out)
            for encoder_output in encoder_output_batch:
                embeddings.append(encoder_output)
    return embeddings
コード例 #2
0
def embed_from_fasta(fasta_file,
                     model: str,
                     load_from=None,
                     use_global=False,
                     vocab=PFAM_VOCAB):
    sess = tf.Session()
    embedding_model = ModelBuilder.build_model(model)

    primary = tf.placeholder(tf.int32, [None, None])
    protein_length = tf.placeholder(tf.int32, [None])
    output = embedding_model({'primary': primary, 'protein_length': protein_length})

    sess.run(tf.global_variables_initializer())
    if load_from is not None:
        embedding_model.load_weights(load_from, by_name=True)

    embeddings = []
    for record in SeqIO.parse(fasta_file, 'fasta'):
        int_sequence = np.array([vocab[aa] for aa in record.seq], ndmin=2)
        if use_global:
            out = output['global_emb']
        else:
            out = output['encoder_output']
        encoder_output = sess.run(out,#output['encoder_output'],
                                  feed_dict={primary: int_sequence,
                                             protein_length: [int_sequence.shape[1]]})
        embeddings.append(encoder_output)
    return embeddings
コード例 #3
0
def eval(_run, _config, tasks: Union[str, List[str]], model: str):
    assert _config['load_task_from'] is not None
    outdir = _run.observers[0].basedir
    atexit.register(cleanup_folders, outdir, debug=True)

    sess = setup_tensorflow()

    if isinstance(tasks, str):
        tasks = [tasks]

    embedding_model = ModelBuilder.build_model(model)
    task_list = TaskBuilder.build_tasks(tasks)
    task_model = TaskBuilder.build_task_model(
        embedding_model, task_list, _config['freeze_embedding_weights'])

    experiment = ProteinExperiment(task_model, task_list)

    if not _config['datafile']:
        _, valid_data = get_data(task_list,
                                 embedding_model.get_optimal_batch_sizes())
    else:
        datafile = _config['datafile'] if ',' not in _config[
            'datafile'] else _config['datafile'].split(',')
        valid_data = task_list[0].get_test_data(
            embedding_model.get_optimal_batch_sizes(), datafile)

    test_graph = rk.train.TestGraph.from_experiment(experiment, valid_data)

    sess.run(tf.global_variables_initializer())

    print('Model Parameters: {}'.format(embedding_model.count_params()))
    print('Loading task weights from {}'.format(_config['load_task_from']))

    rk.utils.load_distributed(experiment.distribution_strategy, task_model,
                              _config['load_task_from'])

    task_dir = os.path.dirname(_config['load_task_from'])
    outfile = os.path.join(task_dir, 'outputs.pkl')
    print('Saving outputs to {}'.format(outfile))
    test_metrics = test_graph.run_epoch(save_outputs=outfile)
    print(test_metrics.get_average())
    consolidate_data(outfile, include_hidden=True)
コード例 #4
0
def main(_run, _config, tasks: Union[str, List[str]], model: str):
    outdir = _run.observers[0].basedir
    atexit.register(cleanup_folders, outdir)

    sess = setup_tensorflow()

    if isinstance(tasks, str):
        tasks = [tasks]

    embedding_model = ModelBuilder.build_model(model)
    task_list = TaskBuilder.build_tasks(tasks)
    task_model = TaskBuilder.build_task_model(
        embedding_model, task_list, _config['freeze_embedding_weights'])

    experiment = ProteinExperiment(task_model, task_list)

    bounds, batch_sizes = embedding_model.get_optimal_batch_sizes()
    batch_sizes = np.asarray(batch_sizes / len(tasks), np.int32)
    batch_sizes[batch_sizes <= 0] = 1
    train_data, valid_data = get_data(task_list, (bounds, batch_sizes))

    if _config['steps_per_epoch'] != -1:
        train_data = train_data.repeat()

    train_graph = rk.train.TrainGraph.from_experiment(experiment, train_data)
    test_graph = rk.train.TestGraph.from_experiment(experiment, valid_data)

    sess.run(tf.global_variables_initializer())

    print('Model Parameters: {}'.format(embedding_model.count_params()))

    if _config['load_from'] is not None:
        print('Loading weights from {}'.format(_config['load_from']))
        rk.utils.load_distributed(experiment.distribution_strategy,
                                  embedding_model, _config['load_from'])

    if _config['load_task_from'] is not None:
        print('Loading task weights from {}'.format(_config['load_task_from']))
        rk.utils.load_distributed(experiment.distribution_strategy, task_model,
                                  _config['load_task_from'])

    evaluator = MetricEvaluator(task_list[0].key_metric)

    train_graph.initialize()
    for epoch in range(_config['num_epochs']):
        train_metrics = train_graph.run_for_n_steps(_config['steps_per_epoch'],
                                                    epoch_num=epoch)
        outfile = os.path.join(
            outdir, 'outputs.pkl') if _config['save_outputs'] else None
        test_metrics = test_graph.run_epoch(epoch_num=epoch,
                                            save_outputs=outfile)

        if all(
                isinstance(task, AbstractLanguageModelingTask)
                for task in task_list):
            with experiment.distribution_strategy.scope():
                embedding_model.save_weights('{}/epoch_{}.h5'.format(
                    outdir, epoch),
                                             overwrite=True)

        evaluator.check_and_log_metric(train_metrics, test_metrics)

        for name, value in train_metrics.items():
            _run.log_scalar('train.{}'.format(name), value)
        for name, value in test_metrics.items():
            _run.log_scalar('valid.{}'.format(name), value)
        _run.log_scalar('runtime',
                        round(train_metrics.runtime + test_metrics.runtime))

        if evaluator.was_improvement:
            with experiment.distribution_strategy.scope():
                embedding_model.save_weights('{}/best_weights.h5'.format(
                    outdir, overwrite=True))
                task_model.save_weights('{}/task_weights.h5'.format(
                    outdir, overwrite=True))
        else:
            if evaluator.n_epochs_no_improvement >= _config['patience']:
                print(
                    "Early stopping because no improvement in validation loss "
                    "for {} epochs\n".format(_config['patience']))
                break
コード例 #5
0
def run_embed(datafile: str,
              model_name: str,
              load_from: Optional[str] = None,
              task_name: Optional[str] = None):

    datapath = Path(datafile)
    if not datapath.exists():
        raise FileNotFoundError(datapath)
    elif datapath.suffix not in ['.fasta', '.tfrecord', '.tfrecords']:
        raise Exception(
            f"Unknown file type: {datapath.suffix}, must be .fasta or .tfrecord"
        )

    load_path: Optional[Path] = None
    if load_from is not None:
        load_path = Path(load_from)
        if not load_path.exists():
            raise FileNotFoundError(load_path)

    import tensorflow as tf
    import tensorflow.keras.backend as K
    import numpy as np

    from tape.models import ModelBuilder

    sess = tf.InteractiveSession()
    K.set_learning_phase(0)
    embedding_model = ModelBuilder.build_model(model_name)

    if datapath.suffix == '.fasta':
        from Bio import SeqIO
        from tape.data_utils import PFAM_VOCAB
        primary = tf.placeholder(tf.int32, [None, None])
        protein_length = tf.placeholder(tf.int32, [None])
        output = embedding_model({
            'primary': primary,
            'protein_length': protein_length
        })
        sess.run(tf.global_variables_initializer())
        if load_path is not None:
            embedding_model.load_weights(str(load_path))

        embeddings = []
        for record in SeqIO.parse(str(datapath), 'fasta'):
            int_sequence = np.array([PFAM_VOCAB[aa] for aa in record.seq],
                                    ndmin=2)
            encoder_output = sess.run(output['encoder_output'],
                                      feed_dict={
                                          primary: int_sequence,
                                          protein_length:
                                          [int_sequence.shape[1]]
                                      })
            embeddings.append(encoder_output)
    else:
        import contextlib
        if task_name is not None:
            from tape.tasks import TaskBuilder
            task = TaskBuilder.build_task(task_name)
            deserialization_func = task.deserialization_func
        else:
            from tape.data_utils import deserialize_fasta_sequence
            deserialization_func = deserialize_fasta_sequence

        data = tf.data.TFRecordDataset(str(datapath)).map(deserialization_func)
        data = data.batch(1)
        iterator = data.make_one_shot_iterator()
        batch = iterator.get_next()
        output = embedding_model(batch)
        if load_path is not None:
            embedding_model.load_weights(str(load_path))

        embeddings = []
        with contextlib.suppress(tf.errors.OutOfRangeError):
            while True:
                output_batch = sess.run(output['encoder_output'])
                for encoder_output in output_batch:
                    embeddings.append(encoder_output)

    return embeddings