def __call__(self, ids, lookup_style=None): """Lookup embeddings. Looks up an embedding vector for each value in `ids`. All ids must be within [0, vocab_size) to prevent NaNs from propagating. Args: ids: Tensor of dtype int64. lookup_style: Overrides the lookup_style given in the constructor. Returns: Tensor of tf.shape(ids) + [embedding_dim] and dtype float32. Raises: ValueError: If lookup_style is not an enum in EmbedLookupStyle or if `ids` is not an integer array. """ lookup_style = lookup_style or self._lookup_style ids = jnp.asarray(ids) if not jnp.issubdtype(ids.dtype, jnp.integer): raise ValueError( "hk.Embed's __call__ method must take an array of " "integer dtype but was called with an array of " "{dtype}".format(dtype=ids.dtype)) if lookup_style == EmbedLookupStyle.ARRAY_INDEX.name: return self._embedding[ids] elif lookup_style == EmbedLookupStyle.ONE_HOT.name: one_hot_ids = basic.one_hot(ids, self._vocab_size)[..., None] return (self._embedding * one_hot_ids).sum(axis=-2) else: raise ValueError( "{s} is not a valid enum in EmbedLookupStyle.".format( s=lookup_style))
def test_onehot_shape(self): indices = jnp.arange(24, dtype=jnp.float32).reshape([2, 3, 4]) num_classes = 24 out = basic.one_hot(indices, num_classes=num_classes) self.assertEqual(out.shape, (2, 3, 4, num_classes))