def attention(query, key, value, mask=None, dropout=None): """Application of generalised attention [Inputs] query : standard query matrix of size:(None, no_query, head_dim) key : standary key matrix of size : (None, no_keys, head_dim values : standardn value matrix of size : (None, no_keys=no_values, model_dim) mask : mask matrix of shape (None, no_query, no_keys) dropout : dropout rate [Output] context_vectors : context results after attention of size : (None, no_query, model_dim) p_attn : matrix of attention probabilities to help in visualisation of size : (None, no_query, no_keys)""" d_k = query.size(-1) scores = torch.matmul(quer, key.transpose(-2,-1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn
def forward(self, x, layer_past=None): B, T, C = x.size() # | B -> Batch # | T -> Time step (sequence len) # | C -> Embedding Dim # B x nh x T x hs k = self.key(x).view(B,T, self.n_head, C // self.n_head).transpose(1,2) q = self.query(x).view(B,T, self.n_head, C // self.n_head).transpose(1,2) v = self.value(x).view(B,T, self.n_head, C // self.n_head).transpose(1,2) # How does tensor multiplication works? Like how to check # if two tensors are compatible for tensor multiplication att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf')) att = F.softmax(att, dim=1) att = self.attn_drop(att) y = att @ v # (B, nh, T, T) x (B,nh,T,hs) => (B, nh, T, hs) y = y.transpose(1,2).contiguous().view(B,T,C) y = self.resid_drop(self.proj(y)) return y