Ejemplo n.º 1
0
    def __call__(self, ids, lookup_style=None):
        """Lookup embeddings.

    Looks up an embedding vector for each value in `ids`. All ids must be within
    [0, vocab_size) to prevent NaNs from propagating.

    Args:
      ids: Tensor of dtype int64.
      lookup_style: Overrides the lookup_style given in the constructor.

    Returns:
      Tensor of tf.shape(ids) + [embedding_dim] and dtype float32.

    Raises:
      ValueError: If lookup_style is not an enum in EmbedLookupStyle or if `ids`
        is not an integer array.
    """
        lookup_style = lookup_style or self._lookup_style
        ids = jnp.asarray(ids)
        if not jnp.issubdtype(ids.dtype, jnp.integer):
            raise ValueError(
                "hk.Embed's __call__ method must take an array of "
                "integer dtype but was called with an array of "
                "{dtype}".format(dtype=ids.dtype))

        if lookup_style == EmbedLookupStyle.ARRAY_INDEX.name:
            return self._embedding[ids]
        elif lookup_style == EmbedLookupStyle.ONE_HOT.name:
            one_hot_ids = basic.one_hot(ids, self._vocab_size)[..., None]
            return (self._embedding * one_hot_ids).sum(axis=-2)
        else:
            raise ValueError(
                "{s} is not a valid enum in EmbedLookupStyle.".format(
                    s=lookup_style))
Ejemplo n.º 2
0
 def test_onehot_shape(self):
     indices = jnp.arange(24, dtype=jnp.float32).reshape([2, 3, 4])
     num_classes = 24
     out = basic.one_hot(indices, num_classes=num_classes)
     self.assertEqual(out.shape, (2, 3, 4, num_classes))