Exemple #1
0
 def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
     node_key, edge_key = hk.next_rng_keys(2)
     nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
     edges = graph.edges
     if not self._disable_edge_updates:
         edges = hk.dropout(edge_key, self._dropout_rate, edges)
     return graph._replace(nodes=nodes, edges=edges)
Exemple #2
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.

    Args:
      h: Inputs, [B, T, H].
      mask: Padding mask, [B, T].
      is_training: Whether we're training or not.

    Returns:
      Array of shape [B, T, H].
    """

        init_scale = 2. / np.sqrt(self._num_layers)
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        # Note: names chosen to approximately match those used in the GPT-2 code;
        # see https://github.com/openai/gpt-2/blob/master/src/model.py.
        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = CausalSelfAttention(self._num_heads,
                                         init_scale,
                                         name=f'h{i}_attn')(h_norm, mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h
Exemple #3
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.
        Args:
          h: Inputs, [B, T, H].
          mask: Padding mask, [B, T].
          is_training: Whether we're training or not.
        Returns:
          Array of shape [B, T, H].
        """

        init_scale = 2. / self._num_layers
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = SelfAttention(num_heads=self._num_heads,
                                   key_size=64,
                                   w_init_scale=init_scale,
                                   name=f'h{i}_attn')(h_norm, mask=mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h
 def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
     hiddens = x.shape[-1]
     x = conv1d(x, num_units=self._dense_dim, init_scale=self._init_scale)
     x = jax.nn.relu(x)
     x = hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
     x = conv1d(x, num_units=hiddens, init_scale=self._init_scale)
     return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
Exemple #5
0
 def __call__(self, x, lengths):
     x = self.embed(x)
     x = jax.nn.relu(self.bn1(self.conv1(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn2(self.conv2(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     x = jax.nn.relu(self.bn3(self.conv3(x), is_training=self.is_training))
     x = hk.dropout(hk.next_rng_key(), self.dropout_rate,
                    x) if self.is_training else x
     B, L, D = x.shape
     mask = jnp.arange(0, L)[None, :] >= (lengths[:, None] - 1)
     h0c0_fwd = self.lstm_fwd.initial_state(B)
     new_hx_fwd, new_hxcx_fwd = hk.dynamic_unroll(self.lstm_fwd,
                                                  x,
                                                  h0c0_fwd,
                                                  time_major=False)
     x_bwd, mask_bwd = jax.tree_map(lambda x: jnp.flip(x, axis=1),
                                    (x, mask))
     h0c0_bwd = self.lstm_bwd.initial_state(B)
     new_hx_bwd, new_hxcx_bwd = hk.dynamic_unroll(self.lstm_bwd,
                                                  (x_bwd, mask_bwd),
                                                  h0c0_bwd,
                                                  time_major=False)
     x = jnp.concatenate((new_hx_fwd, jnp.flip(new_hx_bwd, axis=1)),
                         axis=-1)
     return x
Exemple #6
0
    def __call__(self, h: jnp.ndarray, mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.

    Args:
      h: Inputs, [B, T, H].
      mask: Padding mask, [B, T].
      is_training: Whether we're training or not.

    Returns:
      Array of shape [B, T, H].
    """

        init_scale = 2. / np.sqrt(self._num_layers)
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        h = layer_norm(h)
        for _ in range(self._num_layers):
            h_attn = CausalSelfAttention(self._num_heads, init_scale)(h, mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = layer_norm(h + h_attn)
            h_dense = DenseBlock(init_scale)(h)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = layer_norm(h + h_dense)

        return h
Exemple #7
0
 def forward(self, x, is_training):
     # Block 1
     x = jax.nn.relu(self.conv1_1(x))
     x = self.bn1_1(x, is_training)
     x = jax.nn.relu(self.conv1_2(x))
     x = self.bn1_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.2, x)
     # Block 2
     x = jax.nn.relu(self.conv2_1(x))
     x = self.bn2_1(x, is_training)
     x = jax.nn.relu(self.conv2_2(x))
     x = self.bn2_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.3, x)
     # Block 3
     x = jax.nn.relu(self.conv3_1(x))
     x = self.bn3_1(x, is_training)
     x = jax.nn.relu(self.conv3_2(x))
     x = self.bn3_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.4, x)
     # Linear part
     x = hk.Flatten()(x)
     x = jax.nn.relu(self.lin1(x))
     x = self.bn4(x, is_training)
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     x = self.lin2(x)
     return x  # logits
Exemple #8
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 graph_idx: jnp.array, is_training: bool) -> jnp.ndarray:
        """Predict logits or values

        Parameters
        ----------
        node_feats : ndarray of shape (N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch
        adj : ndarray of shape (2, E)
            Batch adjacency list.
            E is the total number of edges in the batch
        graph_idx : ndarray of shape (N,)
            This idx indicate a graph number for node_feats in the batch.
            When the two nodes shows the same graph idx, these belong to the same graph.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        out : ndarray of shape (batch_size, n_out)
            Predicator output.
        """
        predicator_dropout = self.predicator_dropout if is_training is True else 0.0
        node_feats = self.gcn(node_feats, adj, is_training)
        # pooling
        graph_feat = self.pooling(node_feats, graph_idx)
        if predicator_dropout != 0.0:
            graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout,
                                    graph_feat)
        graph_feat = self.fc(graph_feat)
        graph_feat = self.activation(graph_feat)
        out = self.out(graph_feat)
        return out
Exemple #9
0
        def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray:

            dropout_rate = self._dropout_rate if is_training else 0.
            h = hk.Linear(self.output_size,
                          w_init=hk.initializers.Constant(1),
                          b_init=hk.initializers.Constant(1))(x)
            return hk.dropout(hk.next_rng_key(), dropout_rate, h)
Exemple #10
0
  def __call__(
      self,
      inputs: jnp.ndarray,
      dropout_rate: Optional[float] = None,
      rng=None,
  ) -> jnp.ndarray:
    """Connects the module to some inputs.
    Args:
      inputs: A Tensor of shape `[batch_size, input_size]`.
      dropout_rate: Optional dropout rate.
      rng: Optional RNG key. Require when using dropout.
    Returns:
      output: The output of the model of size `[batch_size, output_size]`.
    """
    if dropout_rate is not None and rng is None:
      raise ValueError("When using dropout an rng key must be passed.")
    elif dropout_rate is None and rng is not None:
      raise ValueError("RNG should only be passed when using dropout.")

    rng = hk.PRNGSequence(rng) if rng is not None else None
    num_layers = len(self.layers)

    out = inputs
    for i, layer in enumerate(self.layers):
      out = layer(out)
      if i < (num_layers - 1) or self.activate_final:
        # Only perform dropout if we are activating the output.
        if dropout_rate is not None:
          out = hk.dropout(next(rng), dropout_rate, out)
        out = self.activation(out)

    return out
Exemple #11
0
  def __call__(self, x, rng, is_training=True, **kwargs):
    # This function assumes that the input is batched!
    batch_size, H, W, C = x.shape

    if rng.ndim > 1:
      # In case we did the split in ResNet or CNN
      assert rng.ndim == 2
      assert rng.shape[0] == len(self.channel_sizes)
      rngs = rng
    else:
      rngs = random.split(rng, len(self.channel_sizes))

    for i, (rng, out_channel, kernel_shape) in enumerate(zip(rngs, self.channel_sizes, self.kernel_shapes)):

      if i == len(self.channel_sizes) - 1 and self.gate == True:
        ab = Conv(2*out_channel, kernel_shape, name=f"conv_{i}", **self.conv_kwargs)(x, is_training=is_training)
        a, b = jnp.split(ab, 2, axis=-1)
        x = a*jax.nn.sigmoid(b)
      else:
        x = Conv(out_channel, kernel_shape, name=f"conv_{i}", **self.conv_kwargs)(x, is_training=is_training)

      if self.norm is not None:
        x = self.norm(f"norm_{i}")(x, is_training=is_training)

      if i < len(self.channel_sizes) - 1:
        x = self.nonlinearity(x)

        if self.dropout_rate is not None:
          rate = self.dropout_rate if is_training else 0.0
          x = hk.dropout(rng, rate, x)

    return x
Exemple #12
0
 def __call__(self, inputs, is_training):
     dropout_rate = self._dropout_rate if is_training else 0.0
     h = hk.dropout(hk.next_rng_key(), dropout_rate, inputs)
     h = hk.Linear(self._vocab_size, with_bias=False)(h)
     return hk.BatchNorm(create_scale=False,
                         create_offset=False,
                         decay_rate=0.9)(h, is_training)
 def __call__(self, x, is_training=True, return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     # Stem
     outputs = {}
     out = self.initial_conv(x)
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_avg_var = block(out, is_training=is_training)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_avg_var
     # Final-conv->activation, pool, dropout, classify
     pool = jnp.mean(self.activation(out), [1, 2])
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs
Exemple #14
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Predict logits or values

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch of graphs.
        adj : ndarray of shape (batch_size, N, N)
            Batch adjacency matrix.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        out : ndarray of shape (batch_size, n_out)
            Predicator output.
        """
        predicator_dropout = self.predicator_dropout if is_training is True else 0.0
        node_feats = self.gcn(node_feats, adj, is_training)
        # pooling
        graph_feat = self.pooling(node_feats)
        if predicator_dropout != 0.0:
            graph_feat = hk.dropout(hk.next_rng_key(), predicator_dropout,
                                    graph_feat)
        graph_feat = self.fc(graph_feat)
        graph_feat = self.activation(graph_feat)
        out = self.out(graph_feat)
        return out
Exemple #15
0
    def __call__(self, node_feats: jnp.ndarray, adj: jnp.ndarray,
                 is_training: bool) -> jnp.ndarray:
        """Update node features.

        Parameters
        ----------
        node_feats : ndarray of shape (batch_size, N, in_feats)
            Batch input node features.
            N is the total number of nodes in the batch of graphs.
        adj : ndarray of shape (batch_size, N, N)
            Batch adjacency matrix.
        is_training : bool
            Whether the model is training or not.

        Returns
        -------
        new_node_feats : ndarray of shape (batch_size, N, out_feats)
            Batch new node features.
        """
        dropout = self.dropout if is_training is True else 0.0

        # for batch data
        new_node_feats = jax.vmap(self._update_nodes)(node_feats, adj)
        if self.bias:
            new_node_feats += self.b
        new_node_feats = self.activation(new_node_feats)

        if dropout != 0.0:
            new_node_feats = hk.dropout(hk.next_rng_key(), dropout,
                                        new_node_feats)
        if self.batch_norm:
            new_node_feats = hk.BatchNorm(True, True, 0.9)(new_node_feats,
                                                           is_training)

        return new_node_feats
Exemple #16
0
    def __call__(self, q: jnp.ndarray, k: jnp.ndarray) -> jnp.ndarray:
        """Computes the relative position embedding.

    Args:
      q: The query.
      k: The key.

    Returns:
      Relative position embedding.
    """
        # Use key instead of query to obtain the length.
        batch_size, key_length, num_heads, head_dim = list(k.shape)
        # Content based addressing and global content bias
        content_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_w_bias, k)

        # Relative position encoding
        positional_encodings = self._sinusoidal_pos_emb(key_length, batch_size)
        positional_encodings = hk.dropout(hk.next_rng_key(),
                                          self._dropout_rate,
                                          positional_encodings)
        rel_pos_emb = hk.Conv1D(
            output_channels=self._dim,
            kernel_shape=1,
            with_bias=False,
            w_init=init.RandomNormal(
                stddev=self._init_scale))(positional_encodings)
        rel_pos_emb = jnp.reshape(
            rel_pos_emb, [batch_size, key_length, num_heads, head_dim])

        # Content dependent positional bias and global positional bias
        rel_pos_score = jnp.einsum('bthd,bThd->bhtT', q + self._r_r_bias,
                                   rel_pos_emb)
        rel_pos_score = relative_shift(rel_pos_score)
        assert content_score.shape == rel_pos_score.shape
        return content_score + rel_pos_score
Exemple #17
0
 def wrapped(*args):
     out = fn(*args)
     if is_training:
         mask = hk.dropout(hk.next_rng_key(), dropout_rate,
                           jnp.ones([out.shape[0], 1]))
         out = out * mask
     return out
Exemple #18
0
 def fn(x):
     if dropout:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     if batchnorm:
         x = hk.BatchNorm(create_offset=True, create_scale=True, decay_rate=0.001)(
             x, is_training=True
         )
     return x
Exemple #19
0
 def __call__(self, x, *, is_training):
     dropout_prob = self._dropout_prob if is_training else 0.0
     output_channels = x.shape[-1]
     x = conv_1d(output_channels=self._widening_factor * output_channels,
                 init_scale=self._init_scale)(x)
     x = jax.nn.gelu(x)
     x = conv_1d(output_channels=output_channels,
                 init_scale=self._init_scale)(x)
     return hk.dropout(hk.next_rng_key(), dropout_prob, x)
Exemple #20
0
 def __call__(self, X: jnp.ndarray, dropout: float,
              train: bool) -> Tuple[jnp.ndarray, jnp.ndarray]:
     X = l2_normalize(X)
     if train:
         X = hk.dropout(hk.next_rng_key(), dropout, X)
     h = self.mlp(X)
     mu = h[:, :self.latent_dim]
     log_var = h[:, self.latent_dim:]
     return mu, log_var
Exemple #21
0
 def maybe_dropedge(x):
     """Dropout on edge messages."""
     if not is_training:
         return x
     return x * hk.dropout(
         hk.next_rng_key(),
         dropedge_rate,
         jnp.ones([x.shape[0], 1]),
     )
Exemple #22
0
 def postnet(self, mel: ndarray) -> ndarray:
   x = mel
   for conv, bn in zip(self.postnet_convs, self.postnet_bns):
     x = conv(x)
     if bn is not None:
       x = bn(x, is_training=self.is_training)
       x = jnp.tanh(x)
     x = hk.dropout(hk.next_rng_key(), 0.5, x) if self.is_training else x
   return x
Exemple #23
0
 def mlp_function(X: jnp.ndarray, training: bool) -> Any:
     layers: List[Any] = []
     for d_o in config.intermediate_dims:
         if training:
             layers.append(
                 lambda x: hk.dropout(hk.next_rng_key(), config.dropout, x))
         layers.append(hk.Linear(d_o))
         layers.append(config.activation)
     layers.append(hk.Linear(dim_out))
     return hk.Sequential(layers)(X)
    def __call__(self,
                 x: jnp.ndarray,
                 mask: Optional[jnp.ndarray] = None,
                 should_reset: Optional[jnp.ndarray] = None,
                 cache_steps: int = 0,
                 extra: Optional[jnp.ndarray] = None,
                 extra_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """Computes the outputs of the self attention block.

    Args:
      x: query input [batch, x_timesteps, in_dim].
      mask: attention mask [batch, 1, 1, x_timesteps].
      should_reset: reset marker [batch, timesteps].
      cache_steps: number of timesteps in the cache.
      extra: if provided should be extra key-value input
        [batch, extra_timesteps, in_dim'].
      extra_mask: if provided should be the mask for extra key-value input,
        [batch, extra_timesteps].

    Returns:
      output: block output [batch, x_timesteps, in_dim].
    """
        if self._causal:
            timesteps = x.shape[1]
            batch_size = x.shape[0]
            t = jnp.arange(timesteps, dtype=jnp.int32)
            causal_mask = (t[:, None] >= t[None, :])[None, None, :, :]
            causal_mask = causal_mask.astype(x.dtype)
            if mask is None:
                mask = jnp.broadcast_to(causal_mask,
                                        (batch_size, 1, timesteps, timesteps))
            else:
                mask *= causal_mask
            x = Attention(self._r_w_bias,
                          self._r_r_bias,
                          num_heads=self._num_heads,
                          init_scale=self._init_scale,
                          relative_pos_clamp_len=self._relative_pos_clamp_len,
                          dropout_prob=self._dropout_attn_prob)(
                              x,
                              mask=mask,
                              should_reset=should_reset,
                              cache_steps=cache_steps,
                              extra=extra,
                              extra_mask=extra_mask)
        else:
            x = Attention(self._r_w_bias,
                          self._r_r_bias,
                          num_heads=self._num_heads,
                          init_scale=self._init_scale,
                          dropout_prob=self._dropout_attn_prob)(
                              x, mask=mask, extra=extra, extra_mask=extra_mask)
        return hk.dropout(hk.next_rng_key(), self._dropout_prob, x)
Exemple #25
0
    def __call__(self, inputs, is_training):
        dropout_rate = self._dropout_rate if is_training else 0.0

        h = jax.nn.softplus(hk.Linear(self._hidden)(inputs))
        h = jax.nn.softplus(hk.Linear(self._hidden)(h))
        h = hk.dropout(hk.next_rng_key(), dropout_rate, h)
        h = hk.Linear(self._num_topics)(h)

        # NB: here we set `create_scale=False` and `create_offset=False` to reduce
        # the number of learning parameters
        log_concentration = hk.BatchNorm(create_scale=False,
                                         create_offset=False,
                                         decay_rate=0.9)(h, is_training)
        return jnp.exp(log_concentration)
Exemple #26
0
    def __call__(self, inputs, rng_key, stochastic, is_training,
                 test_local_stats):
        out = shortcut = inputs

        if self.use_projection:
            shortcut = self.proj_conv(shortcut, rng_key, stochastic)
            shortcut = self.proj_batchnorm(shortcut, is_training,
                                           test_local_stats)
            # DROPOUT
            if self.dropout and is_training:
                shortcut = hk.dropout(rng_key, self.dropout_rate, shortcut)

        for i, (conv_i, bn_i) in enumerate(self.layers):
            out = conv_i(out, rng_key, stochastic)
            out = bn_i(out, is_training, test_local_stats)
            if i < len(self.layers
                       ) - 1:  # Don't apply relu or dropout on last layer
                out = jax.nn.relu(out)
                # DROPOUT
                if self.dropout and is_training:
                    out = hk.dropout(rng_key, self.dropout_rate, out)

        return jax.nn.relu(out + shortcut)
Exemple #27
0
    def __call__(self, h: jnp.ndarray, input_embs: jnp.ndarray,
                 mask: Optional[jnp.ndarray],
                 is_training: bool) -> jnp.ndarray:
        """Connects the transformer.
        Args:
          input_embs: Inputs, [B, T, H].
          h: Hidden, [B, T, H].
          h: Hidden, [B, T, H].
          mask: Padding mask, [B, T].
          is_training: Whether we're training or not.
        Returns:
          Array of shape [B, T, H].
        """

        init_scale = 2. / np.sqrt(self._num_layers)
        dropout_rate = self._dropout_rate if is_training else 0.
        if mask is not None:
            mask = mask[:, None, None, :]

        for i in range(self._num_layers):
            # input injections
            h = h + input_embs

            # regular transformer block
            h_norm = layer_norm(h, name=f'h{i}_ln_1')
            h_attn = CausalSelfAttention(self._num_heads,
                                         init_scale,
                                         name=f'h{i}_attn')(h_norm, mask)
            h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
            h = h + h_attn
            h_norm = layer_norm(h, name=f'h{i}_ln_2')
            h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
            h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
            h = h + h_dense
        h = layer_norm(h, name='ln_f')

        return h
Exemple #28
0
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 *,
                 attention_mask=None,
                 is_training):
        dropout_prob = self._dropout_prob if is_training else 0.0
        dropout_attn_prob = self._dropout_attn_prob if is_training else 0.0

        output_channels = inputs_q.shape[-1]
        if self._shape_for_attn == 'q':
            qk_channels = inputs_q.shape[-1]
        elif self._shape_for_attn == 'kv':
            qk_channels = inputs_kv.shape[-1]
        else:
            raise ValueError(f'Unknown value {self._shape_for_attn} for '
                             'shape_for_attention.')

        v_channels = None
        if self._qk_channels is not None:
            qk_channels = self._qk_channels
        if self._v_channels is not None:
            v_channels = self._v_channels

        attention = Attention(num_heads=self._num_heads,
                              init_scale=self._att_init_scale,
                              dropout_prob=dropout_attn_prob,
                              qk_channels=qk_channels,
                              v_channels=v_channels,
                              output_channels=output_channels)(
                                  layer_norm(inputs_q),
                                  layer_norm(inputs_kv),
                                  attention_mask=attention_mask)
        attention = hk.dropout(hk.next_rng_key(), dropout_prob, attention)

        # Optionally include a residual to the query.
        # Consider omitting the residual if the semantics of query and output
        # are different, e.g. if queries are positions and outputs are pixels.
        if self._use_query_residual:
            x = inputs_q + attention
        else:
            x = attention

        x += MLP(widening_factor=self._widening_factor,
                 dropout_prob=dropout_prob,
                 init_scale=self._dense_init_scale)(layer_norm(x),
                                                    is_training=is_training)
        return x
Exemple #29
0
def attend(q, k, v, dropout_prob=0.0, attention_mask=None):
    """Computes multi-head attention using a query, key and value.

  Args:
    q: Query with shape [batch, q_indices, num_heads, head_dim].
    k: Key with shape [batch, kv_indices, num_heads, head_dim].
    v: Value with shape [batch, kv_indices, num_heads, head_dim].
    dropout_prob: dropout probability on the attention weights.
    attention_mask: Array of shape [batch, q_indices, kv_indices] indicating
      which attentions are valid
  Returns:
    Output of the attention with shape [batch, q_indices, hiddens]
  """
    batch, q_indices, num_heads, q_head_dim = q.shape
    _, _, _, v_head_dim = v.shape
    hiddens = num_heads * v_head_dim

    attention = jnp.einsum('bthd,bThd->bhtT', q, k)

    scale = 1. / math.sqrt(q_head_dim)
    attention *= scale

    if attention_mask is not None:
        # Use large_k instead of np.NINF because np.NINF breaks for causal-masked
        # left-padded sampling.
        large_k = jnp.array(1e4 if attention.dtype == jnp.float16 else 1e30,
                            dtype=attention.dtype)

        attention = jnp.where(attention_mask[:, None, :, :], attention,
                              -large_k)

    normalized = jax.nn.softmax(attention)
    if dropout_prob > 0:
        normalized = hk.dropout(hk.next_rng_key(), dropout_prob, normalized)
    summed = jnp.einsum('bhtT,bThd->bthd', normalized, v)
    summed = jnp.reshape(summed, [batch, q_indices, hiddens])

    if attention_mask is not None:
        # If all attended tokens are masked, or for masked tokens
        # some rows of logits gets completely masked, in which case the softmax
        # gives a uniform row and we obtain non-zero outputs where it should be
        # zero. We force zeros.
        wipe_attn = jnp.all(attention_mask == 0, axis=2,
                            keepdims=True)  # shape (B, T, 1)
        summed = jnp.where(wipe_attn, jnp.zeros_like(summed), summed)
    return summed
Exemple #30
0
    def __call__(self, x, rng, is_training=True, update_params=True, **kwargs):
        # This function assumes that the input is batched!
        batch_size, in_dim = x.shape

        rngs = random.split(rng, len(self.layer_sizes))

        for i, (rng, out_dim) in enumerate(zip(rngs, self.layer_sizes)):
            if self.zero_init and i == len(self.layer_sizes) - 1:
                w, b = data_dependent_param_init(
                    x,
                    out_dim,
                    name_suffix=f"{i}",
                    w_init=hk.initializers.RandomNormal(stddev=0.01),
                    b_init=jnp.zeros,
                    is_training=is_training,
                    update_params=update_params,
                    parameter_norm=None)
            else:
                w, b = data_dependent_param_init(
                    x,
                    out_dim,
                    name_suffix=f"{i}",
                    w_init=self.w_init,
                    b_init=self.b_init,
                    is_training=is_training,
                    update_params=update_params,
                    parameter_norm=self.parameter_norm)
            z = jnp.dot(x, w.T) + b

            if i < len(self.layer_sizes) - 1:
                z = self.nonlinearity(z)

            # Residual connection
            if self.skip_connection and x.shape[-1] == z.shape[-1]:
                x += z
            else:
                x = z

            if i < len(self.layer_sizes) - 1:
                if self.dropout_rate is not None:
                    rate = self.dropout_rate if is_training else 0.0
                    x = hk.dropout(rng, rate, x)

        return x