def forward(self, inputs, subword_select_mask, subword_lens, lens): selected_inputs = inputs.masked_select(subword_select_mask.unsqueeze(-1)) reshaped_inputs = selected_inputs.new_zeros(subword_lens.size(0), max(subword_lens.tolist()), self.hidden_size) subword_mask = lens2mask(subword_lens) reshaped_inputs = reshaped_inputs.masked_scatter_(subword_mask.unsqueeze(-1), selected_inputs) # aggregate subword into word feats reshaped_inputs = self.aggregation(reshaped_inputs, mask=subword_mask) outputs = inputs.new_zeros(lens.size(0), max(lens.tolist()), self.hidden_size) outputs.masked_scatter_(lens2mask(lens).unsqueeze(-1), reshaped_inputs) return outputs
def forward(self, *inputs): ''' @args: hidden_states: RNN hidden states encoder_output: bsize x seqlen x hsize*2 lens: length for each instance ''' if self.method == 'head+tail': ht = inputs[0][0] if self.h_c else inputs[0] last = ht.size(0) index = [last - 2, last - 1] index = torch.tensor(index, dtype=torch.long, device=ht.device) ht = torch.index_select(ht, 0, index) sent = ht.transpose(0,1).contiguous().view(-1, 2 * self.hidden_size) elif self.method == 'hiddenAttn': hc, lstm_out, lens = inputs[0], inputs[1], inputs[2] ht = hc[0][-1] if self.h_c else hc[-1] hiddens = lstm_out.transpose(1, 2) c1 = self.Wa(self.dropout_layer(ht)) c2 = self.Ua(self.dropout_layer(hiddens)) c3 = c1.unsqueeze(2).repeat(1, 1, lstm_out.size(1)) c4 = torch.tanh(c3 + c2) e = self.Va(c4).squeeze(1) e.masked_fill_(lens2mask(lens) == 0, -float('inf')) a = F.softmax(e, dim=1) sent = torch.bmm(hiddens, a.unsqueeze(2)).squeeze(2) else: raise ValueError('Unrecognized intent detection method!') return F.log_softmax(self.decoder(self.dropout_layer(sent)), dim=1)
def forward(self, inputs, lens, dec_inputs=None): outputs, hidden_states = self.encoder(self.word_embed(inputs), lens) bios = self.slot_decoder(outputs) bios = self.crf_layer.neg_log_likelihood_loss(bios, lens2mask(lens), dec_inputs[:, 1:]) intents = self.intent_decoder(hidden_states, outputs, lens) return bios, intents
def decode_batch(self, src_inputs, src_lens, vocab, copy_tokens=None, beam_size=5, n_best=1, alpha=0.6, length_pen='avg'): enc_out, hidden_states = self.encoder(self.src_embed(src_inputs), src_lens) hidden_states = self.enc2dec(hidden_states) src_mask = lens2mask(src_lens) if beam_size == 1: return self.decode_greed(hidden_states, enc_out, src_mask, vocab, copy_tokens) else: return self.decode_beam_search(hidden_states, enc_out, src_mask, vocab, copy_tokens, beam_size=beam_size, n_best=n_best, alpha=alpha, length_pen=length_pen)
def forward(self, slot_seq, slot_lens, intent_emb): """ @args: slot_seq: bsize x max_slot_num x slot_dim slot_lens: tensor of # slot=value pair, bsize intent_emb: bsize x intent_dim @return: num_layers x bsize x hidden_size """ if torch.sum(slot_lens).item() == 0: h = slot_lens.new_zeros(self.num_layers, slot_lens.size(0), self.hidden_size).float().contiguous() if self.cell == 'LSTM': c = h.new_zeros(h.size()).contiguous() return (h, c) return h slot_seq = self.dropout_layer(slot_seq) weights = torch.bmm(self.attn(slot_seq), intent_emb.unsqueeze(dim=-1)).squeeze(dim=-1) weights.masked_fill_(lens2mask(slot_lens) == 0, -1e8) a = F.softmax(weights, dim=-1) conxt = torch.bmm(a.unsqueeze(dim=1), slot_seq).squeeze(dim=1) h = torch.tanh(self.affine(conxt)).unsqueeze(dim=0).repeat(self.num_layers, 1, 1).contiguous() if self.cell == 'LSTM': c = h.new_zeros(h.size()).contiguous() return (h, c) return h
def forward(self, hiddens, decoder_state, slot_lens): ''' @args: hiddens : bsize x max_slot_num x enc_dim decoder_state : bsize x dec_dim slot_lens : slot number for each batch, bsize @return: context : bsize x 1 x enc_dim a : normalized coefficient, bsize x max_slot_num ''' decoder_state = self.dropout_layer(decoder_state) if self.method == 'dot': m = self.Wa(self.dropout_layer(hiddens)) m = m.transpose(-1, -2) e = torch.bmm(decoder_state.unsqueeze(1), m).squeeze(dim=1) else: d = decoder_state.unsqueeze(dim=1).repeat(1, hiddens.size(1), 1) e = self.Wa(torch.cat([d, self.dropout_layer(hiddens)], dim=-1)) e = self.Va(torch.tanh(e)).squeeze(dim=-1) masks = lens2mask(slot_lens) if masks.size(1) < e.size(1): masks = torch.cat([ masks, torch.zeros(masks.size(0), e.size(1) - masks.size(1)).type_as(masks).to( masks.device) ], dim=1) e.masked_fill_(masks == 0, -1e8) a = F.softmax(e, dim=1) context = torch.bmm(a.unsqueeze(1), hiddens) return context, a
def forward(self, src_inputs, src_lens, tgt_inputs, copy_tokens=None): """ Used during training time. """ enc_out, hidden_states = self.encoder(self.src_embed(src_inputs), src_lens) hidden_states = self.enc2dec(hidden_states) src_mask = lens2mask(src_lens) dec_out, _ = self.decoder(self.tgt_embed(tgt_inputs), hidden_states, enc_out, src_mask, copy_tokens) out = self.generator(dec_out) return out
def reconstruction_reward(self, logscores, references, lens): """ logscores: bsize x max_out_len x vocab_size[ + MAX_OOV_NUM] references: bsize x max_out_len lens: len for each sample """ mask = lens2mask(lens) pick_score = torch.gather( logscores, dim=-1, index=references.unsqueeze(dim=-1)).squeeze(dim=-1) masked_score = mask.float() * pick_score reward = masked_score.sum(dim=1) return reward
def sent_logprobability(self, input_feats, lens): ''' Given sentences, calculate its length-normalized log-probability Sequence must contain <s> and </s> symbol lens: length tensor ''' lens = lens - 1 input_feats, output_feats = input_feats[:, :-1], input_feats[:, 1:] emb = self.dropout_layer( self.encoder(input_feats)) # bsize, seq_len, emb_size output, _ = rnn_wrapper(self.rnn, emb, lens, self.cell) decoded = self.decoder(self.affine(self.dropout_layer(output))) scores = F.log_softmax(decoded, dim=-1) log_prob = torch.gather(scores, 2, output_feats.unsqueeze(-1)).contiguous().view( output.size(0), output.size(1)) sent_log_prob = torch.sum(log_prob * lens2mask(lens).float(), dim=-1) return sent_log_prob / lens.float()
def sent_logprob(self, inputs, lens, length_norm=False): ''' Given sentences, calculate the log-probability for each sentence @args: inputs(torch.LongTensor): sequence must contain <s> and </s> symbol lens(torch.LongTensor): length tensor @return: sent_logprob(torch.FloatTensor): logprob for each sent in the batch ''' lens = lens - 1 inputs, outputs = inputs[:, :-1], inputs[:, 1:] emb = self.dropout_layer(self.word_embed(inputs)) # bsize, seq_len, emb_size output, _ = rnn_wrapper(self.encoder, emb, lens, self.cell) decoded = self.decoder(self.dropout_layer(output)) scores = F.log_softmax(decoded, dim=-1) logprob = torch.gather(scores, 2, outputs.unsqueeze(-1)).contiguous().view(output.size(0), output.size(1)) sent_logprob = torch.sum(logprob * lens2mask(lens).float(), dim=-1) if length_norm: return sent_logprob / lens.float() else: return sent_logprob
def forward(self, slot_emb, slot_lens, lens): """ @args: slot_emb: [total_slot_num, max_slot_word_len, emb_size] slot_lens: slot_num for each training sample, [bsize] lens: seq_len for each ${slot}=value sequence, [total_slot_num] @return: slot_feats: bsize, max_slot_num, hidden_size * 2 """ if slot_emb is None or torch.sum(slot_lens).item() == 0: # set seq_len dim to 1 due to decoder attention computation return torch.zeros(slot_lens.size(0), 1, self.hidden_size * 2, dtype=torch.float).to(slot_lens.device) slot_outputs, _ = rnn_wrapper(self.slot_encoder, slot_emb, lens, self.cell) slot_outputs = self.slot_aggregation(slot_outputs, lens2mask(lens)) chunks = slot_outputs.split(slot_lens.tolist(), dim=0) # list of [slot_num x hidden_size] max_slot_num = torch.max(slot_lens).item() padded_chunks = [torch.cat([each, each.new_zeros(max_slot_num - each.size(0), each.size(1))], dim=0) for each in chunks] # bsize x max_slot_num x hidden_size slot_feats = torch.stack(padded_chunks, dim=0) return slot_feats
def decode_greed(self, bios, intents, lens): """ @args: bios(torch.FloatTensor): bsize x seqlen x slot_num intents(torch.FloatTensor): bsize x intent_num lens(torch.LongTensor): bsize @return: dict: slot: tuple of slot_score(torch.FloatTensor): bsize x 1 slot_idx(list): such as [ [[1, 2, 4, 9, 11]] , [[1, 2, 5, 8, 10]] , ... ] intent: tuple of intent_score(torch.FloatTensor): bsize x 1 intent_idx(list): bsize x 1 """ slot_score, slot_idx = self.crf_layer._viterbi_decode( bios, lens2mask(lens)) int_score, int_idx = torch.max(intents, dim=1, keepdim=True) return { "intent": (int_score, int_idx.tolist()), "slot": (slot_score, slot_idx.tolist()) }
def decode_beam_search(self, bios, intents, lens, n_best=1, **kargs): """ @args: n_best(int): number of predictions to return @return: dict: (key)intent: (value) tuple of (n_best most likely intent score, n_best most likely intent idx) intent_score(torch.FloatTensor): bsize x n_best intent_idx(list): bsize x n_best (key)slot: (value) tuple of (n_best most likely seq score, n_best most likely seq idx) slot_score(torch.FloatTensor): bsize x n_best slot_idx(list): n_best=2, such as [ [[1, 2, 4, 9], [1, 2, 3, 5]] , [[1, 2, 5], [1, 2, 3]] , ... ] """ slot_scores, slot_idxs = self.crf_layer._viterbi_decode_nbest( bios, lens2mask(lens), n_best) threshold = n_best if intents.size(-1) > 2 * n_best else int(n_best / 2) intent_scores, intent_idxs = intents.topk(threshold, dim=1) comb_scores = slot_scores.unsqueeze(-1) + intent_scores.unsqueeze(1) flat_comb_scores = comb_scores.contiguous().view( comb_scores.size(0), -1) _, best_score_id = flat_comb_scores.topk(n_best, 1, True, True) pick_slots = best_score_id // intent_scores.size(-1) # bsize x n_best pick_intents = best_score_id - pick_slots * intent_scores.size(-1) slot_scores = torch.gather(slot_scores, 1, pick_slots) slot_idxs = [(torch.gather( b, 0, pick_slots[idx].unsqueeze(dim=1).repeat(1, b.size(1)))).tolist() for idx, b in enumerate(slot_idxs)] intent_scores = torch.gather(intent_scores, 1, pick_intents) intent_idxs = torch.gather(intent_idxs, 1, pick_intents).tolist() return { "intent": (intent_scores, intent_idxs), "slot": (slot_scores, slot_idxs) }