Exemplo n.º 1
0
  def test_overfit(self):
    domain = domains.VariableLengthDiscreteDomain(
        vocab=domains.Vocabulary(
            tokens=['a', 'b', 'c'], include_bos=True, include_eos=True),
        length=9)
    seqs = [
        list('abcabcab'),
        list('bbbbbb'),
        list('cbacbacb'),
    ]
    enc = domain.encode(seqs, pad=True)
    self.assertAllEqual(
        [[0, 1, 2, 0, 1, 2, 0, 1, 4],
         [1, 1, 1, 1, 1, 1, 4, 4, 4],
         [2, 1, 0, 2, 1, 0, 2, 1, 4]
         ], enc)
    enc = np.array(enc)
    model = lm_cls(
        domain=domain,
        learning_rate=0.01,
        dropout_rate=0.0,
        attention_dropout_rate=0.0)
    for _ in range(100):
      metrics = model.fit_batch(enc)

    # 2 less than perfect because the first token is unpredictable given just
    # <BOS>, and there are 3 total examples.
    denom = metrics['denominator'][0]
    correct = metrics['accuracy'][0]
    self.assertEqual((denom - 2)/denom, correct / denom)
Exemplo n.º 2
0
 def _make_pretrained_transformer(self, **kwargs):
     """Trains a transformer to produce strings of alternating a's and b's."""
     seqs = ['abab', 'baba'] * 64
     domain = domains.VariableLengthDiscreteDomain(vocab=domains.Vocabulary(
         tokens=['a', 'b'], include_bos=True, include_eos=True),
                                                   length=len(seqs[0]))
     enc_seqs = np.array(domain.encode(seqs, pad=False))
     lm = lm_cls(domain=domain, learning_rate=0.001, **kwargs)
     lm.fit(enc_seqs, batch_size=len(enc_seqs), epochs=20)
     return lm, domain
Exemplo n.º 3
0
 def test_bos_does_not_appear_in_var_len_output(self):
   """Tests that BOS is not used for padding in var-len domain samples."""
   domain = domains.VariableLengthDiscreteDomain(
       vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True),
       length=10,
   )
   lm = lm_cls(domain=domain)
   samples = lm.sample(10)
   for sample in samples:
     self.assertNotIn(lm.bos_token, sample)
Exemplo n.º 4
0
 def test_only_eos_after_eos(self):
   """Tests that the characters found after EOS are all equal to EOS."""
   domain = domains.VariableLengthDiscreteDomain(
       vocab=domains.Vocabulary(tokens=[0, 1], include_eos=True),
       length=10,
   )
   lm = lm_cls(domain=domain)
   samples = lm.sample(10)
   for sample in samples:
     if lm.eos_token in sample:
       start_eos = np.argwhere(sample == lm.eos_token)[0][0]
       self.assertAllEqual(sample[start_eos:],
                           [lm.eos_token] * (len(sample) - start_eos))
Exemplo n.º 5
0
def make_protein_domain(include_anomalous_amino_acids=True,
                        include_bos=True,
                        include_eos=True,
                        include_pad=True,
                        include_mask=True,
                        length=1024):
    return domains.VariableLengthDiscreteDomain(
        vocab=domains.ProteinVocab(
            include_anomalous_amino_acids=include_anomalous_amino_acids,
            include_bos=include_bos,
            include_eos=include_eos,
            include_pad=include_pad,
            include_mask=include_mask),
        length=length,
    )
Exemplo n.º 6
0
  def setUp(self):
    cls = functools.partial(models.FlaxBERT, **lm_cfg)
    self._domain = domains.VariableLengthDiscreteDomain(
        vocab=domains.ProteinVocab(
            include_anomalous_amino_acids=True,
            include_bos=True,
            include_eos=True,
            include_pad=True,
            include_mask=True),
        length=3)

    lm = cls(domain=self._domain, grad_clip=1.0)
    self.lm = lm
    self.xs = np.array([
        [1, 1, 0],
    ])
    super().setUp()
Exemplo n.º 7
0
"""Dataset preprocessing and pipeline.

Built for Trembl dataset.
"""
import os
import types
from absl import logging
import gin
import numpy as np
import tensorflow.compat.v1 as tf

from protein_lm import domains

protein_domain = domains.VariableLengthDiscreteDomain(
    vocab=domains.ProteinVocab(include_anomalous_amino_acids=True,
                               include_bos=True,
                               include_eos=True),
    length=512)


def dataset_from_tensors(tensors):
    """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset."""
    if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list):
        tensors = tuple(tensors)
    return tf.data.Dataset.from_tensor_slices(tensors)


def _parse_example(value):
    parsed = tf.parse_single_example(
        value, features={'sequence': tf.io.VarLenFeature(tf.int64)})
    sequence = tf.sparse.to_dense(parsed['sequence'])
Exemplo n.º 8
0
Built for Trembl dataset.
"""
import os
import types
from absl import logging
import gin
import numpy as np
import tensorflow.compat.v1 as tf

from protein_lm import domains

protein_domain = domains.VariableLengthDiscreteDomain(
    vocab=domains.ProteinVocab(include_anomalous_amino_acids=True,
                               include_bos=True,
                               include_eos=True,
                               include_pad=True,
                               include_mask=True),
    length=1024)  # TODO(ddohan): Make a `make_protein_domain` fn.


def dataset_from_tensors(tensors):
    """Converts nested tf.Tensors or np.ndarrays to a tf.Data.Dataset."""
    if isinstance(tensors, types.GeneratorType) or isinstance(tensors, list):
        tensors = tuple(tensors)
    return tf.data.Dataset.from_tensor_slices(tensors)


def _parse_example(value):
    parsed = tf.parse_single_example(
        value, features={'sequence': tf.io.VarLenFeature(tf.int64)})