def __init__(self, v_dim, q_dim, mode, dim, num_head, temperature, v_proj, loc_kernel_size, loc_kernel_num): super(Attention, self).__init__() # Setup self.v_dim = v_dim self.dim = dim self.mode = mode.lower() self.num_head = num_head # Linear proj. before attention self.proj_q = nn.Linear(q_dim, dim * num_head) self.proj_k = nn.Linear(v_dim, dim * num_head) self.v_proj = v_proj if v_proj: self.proj_v = nn.Linear(v_dim, v_dim * num_head) # Attention if self.mode == 'dot': self.att_layer = ScaleDotAttention(temperature, self.num_head) elif self.mode == 'loc': self.att_layer = LocationAwareAttention(loc_kernel_size, loc_kernel_num, dim, num_head, temperature) else: raise NotImplementedError # Layer for merging MHA if self.num_head > 1: self.merge_head = nn.Linear(v_dim * num_head, v_dim) # Stored feature self.key = None self.value = None self.mask = None
class Attention(nn.Module): ''' Attention mechanism please refer to http://www.aclweb.org/anthology/D15-1166 section 3.1 for more details about Attention implementation Input : Decoder state with shape [batch size, decoder hidden dimension] Compressed feature from Encoder with shape [batch size, T, encoder feature dimension] Output: Attention score with shape [batch size, num head, T (attention score of each time step)] Context vector with shape [batch size, encoder feature dimension] (i.e. weighted (by attention score) sum of all timesteps T's feature) ''' def __init__(self, v_dim, q_dim, mode, dim, num_head, temperature, v_proj, loc_kernel_size, loc_kernel_num): super(Attention, self).__init__() # Setup self.v_dim = v_dim self.dim = dim self.mode = mode.lower() self.num_head = num_head # Linear proj. before attention ## Q, K, V self.proj_q = nn.Linear(q_dim, dim*num_head) self.proj_k = nn.Linear(v_dim, dim*num_head) self.v_proj = v_proj if v_proj: self.proj_v = nn.Linear(v_dim, v_dim*num_head) # Attention if self.mode == 'dot': self.att_layer = ScaleDotAttention(temperature, self.num_head) elif self.mode == 'loc': self.att_layer = LocationAwareAttention( loc_kernel_size, loc_kernel_num, dim, num_head, temperature) else: raise NotImplementedError # Layer for merging MHA if self.num_head > 1: self.merge_head = nn.Linear(v_dim*num_head, v_dim) # Stored feature self.key = None self.value = None self.mask = None def reset_mem(self): self.key = None self.value = None self.mask = None self.att_layer.reset_mem() def set_mem(self, prev_attn): self.att_layer.set_mem(prev_attn) def forward(self, dec_state, enc_feat, enc_len): # Preprecessing bs, ts, _ = enc_feat.shape query = torch.tanh(self.proj_q(dec_state)) query = query.view(bs, self.num_head, self.dim).view( bs*self.num_head, self.dim) # BNxD if self.key is None: # Maskout attention score for padded states self.att_layer.compute_mask(enc_feat, enc_len.to(enc_feat.device)) # Store enc state to lower computational cost self.key = torch.tanh(self.proj_k(enc_feat)) self.value = torch.tanh(self.proj_v( enc_feat)) if self.v_proj else enc_feat # BxTxN if self.num_head > 1: self.key = self.key.view(bs, ts, self.num_head, self.dim).permute( 0, 2, 1, 3) # BxNxTxD self.key = self.key.contiguous().view(bs*self.num_head, ts, self.dim) # BNxTxD if self.v_proj: self.value = self.value.view( bs, ts, self.num_head, self.v_dim).permute(0, 2, 1, 3) # BxNxTxD self.value = self.value.contiguous().view( bs*self.num_head, ts, self.v_dim) # BNxTxD else: self.value = self.value.repeat(self.num_head, 1, 1) # Calculate attention context, attn = self.att_layer(query, self.key, self.value) if self.num_head > 1: context = context.view( bs, self.num_head*self.v_dim) # BNxD -> BxND context = self.merge_head(context) # BxD return attn, context