def ans_match(src_seq, ans_seq): import torch.nn.functional as F BF_ans_mask = sequence_mask(ans_lengths) # [batch, ans_seq_len] BF_src_mask = sequence_mask(src_lengths) # [batch, src_seq_len] BF_src_outputs = src_seq.transpose(0, 1) # [batch, src_seq_len, 2*hidden_size] BF_ans_outputs = ans_seq.transpose(0, 1) # [batch, ans_seq_len, 2*hidden_size] # compute bi-att scores src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose(2, 1)) # [batch, src_seq_len, ans_seq_len] ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose(2, 1)) # [batch, ans_seq_len, src_seq_len] # mask padding Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand(src_scores.size()) # [batch, src_seq_len, ans_seq_len] src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(), -float('inf')) Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand(ans_scores.size()) # [batch, ans_seq_len, src_seq_len] ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(), -float('inf')) # normalize with softmax src_alpha = F.softmax(src_scores, dim=2) # [batch, src_seq_len, ans_seq_len] ans_alpha = F.softmax(ans_scores, dim=2) # [batch, ans_seq_len, src_seq_len] # take the weighted average BF_src_matched_seq = src_alpha.bmm(BF_ans_outputs) # [batch, src_seq_len, 2*hidden_size] src_matched_seq = BF_src_matched_seq.transpose(0, 1) # [src_seq_len, batch, 2*hidden_size] BF_ans_matched_seq = ans_alpha.bmm(BF_src_outputs) # [batch, ans_seq_len, 2*hidden_size] ans_matched_seq = BF_ans_matched_seq.transpose(0, 1) # [src_seq_len, batch, 2*hidden_size] return src_matched_seq, ans_matched_seq
def forward(self, src, tgt, src_lengths=None, src_emb=None, tgt_emb=None): src_final, src_memory_bank = self.src_encoder(src, src_lengths, emb=src_emb) src_length, batch_size, rnn_size = src_memory_bank.size() tgt_final, tgt_memory_bank = self.tgt_encoder(tgt, emb=tgt_emb) self.q_src_h = src_memory_bank self.q_tgt_h = tgt_memory_bank src_memory_bank = src_memory_bank.transpose( 0, 1) # batch_size, src_length, rnn_size src_memory_bank = src_memory_bank.transpose( 1, 2) # batch_size, rnn_size, src_length tgt_memory_bank = self.W(tgt_memory_bank.transpose( 0, 1)) # batch_size, tgt_length, rnn_size if self.dist_type == "categorical": scores = torch.bmm(tgt_memory_bank, src_memory_bank) # mask source attention assert (self.mask_val == -float('inf')) if src_lengths is not None: mask = sequence_mask(src_lengths) mask = mask.unsqueeze(1) scores.data.masked_fill_(1 - mask, self.mask_val) # scoresF should be softmax log_scores = F.log_softmax(scores, dim=-1) scores = F.softmax(scores, dim=-1) # Make scores : T x N x S scores = scores.transpose(0, 1) log_scores = log_scores.transpose(0, 1) scores = Params( alpha=scores, log_alpha=log_scores, dist_type=self.dist_type, ) elif self.dist_type == "none": scores = torch.bmm(tgt_memory_bank, src_memory_bank) # mask source attention if src_lengths is not None: mask = sequence_mask(src_lengths) mask = mask.unsqueeze(1) scores.data.masked_fill_(1 - mask, self.mask_val) scores = Params( alpha=scores.transpose(0, 1), dist_type=self.dist_type, ) else: raise Exception("Unsupported dist_type") # T x N x S return scores
def ans_match(src_seq, ans_seq): import torch.nn.functional as F BF_ans_mask = sequence_mask(ans_lengths) # [batch, ans_seq_len] BF_src_mask = sequence_mask(src_lengths) # [batch, src_seq_len] BF_src_outputs = src_seq.transpose( 0, 1) # [batch, src_seq_len, 2*hidden_size] BF_ans_outputs = ans_seq.transpose( 0, 1) # [batch, ans_seq_len, 2*hidden_size] # compute bi-att scores src_scores = BF_src_outputs.bmm(BF_ans_outputs.transpose( 2, 1)) # [batch, src_seq_len, ans_seq_len] src_scores = bm25.view(bm25.shape[0], 1, -1).expand( src_scores.shape) * src_scores ans_scores = BF_ans_outputs.bmm(BF_src_outputs.transpose( 2, 1)) # [batch, ans_seq_len, src_seq_len] # mask padding Expand_BF_ans_mask = BF_ans_mask.unsqueeze(1).expand( src_scores.size()) # [batch, src_seq_len, ans_seq_len] src_scores.data.masked_fill_(~(Expand_BF_ans_mask).bool(), -float('inf')) # for i in range(src_scores.size()[0]): # src_scores[i] = bm25[i] * src_scores[i] # UNIFORM ATTENTION # src_scores = torch.ones(src_scores.shape).to(ans_seq.device) Expand_BF_src_mask = BF_src_mask.unsqueeze(1).expand( ans_scores.size()) # [batch, ans_seq_len, src_seq_len] ans_scores.data.masked_fill_(~(Expand_BF_src_mask).bool(), -float('inf')) # normalize with softmax src_alpha = F.softmax( src_scores, dim=2) # [batch, src_seq_len, ans_seq_len] news2src ans_alpha = F.softmax(ans_scores, dim=2) # [batch, ans_seq_len, src_seq_len] # take the weighted average BF_src_matched_seq = src_alpha.bmm( BF_ans_outputs) # [batch, src_seq_len, 2*hidden_size] src_matched_seq = BF_src_matched_seq.transpose( 0, 1) # [src_seq_len, batch, 2*hidden_size] BF_ans_matched_seq = ans_alpha.bmm( BF_src_outputs) # [batch, ans_seq_len, 2*hidden_size] ans_matched_seq = BF_ans_matched_seq.transpose( 0, 1) # [src_seq_len, batch, 2*hidden_size] return src_matched_seq, ans_matched_seq
def forward(self, input, context, context_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x hidden_size]` context (`FloatTensor`): source vectors `[batch x src_len x hidden_size]` context_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x hidden_size]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ batch, sourceL, context_size = context.size() batch_, targetL, hidden_size = input.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. align = self.score(input, context) # BS x tgt_len x src_len 64 x 19 x 13 # pdb.set_trace() if context_lengths is not None: mask = sequence_mask(context_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, -1) attn_h = self.linear_out(concat_c).view(batch, targetL, hidden_size) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() # aeq(targetL, targetL_) # aeq(batch, batch_) # aeq(hidden_size, dim_) targetL_, batch_, sourceL_ = align_vectors.size() # aeq(targetL, targetL_) # aeq(batch, batch_) # aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, src, tgt, src_lengths=None, memory_bank=None): #src_final, src_memory_bank = self.src_encoder(src, src_lengths) #src_length, batch_size, rnn_size = src_memory_bank.size() src_memory_bank = memory_bank.transpose(0, 1).transpose(1, 2) if self.inference_network_type == 'embedding_only': tgt_memory_bank = self.tgt_encoder(tgt) else: tgt_final, tgt_memory_bank = self.tgt_encoder(tgt) #src_memory_bank = src_memory_bank.transpose(0,1) # batch_size, src_length, rnn_size #src_memory_bank = src_memory_bank.contiguous().view(-1, rnn_size) # batch_size*src_length, rnn_size #src_memory_bank = self.W(src_memory_bank) \ # .view(batch_size, src_length, rnn_size) #src_memory_bank = src_memory_bank.transpose(1,2) # batch_size, rnn_size, src_length tgt_memory_bank = tgt_memory_bank.transpose( 0, 1) # batch_size, tgt_length, rnn_size if self.dist_type == "dirichlet": # probably broken scores = torch.bmm(tgt_memory_bank, src_memory_bank) scores = [scores] elif self.dist_type == "normal": # log normal src_memory_bank = src_memory_bank.transpose(1, 2) #assert src_memory_bank.size() == (batch_size, src_length, rnn_size) scores = self.get_normal_scores(src_memory_bank, tgt_memory_bank) elif self.dist_type == "none": scores = [torch.bmm(tgt_memory_bank, src_memory_bank)] else: raise Exception("Unsupported dist_type") nparam = len(scores) # length if src_lengths is not None: mask = sequence_mask(src_lengths) mask = mask.unsqueeze(1) if self.dist_type == 'normal': scores[0].data.masked_fill_(1 - mask, -999) scores[1].data.masked_fill_(1 - mask, 0.001) else: for i in range(nparam): scores[i].data.masked_fill_(1 - mask, self.mask_val) return scores
def forward(self, src, tgt, src_lengths=None): src_final, src_memory_bank = self.src_encoder(src, src_lengths) src_length, batch_size, rnn_size = src_memory_bank.size() tgt_final, tgt_memory_bank = self.tgt_encoder(tgt) src_memory_bank = src_memory_bank.transpose( 0, 1) # batch_size, src_length, rnn_size src_memory_bank = src_memory_bank.contiguous().view( -1, rnn_size) # batch_size*src_length, rnn_size src_memory_bank = self.W(src_memory_bank) \ .view(batch_size, src_length, rnn_size) src_memory_bank = src_memory_bank.transpose( 1, 2) # batch_size, rnn_size, src_length tgt_memory_bank = tgt_memory_bank.transpose( 0, 1) # batch_size, tgt_length, rnn_size if self.dist_type == "dirichlet": scores = torch.bmm(tgt_memory_bank, src_memory_bank) #print("max: {}, min: {}".format(scores.max(), scores.min())) # affine scores = scores - scores.min(-1)[0].unsqueeze(-1) + 1e-2 # exp #scores = scores.clamp(-1, 1).exp() #scores = scores.clamp(min=1e-2) scores = [scores] elif self.dist_type == "normal": # log normal src_memory_bank = src_memory_bank.transpose(1, 2) assert src_memory_bank.size() == (batch_size, src_length, rnn_size) scores = self.get_normal_scores(src_memory_bank, tgt_memory_bank) elif self.dist_type == "none": scores = [torch.bmm(tgt_memory_bank, src_memory_bank)] else: raise Exception("Unsupported dist_type") nparam = len(scores) # length if src_lengths is not None: mask = sequence_mask(src_lengths) mask = mask.unsqueeze(1) for i in range(nparam): scores[i].data.masked_fill_(1 - mask, self.mask_val) return scores
def encode_seq(self, seq, seq_lengths): """ Encode sequence `seq` using its lengths `seq_lengths` to mask paddings and return the average sequence encoding. seq (sequence_length, batch_size, feats): Sequence encodings. seq_length (batch_size): Sequence lengths. """ # mask => [B,n], seq_lengths => [B] mask = sequence_mask(seq_lengths) # mask => [B,n,1] mask = mask.unsqueeze(2) # Make it broadcastable. mask = Variable(mask.type(torch.Tensor), requires_grad=False) # convert to a float variable # x => [B,n,d] seq = seq.transpose(0, 1) seq = seq * mask # average/sum h = seq.sum(1) / mask.sum(1) # [B,d] / [B,1] return h
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch*targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) assert memory_lengths is not None mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. # mask the time step of self mask = mask.repeat(1, sourceL, 1) mask_self_index = list(range(sourceL)) mask[:, mask_self_index, mask_self_index] = 0 if self.attn_type == "fine": mask = mask.unsqueeze(3) align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align) # each context vector c_t is the weighted average # over all the source hidden states if self.attn_type == "fine": c = memory_bank.unsqueeze(1).mul(align_vectors).sum(dim=2, keepdim=False) else: c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2) attn_h = self.linear_out(concat_c) if self.attn_type in ["general", "dot"]: # attn_h = F.elu(attn_h, 0.1) # attn_h = F.elu(self.dropout(attn_h) + input, 0.1) # content selection gate if not self.no_gate: attn_h = F.sigmoid(attn_h).mul(input) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, stage1_target=None, plan_attn=None, player_row_indices=None, team_row_indices=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ PLAYER_ROWS = 26 TEAM_ROWS = 2 EXTRA_RECORDS = 4 PLAYER_COLS = 22 TEAM_COLS = 15 PLAYER_RECORDS_MAX=EXTRA_RECORDS+PLAYER_ROWS*PLAYER_COLS # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() #print 'batch, sourceL, dim',batch, sourceL, dim batch_, targetL, dim_ = input.size() #print 'batch_, targetL, dim_',batch_, targetL, dim_ aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) SOURCEL = sourceL targetL_st1_tgt, batch_st1_tgt,_= stage1_target.size() batch_plan, target_plan = plan_attn.size() aeq(batch_plan, batch) aeq(batch_plan, batch_st1_tgt) aeq(target_plan, targetL_st1_tgt) target_player_indices_L, batch_player_ind, player_rows_len = player_row_indices.size() aeq(target_player_indices_L, targetL_st1_tgt) aeq(batch_player_ind, batch) aeq(player_rows_len, PLAYER_ROWS) target_team_indices_L, batch_team_ind, team_rows_len = team_row_indices.size() aeq(target_team_indices_L, targetL_st1_tgt) aeq(batch_team_ind, batch) aeq(team_rows_len, TEAM_ROWS) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align = align.view(batch * targetL, sourceL) align_player_cells = self.sm(align[:,EXTRA_RECORDS:PLAYER_RECORDS_MAX].contiguous().view(-1, PLAYER_COLS)) align_team_cells = self.sm(align[:,PLAYER_RECORDS_MAX:SOURCEL].contiguous().view(-1, TEAM_COLS)) row_indices = (stage1_target.data.squeeze(2)-EXTRA_RECORDS)/PLAYER_COLS prob_prod = plan_attn.t() * Variable(row_indices.lt(PLAYER_ROWS).float(), requires_grad=False) #stores probabilities for player records # (batch, 1, t_len_plan) x (batch, t_len_plan, 26) --> (batch, 1, 26) player_prob = torch.bmm(prob_prod.t().unsqueeze(1), player_row_indices.transpose(0,1).float()).squeeze(1) player_prob = player_prob.unsqueeze(2).expand(-1,-1,PLAYER_COLS).contiguous().view(-1,PLAYER_COLS) player_prob_table = align_player_cells*player_prob prob_prod = plan_attn.t() * Variable(row_indices.ge(PLAYER_ROWS).float(), requires_grad=False) #stores probabilities for team records # (batch, 1, t_len_plan) x (batch, t_len_plan, 2) --> (batch, 1, 2) team_prob = torch.bmm(prob_prod.t().unsqueeze(1), team_row_indices.transpose(0,1).float()).squeeze(1) team_prob = team_prob.unsqueeze(2).expand(-1,-1,TEAM_COLS).contiguous().view(-1,TEAM_COLS) team_prob_table = align_team_cells*team_prob extra_prob_table = Variable(self.tt.FloatTensor(batch, EXTRA_RECORDS).fill_(0), requires_grad=False) align_vectors = torch.cat([extra_prob_table, player_prob_table.view(batch,-1), team_prob_table.view(batch,-1)],1) align_vectors = align_vectors.view(batch, targetL, sourceL) batch_table, sourceL_table, dim_table = memory_bank.size() aeq(batch, batch_table) aeq(dim, dim_table) if one_step: align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: align_vectors = align_vectors.transpose(0, 1).contiguous() aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return align_vectors
def forward(self, input, context, context_lengths=None, coverage=None): """ input (FloatTensor): batch x tgt_len x dim: decoder's rnn's output. context (FloatTensor): batch x src_len x dim: src hidden states context_lengths (LongTensor): the source context lengths. coverage (FloatTensor): None (not supported yet) """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. align = self.score(input, context) if context_lengths is not None: mask = sequence_mask(context_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourcel, dim = memory_bank.size() batch_, targetl, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourcel_ = coverage.size() aeq(batch, batch_) aeq(sourcel, sourcel_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetl, sourcel)) align_vectors = align_vectors.view(batch, targetl, sourcel) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetl, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetl, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourcel_ = align_vectors.size() aeq(batch, batch_) aeq(sourcel, sourcel_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetl_, batch_, dim_ = attn_h.size() aeq(targetl, targetl_) aeq(batch, batch_) aeq(dim, dim_) targetl_, batch_, sourcel_ = align_vectors.size() aeq(targetl, targetl_) aeq(batch, batch_) aeq(sourcel, sourcel_) return attn_h, align_vectors
def forward(self, input, context, context_lengths=None, coverage=None, embedding_now=None, embedding_copy=None, word_freq=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]`, decoder hidden state at each timestep context (`FloatTensor`): source vectors `[batch x src_len x dim]`, encoder hidden state at each timestep context_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) embedding_copy (`FloatTensor`): the original input sequence embeddings with affect `[batch x src_len x emb_dim]` word_freq (`FloatTensor`): the word frequency `[batch x src_len]` Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input for InputFeedDecoder if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. # Add affective attention here px align = self.score(input, context, embedding_now, embedding_copy, word_freq) if context_lengths is not None: mask = sequence_mask(context_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, coverage=None, emb_weight=None, idf_weights=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) # thkim emb_weight : maybe intra attention related ... idf_weights : idf values, multiply it to attn weight Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) ## Intra-temporal attention ## assum train is going on the gpu align = torch.exp(align) # batch * 1(target_length) * input_length # print("globalattn line 203: align") if len(self.attn_outputs) < 1: # t=1 # print("global attn line:208, attn_outputs") # print(len(self.attn_outputs)) align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) else: # t > 1 # print("global attn line:209, attn_outputs") # print(len(self.attn_outputs)) temporal_attns = torch.cat(self.attn_outputs, 1) # batch * len(t-1) * input_length normalizing_factor = torch.sum(temporal_attns, 1).unsqueeze(1) # print("global attn line:214, normalizing factor") # wrong implementation # normalizing_factor = torch.autograd.Variable(torch.cat([torch.ones(align.size()[0], 1, 1).cuda(), torch.cumsum(torch.exp(align), 2).data[:,:,:-1]],2)) # align = torch.exp(align) / normalizing_factor # align_vectors = align / torch.sum(align, 2).unsqueeze(2) align_vectors = align / normalizing_factor align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # Softmax to normalize attention weights ## 기존 attention # align_vectors = self.sm(align.view(batch*targetL, sourceL)) # align_vectors = align_vectors.view(batch, targetL, sourceL) # print("global attn line:270 idf_weights", torch.autograd.Variable(idf_weights.t().unsqueeze(1), requires_grad=False)) # print("global attn line:270", align_vectors) if idf_weights is not None: align_vectors = align_vectors * torch.autograd.Variable( idf_weights.t().unsqueeze(1), requires_grad=False) # input() # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # for intra-temporal attention self.attn_outputs.append(align) # print("gb attn line:237 len attn_outputs", len(self.attn_outputs)) # ======== intra-decoder attention if len(self.decoder_outputs) < 1: # TO DO : change initial value to zero vector # ? what is size of zero vector? 밑에 decoder attn도 조금 이상해 보임 # set zero vector to first case c_dec = input * 0 # print("glbal-attn", "dd") else: decoder_history = torch.cat(self.decoder_outputs, 1) # batch * tgt_len(?) * dim decoder_align = self.score(input, decoder_history, "dec_attn") # print("global attn line:223 decoder align") # print(decoder_align) # input() # print("global-attn line:225", decoder_history) # if len(self.decoder_outputs) == 5: # input() history_len = len(self.decoder_outputs) decoder_align_vectors = self.sm( decoder_align.view(batch * targetL, history_len)) decoder_align_vectors = decoder_align_vectors.view( batch, targetL, history_len) # print("global-attn line:232", decoder_align_vectors) c_dec = torch.bmm(decoder_align_vectors, decoder_history) self.decoder_outputs.append(input) # ======== ## # print("gb-attn line:239", self.linear_out.weight.data.size()) # if emb_weight is not None: # print("gb-attn line:240", emb_weight.data.size()) # self.linear_out.weight = self.tanh(emb_weight * self.linear_out.weight) # print("gb-attn line:240", (self.linear_out.weight.data*emb_weight.data).size()) # input() # print("h attn line:371 c", c.size()) # print("h attn line:372 input", input.size()) # print("h attn line:372 c_dec", c_dec.size()) # input() # concatenate concat_c = torch.cat([c, input, c_dec], 2).view(batch * targetL, dim * 3) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, coverage=None, q_scores=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) q_scores (`FloatTensor`): the attention params from the inference network Returns: (`FloatTensor`, `FloatTensor`): * Weighted context vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` * Unormalized attention scores for each query `[batch x tgt_len x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) if q_scores is not None: # oh, I guess this is super messy if q_scores.alpha is not None: q_scores = Params( alpha=q_scores.alpha.unsqueeze(1), log_alpha=q_scores.log_alpha.unsqueeze(1), dist_type=q_scores.dist_type, ) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) # compute attention scores, as in Luong et al. # Params should be T x N x S if self.p_dist_type == "categorical": scores = self.score(input, memory_bank) if memory_lengths is not None: # mask : N x T x S mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. scores.data.masked_fill_(~mask, -float('inf')) if self.k > 0 and self.k < scores.size(-1): topk, idx = scores.data.topk(self.k) new_attn_score = torch.zeros_like(scores.data).fill_( float("-inf")) new_attn_score = new_attn_score.scatter_(2, idx, topk) scores = new_attn_score log_scores = F.log_softmax(scores, dim=-1) scores = log_scores.exp() c_align_vectors = scores p_scores = Params( alpha=scores, log_alpha=log_scores, dist_type=self.p_dist_type, ) # each context vector c_t is the weighted average # over all the source hidden states context_c = torch.bmm(c_align_vectors, memory_bank) if self.mode != 'wsram': concat_c = torch.cat([input, context_c], -1) # N x T x H h_c = self.tanh(self.linear_out(concat_c)) else: h_c = None # sample or enumerate # y_align_vectors: K x N x T x S q_sample, p_sample, sample_log_probs = None, None, None sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = None, None, None if self.mode == "sample": if q_scores is None or self.use_prior: p_sample, sample_log_probs = self.sample_attn( p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, sample_log_probs = self.sample_attn( q_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "gumbel": if q_scores is None or self.use_prior: p_sample, _ = self.sample_attn_gumbel( p_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = p_sample else: q_sample, _ = self.sample_attn_gumbel( q_scores, self.temperature, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample elif self.mode == "enum" or self.mode == "exact": y_align_vectors = None elif self.mode == "wsram": assert q_scores is not None q_sample, sample_log_probs_q, sample_log_probs_p, sample_p_div_q_log = self.sample_attn_wsram( q_scores, p_scores, n_samples=self.n_samples, lengths=memory_lengths, mask=mask if memory_lengths is not None else None) y_align_vectors = q_sample # context_y: K x N x T x H if y_align_vectors is not None: context_y = torch.bmm( y_align_vectors.view(-1, targetL, sourceL), memory_bank.unsqueeze(0).repeat(self.n_samples, 1, 1, 1).view( -1, sourceL, dim)).view(self.n_samples, batch, targetL, dim) else: # For enumerate, K = S. # memory_bank: N x S x H context_y = ( memory_bank.unsqueeze(0).repeat(targetL, 1, 1, 1) # T, N, S, H .permute(2, 1, 0, 3)) # S, N, T, H input = input.unsqueeze(0).repeat(context_y.size(0), 1, 1, 1) concat_y = torch.cat([input, context_y], -1) # K x N x T x H h_y = self.tanh(self.linear_out(concat_y)) if one_step: if h_c is not None: # N x H h_c = h_c.squeeze(1) # N x S c_align_vectors = c_align_vectors.squeeze(1) context_c = context_c.squeeze(1) # K x N x H h_y = h_y.squeeze(2) # K x N x S #y_align_vectors = y_align_vectors.squeeze(2) q_scores = Params( alpha=q_scores.alpha.squeeze(1) if q_scores.alpha is not None else None, dist_type=q_scores.dist_type, samples=q_sample.squeeze(2) if q_sample is not None else None, sample_log_probs=sample_log_probs.squeeze(2) if sample_log_probs is not None else None, sample_log_probs_q=sample_log_probs_q.squeeze(2) if sample_log_probs_q is not None else None, sample_log_probs_p=sample_log_probs_p.squeeze(2) if sample_log_probs_p is not None else None, sample_p_div_q_log=sample_p_div_q_log.squeeze(2) if sample_p_div_q_log is not None else None, ) if q_scores is not None else None p_scores = Params( alpha=p_scores.alpha.squeeze(1), log_alpha=log_scores.squeeze(1), dist_type=p_scores.dist_type, samples=p_sample.squeeze(2) if p_sample is not None else None, ) if h_c is not None: # Check output sizes batch_, dim_ = h_c.size() aeq(batch, batch_) batch_, sourceL_ = c_align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: assert False # Only support input feeding. # T x N x H h_c = h_c.transpose(0, 1).contiguous() # T x N x S c_align_vectors = c_align_vectors.transpose(0, 1).contiguous() # T x K x N x H h_y = h_y.permute(2, 0, 1, 3).contiguous() # T x K x N x S #y_align_vectors = y_align_vectors.permute(2, 0, 1, 3).contiguous() q_scores = Params( alpha=q_scores.alpha.transpose(0, 1).contiguous(), dist_type=q_scores.dist_type, samples=q_sample.permute(2, 0, 1, 3).contiguous(), ) p_scores = Params( alpha=p_scores.alpha.transpose(0, 1).contiguous(), log_alpha=log_alpha.transpose(0, 1).contiguous(), dist_type=p_scores.dist_type, samples=p_sample.permute(2, 0, 1, 3).contiguous(), ) # Check output sizes targetL_, batch_, dim_ = h_c.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = c_align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) # For now, don't include samples. dist_info = DistInfo( q=q_scores, p=p_scores, ) # h_y: samples from simplex # either K x N x H, or T x K x N x H # h_c: convex combination of memory_bank for input feeding # either N x H, or T x N x H # align_vectors: convex coefficients / boltzmann dist # either N x S, or T x N x S # raw_scores: unnormalized scores # either N x S, or T x N x S return h_y, h_c, context_c, c_align_vectors, dist_info
def forward2(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ memory_bank1, memory_bank2 = memory_bank if memory_lengths is not None: memory_lengths1, memory_lengths2 = memory_lengths # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch1, sourceL1, dim1 = memory_bank1.size() batch2, sourceL2, dim2 = memory_bank2.size() batch_, targetL, dim_ = input.size() aeq(batch1, batch2) aeq(batch1, batch_) aeq(dim1, dim2) aeq(dim1, dim_) aeq(self.dim, dim1) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch1, batch_) aeq(sourceL2, sourceL_) if coverage is not None: # Todo: do not support cover = coverage.view(-1).unsqueeze(1) memory_bank2 += self.linear_cover(cover).view_as(memory_bank2) memory_bank2 = self.tanh(memory_bank2) # compute attention scores, as in Luong et al. align1 = self.score(input, memory_bank1) align2 = self.score(input, memory_bank2) if memory_lengths is not None: mask1 = sequence_mask(memory_lengths1) mask1 = mask1.unsqueeze(1) # Make it broadcastable. align1.data.masked_fill_(~mask1, -float('inf')) mask2 = sequence_mask(memory_lengths2) mask2 = mask2.unsqueeze(1) # Make it broadcastable. align2.data.masked_fill_(~mask2, -float('inf')) # Softmax to normalize attention weights align_vectors1 = self.sm(align1.view(batch1*targetL, sourceL1)) align_vectors1 = align_vectors1.view(batch1, targetL, sourceL1) align_vectors2 = self.sm(align2.view(batch2 * targetL, sourceL2)) align_vectors2 = align_vectors2.view(batch2, targetL, sourceL2) # each context vector c_t is the weighted average # over all the source hidden states c1 = torch.bmm(align_vectors1, memory_bank1) # 64 * 1 * 256 c2 = torch.bmm(align_vectors2, memory_bank2) # 64 * 1 * 256 # concatenate concat_c = torch.cat([c1, c2, input], 2).view(batch1*targetL, dim1*3) attn_h = self.linear_out2(concat_c).view(batch1, targetL, dim1) # decoding output if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors1 = align_vectors1.squeeze(1) align_vectors2 = align_vectors2.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch1, batch_) aeq(dim1, dim_) batch_, sourceL_ = align_vectors1.size() aeq(batch1, batch_) aeq(sourceL1, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors1 = align_vectors1.transpose(0, 1).contiguous() align_vectors2 = align_vectors2.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch1, batch_) aeq(dim1, dim_) targetL_, batch_, sourceL_ = align_vectors1.size() aeq(targetL, targetL_) aeq(batch1, batch_) aeq(sourceL1, sourceL_) return attn_h, (align_vectors1, align_vectors2)
def forward(self, input, memory_bank, entity_representation, memory_lengths=None, coverage=None, count_entities=None, total_entities_list=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` entity_representation (`FloatTensor`): source vectors `[batch x num_entities x eu_k_dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) count_entities (`LongTensor`): entity lengths `[batch]` total_entities_list (`FloatTensor`): source vectors `[batch x num_entities x src_len]` Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch__, sourceL__, dim__ = entity_representation.size() batch_, targetL, dim_ = input.size() batch___, num_entities, src_len = total_entities_list.size() aeq(batch, batch_) aeq(batch, batch__) aeq(self.dim, dim_) aeq(self.entity_dim, dim__) aeq(self.dim, dim) aeq(sourceL, src_len) aeq(num_entities, sourceL__) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. entity_align = self.score(input, entity_representation, entity_attn=True) if count_entities is not None: count_entities_mask = sequence_mask(count_entities.data) count_entities_mask = count_entities_mask.unsqueeze( 1) # Make it broadcastable. entity_align.data.masked_fill_(1 - count_entities_mask, -float('inf')) entity_align_vectors = self.sm( entity_align.view(batch * targetL, sourceL__)) entity_align_vectors = entity_align_vectors.unsqueeze(2).expand( -1, -1, sourceL) align = self.score(input, memory_bank) align = align.unsqueeze(2).expand(-1, -1, sourceL__, -1) total_entities_list = total_entities_list.unsqueeze(1).expand( -1, targetL, -1, -1) align = align * total_entities_list # apply mask of records belonging to entities mask = total_entities_list.eq(0) align.data.masked_fill_(mask.data, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm( align.view(batch * targetL * sourceL__, sourceL)) count_entities_mask = count_entities_mask.unsqueeze( 3) #.expand(-1, -1, -1, sourceL) align_vectors = align_vectors.view(batch, targetL, sourceL__, sourceL) align_vectors.data.masked_fill_(1 - count_entities_mask, 0) align_vectors = align_vectors.view(batch * targetL, sourceL__, sourceL) align_vectors = (entity_align_vectors * align_vectors).sum(1) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2).view(batch * targetL, dim * 2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, context, ctl, ctl_iter, context_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` context (`FloatTensor`): source vectors `[batch x src_len x dim]` context_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = context.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) context += self.linear_cover(cover).view_as(context) context = self.tanh(context) # compute attention scores, as in Luong et al. align = self.score(input, context) if context_lengths is not None: mask = sequence_mask(context_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch * targetL, sourceL)) align_vectors = align_vectors.view(batch, targetL, sourceL) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, context) ctl_diff = ctl.expand(-1, ctl_iter.size()[1]) - ctl_iter weights_attn = self.sigmoid(ctl_diff) weights_lm = 1 - self.sigmoid(ctl_diff) weights_attn = weights_attn.expand(c.size()[2], -1, -1) weights_lm = weights_lm.expand(c.size()[2], -1, -1) weights_attn = torch.transpose(weights_attn, 0, 1) weights_attn = torch.transpose(weights_attn, 1, 2) weights_lm = torch.transpose(weights_lm, 0, 1) weights_lm = torch.transpose(weights_lm, 1, 2) # concatenate concat_c = torch.cat([c * weights_attn, input * weights_lm], 2).view(batch * targetL, dim * 2) attn_h = self.linear_trans(concat_c).view(batch, targetL, dim) if self.attn_type in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. # Local attention # Generate aligned position p_t if self.attn_model == "local-p": # If predictive alignment model p_t = torch.zeros((batch, targetL, 1), device=input.device) + (sourceL - 1) p_t = p_t * self.sigmoid(self.v_predictive(self.tanh(self.linear_predictive(input.view(-1, dim))))).view(batch, targetL, 1) elif self.attn_model == "local-m": # If monotonic alignment model p_t = torch.arange(targetL, device=input.device).repeat(batch, 1).view(batch, targetL, 1) # Create a mask to filter all scores that are outside of the window with size 2D indices_of_sources = torch.arange(sourceL, device=input.device).repeat(batch, targetL, 1) # batch x tgt_len x src_len mask_local = (indices_of_sources >= p_t - self.D).int() & (indices_of_sources <= p_t + self.D).int() # batch x tgt_len x src_len # Calculate alignment scores align = self.score(input, memory_bank, mask_local) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align.view(batch*targetL, sourceL)).view(batch, targetL, sourceL) # align_vectors = align_vectors.view(batch, targetL, sourceL) # Local attention if self.attn_model == "local-p": # If predictive alignment model # Favor alignment points near p_t by truncated Gaussian distribution gaussian = torch.exp(-1.0*(((indices_of_sources - p_t) ** 2))/(2*(self.D/2.0)**2)) * mask_local.float() # batch x tgt_len x src_len align_vectors = align_vectors * gaussian # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2).view(batch*targetL, dim*2) attn_h = self.linear_out(concat_c).view(batch, targetL, dim) if self.attn_score_func in ["general", "dot"]: attn_h = self.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None, q_scores_sample=None, q_scores=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = state.input_feed.squeeze(0) input_feed_batch, _ = input_feed.size() tgt_len, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. if self.dist_type == "dirichlet": p_a_scores = [[]] elif self.dist_type == "normal": p_a_scores = [[], []] else: p_a_scores = [[]] n_param = len(p_a_scores) # Initialize local and return variables. decoder_outputs = [] attns = {"std": []} if q_scores_sample is not None: attns["q"] = [] if q_scores is not None: attns["q_raw_mean"] = [] attns["q_raw_std"] = [] attns["p_raw_mean"] = [] attns["p_raw_std"] = [] if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim tgt_len, batch_size = emb.size(0), emb.size(1) src_len = memory_bank.size(0) hidden = state.hidden #[item.fill_(0) for item in hidden] coverage = state.coverage.squeeze(0) \ if state.coverage is not None else None # Input feed concatenates hidden state with # input at every time step. if q_scores is not None: q_scores_mean = q_scores[0].view(batch_size, tgt_len, -1).transpose(0, 1) q_scores_std = q_scores[1].view(batch_size, tgt_len, -1).transpose(0, 1) for i, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, hidden = self.rnn(decoder_input, hidden) if q_scores_sample is not None: q_sample = q_scores_sample[i] else: q_sample = None decoder_output, p_attn, raw_scores = self.attn( rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths, q_scores_sample=q_sample) if q_sample is not None: attns["q"] += [q_sample] attns["q_raw_mean"] += [q_scores_mean[i]] attns["q_raw_std"] += [q_scores_std[i]] attns["p_raw_mean"] += [raw_scores[0].view(-1, src_len)] attns["p_raw_std"] += [raw_scores[1].view(-1, src_len)] # raw_scores: [batch x tgt_len x src_len] #assert raw_scores.size() == (batch_size, 1, src_len) assert len(raw_scores) == n_param for i in range(n_param): p_a_scores[i] += [raw_scores[i]] if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) decoder_output = self.dropout(decoder_output) input_feed = decoder_output decoder_outputs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] for i in range(n_param): p_a_scores[i] = torch.cat(p_a_scores[i], dim=1) if memory_lengths is not None: mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) if self.dist_type == 'normal': p_a_scores[0].data.masked_fill_(1 - mask, -999) p_a_scores[1].data.masked_fill_(1 - mask, 0.001) else: for i in range(n_param): p_a_scores[i].data.masked_fill_(1 - mask, 1e-2) return hidden, decoder_outputs, attns, p_a_scores