def forward( self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False, ): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Future timesteps can be masked with the `mask_future_timesteps` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ src_len, bsz, out_channels = key.size() tgt_len = query.size(0) assert list(query.size()) == [tgt_len, bsz, out_channels] assert key.size() == value.size() if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.downsample: size = bsz else: size = bsz * self.num_heads k = key v = value q = query if self.project_input: q = self.in_proj_q(q) k = self.in_proj_k(k) v = self.in_proj_v(v) src_len = k.size()[0] q *= self.scaling if not self.downsample: q = q.view(tgt_len, size, self.head_dim) k = k.view(src_len, size, self.head_dim) v = v.view(src_len, size, self.head_dim) q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) if mask_future_timesteps: assert query.size() == key.size(), \ 'mask_future_timesteps only applies to self-attention' attn_weights *= torch.tril( attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), diagonal=-1, )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) attn_weights += torch.triu( attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), diagonal=0)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) tgt_size = tgt_len if use_scalar_bias: attn_weights = scalar_bias(attn_weights, 2) v = scalar_bias(v, 1) tgt_size += 1 if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask.max() > 0: if self.downsample: attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) else: attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -math.inf, ) attn_weights = attn_weights.view(size, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) if self.downsample: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) attn = self.out_proj(attn) return attn, attn_weights
def forward( self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, use_scalar_bias=False, ): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Future timesteps can be masked with the `mask_future_timesteps` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ src_len, bsz, out_channels = key.size() tgt_len = query.size(0) assert list(query.size()) == [tgt_len, bsz, out_channels] assert key.size() == value.size() if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.downsample: size = bsz else: size = bsz * self.num_heads k = key v = value q = query if self.project_input: q = self.in_proj_q(q) k = self.in_proj_k(k) v = self.in_proj_v(v) src_len = k.size()[0] q *= self.scaling if not self.downsample: q = q.view(tgt_len, size, self.head_dim) k = k.view(src_len, size, self.head_dim) v = v.view(src_len, size, self.head_dim) q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) if mask_future_timesteps: assert query.size() == key.size(), \ 'mask_future_timesteps only applies to self-attention' attn_weights *= torch.tril( attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), diagonal=-1, )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) attn_weights += torch.triu( attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), diagonal=0 )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0) tgt_size = tgt_len if use_scalar_bias: attn_weights = scalar_bias(attn_weights, 2) v = scalar_bias(v, 1) tgt_size += 1 if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask.max() > 0: if self.downsample: attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) else: attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -math.inf, ) attn_weights = attn_weights.view(size, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) if self.downsample: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) attn = self.out_proj(attn) return attn, attn_weights