def _create_cluster(self, cluster_spec, index, start_token_id):
        """Creates a cluster given its spec."""
        token_count = cluster_spec["token_count"]
        embedding_size = cluster_spec.get("embedding_size",
                                          self._output_dim.size)
        length_projection_factor = cluster_spec.get("length_projection_factor",
                                                    1)
        if length_projection_factor <= 0 or length_projection_factor > 1:
            raise ValueError(
                "Invalid length_projection_factor of {}. Must be in range (0, 1]"
                .format(length_projection_factor))

        if index == 0:
            # Include the entries for the tail clusters in the head cluster "vocab".
            cluster_vocab_dim = mtf.Dimension(
                self._vocab_dim.name, token_count + self._num_clusters - 1)
        else:
            cluster_vocab_dim = mtf.Dimension(self._vocab_dim.name,
                                              token_count)

        if embedding_size == self._output_dim.size:
            # In this case we don't need to up project from the embedding space to
            # the model state space.
            cluster_embedding = transformer.VocabEmbedding(
                mesh=self._mesh,
                vocab_dim=cluster_vocab_dim,
                output_dim=self._output_dim,
                variable_dtype=self._variable_dtype,
                name="{}_{}".format(self._name, index),
                ensemble_dim=self._ensemble_dim)
        else:
            cluster_embedding = vocab_embeddings.FactorizedVocabEmbedding(
                mesh=self._mesh,
                vocab_dim=cluster_vocab_dim,
                output_dim=self._output_dim,
                variable_dtype=self._variable_dtype,
                name="{}_{}".format(self._name, index),
                ensemble_dim=self._ensemble_dim,
                inner_dimension_size=embedding_size)
        return _Cluster(embedding=cluster_embedding,
                        start_token_id=start_token_id,
                        end_token_id=start_token_id + token_count,
                        length_projection_factor=length_projection_factor,
                        vocab_dim=cluster_vocab_dim)
Exemple #2
0
  def __init__(self,
               mesh,
               vocab_dim,
               output_dim,
               variable_dtype,
               name,
               ensemble_dim,
               clusters=gin.REQUIRED):
    """Configurable embedding for the vocabulary.

    Most of the arguments get passed to `mtf.layers.embedding_weights`.

    The clustering parameters are specified by the `clusters` argument. It is a
    list of dicts with keys "token_count" and "embedding_size". Token count
    specifies the number of tokens in the cluster, and embedding size specifies
    the hidden dimension size of its embedding.

    For example, let's say we have a vocab size of 500k and pass as clusters:
      [
        {"token_count": 50000,  "embedding_size": 1024},
        {"token_count": 100000, "embedding_size": 256},
        {"token_count": 350000, "embedding_size": 64},
      ]
    Then tokens with ids 0 (inclusive) to 50k (exclusive) will be in the first
    cluster with embedding size of 1024, tokens with ids 50k to 150k will be in
    the second cluster with embedding size of 256, and tokens with ids 150k to
    500k will be in the third cluster with embedding size of 64.

    Args:
      mesh: a mtf.Mesh
      vocab_dim: a mtf.Dimension
      output_dim: a mtf.Dimension
      variable_dtype: a mtf.VariableDType
      name: a string
      ensemble_dim: a mtf.Dimension
      clusters: a list(dict), specification of the clusters

    Raises:
      ValueError: The sum of the token counts across the clusters does not equal
        the vocabulary size.
    """
    self._vocab_dim = vocab_dim
    self._output_dim = output_dim

    token_counts = [cluster["token_count"] for cluster in clusters]
    if sum(token_counts) != vocab_dim.size:
      raise ValueError(
          "The cluster token counts {} do not sum to the vocab size {}.".format(
              token_counts, vocab_dim.size))

    self._clusters = []
    start_token_id = 0
    for i, cluster in enumerate(clusters):
      token_count = cluster["token_count"]
      embedding_size = cluster["embedding_size"]
      cluster_vocab_dim = mtf.Dimension(vocab_dim.name, token_count)

      if embedding_size == self._output_dim.size:
        # In this case we don't need to up project from the embedding space to
        # the model state space.
        cluster_embedding = transformer.VocabEmbedding(
            mesh=mesh,
            vocab_dim=cluster_vocab_dim,
            output_dim=output_dim,
            variable_dtype=variable_dtype,
            name="{}_{}".format(name, i),
            ensemble_dim=ensemble_dim)
      else:
        cluster_embedding = FactorizedVocabEmbedding(
            mesh=mesh,
            vocab_dim=cluster_vocab_dim,
            output_dim=output_dim,
            variable_dtype=variable_dtype,
            name="{}_{}".format(name, i),
            ensemble_dim=ensemble_dim,
            inner_dimension_size=embedding_size)
      self._clusters.append(
          _Cluster(
              embedding=cluster_embedding,
              start_token_id=start_token_id,
              end_token_id=start_token_id + token_count))
      start_token_id += token_count