def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ Private helper for running the specific RNN forward pass. Must be overriden by all subclasses. Args: tgt (LongTensor): a sequence of input tokens tensors [len x batch x nfeats]. memory_bank (FloatTensor): output(tensor sequence) from the encoder RNN of size (src_len x batch x hidden_size). state (FloatTensor): hidden state from the encoder RNN for initializing the decoder. memory_lengths (LongTensor): the source memory_bank lengths. Returns: dec_state (Tensor): final hidden state from the decoder. dec_outs ([FloatTensor]): an array of output of every time step from the decoder. attns (dict of (str, [FloatTensor]): a dictionary of different type of attention Tensor array of every time step from the decoder. """ assert not self._copy # TODO, no support yet. assert not self._coverage # TODO, no support yet. # Initialize local and return variables. attns = {} emb = self.embeddings(tgt) # Run the forward pass of the RNN. if isinstance(self.rnn, nn.GRU): rnn_output, dec_state = self.rnn(emb, self.state["hidden"][0]) else: rnn_output, dec_state = self.rnn(emb, self.state["hidden"]) # Check tgt_len, tgt_batch, _ = tgt.size() output_len, output_batch, _ = rnn_output.size() aeq(tgt_len, output_len) aeq(tgt_batch, output_batch) # END # Calculate the attention. dec_outs, p_attn = self.attn(rnn_output.transpose(0, 1).contiguous(), memory_bank.transpose(0, 1), memory_lengths=memory_lengths) attns["std"] = p_attn # Calculate the context gate. if self.context_gate is not None: dec_outs = self.context_gate( emb.view(-1, emb.size(2)), rnn_output.view(-1, rnn_output.size(2)), dec_outs.view(-1, dec_outs.size(2))) dec_outs = \ dec_outs.view(tgt_len, tgt_batch, self.hidden_size) dec_outs = self.dropout(dec_outs) return dec_state, dec_outs, attns
def forward(self, base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine): """ Args: base_target_emb: target emb tensor input: output of decode conv encoder_out_t: the key matrix for calculation of attetion weight, which is the top output of encode conv encoder_out_combine: the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode """ # checks # batch, channel, height, width = base_target_emb.size() batch, _, height, _ = base_target_emb.size() # batch_, channel_, height_, width_ = input_from_dec.size() batch_, _, height_, _ = input_from_dec.size() aeq(batch, batch_) aeq(height, height_) # enc_batch, enc_channel, enc_height = encoder_out_top.size() enc_batch, _, enc_height = encoder_out_top.size() # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() enc_batch_, _, enc_height_ = encoder_out_combine.size() aeq(enc_batch, enc_batch_) aeq(enc_height, enc_height_) preatt = seq_linear(self.linear_in, input_from_dec) target = (base_target_emb + preatt) * SCALE_WEIGHT target = torch.squeeze(target, 3) target = torch.transpose(target, 1, 2) pre_attn = torch.bmm(target, encoder_out_top) if self.mask is not None: pre_attn.data.masked_fill_(self.mask, -float('inf')) attn = F.softmax(pre_attn, dim=2) context_output = torch.bmm(attn, torch.transpose(encoder_out_combine, 1, 2)) context_output = torch.transpose(torch.unsqueeze(context_output, 3), 1, 2) return context_output, attn
def score(self, h_t, h_s): """ Args: h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` Returns: :obj:`FloatTensor`: raw attention scores (unnormalized) for each src index `[batch x tgt_len x src_len]` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_hidden > 0: h_t = self.transform_in(h_t) h_s = self.transform_in(h_s) if self.attn_type == "general": h_t = self.linear_in(h_t) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim d = self.attn_hidden if self.attn_hidden > 0 else dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, d) wq = wq.expand(tgt_batch, tgt_len, src_len, d) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, d) uh = uh.expand(src_batch, tgt_len, src_len, d) # (batch, t_len, s_len, d) wquh = self.tanh(wq + uh) if self.attn_type == "mlp": return self.v(wquh.view(-1, d)).view(tgt_batch, tgt_len, src_len) elif self.attn_type == "fine": return self.v(wquh.view(-1, d)).view(tgt_batch, tgt_len, src_len, dim)
def forward(self, input, memory_bank, memory_lengths=None, coverage=None): """ Args: input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if input.dim() == 2: one_step = True input = input.unsqueeze(1) else: one_step = False batch, sourceL, dim = memory_bank.size() batch_, targetL, dim_ = input.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, sourceL_ = coverage.size() aeq(batch, batch_) aeq(sourceL, sourceL_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = self.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(input, memory_bank) assert memory_lengths is not None mask = sequence_mask(memory_lengths) mask = mask.unsqueeze(1) # Make it broadcastable. # mask the time step of self mask = mask.repeat(1, sourceL, 1) mask_self_index = list(range(sourceL)) mask[:, mask_self_index, mask_self_index] = 0 if self.attn_type == "fine": mask = mask.unsqueeze(3) align.data.masked_fill_(1 - mask, -float('inf')) # Softmax to normalize attention weights align_vectors = self.sm(align) # each context vector c_t is the weighted average # over all the source hidden states if self.attn_type == "fine": c = memory_bank.unsqueeze(1).mul(align_vectors).sum(dim=2, keepdim=False) else: c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, input], 2) attn_h = self.linear_out(concat_c) if self.attn_type in ["general", "dot"]: # attn_h = F.elu(attn_h, 0.1) # attn_h = F.elu(self.dropout(attn_h) + input, 0.1) # content selection gate attn_h = F.sigmoid(attn_h).mul(input) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, sourceL_ = align_vectors.size() aeq(batch, batch_) aeq(sourceL, sourceL_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes targetL_, batch_, dim_ = attn_h.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(dim, dim_) targetL_, batch_, sourceL_ = align_vectors.size() aeq(targetL, targetL_) aeq(batch, batch_) aeq(sourceL, sourceL_) return attn_h, align_vectors
def _check_args(self, src, lengths=None, hidden=None): _, n_batch, _ = src.size() if lengths is not None: n_batch_, = lengths.size() aeq(n_batch, n_batch_)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` memory_lengths (`LongTensor`): the source context lengths `[batch]` coverage (`FloatTensor`): None (not supported yet) Returns: (`FloatTensor`, `FloatTensor`): * Computed vector `[tgt_len x batch x dim]` * Attention distribtutions for each query `[tgt_len x batch x src_len]` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) if coverage is not None: cover = coverage.view(-1).unsqueeze(1) memory_bank += self.linear_cover(cover).view_as(memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) if memory_lengths is not None: mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(1 - mask, -float('inf')) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors
def _run_forward_pass(self, tgt, memory_bank1, memory_bank2, memory_lengths1=None, memory_lengths2=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. Args: tgt (`LongTensor`): sequences of padded tokens `[tgt_len x batch x nfeats]`. memory_bank1 (`FloatTensor`): vectors from the encoder1 `[src_len x batch x hidden]`. memory_lengths1 (`LongTensor`): the padded source lengths `[batch]`. memory_bank2 (`FloatTensor`): vectors from the encoder2 `[tmpl_len x batch x hidden]`. memory_lengths2 (`LongTensor`): the padded source lengths `[batch]`. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. dec_outs = [] attns = {"std": []} # std attn for src1 attns["std2"] = [] # std attn for src2 if self._copy: attns["copy"] = [] # copy attn for src1 if self._coverage: # TODO: necessary for src2? attns["coverage"] = [] # coverage for src1 emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. for _, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) decoder_output1, p_attn1 = self.attn( rnn_output, memory_bank1.transpose(0, 1), memory_lengths=memory_lengths1) decoder_output2, p_attn2 = self.attn( rnn_output, memory_bank2.transpose(0, 1), memory_lengths=memory_lengths2) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate( decoder_input, rnn_output, decoder_output1, decoder_output2 ) decoder_output = self.dropout(decoder_output) input_feed = decoder_output dec_outs += [decoder_output] attns["std"] += [p_attn1] attns["std2"] += [p_attn2] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn1 if coverage is not None else p_attn1 attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn1 = self.copy_attn(decoder_output, memory_bank1.transpose(0, 1)) attns["copy"] += [copy_attn1] elif self._copy: attns["copy"] = attns["std"] # Return result. return dec_state, dec_outs, attns
def _run_forward_pass(self, tgt, memory_bank, memory_lengths=None): """ See StdRNNDecoder._run_forward_pass() for description of arguments and return values. """ # Additional args check. input_feed = self.state["input_feed"].squeeze(0) input_feed_batch, _ = input_feed.size() _, tgt_batch, _ = tgt.size() aeq(tgt_batch, input_feed_batch) # END Additional args check. # Initialize local and return variables. dec_outs = [] attns = {"std": []} if self._copy: attns["copy"] = [] if self._coverage: attns["coverage"] = [] emb = self.embeddings(tgt) assert emb.dim() == 3 # len x batch x embedding_dim dec_state = self.state["hidden"] coverage = self.state["coverage"].squeeze(0) \ if self.state["coverage"] is not None else None # Input feed concatenates hidden state with # input at every time step. for _, emb_t in enumerate(emb.split(1)): emb_t = emb_t.squeeze(0) decoder_input = torch.cat([emb_t, input_feed], 1) rnn_output, dec_state = self.rnn(decoder_input, dec_state) decoder_output, p_attn = self.attn(rnn_output, memory_bank.transpose(0, 1), memory_lengths=memory_lengths) if self.context_gate is not None: # TODO: context gate should be employed # instead of second RNN transform. decoder_output = self.context_gate(decoder_input, rnn_output, decoder_output) decoder_output = self.dropout(decoder_output) input_feed = decoder_output dec_outs += [decoder_output] attns["std"] += [p_attn] # Update the coverage attention. if self._coverage: coverage = coverage + p_attn \ if coverage is not None else p_attn attns["coverage"] += [coverage] # Run the forward pass of the copy attention layer. if self._copy and not self._reuse_copy_attn: _, copy_attn = self.copy_attn(decoder_output, memory_bank.transpose(0, 1)) attns["copy"] += [copy_attn] elif self._copy: attns["copy"] = attns["std"] # Return result. return dec_state, dec_outs, attns