def forward(self, x, mask): """ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) mask : (B(batch_size) x S(seq_len)) * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) #q, k, v = torch.squeeze(self.proj_q(torch.unsqueeze(x, dim=1)), dim=1), \ # torch.squeeze(self.proj_k(torch.unsqueeze(x, dim=1)), dim=1), \ # self.proj_v(x) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) if mask is not None: mask = mask[:, None, None, :].float() scores -= 10000.0 * (1.0 - mask) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D) h = merge_last(h, 2) self.scores = scores return h
def forward(self, x, mask): """ :param x, q, k, v: (Batch_size, Seq_len, Dim) :param mask: (Batch_size, Seq_len) * split Dim into (H(n_heads), W(width of head)) ; Dim = H * W :return: (Batch, seq_len, Dim) """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) # n_head 수만큼 쪼개서 사용 q, k, v (batch, n_head, seq_length, head_features) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # Scale Dot Product Attention 부분(multi head인 경우 고려) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) if mask is not None: mask = mask[:, None, None, :].float() scores -= 10000.0 * (1.0 - mask) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D=H*W) h = merge_last(h, 2) self.scores = scores return h
def forward(self, x): # mask = None """ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(hidden_dim)) mask : (B(batch_size) x S(seq_len)) * split D(hidden_dim) into (H(n_heads), W(width of head)) ; D = H * W """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D) h = merge_last(h, 2) self.scores = scores return h