Esempio n. 1
0
  def __call__(self, x):
    B, H, W, C = x.shape  # pylint: disable=invalid-name,unused-variable
    assert C % self.num_heads == 0
    head_dim = C // self.num_heads

    h = Normalize(name='norm')(x)

    assert h.shape == (B, H, W, C)
    h = h.reshape(B, H * W, C)
    q = nn.DenseGeneral(features=(self.num_heads, head_dim), name='q')(h)
    k = nn.DenseGeneral(features=(self.num_heads, head_dim), name='k')(h)
    v = nn.DenseGeneral(features=(self.num_heads, head_dim), name='v')(h)
    assert q.shape == k.shape == v.shape == (B, H * W, self.num_heads, head_dim)
    h = nn.dot_product_attention(query=q, key=k, value=v)
    assert h.shape == (B, H * W, self.num_heads, head_dim)
    h = nn.DenseGeneral(
        features=C,
        axis=(-2, -1),
        kernel_init=nn.initializers.zeros,
        name='proj_out')(
            h)
    assert h.shape == (B, H * W, C)
    h = h.reshape(B, H, W, C)
    assert h.shape == x.shape
    return x + h
    def __call__(self,
                 hidden_states,
                 attention_mask,
                 deterministic=True,
                 output_attentions: bool = False):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        value_states = self.value(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))
        key_states = self.key(
            hidden_states).reshape(hidden_states.shape[:2] +
                                   (self.config.num_attention_heads, head_dim))

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_output = dot_product_attention(
            query_states,
            key_states,
            value_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        outputs = (attn_output.reshape(attn_output.shape[:2] + (-1, )), )

        # TODO: at the moment it's not possible to retrieve attn_weights from
        # dot_product_attention, but should be in the future -> add functionality then

        return outputs
Esempio n. 3
0
    def __call__(self, language, vision, hidden):
        input_q = jnp.concatenate([language, hidden], axis=-1)

        query = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                                name="query")(input_q)
        key = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                              name="key")(language)
        value = nn.DenseGeneral(features=(self.num_heads, self.head_dim),
                                name="memory_value")(vision)

        x = nn.dot_product_attention(query, key, value)

        out = nn.DenseGeneral(features=self.out_features,
                              axis=(-2, -1),
                              name="out")(x)
        return out
    def __call__(self, hidden_states, attention_mask, deterministic=True):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_output = dot_product_attention(
            query_states,
            key_states,
            value_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        return attn_output.reshape(attn_output.shape[:2] + (-1,))
Esempio n. 5
0
  def __call__(self, x):
    B, H, W, C = x.shape  # pylint: disable=invalid-name,unused-variable

    if self.head_dim is None:
      assert self.num_heads is not None
      assert C % self.num_heads == 0
      num_heads = self.num_heads
      head_dim = C // num_heads
    else:
      assert self.num_heads is None
      assert C % self.head_dim == 0
      head_dim = self.head_dim
      num_heads = C // head_dim

    h = Normalize(name='norm')(x)

    assert h.shape == (B, H, W, C)
    h = h.reshape(B, H * W, C)
    q = nn.DenseGeneral(features=(num_heads, head_dim), name='q')(h)
    k = nn.DenseGeneral(features=(num_heads, head_dim), name='k')(h)
    v = nn.DenseGeneral(features=(num_heads, head_dim), name='v')(h)
    assert q.shape == k.shape == v.shape == (B, H * W, num_heads, head_dim)
    h = nn.dot_product_attention(query=q, key=k, value=v)
    assert h.shape == (B, H * W, num_heads, head_dim)
    h = nn.DenseGeneral(
        features=C,
        axis=(-2, -1),
        kernel_init=nn.initializers.zeros,
        name='proj_out')(h)
    assert h.shape == (B, H * W, C)
    h = h.reshape(B, H, W, C)
    assert h.shape == x.shape
    logging.info(
        '%s: x=%r num_heads=%d head_dim=%d',
        self.name, x.shape, num_heads, head_dim)
    return x + h
Esempio n. 6
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        qkv_out = self.c_attn(hidden_states)
        query, key, value = jnp.split(qkv_out, 3, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
        )

        # usual dot product attention
        attn_output = dot_product_attention(
            query,
            key,
            value,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        # TODO: at the moment it's not possible to retrieve attn_weights from
        # dot_product_attention, but should be in the future -> add functionality then

        return (attn_output, )