def forward(self, key, value, query, mask, aw_prev=None, aw_lower=None, cache=False, mode='', trigger_points=None, eps_wait=-1, streaming=False): """Forward pass. Args: key (FloatTensor): `[B, klen, kdim]` value (FloatTensor): `[B, klen, vdim]` query (FloatTensor): `[B, qlen, qdim]` mask (ByteTensor): `[B, qlen, klen]` aw_prev: dummy interface cache (bool): cache key, value, and mask mode: dummy interface for MoChA/MMA trigger_points: dummy interface for MoChA/MMA eps_wait: dummy interface for MMA streaming: dummy interface for streaming attention Returns: cv (FloatTensor): `[B, qlen, vdim]` aw (FloatTensor): `[B, H, qlen, klen]` beta: dummy interface for MoChA/MMA p_choose: dummy interface for MoChA/MMA """ bs, klen = key.size()[:2] qlen = query.size(1) # Pre-computation of encoder-side features for computing scores if self.key is None or not cache: self.key = self.w_key(key).view(bs, -1, self.n_heads, self.d_k) # `[B, klen, H, d_k]` self.value = self.w_value(value).view( bs, -1, self.n_heads, self.d_k) # `[B, klen, H, d_k]` if mask is not None: self.mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads]) mask_size = (bs, qlen, klen, self.n_heads) assert self.mask.size() == mask_size, (self.mask.size(), mask_size) else: self.mask = None key = self.key query = self.w_query(query).view(bs, -1, self.n_heads, self.d_k) # `[B, qlen, H, d_k]` if self.atype == 'scaled_dot': e = torch.einsum("bihd,bjhd->bijh", (query, key)) / self.scale elif self.atype == 'add': e = self.v( torch.tanh(key[:, None] + query[:, :, None]).view( bs, qlen, klen, -1)) # e: `[B, qlen, klen, H]` # Compute attention weights if self.mask is not None: NEG_INF = float( np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min) e = e.masked_fill_(self.mask == 0, NEG_INF) # `[B, qlen, klen, H]` aw = torch.softmax(e, dim=2) aw = self.dropout_attn(aw) aw_masked = aw.clone() # mask out each head independently (HeadDrop) if self.dropout_head > 0 and self.training: aw_masked = aw_masked.permute(0, 3, 1, 2) aw_masked = headdrop(aw_masked, self.n_heads, self.dropout_head) # `[B, H, qlen, klen]` aw_masked = aw_masked.permute(0, 2, 3, 1) cv = torch.einsum("bijh,bjhd->bihd", (aw_masked, self.value)) # `[B, qlen, H, d_k]` cv = cv.contiguous().view(bs, -1, self.n_heads * self.d_k) # `[B, qlen, H * d_k]` cv = self.w_out(cv) aw = aw.permute(0, 3, 1, 2) # `[B, H, qlen, klen]` return cv, aw, None, None
def forward(self, key, query, pos_embs, mask, u_bias=None, v_bias=None): """Forward pass. Args: cat (FloatTensor): `[B, mlen+qlen, kdim]` mask (ByteTensor): `[B, qlen, mlen+qlen]` pos_embs (LongTensor): `[mlen+qlen, 1, d_model]` u_bias (nn.Parameter): `[H, d_k]` v_bias (nn.Parameter): `[H, d_k]` Returns: cv (FloatTensor): `[B, qlen, vdim]` aw (FloatTensor): `[B, H, qlen, mlen+qlen]` """ bs, qlen = query.size()[:2] mlen = key.size(1) - qlen # NOTE: cat already includes memory, i.e., klen=mlen+qlen if mask is not None: mask = mask.unsqueeze(3).repeat([1, 1, 1, self.n_heads]) assert mask.size() == (bs, qlen, mlen + qlen, self.n_heads), \ (mask.size(), (bs, qlen, mlen + qlen, self.n_heads)) k = self.w_key(key).view(bs, -1, self.n_heads, self.d_k) # `[B, mlen+qlen, H, d_k]` v = self.w_value(key).view(bs, -1, self.n_heads, self.d_k) # `[B, mlen+qlen, H, d_k]` q = self.w_query(key[:, -qlen:]).view(bs, -1, self.n_heads, self.d_k) # `[B, qlen, H, d_k]` if self.xl_like: _pos_embs = self.w_pos(pos_embs) else: _pos_embs = self.w_value(pos_embs) # NOTE: this is not w_value _pos_embs = _pos_embs.view(-1, self.n_heads, self.d_k) # `[mlen+qlen, H, d_k]` # content-based attention term: (a) + (c) if u_bias is not None: assert self.xl_like AC = torch.einsum( "bihd,bjhd->bijh", (q + u_bias[None, None], k)) # `[B, qlen, mlen+qlen, H]` else: # A only accutually AC = torch.einsum("bihd,bjhd->bijh", (q, k)) # `[B, qlen, mlen+qlen, H]` # position-based attention term: (b) + (d) if v_bias is not None: assert self.xl_like BD = torch.einsum("bihd,jhd->bijh", (q + v_bias[None, None], _pos_embs)) # `[B, qlen, mlen+qlen, H]` else: # B only accutually BD = torch.einsum("bihd,jhd->bijh", (q, _pos_embs)) # `[B, qlen, mlen+qlen, H]` # Compute positional attention efficiently # BD = self._rel_shift_v1(BD) BD = self._rel_shift_v2(BD) # the attention is the sum of content-based and position-based attention e = (AC + BD) / self.scale # `[B, qlen, mlen+qlen, H]` # Compute attention weights if mask is not None: NEG_INF = float( np.finfo(torch.tensor(0, dtype=e.dtype).numpy().dtype).min) e = e.masked_fill_(mask == 0, NEG_INF) # `[B, qlen, mlen+qlen, H]` aw = torch.softmax(e, dim=2) aw = self.dropout_attn(aw) # `[B, qlen, mlen+qlen, H]` aw_masked = aw.clone() # mask out each head independently (HeadDrop) if self.dropout_head > 0 and self.training: aw_masked = aw_masked.permute(0, 3, 1, 2) aw_masked = headdrop(aw_masked, self.n_heads, self.dropout_head) # `[B, H, qlen, klen]` aw_masked = aw_masked.permute(0, 2, 3, 1) cv = torch.einsum("bijh,bjhd->bihd", (aw, v)) # `[B, qlen, H, d_k]` cv = cv.contiguous().view(bs, -1, self.n_heads * self.d_k) # `[B, qlen, H * d_k]` cv = self.w_out(cv) aw = aw.permute(0, 3, 1, 2) # `[B, H, qlen, mlen+qlen]` return cv, aw