def _scaled_dot_product_attention( q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, ) -> Tuple[Tensor, Tensor]: B, Nt, E = q.shape q = q / math.sqrt(E) # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) attn = flow.bmm(q, k.transpose(-2, -1)) if attn_mask is not None: attn += attn_mask attn = flow.softmax(attn, dim=-1) if dropout_p > 0.0: attn = flow.nn.functional.dropout(attn, p=dropout_p) # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) output = flow.bmm(attn, v) return output, attn
def forward( self, query: flow.Tensor, key: flow.Tensor, value: flow.Tensor, attn_mask: Optional[flow.Tensor] = None, ) -> Tuple[flow.Tensor, flow.Tensor]: r""" Args: query: [batch, num_attention_heads, len_query, dim_query] key: [batch, num_attention_heads, len_key, dim_key] value: [batch, num_attention_heads, len_value, dim_value] attn_mask: [batch, num_attention_heads, len_query, len_key] """ attention = flow.matmul(query, key.transpose(-1, -2)) attention = attention / math.sqrt(query.size(-1)) if attn_mask is not None: attention = attention + attn_mask attention = nn.Softmax(dim=-1)(attention) attention = self.dropout(attention) context = flow.matmul(attention, value) return context, attention