def get_embed_fn(model=None, ckpt_dir=None, output_head='output_emb', reduce_fn=None, length=128): """Get a function which maps lists of strings to embeddings. Args: model: Pretrained model. ckpt_dir: Directory to load if model is None. output_head: Which model output to return. See embed.FlaxLM reduce_fn: Postprocessing function to apply on top of embeddings, such as `partial(jax.numpy.sum, axis=-2)`. length: If given, use fixed length batches. Otherwise is length of longest string in the batch. Returns: Function which accepts a list of strings, and returns batched embeddings. """ if model is None: if ckpt_dir is None: raise ValueError( 'Must provide a loaded model or checkpoint directory.') model = models.load_model(ckpt_dir=ckpt_dir) else: if ckpt_dir is not None: raise ValueError( 'Provide only one of `model` or checkpoint directory.') def predict_fn(model_target, inputs): emb = models.predict_step(model_target, inputs, preprocess_fn=model.preprocess, output_head=output_head) if reduce_fn: emb = reduce_fn(emb) return emb p_predict_step = jax.pmap(predict_fn, axis_name='batch') def _embed(protein_sequences): """Encode proteins into a batch, embed, and run reduce_fn on output.""" batch = make_batch(protein_sequences, length=length) batch = common_utils.shard(batch) result = p_predict_step(model.optimizer.target, batch) # Combine leading two dimensions (ndevices, batch_size / n_devices) result = jax.numpy.reshape(result, [-1] + list(result.shape[2:])) return result return _embed
def test_load_model(self): with gin.config_scope('test'): for k, v in lm_cfg.items(): gin.bind_parameter('FlaxLM.%s' % k, v) lm = models.FlaxLM(domain=_test_domain(), random_seed=1) save_dir = self._tmpdir / 'save_ckpt' lm.save_checkpoint(save_dir) config_str = gin.operative_config_str() with tf.gfile.GFile(str(save_dir / 'config.gin'), 'w') as f: f.write(config_str) loaded_model = models.load_model(save_dir, model_cls=models.FlaxLM) self.assertAllEqual( lm.optimizer.target.params['embed']['embedding'], loaded_model.optimizer.target.params['embed']['embedding'])
def get_embed_fn(model=None, checkpoint_dir=None, domain=None, output_head='output_emb', reduce_fn=None, length=None): """Get a function that maps sequences to fixed-length embedding vectors. Args: model: A FlaxModel (e.g. FlaxLM or FlaxBERT). checkpoint_dir: A string directory where the model checkpoint is stored. domain: An instance of VariableLengthDiscreteDomain. output_head: Which model output to return. See embed.FlaxModel. reduce_fn: Postprocessing function to apply on top of embeddings, such as `masked_reduce_fn`. The reduce_fn takes and input padded embeddings and padded inputs (to allow masking the pad dimensions). If None, no reduction is made. length: Input sequences will be cropped and padded to have length N = min(max_len, length), where max_len is the length of the longest sequence in the input data. If length is None, domain.length is used when computing N. Returns: Function which accepts sequences and returns batched embeddings. If the the sequences are strings, we first encode them into the domain. Otherwise, we assume that they are already encoded. """ if model is None: if checkpoint_dir is None: raise ValueError('Must provide a loaded model or checkpoint directory.') # Note that this assumes that the model_cls is stored in the config dict. model = models.load_model(checkpoint_dir=checkpoint_dir) else: if checkpoint_dir is not None: raise ValueError('Provide only one of `model` or checkpoint directory.') if domain is None: domain = data.protein_domain def predict_fn(model_target, inputs): emb = models.predict_step( model_target, inputs, preprocess_fn=model.preprocess, output_head=output_head) if reduce_fn: # Pass the inputs to allow padding-aware aggregation. emb = reduce_fn(emb, inputs) return emb if model.pmap: p_predict_step = jax.pmap(predict_fn, axis_name='batch') else: p_predict_step = predict_fn def _embed(protein_sequences): """Encode proteins into a batch, embed, and run reduce_fn on output.""" if isinstance(protein_sequences[0], str): batch = _encode_string_sequences(protein_sequences, domain=domain, length=length) else: if not domain.are_valid(protein_sequences).any(): raise ValueError('Input int-encoded sequences are not valid members ' 'of input domain.') batch = protein_sequences if model.pmap: batch = common_utils.shard(batch) result = p_predict_step(model.optimizer.target, batch) if model.pmap: # Combine the leading two dimensions (ndevices, batch_size / n_devices) result = jax.numpy.reshape(result, [-1] + list(result.shape[2:])) return result return _embed