def setUp(self): self._domain = data.make_protein_domain(length=12) self._model = _get_model(self._domain) self._embed_fn = embed.get_embed_fn(model=self._model, domain=self._domain, reduce_fn=embed.masked_reduce_fn) super().setUp()
def test_embed_with_loaded_model(self): p1 = 'ACDEFHIKLNQP' p2 = 'ALNQP' encoded = data.protein_domain.encode([p1, p2]) model = models.FlaxLM( domain=data.protein_domain, num_layers=1, num_heads=1, qkv_dim=32, emb_dim=32, mlp_dim=32) reduction = functools.partial(jax.numpy.sum, axis=-2) # Check we can embed int encoded sequences. embed_with_sum_fn = embed.ProteinLMEmbedder( model=model, output_head='output_emb', length=128, reduce_fn=reduction) int_embs = embed_with_sum_fn(encoded) self.assertEqual((2, 32), int_embs.shape) # Check we can embed strings embed_strings_fn = embed.get_embed_fn(model=model, reduce_fn=reduction) str_embs = embed_strings_fn([p1, p2]) self.assertEqual((2, 32), str_embs.shape)