def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t_ = self.dropout(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def _check_args(self, src, lengths=None, hidden=None): _, n_batch, _ = src.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, is_mask=False): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: if not is_mask: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) else: mask = memory_lengths.byte() mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) # attn_h = self.dropout(attn_h) # if self.attn_type in ["general", "dot"]: # attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors