def _get_model(domain): return models.FlaxLM(domain=domain, num_layers=1, num_heads=1, qkv_dim=32, emb_dim=32, mlp_dim=32)
def _get_lm(domain, use_dropout=True): cfg = dict(batch_size=1, num_layers=2, num_heads=2, emb_dim=32, mlp_dim=32, qkv_dim=32) if not use_dropout: cfg.update(dict(dropout_rate=0., attention_dropout_rate=0.)) return models.FlaxLM(domain, **cfg)
def test_load_model(self): with gin.config_scope('test'): for k, v in lm_cfg.items(): gin.bind_parameter('FlaxLM.%s' % k, v) lm = models.FlaxLM(domain=_test_domain(), random_seed=1) save_dir = self._tmpdir / 'save_ckpt' lm.save_checkpoint(save_dir) config_str = gin.operative_config_str() with tf.gfile.GFile(str(save_dir / 'config.gin'), 'w') as f: f.write(config_str) loaded_model = models.load_model(save_dir, model_cls=models.FlaxLM) self.assertAllEqual( lm.optimizer.target.params['embed']['embedding'], loaded_model.optimizer.target.params['embed']['embedding'])
def test_embed_with_loaded_model(self): p1 = 'ACDEFHIKLNQP' p2 = 'ALNQP' encoded = data.protein_domain.encode([p1, p2]) model = models.FlaxLM( domain=data.protein_domain, num_layers=1, num_heads=1, qkv_dim=32, emb_dim=32, mlp_dim=32) reduction = functools.partial(jax.numpy.sum, axis=-2) # Check we can embed int encoded sequences. embed_with_sum_fn = embed.ProteinLMEmbedder( model=model, output_head='output_emb', length=128, reduce_fn=reduction) int_embs = embed_with_sum_fn(encoded) self.assertEqual((2, 32), int_embs.shape) # Check we can embed strings embed_strings_fn = embed.get_embed_fn(model=model, reduce_fn=reduction) str_embs = embed_strings_fn([p1, p2]) self.assertEqual((2, 32), str_embs.shape)