コード例 #1
0
def attention(query, key, value, mask=None, dropout=None):
	"""Application of generalised attention

	[Inputs]
	query : standard query matrix of size:(None, no_query, head_dim)
	key : standary key matrix of size : (None, no_keys, head_dim
	values  : standardn value matrix of size : (None, no_keys=no_values, model_dim)
	mask : mask matrix of shape (None, no_query, no_keys)
	dropout : dropout rate

	[Output]
	context_vectors : context results after attention of size : (None, no_query, model_dim)
	p_attn : matrix of attention probabilities to help in visualisation of size : (None, no_query, no_keys)"""

	d_k = query.size(-1)
	scores = torch.matmul(quer, key.transpose(-2,-1)) / math.sqrt(d_k)

	if mask is not None:
		scores = scores.masked_fill(mask == 0, -1e9)

	p_attn = F.softmax(scores, dim=-1)
	if dropout is not None:
		p_attn = dropout(p_attn)

	return torch.matmul(p_attn, value), p_attn
コード例 #2
0
ファイル: model.py プロジェクト: SiddeshSambasivam/minGPT
    def forward(self, x, layer_past=None):
        B, T, C = x.size()
        # | B -> Batch 
        # | T -> Time step (sequence len) 
        # | C -> Embedding Dim

        # B x nh x T x hs
        k = self.key(x).view(B,T, self.n_head, C //  self.n_head).transpose(1,2) 
        q = self.query(x).view(B,T, self.n_head, C //  self.n_head).transpose(1,2)
        v = self.value(x).view(B,T, self.n_head, C //  self.n_head).transpose(1,2)

        # How does tensor multiplication works? Like how to check 
        # if two tensors are compatible for tensor multiplication
        att = (q @ k.transpose(-2,-1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T]==0, float('-inf'))
        att = F.softmax(att, dim=1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B,nh,T,hs) => (B, nh, T, hs)
        y = y.transpose(1,2).contiguous().view(B,T,C)

        y = self.resid_drop(self.proj(y))
        return y