def forward(self, inputs, lens): inputs, lens = inputs[:, :-1], lens - 1 emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_length, emb_size outputs, _ = rnn_wrapper(self.encoder, emb, lens, self.cell) decoded = self.decoder(self.dropout_layer(outputs)) scores = F.log_softmax(decoded, dim=-1) return scores
def forward(self, x, lens): """ Pass the x and lens through each RNN layer. """ out, hidden_states = rnn_wrapper( self.rnn_encoder, x, lens, cell=self.cell) # bsize x srclen x dim return out, hidden_states
def forward(self, x, src_lens): """ Pass the input (and src_lens) through each RNN layer in turn. """ out, hidden_states = rnn_wrapper(self.rnn_encoder, x, src_lens, self.cell) return out, hidden_states
def forward(self, slot_emb, slot_lens, lens): """ @args: slot_emb: [total_slot_num, max_slot_word_len, emb_size] slot_lens: slot_num for each training sample, [bsize] lens: seq_len for each ${slot}=value sequence, [total_slot_num] @return: slot_feats: bsize, max_slot_num, hidden_size * 2 """ if slot_emb is None or torch.sum(slot_lens).item() == 0: # set seq_len dim to 1 due to decoder attention computation return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device) else: slot_feats = self.slot_encoder(slot_emb, slot_lens, lens) slot_outputs, _ = rnn_wrapper(self.rnn_encoder, self.dropout_layer(slot_feats), slot_lens, self.cell) return slot_outputs
def sent_logprobability(self, input_feats, lens): ''' Given sentences, calculate its length-normalized log-probability Sequence must contain <s> and </s> symbol lens: length tensor ''' lens = lens - 1 input_feats, output_feats = input_feats[:, :-1], input_feats[:, 1:] emb = self.dropout_layer( self.encoder(input_feats)) # bsize, seq_len, emb_size output, _ = rnn_wrapper(self.rnn, emb, lens, self.cell) decoded = self.decoder(self.affine(self.dropout_layer(output))) scores = F.log_softmax(decoded, dim=-1) log_prob = torch.gather(scores, 2, output_feats.unsqueeze(-1)).contiguous().view( output.size(0), output.size(1)) sent_log_prob = torch.sum(log_prob * lens2mask(lens).float(), dim=-1) return sent_log_prob / lens.float()
def sent_logprob(self, inputs, lens, length_norm=False): ''' Given sentences, calculate the log-probability for each sentence @args: inputs(torch.LongTensor): sequence must contain <s> and </s> symbol lens(torch.LongTensor): length tensor @return: sent_logprob(torch.FloatTensor): logprob for each sent in the batch ''' lens = lens - 1 inputs, outputs = inputs[:, :-1], inputs[:, 1:] emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_len, emb_size output, _ = rnn_wrapper(self.encoder, emb, lens, self.cell) decoded = self.decoder(self.dropout_layer(output)) scores = F.log_softmax(decoded, dim=-1) logprob = torch.gather(scores, 2, outputs.unsqueeze(-1)).contiguous().view(output.size(0), output.size(1)) sent_logprob = torch.sum(logprob * lens2mask(lens).float(), dim=-1) if length_norm: return sent_logprob / lens.float() else: return sent_logprob
def forward(self, slot_emb, slot_lens, lens): """ @args: slot_emb: [total_slot_num, max_slot_word_len, emb_size] slot_lens: slot_num for each training sample, [bsize] lens: seq_len for each ${slot}=value sequence, [total_slot_num] @return: slot_feats: bsize, max_slot_num, hidden_size * 2 """ if slot_emb is None or torch.sum(slot_lens).item() == 0: # set seq_len dim to 1 due to decoder attention computation return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device) slot_outputs, _ = rnn_wrapper(self.slot_encoder, slot_emb, lens, self.cell) slot_outputs = self.slot_aggregation(slot_outputs, lens2mask(lens)) chunks = slot_outputs.split(slot_lens.tolist(), dim=0) # list of [slot_num x hidden_size] max_slot_num = torch.max(slot_lens).item() padded_chunks = [torch.cat([each, each.new_zeros(max_slot_num - each.size(0), each.size(1))], dim=0) for each in chunks] # bsize x max_slot_num x hidden_size slot_feats = torch.stack(padded_chunks, dim=0) return slot_feats