def _create_cluster(self, cluster_spec, index, start_token_id): """Creates a cluster given its spec.""" token_count = cluster_spec["token_count"] embedding_size = cluster_spec.get("embedding_size", self._output_dim.size) length_projection_factor = cluster_spec.get("length_projection_factor", 1) if length_projection_factor <= 0 or length_projection_factor > 1: raise ValueError( "Invalid length_projection_factor of {}. Must be in range (0, 1]" .format(length_projection_factor)) if index == 0: # Include the entries for the tail clusters in the head cluster "vocab". cluster_vocab_dim = mtf.Dimension( self._vocab_dim.name, token_count + self._num_clusters - 1) else: cluster_vocab_dim = mtf.Dimension(self._vocab_dim.name, token_count) if embedding_size == self._output_dim.size: # In this case we don't need to up project from the embedding space to # the model state space. cluster_embedding = transformer.VocabEmbedding( mesh=self._mesh, vocab_dim=cluster_vocab_dim, output_dim=self._output_dim, variable_dtype=self._variable_dtype, name="{}_{}".format(self._name, index), ensemble_dim=self._ensemble_dim) else: cluster_embedding = vocab_embeddings.FactorizedVocabEmbedding( mesh=self._mesh, vocab_dim=cluster_vocab_dim, output_dim=self._output_dim, variable_dtype=self._variable_dtype, name="{}_{}".format(self._name, index), ensemble_dim=self._ensemble_dim, inner_dimension_size=embedding_size) return _Cluster(embedding=cluster_embedding, start_token_id=start_token_id, end_token_id=start_token_id + token_count, length_projection_factor=length_projection_factor, vocab_dim=cluster_vocab_dim)
def __init__(self, mesh, vocab_dim, output_dim, variable_dtype, name, ensemble_dim, clusters=gin.REQUIRED): """Configurable embedding for the vocabulary. Most of the arguments get passed to `mtf.layers.embedding_weights`. The clustering parameters are specified by the `clusters` argument. It is a list of dicts with keys "token_count" and "embedding_size". Token count specifies the number of tokens in the cluster, and embedding size specifies the hidden dimension size of its embedding. For example, let's say we have a vocab size of 500k and pass as clusters: [ {"token_count": 50000, "embedding_size": 1024}, {"token_count": 100000, "embedding_size": 256}, {"token_count": 350000, "embedding_size": 64}, ] Then tokens with ids 0 (inclusive) to 50k (exclusive) will be in the first cluster with embedding size of 1024, tokens with ids 50k to 150k will be in the second cluster with embedding size of 256, and tokens with ids 150k to 500k will be in the third cluster with embedding size of 64. Args: mesh: a mtf.Mesh vocab_dim: a mtf.Dimension output_dim: a mtf.Dimension variable_dtype: a mtf.VariableDType name: a string ensemble_dim: a mtf.Dimension clusters: a list(dict), specification of the clusters Raises: ValueError: The sum of the token counts across the clusters does not equal the vocabulary size. """ self._vocab_dim = vocab_dim self._output_dim = output_dim token_counts = [cluster["token_count"] for cluster in clusters] if sum(token_counts) != vocab_dim.size: raise ValueError( "The cluster token counts {} do not sum to the vocab size {}.".format( token_counts, vocab_dim.size)) self._clusters = [] start_token_id = 0 for i, cluster in enumerate(clusters): token_count = cluster["token_count"] embedding_size = cluster["embedding_size"] cluster_vocab_dim = mtf.Dimension(vocab_dim.name, token_count) if embedding_size == self._output_dim.size: # In this case we don't need to up project from the embedding space to # the model state space. cluster_embedding = transformer.VocabEmbedding( mesh=mesh, vocab_dim=cluster_vocab_dim, output_dim=output_dim, variable_dtype=variable_dtype, name="{}_{}".format(name, i), ensemble_dim=ensemble_dim) else: cluster_embedding = FactorizedVocabEmbedding( mesh=mesh, vocab_dim=cluster_vocab_dim, output_dim=output_dim, variable_dtype=variable_dtype, name="{}_{}".format(name, i), ensemble_dim=ensemble_dim, inner_dimension_size=embedding_size) self._clusters.append( _Cluster( embedding=cluster_embedding, start_token_id=start_token_id, end_token_id=start_token_id + token_count)) start_token_id += token_count