def setUp(self):
     self._domain = data.make_protein_domain(length=12)
     self._model = _get_model(self._domain)
     self._embed_fn = embed.get_embed_fn(model=self._model,
                                         domain=self._domain,
                                         reduce_fn=embed.masked_reduce_fn)
     super().setUp()
  def test_embed_with_loaded_model(self):
    p1 = 'ACDEFHIKLNQP'
    p2 = 'ALNQP'
    encoded = data.protein_domain.encode([p1, p2])
    model = models.FlaxLM(
        domain=data.protein_domain,
        num_layers=1,
        num_heads=1,
        qkv_dim=32,
        emb_dim=32,
        mlp_dim=32)
    reduction = functools.partial(jax.numpy.sum, axis=-2)

    # Check we can embed int encoded sequences.
    embed_with_sum_fn = embed.ProteinLMEmbedder(
        model=model, output_head='output_emb', length=128, reduce_fn=reduction)
    int_embs = embed_with_sum_fn(encoded)
    self.assertEqual((2, 32), int_embs.shape)

    # Check we can embed strings
    embed_strings_fn = embed.get_embed_fn(model=model, reduce_fn=reduction)
    str_embs = embed_strings_fn([p1, p2])
    self.assertEqual((2, 32), str_embs.shape)