Ejemplo n.º 1
0
def _scaled_dot_product_attention(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    attn_mask: Optional[Tensor] = None,
    dropout_p: float = 0.0,
) -> Tuple[Tensor, Tensor]:
    B, Nt, E = q.shape
    q = q / math.sqrt(E)
    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    attn = flow.bmm(q, k.transpose(-2, -1))
    if attn_mask is not None:
        attn += attn_mask
    attn = flow.softmax(attn, dim=-1)
    if dropout_p > 0.0:
        attn = flow.nn.functional.dropout(attn, p=dropout_p)
    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = flow.bmm(attn, v)
    return output, attn
Ejemplo n.º 2
0
 def forward(
     self,
     query: flow.Tensor,
     key: flow.Tensor,
     value: flow.Tensor,
     attn_mask: Optional[flow.Tensor] = None,
 ) -> Tuple[flow.Tensor, flow.Tensor]:
     r"""
     Args:
         query: [batch, num_attention_heads, len_query, dim_query]
         key: [batch, num_attention_heads, len_key, dim_key]
         value: [batch, num_attention_heads, len_value, dim_value]
         attn_mask: [batch, num_attention_heads, len_query, len_key]
     """
     attention = flow.matmul(query, key.transpose(-1, -2))
     attention = attention / math.sqrt(query.size(-1))
     if attn_mask is not None:
         attention = attention + attn_mask
     attention = nn.Softmax(dim=-1)(attention)
     attention = self.dropout(attention)
     context = flow.matmul(attention, value)
     return context, attention