Ejemplo n.º 1
0
 def forward(self, inputs, lens):
     inputs, lens = inputs[:, :-1], lens - 1
     emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_length, emb_size
     outputs, _ = rnn_wrapper(self.encoder, emb, lens, self.cell)
     decoded = self.decoder(self.dropout_layer(outputs))
     scores = F.log_softmax(decoded, dim=-1)
     return scores
Ejemplo n.º 2
0
 def forward(self, x, lens):
     """
         Pass the x and lens through each RNN layer.
     """
     out, hidden_states = rnn_wrapper(
         self.rnn_encoder, x, lens, cell=self.cell)  # bsize x srclen x dim
     return out, hidden_states
Ejemplo n.º 3
0
 def forward(self, x, src_lens):
     """
         Pass the input (and src_lens) through each RNN layer in turn.
     """
     out, hidden_states = rnn_wrapper(self.rnn_encoder, x, src_lens,
                                      self.cell)
     return out, hidden_states
Ejemplo n.º 4
0
 def forward(self, slot_emb, slot_lens, lens):
     """
     @args:
         slot_emb: [total_slot_num, max_slot_word_len, emb_size]
         slot_lens: slot_num for each training sample, [bsize]
         lens: seq_len for each ${slot}=value sequence, [total_slot_num]
     @return:
         slot_feats: bsize, max_slot_num, hidden_size * 2
     """
     if slot_emb is None or torch.sum(slot_lens).item() == 0:
         # set seq_len dim to 1 due to decoder attention computation
         return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device)
     else:
         slot_feats = self.slot_encoder(slot_emb, slot_lens, lens)
         slot_outputs, _ = rnn_wrapper(self.rnn_encoder, self.dropout_layer(slot_feats), slot_lens, self.cell)
         return slot_outputs
Ejemplo n.º 5
0
 def sent_logprobability(self, input_feats, lens):
     '''
         Given sentences, calculate its length-normalized log-probability
         Sequence must contain <s> and </s> symbol
         lens: length tensor
     '''
     lens = lens - 1
     input_feats, output_feats = input_feats[:, :-1], input_feats[:, 1:]
     emb = self.dropout_layer(
         self.encoder(input_feats))  # bsize, seq_len, emb_size
     output, _ = rnn_wrapper(self.rnn, emb, lens, self.cell)
     decoded = self.decoder(self.affine(self.dropout_layer(output)))
     scores = F.log_softmax(decoded, dim=-1)
     log_prob = torch.gather(scores, 2,
                             output_feats.unsqueeze(-1)).contiguous().view(
                                 output.size(0), output.size(1))
     sent_log_prob = torch.sum(log_prob * lens2mask(lens).float(), dim=-1)
     return sent_log_prob / lens.float()
Ejemplo n.º 6
0
 def sent_logprob(self, inputs, lens, length_norm=False):
     ''' Given sentences, calculate the log-probability for each sentence
     @args:
         inputs(torch.LongTensor): sequence must contain <s> and </s> symbol
         lens(torch.LongTensor): length tensor
     @return:
         sent_logprob(torch.FloatTensor): logprob for each sent in the batch
     '''
     lens = lens - 1
     inputs, outputs = inputs[:, :-1], inputs[:, 1:]
     emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_len, emb_size
     output, _ = rnn_wrapper(self.encoder, emb, lens, self.cell)
     decoded = self.decoder(self.dropout_layer(output))
     scores = F.log_softmax(decoded, dim=-1)
     logprob = torch.gather(scores, 2, outputs.unsqueeze(-1)).contiguous().view(output.size(0), output.size(1))
     sent_logprob = torch.sum(logprob * lens2mask(lens).float(), dim=-1)
     if length_norm:
         return sent_logprob / lens.float()
     else:
         return sent_logprob
Ejemplo n.º 7
0
 def forward(self, slot_emb, slot_lens, lens):
     """
     @args:
         slot_emb: [total_slot_num, max_slot_word_len, emb_size]
         slot_lens: slot_num for each training sample, [bsize]
         lens: seq_len for each ${slot}=value sequence, [total_slot_num]
     @return:
         slot_feats: bsize, max_slot_num, hidden_size * 2
     """
     if slot_emb is None or torch.sum(slot_lens).item() == 0:
         # set seq_len dim to 1 due to decoder attention computation
         return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device)
     slot_outputs, _ = rnn_wrapper(self.slot_encoder, slot_emb, lens, self.cell)
     slot_outputs = self.slot_aggregation(slot_outputs, lens2mask(lens))
     chunks = slot_outputs.split(slot_lens.tolist(), dim=0) # list of [slot_num x hidden_size]
     max_slot_num = torch.max(slot_lens).item()
     padded_chunks = [torch.cat([each, each.new_zeros(max_slot_num - each.size(0), each.size(1))], dim=0) for each in chunks]
     # bsize x max_slot_num x hidden_size
     slot_feats = torch.stack(padded_chunks, dim=0)
     return slot_feats