def forward(self, source, memory_bank, memory_lengths=None, memory_turns = None, coverage=None): # here we implement a hierarchical attention if source.dim() == 2: source = source.unsqueeze(1) batch, source_tl, source_wl, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # word level attention word_align = self.word_score(source, memory_bank.contiguous() .view(batch, -1, dim)) # transform align (b, 1, tl * wl) -> (b * tl, 1, wl) word_align = word_align.view(batch * source_tl, 1, source_wl) if memory_lengths is not None: word_mask = sequence_mask_herd(memory_lengths.view(-1), max_len=word_align.size(-1)) word_mask = word_mask.unsqueeze(1) # Make it broadcastable. word_align.masked_fill_(1 - word_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": word_align_vectors = F.softmax(word_align.view(batch * source_tl, source_wl), -1) else: word_align_vectors = sparsemax(word_align.view(batch * source_tl, source_wl), -1) # mask the all padded sentences sent_pad_mask = memory_lengths.view(-1).eq(0).unsqueeze(1) word_align_vectors = torch.mul(word_align_vectors, (1.0 - sent_pad_mask).type_as(word_align_vectors)) word_align_vectors = word_align_vectors.view(batch * source_tl, target_l, source_wl) # each context vector c_t is the weighted average # over all the source hidden states cw = torch.bmm(word_align_vectors, memory_bank.view(batch * source_tl, source_wl, -1)) cw = cw.view(batch, source_tl, -1) # concat_cw = torch.cat([cw, source.repeat(1, source_tl, 1)], 2).view(batch*source_tl, -1) # attn_hw = self.word_linear_out(concat_cw).view(batch, source_tl, -1) # attn_hw = torch.tanh(attn_hw) # turn level attention turn_align = self.turn_score(source, cw) if memory_turns is not None: turn_mask = sequence_mask(memory_turns, max_len=turn_align.size(-1)) turn_mask = turn_mask.unsqueeze(1) # Make it broadcastable. turn_align.masked_fill_(1 - turn_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": turn_align_vectors = F.softmax(turn_align.view(batch * target_l, source_tl), -1) else: turn_align_vectors = sparsemax(turn_align.view(batch * target_l, source_tl), -1) turn_align_vectors = turn_align_vectors.view(batch, target_l, source_tl) # each context vector c_t is the weighted average # over all the source hidden states ct = torch.bmm(turn_align_vectors, cw) return ct.squeeze(1), None
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): # here we do not need to calculate the align # because the answer vector is already averaged representations if source.dim() == 2: source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) return c.squeeze(1), align_vectors
def forward(self, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ batch, source_l, dim = memory_bank.size() source = self.source source = source.expand(batch, -1) source = source.unsqueeze(1) batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # batch, target_l, dim c = c.mean(dim=1) # batch, dim # Check output sizes batch_, dim_ = c.size() aeq(batch, batch_) aeq(dim, dim_) return c
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ if source.dim() == 2: source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, self.indim + self.outdim) attn_h = self.linear_out(concat_c).view(batch, target_l, self.outdim) return attn_h.squeeze(1), align_vectors.squeeze(1)
def forward(cls, ctx, input, target): """ input (FloatTensor): n x num_classes target (LongTensor): n, the indices of the target classes """ assert_equal(input.shape[0], target.shape[0]) p_star = sparsemax(input, 1) cls.p_star = p_star.clone().detach() loss = _omega_sparsemax(p_star) p_star.scatter_add_(1, target.unsqueeze(1), torch.full_like(p_star, -1)) loss += torch.einsum("ij,ij->i", p_star, input) ctx.save_for_backward(p_star) return loss
def attn_map(self, Z): if self.attn_func == "softmax": return F.softmax(Z, -1) elif self.attn_func == "esoftmax": return esoftmax(Z, -1) elif self.attn_func == "sparsemax": return sparsemax(Z, -1) elif self.attn_func == "tsallis15": return tsallis15(Z, -1) elif self.attn_func == "tsallis": if self.attn_alpha == 2: # slightly faster specialized impl return sparsemax_bisect(Z, self.bisect_iter) else: return tsallis_bisect(Z, self.attn_alpha, self.bisect_iter) raise ValueError("invalid combination of arguments")
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, sent_align_vectors=None, sent_nums=None): """ Only one-step attention is supported now. Args: source (`FloatTensor`): query vectors `[batch x dim]` memory_bank (`FloatTensor`): word_memory_bank is `FloatTensor` with shape `[batch x s_num x s_len x dim]` sent_lens (`LongTensor`): for word_memory_bank, `[batch x s_num]` coverage (`FloatTensor`): None (not supported yet) sent_align_vectors (`FloatTensor`): the computed sentence align distribution, `[batch x s_num]` sent_nums (`LongTensor`): the sentence numbers of inputs, `[batch]` use_tanh (`bool`): True, whether use tanh activation function for `general` and 'dot' attention Returns: (`FloatTensor`, `FloatTensor`): * Computed word attentional vector `[batch x dim]` * Word Attention distribtutions for the query of word `[batch x s_num x s_len]` """ # only one step input is supported assert source.dim( ) == 2, "Only one step input is supported for current attention." one_step = True # [batch, 1, dim] source = source.unsqueeze(1) batch, tgt_l, dim = source.size() # check the specification for word level attention assert sent_align_vectors is not None, "For word level attention, the 'sent_align' must be specified." assert sent_nums is not None, "For word level attention, the 'sent_nums' must be specified." assert memory_lengths is not None, "The lengths for the word memory bank are required." sent_lens = memory_lengths batch_1, s_num, s_len, dim_ = memory_bank.size() batch_2, s_num_ = sent_align_vectors.size() batch_3 = sent_nums.size(0) aeq(batch, batch_1, batch_2, batch_3) aeq(dim, dim_, self.dim) aeq(s_num, s_num_) # if coverage is not None: # batch_, source_l_ = coverage.size() # aeq(batch, batch_) # aeq(source_l, source_l_) # # if coverage is not None: # cover = coverage.view(-1).unsqueeze(1) # memory_bank += self.linear_cover(cover).view_as(memory_bank) # memory_bank = torch.tanh(memory_bank) # compute word attention scores, as in Luong et al. # [batch, s_num, s_len, dim] -> [batch, s_num * s_len, dim] memory_bank = memory_bank.view(batch, s_num * s_len, dim) # [batch, 1, s_num * s_len] word_align = self.score(source, memory_bank) # [batch, s_num * s_len] word_align = word_align.squeeze(1) # [batch, s_num, s_len] word_align = word_align.view(batch, s_num, s_len) # remove the empty sentences # [s_toal, s_len], [s_total] valid_word_align, valid_sent_lens = valid_src_compress( word_align, sent_nums=sent_nums, sent_lens=sent_lens) # [s_toal, s_len] word_mask = sequence_mask(valid_sent_lens, max_len=valid_word_align.size(-1)) # word_mask = word_mask.view(batch, s_num, s_len) # # [batch, s_num] # sent_mask = sequence_mask(sent_nums, max_len=s_num) # # [batch, s_num, 1] # sent_mask = sent_mask.unsqueeze(2) # # [batch, s_num, s_len] # align_vectors.masked_fill_(1 - sent_mask, 0.0) # [s_total, s_len] valid_word_align.masked_fill_(1 - word_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(valid_word_align, -1) else: align_vectors = sparsemax(valid_word_align, -1) # Recover the original shape by pad 0.s for empty sentence # [batch, s_num, s_len] align_vectors = recover_src(align_vectors, sent_nums) # # For the whole invalid sentence, we set all the word aligns to 0s. # # Since # # [batch, s_num] # sent_mask = sequence_mask(sent_nums, max_len=s_num) # # [batch, s_num, 1] # sent_mask = sent_mask.unsqueeze(2) # # [batch, s_num, s_len] # align_vectors.masked_fill_(1 - sent_mask, 0.0) # [batch, s_num, 1] sent_align_vectors = sent_align_vectors.unsqueeze(-1) # [batch, s_num, s_len] align_vectors = align_vectors * sent_align_vectors # each context vector c_t is the weighted average # over all the source hidden states # [batch, 1, s_num * s_len] align_vectors = align_vectors.view(batch, -1).unsqueeze(1) # [batch, 1, s_num * s_len] x [batch, s_num * s_len, dim] -> [batch, 1, dim] c = torch.bmm(align_vectors, memory_bank) # [batch, dim] c = c.squeeze(1) returned_vec = c # If output_attn_h == False, we put linear out layer into decoder part if self.output_attn_h: # concatenate # [batch, dim] source = source.squeeze(1) # [batch, 2*dim] concat_c = torch.cat([c, source], 1) # [batch, dim] attn_h = self.linear_out(concat_c) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) returned_vec = attn_h align_vectors = align_vectors.squeeze(1).view(batch, s_num, s_len) # Check output sizes batch_, dim_ = returned_vec.size() aeq(batch, batch_) aeq(dim, dim_) # check batch_, s_num_, s_len_ = align_vectors.size() aeq(batch, batch_) aeq(s_num, s_num_) return returned_vec, align_vectors
def forward(self, source, memory_bank, memory_lengths=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x tgt_enc_dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x src_enc_dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` Returns: (`FloatTensor`): * Attention distribtutions for each query `[batch x tgt_len x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, src_enc_dim = memory_bank.size() batch_, target_l, tgt_enc_dim = source.size() aeq(batch, batch_) aeq(self.src_enc_dim, src_enc_dim) aeq(self.tgt_enc_dim, tgt_enc_dim) # compute attention scores, as in Luong et al. # (batch, t_len, s_len) align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # # each context vector c_t is the weighted average # # over all the source hidden states # c = torch.bmm(align_vectors, memory_bank) # # # concatenate # concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) # attn_h = self.linear_out(concat_c).view(batch, target_l, dim) # if self.attn_type in ["general", "dot"]: # attn_h = torch.tanh(attn_h) if one_step: # attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes # batch_, dim_ = attn_h.size() aeq(batch, batch_) # aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: # attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.contiguous() # Check output sizes # target_l_, batch_, dim_ = attn_h.size() # aeq(target_l, target_l_) aeq(batch, batch_) # aeq(dim, dim_) batch_, target_l_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` rnn output memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` encoder out memory_lengths (`LongTensor`): the source context lengths `[batch]` encoder out length coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) # attn 이랑 share하면 이까지 안들어감. memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. # print('source', source.size()) # print('memory_bank', memory_bank.size()) align = self.score(source, memory_bank) # 对应公式 计算attention权重公式 # print('align', align.size()) # align torch.Size([bz, 1, WxH]) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # c (5, 1, 512) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) #ot = tanh(Wc[ht; ct]) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = torch.log_softmax( align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) if one_step: align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def forward(self, source, memory_bank,memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # print ('Source..',source.size()) # print ('memory_bank..',memory_bank.size()) # Source..torch.Size([16, 512]) # memory_bank..torch.Size([16, 400, 512]) # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: #???? mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. # fix for tensor version > 1.2 # refer to https://github.com/OpenNMT/OpenNMT-py/pull/1527/commits/234f9a5f6fca989fe6804e44ea68b58786ed58b8 align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) # print ('Atten Hidden...',attn_h.size()) # torch.Size([16, 512]) # print ('Align...',align_vectors.size()) # torch.Size([16, 400]) return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): # BxT'xd BxTxd B """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) # (B x T' x T) # align[b,t',t] = score(w_t, w_t') in batch b if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) # {0,1}^{B x T} mask = mask.unsqueeze( 1) # Make it broadcastable: {0,1}^{B x 1 x T} align.masked_fill_(1 - mask, -float('inf')) # align[b,t',t] = -inf if t > len(b) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": # (BT' x T): apply softmax on each column align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # B x T' x T: [b,t',:] is a distrib like [0.02, 0.02, ..., 0.0, 0.0] # each context vector c_t is the weighted average # over all the source hidden states # # (B x T' x T) (B x T x d) ---------> (B x T' x d) c = torch.bmm(align_vectors, memory_bank) # c[b,t',:] = context vector for w_t' in batch b # context + query concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) # shrink it: (BT' x 2d) ---> (B x T' x d) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) # Eq. (5) in Luong (2015) # T' = 1 (e.g., input feeding) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() # TRANSPOSED align_vectors = align_vectors.transpose(0, 1).contiguous() # '' # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) # (T' x B x d) (T' x B x T) return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[batch x dim]` * Attention distribtutions for each query `[batch x src_len]` """ # one step input assert source.dim() == 2, "Only one step input is supported" #one_step = True source = source.unsqueeze(1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) # compute attention scores, as in Luong et al. # [batch x 1 x src_len] align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if self.coverage_attn and coverage is not None: # [batch, src_len] mask = sequence_mask(memory_lengths, max_len=align.size(-1)) # [batch, src_len] coverage_reversed = -1 * coverage coverage_reversed.masked_fill_(1 - mask, -float('inf')) coverage_reversed = F.softmax(coverage_reversed, -1) coverage_reversed = coverage_reversed.unsqueeze(1) # we only use the coverage_reversed to rescale the current sent attention and do not backward the gradient coverage_reversed = coverage_reversed.detach() align_vectors = align_vectors * coverage_reversed norm_term = align_vectors.sum(dim=2, keepdim=True) align_vectors = align_vectors / norm_term # each context vector c_t is the weighted average # over all the source hidden states # [batch, target_l, dim] c = torch.bmm(align_vectors, memory_bank) # [batch, dim] returned_vec = c.squeeze(1) # # concatenate if self.output_attn_h: concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) attn_h = attn_h.squeeze(1) returned_vec = attn_h align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = returned_vec.size() aeq(batch, batch_) aeq(dim, dim_) # Check output sizes batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) return returned_vec, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, sent_align_vectors=None, sent_position_tuple=None, src_word_sent_ids=None): """ Only one-step attention is supported now. Args: source (`FloatTensor`): query vectors `[batch x dim]` memory_bank (`FloatTensor`): word_memory_bank is `FloatTensor` with shape `[batch x src_len x dim]` memory_lengths (`LongTensor`): for word_memory_bank, `[batch]` coverage (`FloatTensor`): None (not supported yet) sent_align_vectors (`FloatTensor`): the computed sentence align distribution, `[batch x s_num]` sent_position_tuple (:obj: `tuple`): Only used for seqhr_enc (sent_p, sent_nums) with size `([batch_size, s_num, 2], [batch])`. src_word_sent_ids (:obj: `tuple'): (word_sent_ids, src_lengths) with size `([batch, src_len], [batch])' use_tanh (`bool`): True, whether use tanh activation function for `general` and 'dot' attention Returns: (`FloatTensor`, `FloatTensor`): * Computed word attentional vector `[batch x dim]` * Word Attention distribtutions for the query of word `[batch x src_len]` """ # only one step input is supported assert source.dim( ) == 2, "Only one step input is supported for current attention." assert isinstance(sent_position_tuple, tuple) sent_position, sent_nums = sent_position_tuple one_step = True # [batch, 1, dim] source = source.unsqueeze(1) batch, tgt_l, dim = source.size() # check the specification for word level attention assert sent_align_vectors is not None, "For word level attention, the 'sent_align' must be specified." assert sent_position is not None, "For word level attention, the 'sent_position' must be specified." assert sent_nums is not None, "For word level attention, the 'sent_nums' must be specified." assert memory_lengths is not None, "The lengths for the word memory bank are required." sent_lens = memory_lengths batch_1, src_len, dim_ = memory_bank.size() batch_2, sent_num = sent_align_vectors.size() batch_3 = sent_nums.size(0) aeq(batch, batch_1, batch_2, batch_3) aeq(dim, dim_, self.dim) # if coverage is not None: # batch_, source_l_ = coverage.size() # aeq(batch, batch_) # aeq(source_l, source_l_) # # if coverage is not None: # cover = coverage.view(-1).unsqueeze(1) # memory_bank += self.linear_cover(cover).view_as(memory_bank) # memory_bank = torch.tanh(memory_bank) # compute word attention scores, as in Luong et al. # [batch, 1, src_len] word_align = self.score(source, memory_bank) # [batch, src_len] word_align = word_align.squeeze(1) # [batch, src_len] word_mask = sequence_mask(memory_lengths, max_len=word_align.size(-1)) # [batch, src_len] word_align.masked_fill_(1 - word_mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(word_align, -1) else: align_vectors = sparsemax(word_align, -1) if self.seqHRE_attn_rescale: word_sent_ids, memory_lengths_ = src_word_sent_ids assert memory_lengths.eq(memory_lengths_).all(), \ "The src lengths in src_word_sent_ids should be the same as the memory_lengths" # # attention score reweighting method 1 # # broadcast the sent_align_vectors from [batch, sent_num] to [batch, src_len] # # according to sent_position [batch, sent_num, 2] and sent_nums [batch] # expand_sent_align_vectors = [] # for b_idx in range(batch): # one_ex_expand = [] # for sent_idx in range(sent_num): # sent_token_num = sent_position[b_idx][sent_idx][0] - sent_position[b_idx][sent_idx][1] + 1 # if sent_token_num != 1: # one_ex_expand.append(sent_align_vectors[b_idx][sent_idx].expand(sent_token_num)) # else: # break # one_ex_expand = torch.cat(one_ex_expand, dim=0) # if one_ex_expand.size(0) < src_len: # pad_vector = torch.zeros([src_len - one_ex_expand.size(0)], # dtype=one_ex_expand.dtype, device=one_ex_expand.device) # one_ex_expand = torch.cat([one_ex_expand, pad_vector], dim=0).contiguous() # expand_sent_align_vectors.append(one_ex_expand) # # # [batch, src_len] # expand_sent_align_vectors = torch.stack(expand_sent_align_vectors, dim=0).contiguous() # # reweight and renormalize the word align_vectors # align_vectors = align_vectors * expand_sent_align_vectors # norm_term = align_vectors.sum(dim=1, keepdim=True) # align_vectors = align_vectors / norm_term # attention score reweighting method 2 # word_sent_ids: [batch, src_len] # sent_align_vectors: [batch, sent_num] # expand_sent_align_vectors: [batch, src_len] expand_sent_align_vectors = sent_align_vectors.gather( dim=1, index=word_sent_ids) # # reweight and renormalize the word align_vectors # Although word_sent_ids are padded with 0s which will gather the attention score of the sentence 0 # align_vectors are 0.0000 on these padded places. align_vectors = align_vectors * expand_sent_align_vectors norm_term = align_vectors.sum(dim=1, keepdim=True) align_vectors = align_vectors / norm_term # each context vector c_t is the weighted average # over all the source hidden states # [batch, 1, src_len] align_vectors = align_vectors.unsqueeze(1) # [batch, 1, src_len] x [batch, src_len, dim] -> [batch, 1, dim] c = torch.bmm(align_vectors, memory_bank) # [batch, dim] c = c.squeeze(1) returned_vec = c # If output_attn_h == False, we put linear out layer on decoder part if self.output_attn_h: # concatenate # [batch, dim] source = source.squeeze(1) # [batch, 2*dim] concat_c = torch.cat([c, source], 1) # [batch, dim] attn_h = self.linear_out(concat_c) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) returned_vec = attn_h # [batch, src_len] align_vectors = align_vectors.squeeze(1) # Check output sizes batch_1, dim_ = returned_vec.size() batch_2, _ = align_vectors.size() aeq(batch, batch_1, batch_2) aeq(dim, dim_) return returned_vec, align_vectors
rel2id[id2emotion[e]] for e in tgt_emotions[s_id:e_id] ]) # (batch, ) # print(batch_rel_ids) # batch_score = -torch.norm(ent_emb[batch_concept_ids].unsqueeze(2) + # rel_emb[batch_rel_ids].unsqueeze(1).unsqueeze(2) - ent_emb.unsqueeze(0).unsqueeze(0), dim=-1) # (batch, max_num, vocab), costly batch_score = torch.mm((ent_emb[batch_concept_ids] + rel_emb[batch_rel_ids].unsqueeze(1)).view(-1, ent_emb_dim), ent_emb.transpose(1,0))\ .view(len(batch_concept_ids), max_num_concepts, -1) # (batch, max_num, vocab) # batch_score = -torch.norm(ent_emb[batch_concept_ids].unsqueeze(2) * # rel_emb[batch_rel_ids].unsqueeze(1).unsqueeze(2) - ent_emb.unsqueeze(0).unsqueeze(0), p=1, dim=-1) # (batch, max_num, vocab), costly if s_id == 0: print(src_concepts[:3]) print(tgt_emotions[:3]) if use_sparsemax: batch_attn = sparsemax(sparsemax_temp * batch_score, -1) # (batch, max_num, vocab) combined_emb = torch.mm( batch_attn.view(-1, batch_attn.shape[-1]), concept_embedding).view(-1, max_num_concepts, emb_dim) elif top_k != 0: top_k_scores, top_k_indices = batch_score.topk( top_k, dim=2) # (batch, max_num, top_k), (batch, max_num, top_k) top_k_attn = torch.softmax(top_k_scores, dim=-1) # (batch, max_num, top_k) if s_id == 0: print("Top k probs: ", top_k_attn[:3, 0]) # augment VAD if concept_VAD_strength_embedding is not None: VAD_attn = torch.softmax( concept_VAD_strength_temp *
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, modification_method=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. # dimension: batch x target len x source len align = self.score(source, memory_bank) if modification_method is not None: align = align.detach() # Is this OK? top_indices = torch.argsort(align, descending=True) memory_lengths_vector = memory_lengths.cpu() for i in range(align.shape[0]): true_length = memory_lengths_vector[i]#true_length_vector[i] #memory_lengths[i] if modification_method == 'uniform': align[i,:,:true_length] = 1 continue for j in range(align.shape[1]): if modification_method == 'zero_out_max': max_index = align[i][j][0:true_length].argmax() align[i][j][max_index] = -float('inf') elif modification_method == 'random_permute': rand_indices = torch.randperm(true_length, requires_grad=False) #cloned = align[i,j,rand_indices].clone() align[i,j,0:true_length] = align[i,j,rand_indices]#.clone() #align[i,j,0:true_length] = cloned elif modification_method == 'second_max': #top_indices = torch.argsort(align[i][j][:true_length], descending=True) if true_length <= 2: continue first_max = None second_max = None third_max = None for el in top_indices[i][j].cpu(): if el >= true_length:#.cuda(el.get_device()): continue if first_max is None: first_max = el elif second_max is None: second_max = el elif third_max is None: third_max = el else: break align[i][j][first_max] = align[i][j][third_max] else: print(">>> shit (Nothing was selected for attetnion modification) <<< ") if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def get_sparse_attention(src_ent, rel): src_ent_emb_ = ent_emb[ent2id[src_ent]] rel_emb_ = rel_emb[rel2id[rel]] return sparsemax(-torch.norm(src_ent_emb_ + rel_emb_ - ent_emb, dim=1), -1)
def test_sparsemax(): for _ in range(10): x = 0.5 * torch.randn(10, 30000, dtype=torch.float32) p1 = sparsemax(x, 1) p2 = sparsemax_bisect(x) assert torch.sum((p1 - p2)**2) < 1e-7
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): ''' added by zhengquan 调用的时候是 #rnn_output.size() = [3,100] =[batch_size, hidden_size] if self.attentional: #p_attn是什么呢?为什么每次的维度还不一样呢?p.attn.size()=[3,32]=[batch_size,32]32是什么?最可怕的是这个维度的大小还随着样例的不同而改变。 #memory_bank.size() = [32,3,100] =[src_len,batch_size,rnn_size],可能上面的那个src也就是src_len #memory_lengths.size() = [3] ,memory_lengths= [32,32,31] decoder_output, p_attn = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"].append(p_attn) ''' """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) #[batch_size , 1 , hidden_size] else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) #[batch_size,tgt_l(=1),src_l] # print("I love you") # import pdb # pdb.set_trace() if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, experiment_type=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) experiment_type : Type of experiment. Possible values: permute, zero_out, equal_weight, last_state Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) #import math #print("size of align: ") #print(align.size()) #max_attention = -math.inf #max_index = 0 #for i in range(align.size()[2]): # if(align[0][0][i] >= max_attention): # max_attention = align[0][0][i] # max_index = i # print(align[0][0][i]) #print("Attended mostly to source position %d with attention value: %f" % (max_index, max_attention)) #print("#"*20) # i is batch index, j is target token index, third dimension is for source token index if experiment_type is not None and experiment_type not in ['zero_out', 'keep_max_uniform_others']: assert (align.size()[1] == 1) new_align = align.clone().cpu().numpy() for i in range(align.size()[0]): length = memory_lengths[i] if memory_lengths is not None else new_align[i][j].size()[0] for j in range(align.size()[1]): if experiment_type == 'permute': max_index = new_align[i][j][0:length].argmax() succeed = False for _ in range(10): random.shuffle(new_align[i][j][0:length]) if(new_align[i][j][0:length].argmax() != max_index): succeed = True break if succeed is False: print("Couldn't permute properly! Be careful") elif experiment_type == 'uniform': new_align[i][j][0:length] = 1 elif experiment_type == 'last_state': new_align[i][j][0:length] = -float('inf') new_align[i][j][length-1] = 1 elif experiment_type == 'only_max': keep_k = 1 indices = new_align[i][j][0:length].argsort()[-keep_k:][::-1] backup = np.copy(new_align[i][j][indices]) new_align[i][j][0:length] = -float('inf') for k in range(keep_k): new_align[i][j][indices[k]] = backup[k] elif experiment_type == 'zero_out_max': max_index = new_align[i][j][0:length].argmax() new_align[i][j][max_index] = -float('inf') else: print(">>> non of them is True <<<") assert False align = torch.from_numpy(new_align).cuda() if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch*target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch*target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) if experiment_type == 'keep_max_uniform_others': for i in range(align_vectors.size()[0]): # Batch assert (align_vectors.size()[1] == 1) length = memory_lengths[i] if memory_lengths is not None else align_vectors[i][0].size()[0] if(length == 1): print("length of source is 1!") continue max_index = align_vectors[i][0][0:length].argmax() max_val = align_vectors[i][0][max_index].item() assert (max_val > 0) assert (max_val <= 1) align_vectors[i][0][0:length] = (1 - max_val) * 1.0 / float(length - 1) align_vectors[i][0][max_index] = max_val # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch*target_l, dim*2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) if experiment_type == 'zero_out': attn_h = torch.zeros(attn_h.size()).cuda() return attn_h, align_vectors
def forward(self, source, memory_bank, n, latt=False, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False # one_step = True #latt for lattice use only batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = F.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # print('align', align) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) # latt comment align vectors else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate context and source concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = F.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: # print('test else in global attn', attn_h.size()) #test # attn_h = attn_h.transpose(0, 1).contiguous() # comment out for lattice # latt align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes #target_l_, batch_, dim_ = attn_h.size() batch_, target_l_, dim_ = attn_h.size() #latt use only aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) #print('global attn attn_h, align_vectors', attn_h.size(), align_vectors.size()) # latt #global attn attn_h, align_vectors torch.Size([1, 9, 256]) torch.Size([9, 1, 27]) # no. of sen x max length x dim #global attn attn_h, align_vectors torch.Size([2, 17, 256]) torch.Size([17, 2, 51]) #print('attn_h', attn_h) # latt torch.set_printoptions(profile="full") #print('align_vectors', align_vectors) # latt logger.info('align_vectors') # latt logger.info(align_vectors) # latt align_vectors[align_vectors == 0.5] = 0 align_sum = torch.sum(align_vectors, 1) # logger.info(align_sum) # latt torch.set_printoptions(profile="default") return attn_h, align_vectors, align_sum #latt ignore
def forward( self, source: torch.FloatTensor, # [batch, tgt_len, dim] memory_bank_list: List[ torch.FloatTensor], # [num_srcs] x [batch, src_len, dim] memory_lengths_list: List[ torch.FloatTensor] = None, # [num_srcs] x [batch] coverage=None ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: assert coverage is None # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False # end if # Join memory bank memory_bank = torch.cat(memory_bank_list, dim=1) batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths_list is not None: mask = torch.cat([ sequence_mask(memory_lengths, max_len=memory_bank_list[src_i].size(1)) for src_i, memory_lengths in enumerate(memory_lengths_list) ], dim=1) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # end if # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) # end if if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) # end if return attn_h, align_vectors
def forward(self, source, memory_bank, memory_lengths=None, coverage=None, sent_align_vectors=None): """ Args: source (`FloatTensor`): query vectors `[batch x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) sent_align_vectors (`FloatTensor`): sentence level attention cores `[batch x src_len]` Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[batch x dim]` * Attention distribtutions for each query `[batch x src_len]` """ # one step input assert source.dim() == 2, "Only one step input is supported" #one_step = True source = source.unsqueeze(1) sent_align_vectors = sent_align_vectors.unsqueeze(1) batch, src_len, dim = memory_bank.size() batch1, tgt_len, dim1 = source.size() batch2, tgt_len2, src_len2 = sent_align_vectors.size() aeq(batch, batch1, batch2) aeq(self.dim, dim, dim1) aeq(src_len, src_len2) aeq(tgt_len, tgt_len2) # if coverage is not None: # batch_, source_l_ = coverage.size() # aeq(batch, batch_) # aeq(source_l, source_l_) # if coverage is not None: # cover = coverage.view(-1).unsqueeze(1) # memory_bank += self.linear_cover(cover).view_as(memory_bank) # memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. # [batch, tgt_len, src_len] align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * tgt_len, src_len), -1) else: align_vectors = sparsemax(align.view(batch * tgt_len, src_len), -1) align_vectors = align_vectors.view(batch, tgt_len, src_len) # rescale the word attention scores using the sent_align_vec align_vectors = align_vectors * sent_align_vectors norm_vec = align_vectors.sum(dim=-1, keepdim=True) align_vectors = align_vectors / norm_vec # each context vector c_t is the weighted average # over all the source hidden states # [batch, tgt_len, dim] c = torch.bmm(align_vectors, memory_bank) # [batch, dim] returned_vec = c.squeeze(1) # # concatenate if self.output_attn_h: concat_c = torch.cat([c, source], 2).view(batch * tgt_len, dim * 2) attn_h = self.linear_out(concat_c).view(batch, tgt_len, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) attn_h = attn_h.squeeze(1) returned_vec = attn_h align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = returned_vec.size() aeq(batch, batch_) aeq(dim, dim_) # Check output sizes batch_, src_len_ = align_vectors.size() aeq(batch, batch_) aeq(src_len, src_len_) return returned_vec, align_vectors