Example #1
0
def get_embed_fn(model=None,
                 ckpt_dir=None,
                 output_head='output_emb',
                 reduce_fn=None,
                 length=128):
    """Get a function which maps lists of strings to embeddings.

  Args:
    model: Pretrained model.
    ckpt_dir: Directory to load if model is None.
    output_head: Which model output to return. See embed.FlaxLM
    reduce_fn: Postprocessing function to apply on top of embeddings, such as
      `partial(jax.numpy.sum, axis=-2)`.
    length: If given, use fixed length batches. Otherwise is length of longest
      string in the batch.

  Returns:
    Function which accepts a list of strings, and returns batched embeddings.
  """
    if model is None:
        if ckpt_dir is None:
            raise ValueError(
                'Must provide a loaded model or checkpoint directory.')
        model = models.load_model(ckpt_dir=ckpt_dir)
    else:
        if ckpt_dir is not None:
            raise ValueError(
                'Provide only one of `model` or checkpoint directory.')

    def predict_fn(model_target, inputs):

        emb = models.predict_step(model_target,
                                  inputs,
                                  preprocess_fn=model.preprocess,
                                  output_head=output_head)
        if reduce_fn:
            emb = reduce_fn(emb)
        return emb

    p_predict_step = jax.pmap(predict_fn, axis_name='batch')

    def _embed(protein_sequences):
        """Encode proteins into a batch, embed, and run reduce_fn on output."""
        batch = make_batch(protein_sequences, length=length)
        batch = common_utils.shard(batch)
        result = p_predict_step(model.optimizer.target, batch)

        # Combine leading two dimensions (ndevices, batch_size / n_devices)
        result = jax.numpy.reshape(result, [-1] + list(result.shape[2:]))
        return result

    return _embed
Example #2
0
  def test_load_model(self):
    with gin.config_scope('test'):
      for k, v in lm_cfg.items():
        gin.bind_parameter('FlaxLM.%s' % k, v)

      lm = models.FlaxLM(domain=_test_domain(), random_seed=1)

      save_dir = self._tmpdir / 'save_ckpt'
      lm.save_checkpoint(save_dir)
      config_str = gin.operative_config_str()
      with tf.gfile.GFile(str(save_dir / 'config.gin'), 'w') as f:
        f.write(config_str)

      loaded_model = models.load_model(save_dir, model_cls=models.FlaxLM)
      self.assertAllEqual(
          lm.optimizer.target.params['embed']['embedding'],
          loaded_model.optimizer.target.params['embed']['embedding'])
Example #3
0
def get_embed_fn(model=None,
                 checkpoint_dir=None,
                 domain=None,
                 output_head='output_emb',
                 reduce_fn=None,
                 length=None):
  """Get a function that maps sequences to fixed-length embedding vectors.

  Args:
    model: A FlaxModel (e.g. FlaxLM or FlaxBERT).
    checkpoint_dir: A string directory where the model checkpoint is stored.
    domain: An instance of VariableLengthDiscreteDomain.
    output_head: Which model output to return. See embed.FlaxModel.
    reduce_fn: Postprocessing function to apply on top of embeddings, such as
      `masked_reduce_fn`. The reduce_fn takes and input padded embeddings
      and padded inputs (to allow masking the pad dimensions). If None, no
      reduction is made.
    length: Input sequences will be cropped and padded to have length
      N = min(max_len, length), where max_len is the length of the longest
      sequence in the input data. If length is None, domain.length is used when
      computing N.

  Returns:
    Function which accepts sequences and returns batched embeddings. If the
      the sequences are strings, we first encode them into the domain.
      Otherwise, we assume that they are already encoded.
  """
  if model is None:
    if checkpoint_dir is None:
      raise ValueError('Must provide a loaded model or checkpoint directory.')
    # Note that this assumes that the model_cls is stored in the config dict.
    model = models.load_model(checkpoint_dir=checkpoint_dir)
  else:
    if checkpoint_dir is not None:
      raise ValueError('Provide only one of `model` or checkpoint directory.')

  if domain is None:
    domain = data.protein_domain

  def predict_fn(model_target, inputs):
    emb = models.predict_step(
        model_target,
        inputs,
        preprocess_fn=model.preprocess,
        output_head=output_head)

    if reduce_fn:
      # Pass the inputs to allow padding-aware aggregation.
      emb = reduce_fn(emb, inputs)
    return emb

  if model.pmap:
    p_predict_step = jax.pmap(predict_fn, axis_name='batch')
  else:
    p_predict_step = predict_fn

  def _embed(protein_sequences):
    """Encode proteins into a batch, embed, and run reduce_fn on output."""
    if isinstance(protein_sequences[0], str):
      batch = _encode_string_sequences(protein_sequences,
                                       domain=domain, length=length)
    else:
      if not domain.are_valid(protein_sequences).any():
        raise ValueError('Input int-encoded sequences are not valid members '
                         'of input domain.')
      batch = protein_sequences

    if model.pmap:
      batch = common_utils.shard(batch)
    result = p_predict_step(model.optimizer.target, batch)

    if model.pmap:
      # Combine the leading two dimensions (ndevices, batch_size / n_devices)
      result = jax.numpy.reshape(result, [-1] + list(result.shape[2:]))
    return result

  return _embed