コード例 #1
0
 def forward(self, inputs, subword_select_mask, subword_lens, lens):
     selected_inputs = inputs.masked_select(subword_select_mask.unsqueeze(-1))
     reshaped_inputs = selected_inputs.new_zeros(subword_lens.size(0), max(subword_lens.tolist()), self.hidden_size)
     subword_mask = lens2mask(subword_lens)
     reshaped_inputs = reshaped_inputs.masked_scatter_(subword_mask.unsqueeze(-1), selected_inputs)
     # aggregate subword into word feats
     reshaped_inputs = self.aggregation(reshaped_inputs, mask=subword_mask)
     outputs = inputs.new_zeros(lens.size(0), max(lens.tolist()), self.hidden_size)
     outputs.masked_scatter_(lens2mask(lens).unsqueeze(-1), reshaped_inputs)
     return outputs
コード例 #2
0
 def forward(self, *inputs):
     '''
     @args:
         hidden_states: RNN hidden states
         encoder_output: bsize x seqlen x hsize*2
         lens: length for each instance
     '''
     if self.method == 'head+tail':
         ht = inputs[0][0] if self.h_c else inputs[0]
         last = ht.size(0)
         index = [last - 2, last - 1]
         index = torch.tensor(index, dtype=torch.long, device=ht.device)
         ht = torch.index_select(ht, 0, index)
         sent = ht.transpose(0,1).contiguous().view(-1, 2 * self.hidden_size)
     elif self.method == 'hiddenAttn':
         hc, lstm_out, lens = inputs[0], inputs[1], inputs[2]
         ht = hc[0][-1] if self.h_c else hc[-1]
         hiddens = lstm_out.transpose(1, 2)
         c1 = self.Wa(self.dropout_layer(ht))
         c2 = self.Ua(self.dropout_layer(hiddens))
         c3 = c1.unsqueeze(2).repeat(1, 1, lstm_out.size(1))
         c4 = torch.tanh(c3 + c2)
         e = self.Va(c4).squeeze(1)
         e.masked_fill_(lens2mask(lens) == 0, -float('inf'))
         a = F.softmax(e, dim=1)
         sent = torch.bmm(hiddens, a.unsqueeze(2)).squeeze(2)
     else:
         raise ValueError('Unrecognized intent detection method!')
     return F.log_softmax(self.decoder(self.dropout_layer(sent)), dim=1)
コード例 #3
0
 def forward(self, inputs, lens, dec_inputs=None):
     outputs, hidden_states = self.encoder(self.word_embed(inputs), lens)
     bios = self.slot_decoder(outputs)
     bios = self.crf_layer.neg_log_likelihood_loss(bios, lens2mask(lens),
                                                   dec_inputs[:, 1:])
     intents = self.intent_decoder(hidden_states, outputs, lens)
     return bios, intents
コード例 #4
0
 def decode_batch(self,
                  src_inputs,
                  src_lens,
                  vocab,
                  copy_tokens=None,
                  beam_size=5,
                  n_best=1,
                  alpha=0.6,
                  length_pen='avg'):
     enc_out, hidden_states = self.encoder(self.src_embed(src_inputs),
                                           src_lens)
     hidden_states = self.enc2dec(hidden_states)
     src_mask = lens2mask(src_lens)
     if beam_size == 1:
         return self.decode_greed(hidden_states, enc_out, src_mask, vocab,
                                  copy_tokens)
     else:
         return self.decode_beam_search(hidden_states,
                                        enc_out,
                                        src_mask,
                                        vocab,
                                        copy_tokens,
                                        beam_size=beam_size,
                                        n_best=n_best,
                                        alpha=alpha,
                                        length_pen=length_pen)
コード例 #5
0
 def forward(self, slot_seq, slot_lens, intent_emb):
     """
     @args:
         slot_seq: bsize x max_slot_num x slot_dim
         slot_lens: tensor of # slot=value pair, bsize
         intent_emb: bsize x intent_dim
     @return:
         num_layers x bsize x hidden_size
     """
     if torch.sum(slot_lens).item() == 0:
         h = slot_lens.new_zeros(self.num_layers, slot_lens.size(0), self.hidden_size).float().contiguous()
         if self.cell == 'LSTM':
             c = h.new_zeros(h.size()).contiguous()
             return (h, c)
         return h
     slot_seq = self.dropout_layer(slot_seq)
     weights = torch.bmm(self.attn(slot_seq), intent_emb.unsqueeze(dim=-1)).squeeze(dim=-1)
     weights.masked_fill_(lens2mask(slot_lens) == 0, -1e8)
     a = F.softmax(weights, dim=-1)
     conxt = torch.bmm(a.unsqueeze(dim=1), slot_seq).squeeze(dim=1)
     h = torch.tanh(self.affine(conxt)).unsqueeze(dim=0).repeat(self.num_layers, 1, 1).contiguous()
     if self.cell == 'LSTM':
         c = h.new_zeros(h.size()).contiguous()
         return (h, c)
     return h
コード例 #6
0
 def forward(self, hiddens, decoder_state, slot_lens):
     '''
     @args:
         hiddens : bsize x max_slot_num x enc_dim
         decoder_state : bsize x dec_dim
         slot_lens : slot number for each batch, bsize
     @return:
         context : bsize x 1 x enc_dim
         a : normalized coefficient, bsize x max_slot_num
     '''
     decoder_state = self.dropout_layer(decoder_state)
     if self.method == 'dot':
         m = self.Wa(self.dropout_layer(hiddens))
         m = m.transpose(-1, -2)
         e = torch.bmm(decoder_state.unsqueeze(1), m).squeeze(dim=1)
     else:
         d = decoder_state.unsqueeze(dim=1).repeat(1, hiddens.size(1), 1)
         e = self.Wa(torch.cat([d, self.dropout_layer(hiddens)], dim=-1))
         e = self.Va(torch.tanh(e)).squeeze(dim=-1)
     masks = lens2mask(slot_lens)
     if masks.size(1) < e.size(1):
         masks = torch.cat([
             masks,
             torch.zeros(masks.size(0),
                         e.size(1) - masks.size(1)).type_as(masks).to(
                             masks.device)
         ],
                           dim=1)
     e.masked_fill_(masks == 0, -1e8)
     a = F.softmax(e, dim=1)
     context = torch.bmm(a.unsqueeze(1), hiddens)
     return context, a
コード例 #7
0
 def forward(self, src_inputs, src_lens, tgt_inputs, copy_tokens=None):
     """
         Used during training time.
     """
     enc_out, hidden_states = self.encoder(self.src_embed(src_inputs),
                                           src_lens)
     hidden_states = self.enc2dec(hidden_states)
     src_mask = lens2mask(src_lens)
     dec_out, _ = self.decoder(self.tgt_embed(tgt_inputs), hidden_states,
                               enc_out, src_mask, copy_tokens)
     out = self.generator(dec_out)
     return out
コード例 #8
0
 def reconstruction_reward(self, logscores, references, lens):
     """
         logscores: bsize x max_out_len x vocab_size[ + MAX_OOV_NUM]
         references: bsize x max_out_len
         lens: len for each sample
     """
     mask = lens2mask(lens)
     pick_score = torch.gather(
         logscores, dim=-1,
         index=references.unsqueeze(dim=-1)).squeeze(dim=-1)
     masked_score = mask.float() * pick_score
     reward = masked_score.sum(dim=1)
     return reward
コード例 #9
0
 def sent_logprobability(self, input_feats, lens):
     '''
         Given sentences, calculate its length-normalized log-probability
         Sequence must contain <s> and </s> symbol
         lens: length tensor
     '''
     lens = lens - 1
     input_feats, output_feats = input_feats[:, :-1], input_feats[:, 1:]
     emb = self.dropout_layer(
         self.encoder(input_feats))  # bsize, seq_len, emb_size
     output, _ = rnn_wrapper(self.rnn, emb, lens, self.cell)
     decoded = self.decoder(self.affine(self.dropout_layer(output)))
     scores = F.log_softmax(decoded, dim=-1)
     log_prob = torch.gather(scores, 2,
                             output_feats.unsqueeze(-1)).contiguous().view(
                                 output.size(0), output.size(1))
     sent_log_prob = torch.sum(log_prob * lens2mask(lens).float(), dim=-1)
     return sent_log_prob / lens.float()
コード例 #10
0
 def sent_logprob(self, inputs, lens, length_norm=False):
     ''' Given sentences, calculate the log-probability for each sentence
     @args:
         inputs(torch.LongTensor): sequence must contain <s> and </s> symbol
         lens(torch.LongTensor): length tensor
     @return:
         sent_logprob(torch.FloatTensor): logprob for each sent in the batch
     '''
     lens = lens - 1
     inputs, outputs = inputs[:, :-1], inputs[:, 1:]
     emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_len, emb_size
     output, _ = rnn_wrapper(self.encoder, emb, lens, self.cell)
     decoded = self.decoder(self.dropout_layer(output))
     scores = F.log_softmax(decoded, dim=-1)
     logprob = torch.gather(scores, 2, outputs.unsqueeze(-1)).contiguous().view(output.size(0), output.size(1))
     sent_logprob = torch.sum(logprob * lens2mask(lens).float(), dim=-1)
     if length_norm:
         return sent_logprob / lens.float()
     else:
         return sent_logprob
コード例 #11
0
 def forward(self, slot_emb, slot_lens, lens):
     """
     @args:
         slot_emb: [total_slot_num, max_slot_word_len, emb_size]
         slot_lens: slot_num for each training sample, [bsize]
         lens: seq_len for each ${slot}=value sequence, [total_slot_num]
     @return:
         slot_feats: bsize, max_slot_num, hidden_size * 2
     """
     if slot_emb is None or torch.sum(slot_lens).item() == 0:
         # set seq_len dim to 1 due to decoder attention computation
         return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device)
     slot_outputs, _ = rnn_wrapper(self.slot_encoder, slot_emb, lens, self.cell)
     slot_outputs = self.slot_aggregation(slot_outputs, lens2mask(lens))
     chunks = slot_outputs.split(slot_lens.tolist(), dim=0) # list of [slot_num x hidden_size]
     max_slot_num = torch.max(slot_lens).item()
     padded_chunks = [torch.cat([each, each.new_zeros(max_slot_num - each.size(0), each.size(1))], dim=0) for each in chunks]
     # bsize x max_slot_num x hidden_size
     slot_feats = torch.stack(padded_chunks, dim=0)
     return slot_feats
コード例 #12
0
 def decode_greed(self, bios, intents, lens):
     """
     @args:
         bios(torch.FloatTensor): bsize x seqlen x slot_num
         intents(torch.FloatTensor): bsize x intent_num
         lens(torch.LongTensor): bsize
     @return:
         dict:
             slot: tuple of
                 slot_score(torch.FloatTensor): bsize x 1
                 slot_idx(list): such as [ [[1, 2, 4, 9, 11]] , [[1, 2, 5, 8, 10]] , ... ]
             intent: tuple of
                 intent_score(torch.FloatTensor): bsize x 1
                 intent_idx(list): bsize x 1
     """
     slot_score, slot_idx = self.crf_layer._viterbi_decode(
         bios, lens2mask(lens))
     int_score, int_idx = torch.max(intents, dim=1, keepdim=True)
     return {
         "intent": (int_score, int_idx.tolist()),
         "slot": (slot_score, slot_idx.tolist())
     }
コード例 #13
0
    def decode_beam_search(self, bios, intents, lens, n_best=1, **kargs):
        """
        @args:
            n_best(int): number of predictions to return
        @return:
            dict:
            (key)intent: (value) tuple of (n_best most likely intent score, n_best most likely intent idx)
                intent_score(torch.FloatTensor): bsize x n_best
                intent_idx(list): bsize x n_best
            (key)slot: (value) tuple of (n_best most likely seq score, n_best most likely seq idx)
                slot_score(torch.FloatTensor): bsize x n_best
                slot_idx(list): n_best=2, such as [ [[1, 2, 4, 9], [1, 2, 3, 5]] , [[1, 2, 5], [1, 2, 3]] , ... ]
        """
        slot_scores, slot_idxs = self.crf_layer._viterbi_decode_nbest(
            bios, lens2mask(lens), n_best)
        threshold = n_best if intents.size(-1) > 2 * n_best else int(n_best /
                                                                     2)
        intent_scores, intent_idxs = intents.topk(threshold, dim=1)

        comb_scores = slot_scores.unsqueeze(-1) + intent_scores.unsqueeze(1)
        flat_comb_scores = comb_scores.contiguous().view(
            comb_scores.size(0), -1)
        _, best_score_id = flat_comb_scores.topk(n_best, 1, True, True)
        pick_slots = best_score_id // intent_scores.size(-1)  # bsize x n_best
        pick_intents = best_score_id - pick_slots * intent_scores.size(-1)

        slot_scores = torch.gather(slot_scores, 1, pick_slots)
        slot_idxs = [(torch.gather(
            b, 0,
            pick_slots[idx].unsqueeze(dim=1).repeat(1, b.size(1)))).tolist()
                     for idx, b in enumerate(slot_idxs)]
        intent_scores = torch.gather(intent_scores, 1, pick_intents)
        intent_idxs = torch.gather(intent_idxs, 1, pick_intents).tolist()
        return {
            "intent": (intent_scores, intent_idxs),
            "slot": (slot_scores, slot_idxs)
        }