ModuleDescriptor( name="InstanceNorm", create=lambda: hk.InstanceNorm(True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="GroupNorm", create=lambda: hk.GroupNorm(5), shape=(BATCH_SIZE, 4, 4, 10)), ModuleDescriptor( name="LayerNorm", create=lambda: hk.LayerNorm(1, True, True), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="MultiHeadAttention", create=lambda: MultiInput( # pylint: disable=g-long-lambda hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0), num_inputs=3), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="RMSNorm", create=lambda: hk.RMSNorm(1), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="SpectralNorm", create=lambda: hk.SpectralNorm(), shape=(BATCH_SIZE, 3, 2)), ModuleDescriptor( name="nets.ResNet", create=lambda: Training(hk.nets.ResNet((3, 4, 6, 3), 1000)), shape=(BATCH_SIZE, 3, 3, 2)), # pylint: disable=g-long-lambda
def __call__(self, queries: jnp.ndarray, hm_memory: HierarchicalMemory, hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray: """Do hierarchical attention over the stored memories. Args: queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length Q, and embedding dimension E. hm_memory: Hierarchical Memory. hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false, the corresponding query timepoints cannot attend to the corresponding memory chunks. This can be used for enforcing causal attention on the learner, not attending to memories from prior episodes, etc. Returns: Value updates for each query slot: [B, Q, D] """ # some shape checks batch_size, query_length, _ = queries.shape (memory_batch_size, num_memories, memory_chunk_size, mem_embbedding_size) = hm_memory.contents.shape assert batch_size == memory_batch_size chex.assert_shape(hm_memory.keys, (batch_size, num_memories, mem_embbedding_size)) chex.assert_shape( hm_memory.accumulator, (memory_batch_size, memory_chunk_size, mem_embbedding_size)) chex.assert_shape(hm_memory.steps_since_last_write, (memory_batch_size, )) if hm_mask is not None: chex.assert_type(hm_mask, bool) chex.assert_shape(hm_mask, (batch_size, query_length, num_memories)) query_head = self._singlehead_linear(queries, self._size, "query") key_head = self._singlehead_linear( jax.lax.stop_gradient(hm_memory.keys), self._size, "key") # What times in the input [t] attend to what times in the memories [T]. logits = jnp.einsum("btd,bTd->btT", query_head, key_head) scaled_logits = logits / np.sqrt(self._size) # Mask last dimension, replacing invalid logits with large negative values. # This allows e.g. enforcing causal attention on learner, or blocking # attention across episodes if hm_mask is not None: masked_logits = jnp.where(hm_mask, scaled_logits, -1e6) else: masked_logits = scaled_logits # identify the top-k memories and their relevance weights top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k) weights = jax.nn.softmax(top_k_logits) # set up the within-memory attention assert self._size % self._num_heads == 0 mha_key_size = self._size // self._num_heads attention_layer = hk.MultiHeadAttention(key_size=mha_key_size, model_size=self._size, num_heads=self._num_heads, w_init_scale=self._init_scale, name="within_mem_attn") # position encodings augmented_contents = hm_memory.contents if self._memory_position_encoding: position_embs = sinusoid_position_encoding(memory_chunk_size, mem_embbedding_size) augmented_contents += position_embs[None, None, :, :] def _within_memory_attention(sub_inputs, sub_memory_contents, sub_weights, sub_top_k_indices): top_k_contents = sub_memory_contents[sub_top_k_indices, :, :] # Now we go deeper, with another vmap over **tokens**, because each token # can each attend to different memories. def do_attention(sub_sub_inputs, sub_sub_top_k_contents): tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :], reps=(self._k, 1, 1)) sub_attention_results = attention_layer( query=tiled_inputs, key=sub_sub_top_k_contents, value=sub_sub_top_k_contents) return sub_attention_results do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False) attention_results = do_attention(sub_inputs, top_k_contents) attention_results = jnp.squeeze(attention_results, axis=2) # Now collapse results across k memories attention_results = sub_weights[:, :, None] * attention_results attention_results = jnp.sum(attention_results, axis=1) return attention_results # vmap across batch batch_within_memory_attention = hk_vmap(_within_memory_attention, in_axes=0, split_rng=False) outputs = batch_within_memory_attention( queries, jax.lax.stop_gradient(augmented_contents), weights, top_k_indices) return outputs