Example #1
0
def _get_model(domain):
    return models.FlaxLM(domain=domain,
                         num_layers=1,
                         num_heads=1,
                         qkv_dim=32,
                         emb_dim=32,
                         mlp_dim=32)
Example #2
0
def _get_lm(domain, use_dropout=True):
    cfg = dict(batch_size=1,
               num_layers=2,
               num_heads=2,
               emb_dim=32,
               mlp_dim=32,
               qkv_dim=32)
    if not use_dropout:
        cfg.update(dict(dropout_rate=0., attention_dropout_rate=0.))
    return models.FlaxLM(domain, **cfg)
Example #3
0
  def test_load_model(self):
    with gin.config_scope('test'):
      for k, v in lm_cfg.items():
        gin.bind_parameter('FlaxLM.%s' % k, v)

      lm = models.FlaxLM(domain=_test_domain(), random_seed=1)

      save_dir = self._tmpdir / 'save_ckpt'
      lm.save_checkpoint(save_dir)
      config_str = gin.operative_config_str()
      with tf.gfile.GFile(str(save_dir / 'config.gin'), 'w') as f:
        f.write(config_str)

      loaded_model = models.load_model(save_dir, model_cls=models.FlaxLM)
      self.assertAllEqual(
          lm.optimizer.target.params['embed']['embedding'],
          loaded_model.optimizer.target.params['embed']['embedding'])
  def test_embed_with_loaded_model(self):
    p1 = 'ACDEFHIKLNQP'
    p2 = 'ALNQP'
    encoded = data.protein_domain.encode([p1, p2])
    model = models.FlaxLM(
        domain=data.protein_domain,
        num_layers=1,
        num_heads=1,
        qkv_dim=32,
        emb_dim=32,
        mlp_dim=32)
    reduction = functools.partial(jax.numpy.sum, axis=-2)

    # Check we can embed int encoded sequences.
    embed_with_sum_fn = embed.ProteinLMEmbedder(
        model=model, output_head='output_emb', length=128, reduce_fn=reduction)
    int_embs = embed_with_sum_fn(encoded)
    self.assertEqual((2, 32), int_embs.shape)

    # Check we can embed strings
    embed_strings_fn = embed.get_embed_fn(model=model, reduce_fn=reduction)
    str_embs = embed_strings_fn([p1, p2])
    self.assertEqual((2, 32), str_embs.shape)