Esempio n. 1
0
 ModuleDescriptor(
     name="InstanceNorm",
     create=lambda: hk.InstanceNorm(True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="GroupNorm",
     create=lambda: hk.GroupNorm(5),
     shape=(BATCH_SIZE, 4, 4, 10)),
 ModuleDescriptor(
     name="LayerNorm",
     create=lambda: hk.LayerNorm(1, True, True),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="MultiHeadAttention",
     create=lambda: MultiInput(  # pylint: disable=g-long-lambda
         hk.MultiHeadAttention(num_heads=8, key_size=64, w_init_scale=1.0),
         num_inputs=3),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="RMSNorm",
     create=lambda: hk.RMSNorm(1),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="SpectralNorm",
     create=lambda: hk.SpectralNorm(),
     shape=(BATCH_SIZE, 3, 2)),
 ModuleDescriptor(
     name="nets.ResNet",
     create=lambda: Training(hk.nets.ResNet((3, 4, 6, 3), 1000)),
     shape=(BATCH_SIZE, 3, 3, 2)),
 # pylint: disable=g-long-lambda
    def __call__(self,
                 queries: jnp.ndarray,
                 hm_memory: HierarchicalMemory,
                 hm_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """Do hierarchical attention over the stored memories.

    Args:
       queries: Tensor [B, Q, E] Query(ies) in, for batch size B, query length
         Q, and embedding dimension E.
       hm_memory: Hierarchical Memory.
       hm_mask: Optional boolean mask tensor of shape [B, Q, M]. Where false,
         the corresponding query timepoints cannot attend to the corresponding
         memory chunks. This can be used for enforcing causal attention on the
         learner, not attending to memories from prior episodes, etc.

    Returns:
      Value updates for each query slot: [B, Q, D]
    """
        # some shape checks
        batch_size, query_length, _ = queries.shape
        (memory_batch_size, num_memories, memory_chunk_size,
         mem_embbedding_size) = hm_memory.contents.shape
        assert batch_size == memory_batch_size
        chex.assert_shape(hm_memory.keys,
                          (batch_size, num_memories, mem_embbedding_size))
        chex.assert_shape(
            hm_memory.accumulator,
            (memory_batch_size, memory_chunk_size, mem_embbedding_size))
        chex.assert_shape(hm_memory.steps_since_last_write,
                          (memory_batch_size, ))
        if hm_mask is not None:
            chex.assert_type(hm_mask, bool)
            chex.assert_shape(hm_mask,
                              (batch_size, query_length, num_memories))
        query_head = self._singlehead_linear(queries, self._size, "query")
        key_head = self._singlehead_linear(
            jax.lax.stop_gradient(hm_memory.keys), self._size, "key")

        # What times in the input [t] attend to what times in the memories [T].
        logits = jnp.einsum("btd,bTd->btT", query_head, key_head)

        scaled_logits = logits / np.sqrt(self._size)

        # Mask last dimension, replacing invalid logits with large negative values.
        # This allows e.g. enforcing causal attention on learner, or blocking
        # attention across episodes
        if hm_mask is not None:
            masked_logits = jnp.where(hm_mask, scaled_logits, -1e6)
        else:
            masked_logits = scaled_logits

        # identify the top-k memories and their relevance weights
        top_k_logits, top_k_indices = jax.lax.top_k(masked_logits, self._k)
        weights = jax.nn.softmax(top_k_logits)

        # set up the within-memory attention
        assert self._size % self._num_heads == 0
        mha_key_size = self._size // self._num_heads
        attention_layer = hk.MultiHeadAttention(key_size=mha_key_size,
                                                model_size=self._size,
                                                num_heads=self._num_heads,
                                                w_init_scale=self._init_scale,
                                                name="within_mem_attn")

        # position encodings
        augmented_contents = hm_memory.contents
        if self._memory_position_encoding:
            position_embs = sinusoid_position_encoding(memory_chunk_size,
                                                       mem_embbedding_size)
            augmented_contents += position_embs[None, None, :, :]

        def _within_memory_attention(sub_inputs, sub_memory_contents,
                                     sub_weights, sub_top_k_indices):
            top_k_contents = sub_memory_contents[sub_top_k_indices, :, :]

            # Now we go deeper, with another vmap over **tokens**, because each token
            # can each attend to different memories.
            def do_attention(sub_sub_inputs, sub_sub_top_k_contents):
                tiled_inputs = jnp.tile(sub_sub_inputs[None, None, :],
                                        reps=(self._k, 1, 1))
                sub_attention_results = attention_layer(
                    query=tiled_inputs,
                    key=sub_sub_top_k_contents,
                    value=sub_sub_top_k_contents)
                return sub_attention_results

            do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False)
            attention_results = do_attention(sub_inputs, top_k_contents)
            attention_results = jnp.squeeze(attention_results, axis=2)
            # Now collapse results across k memories
            attention_results = sub_weights[:, :, None] * attention_results
            attention_results = jnp.sum(attention_results, axis=1)
            return attention_results

        # vmap across batch
        batch_within_memory_attention = hk_vmap(_within_memory_attention,
                                                in_axes=0,
                                                split_rng=False)
        outputs = batch_within_memory_attention(
            queries, jax.lax.stop_gradient(augmented_contents), weights,
            top_k_indices)

        return outputs