def reorder_encoder_states(self, encoder_states, indices): """ Reorder encoder states according to a new set of indices. """ enc_out, hidden, attn_mask, context_vec = encoder_states # make sure we swap the hidden state around, apropos multigpu settings hidden = _transpose_hidden_state(hidden) # LSTM or GRU/RNN hidden state? if isinstance(hidden, torch.Tensor): hid, cell = hidden, None else: hid, cell = hidden if not torch.is_tensor(indices): # cast indices to a tensor if needed indices = torch.LongTensor(indices).to(hid.device) hid = hid.index_select(1, indices) if cell is None: hidden = hid else: cell = cell.index_select(1, indices) hidden = (hid, cell) # and bring it back to multigpu friendliness hidden = _transpose_hidden_state(hidden) context_vec = context_vec.index_select(0, indices) return enc_out, hidden, attn_mask, context_vec
def forward(self, xs, encoder_output, incremental_state=None): """ Decode from input tokens. :param xs: (bsz x seqlen) LongTensor of input token indices :param encoder_output: output from HredEncoder. Tuple containing (enc_out, enc_hidden, attn_mask, context_hidden) tuple. :param incremental_state: most recent hidden state to the decoder. :returns: (output, hidden_state) pair from the RNN. - output is a bsz x time x latentdim matrix. This value must be passed to the model's OutputLayer for a final softmax. - hidden_state depends on the choice of RNN """ ( enc_state, (hidden_state, cell_state), attn_mask, context_hidden, ) = encoder_output # sequence indices => sequence embeddings seqlen = xs.size(1) xes = self.dropout(self.lt(xs)) # concatentate context lstm hidden state context_hidden_final_layer = context_hidden[:, -1, :].unsqueeze(1) resized_context_h = context_hidden_final_layer.expand(-1, seqlen, -1) xes = torch.cat((xes, resized_context_h), dim=-1).to(xes.device) # run through rnn with None as initial decoder state # source for zeroes hidden state: http://www.cs.toronto.edu/~lcharlin/papers/vhred_aaai17.pdf output, new_hidden = self.rnn(xes, None) return output, _transpose_hidden_state(new_hidden)
def forward(self, init_hidden, context=None, inputs=None, lens=None, context_encoder_states=None): batch_size, maxlen = inputs.size() if self.embedding is not None: inputs = self.embedding(inputs) if context is not None: repeated_context = context.unsqueeze(1).repeat(1, maxlen, 1) inputs = torch.cat([inputs, repeated_context], 2) inputs = self.dropout(inputs) self.rnn.flatten_parameters() if context_encoder_states is not None: # attention on the context encoder outputs context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states context_attn_params = (context_enc_state, context_attn_mask) context_hidden = _transpose_hidden_state(context_enc_hidden) if isinstance(context_hidden, tuple): context_hidden = tuple(x.contiguous() for x in context_hidden) else: context_hidden = context_hidden.contiguous() new_hidden = context_hidden output = [] for i in range(maxlen): o, new_hidden = self.rnn(inputs[:, i, :].unsqueeze(1), new_hidden) o, _ = self.context_attention(o, new_hidden, context_attn_params) output.append(o) output = torch.cat(output, dim=1).to(inputs.device) else: init_hidden = init_hidden.view(batch_size, self.n_layers, self.hidden_size) init_hidden = init_hidden.transpose(0, 1).contiguous() if self.rnn_class == 'lstm': init_hidden = (init_hidden, init_hidden) output, _ = self.rnn(inputs, init_hidden) decoded = self.out(output) decoded = decoded.view(batch_size, maxlen, self.vocab_size) return decoded
def forward(self, xs, context_vec, hist_lens): # encode current utterrance (enc_state, (hidden_state, cell_state), attn_mask) = super().forward(xs) # if all utterances in context vec length 1, unsqueeze to prevent loss of dimensionality if len(context_vec.shape) < 2: context_vec = context_vec.unsqueeze(1) # get utt lengths of each utt in context vector utt_lens = torch.sum(context_vec.ne(0).int(), dim=1) # sort by lengths descending for utterance encoder sorted_lens, sorted_idx = utt_lens.sort(descending=True) sorted_context_vec = context_vec[sorted_idx] (_, (sorted_hidden_state, _), _) = super().forward(sorted_context_vec) sorted_final_hidden_states = sorted_hidden_state[:, -1, :] ### reshape and pad hidden states to bsz x max_hist_len x hidden_size using hist_lens original_order_final_hidden = torch.zeros_like( sorted_final_hidden_states).scatter_( 0, sorted_idx.unsqueeze(1).expand( -1, sorted_final_hidden_states.shape[1]), sorted_final_hidden_states, ) # pad to max hist_len original_size_final_hidden = self.sequence_to_padding( original_order_final_hidden, hist_lens) # pack padded sequence so that we ignore padding original_size_final_hidden_packed = nn.utils.rnn.pack_padded_sequence( original_size_final_hidden, hist_lens.cpu(), batch_first=True, enforce_sorted=False, ) # pass through context lstm _, (context_h_n, _) = self.context_lstm(original_size_final_hidden_packed) return ( enc_state, (hidden_state, cell_state), attn_mask, _transpose_hidden_state(context_h_n), )
def forward(self, inputs, input_lens=None, noise=False): inputs = self.input_dropout(inputs) attn_mask = inputs.ne(0) if self.embedding is not None: inputs = self.embedding(inputs) batch_size, seq_len, emb_size = inputs.size() inputs = self.dropout(inputs) self.rnn.flatten_parameters() encoder_output, hidden = self.rnn(inputs) h_n = hidden[0] if self.rnn_class == 'lstm' else hidden h_n = h_n.view(self.n_layers, self.dirs, batch_size, self.hidden_size) enc = h_n[-1].transpose(1, 0).contiguous().view( batch_size, -1) # bsz, num_dirs*hidden_size if isinstance(self.rnn, nn.LSTM): hidden = ( hidden[0].view(-1, self.dirs, batch_size, self.hidden_size).sum(1), hidden[1].view(-1, self.dirs, batch_size, self.hidden_size).sum(1), ) else: hidden = hidden.view(-1, self.dirs, batch_size, self.hidden_size).sum(1) hidden = _transpose_hidden_state(hidden) if noise and self.noise_radius > 0: gauss_noise = gVar( torch.normal(means=torch.zeros(enc.size()), std=self.noise_radius), self.use_cuda) enc = enc + gauss_noise utt_encoder_states = (encoder_output, hidden, attn_mask) return enc, utt_encoder_states
def sampling(self, init_hidden, context, maxlen, SOS_tok, EOS_tok, mode='greedy', context_encoder_states=None): batch_size = init_hidden.size(0) decoded_words = np.zeros((batch_size, maxlen), dtype=np.int) sample_lens = np.zeros(batch_size, dtype=np.int) # noinspection PyArgumentList decoder_input = gVar( torch.LongTensor([[SOS_tok] * batch_size]).view(batch_size, 1), self.use_cuda) decoder_input = self.embedding( decoder_input) if self.embedding is not None else decoder_input decoder_input = torch.cat( [decoder_input, context.unsqueeze(1)], 2) if context is not None else decoder_input if context_encoder_states is not None: context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states context_attn_params = (context_enc_state, context_attn_mask) context_hidden = _transpose_hidden_state(context_enc_hidden) if isinstance(context_hidden, tuple): context_hidden = tuple(x.contiguous() for x in context_hidden) else: context_hidden = context_hidden.contiguous() decoder_hidden = context_hidden else: decoder_hidden = init_hidden.view(batch_size, self.n_layers, self.hidden_size) decoder_hidden = decoder_hidden.transpose(0, 1).contiguous() if self.rnn_class == 'lstm': decoder_hidden = (decoder_hidden, decoder_hidden) for di in range(maxlen): decoder_output, decoder_hidden = self.rnn(decoder_input, decoder_hidden) if context_encoder_states is not None: # apply attention decoder_output, _ = self.context_attention( decoder_output, decoder_hidden, context_attn_params) decoder_output = self.out(decoder_output) if mode == 'greedy': topi = decoder_output[:, -1].max(1, keepdim=True)[1] elif mode == 'nucleus': # Nucelus, aka top-p sampling (Holtzman et al., 2019). logprobs = decoder_output[:, -1] probs = torch.softmax(logprobs, dim=-1) sprobs, sinds = probs.sort(dim=-1, descending=True) mask = (sprobs.cumsum(dim=-1) - sprobs[:, :1]) >= self.topp sprobs[mask] = 0 sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1)) choices = torch.multinomial(sprobs, 1)[:, 0] hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device) topi = sinds[hyp_ids, choices].unsqueeze(dim=1) else: raise RuntimeError('inference method: {} not supported yet!') decoder_input = self.embedding( topi) if self.embedding is not None else topi decoder_input = torch.cat( [decoder_input, context.unsqueeze(1)], 2) if context is not None else decoder_input ni = topi.squeeze().data.cpu().numpy() decoded_words[:, di] = ni for i in range(batch_size): for word in decoded_words[i]: if word == EOS_tok: break sample_lens[i] += 1 return decoded_words, sample_lens
def forward(self, context, context_lens, utt_lens, floors, noise=False): batch_size, max_context_len, max_utt_len = context.size() utts = context.view(-1, max_utt_len) batch_max_lens = torch.arange(max_context_len).expand( batch_size, max_context_len) if self.use_cuda: batch_max_lens = batch_max_lens.cuda() context_mask = batch_max_lens < context_lens.unsqueeze(1) utt_lens = utt_lens.view(-1) utt_encs, utt_encoder_states = self.utt_encoder(utts, utt_lens) utt_encs = utt_encs.view(batch_size, max_context_len, -1) utt_encoder_output, utt_hidden, utt_attn_mask = utt_encoder_states utt_encoder_output = utt_encoder_output.view( batch_size, max_context_len, max_utt_len, self.utt_encoder.dirs * self.utt_encoder.hidden_size) utt_hidden = _transpose_hidden_state(utt_hidden) if isinstance(utt_hidden, tuple): utt_hidden = tuple( x.view(self.utt_encoder.n_layers, batch_size, max_context_len, self.utt_encoder.hidden_size).contiguous() for x in utt_hidden) else: utt_hidden = utt_hidden.view( self.utt_encoder.n_layers, batch_size, max_context_len, self.utt_encoder.hidden_size).contiguous() utt_attn_mask = utt_attn_mask.view(batch_size, max_context_len, max_utt_len) floor_one_hot = gVar(torch.zeros(floors.numel(), 2), self.use_cuda) floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1) floor_one_hot = floor_one_hot.view(-1, max_context_len, 2) utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2) utt_floor_encs = self.dropout(utt_floor_encs) self.rnn.flatten_parameters() if self.rnn_class == 'lstm': new_hidden = tuple(x[:, :, -1, :].contiguous() for x in utt_hidden) else: new_hidden = utt_hidden[:, :, -1, :].contiguous() if self.attn_type != 'none': output = [] for i in range(max_context_len): o, new_hidden = self.rnn(utt_floor_encs[:, i, :].unsqueeze(1), new_hidden) o, _ = self.word_attention( o, new_hidden, (utt_encoder_output[:, i, :, :], utt_attn_mask[:, i, :])) output.append(o) context_encoder_output = torch.cat(output, dim=1).to(utt_floor_encs.device) else: utt_floor_encs = pack_padded_sequence(utt_floor_encs, context_lens, batch_first=True, enforce_sorted=False) context_encoder_output, new_hidden = self.rnn( utt_floor_encs, new_hidden) context_encoder_output, _ = pad_packed_sequence( context_encoder_output, batch_first=True, total_length=max_context_len) new_hidden = _transpose_hidden_state(new_hidden) if self.rnn_class == 'lstm': enc = new_hidden[0] else: enc = new_hidden enc = enc.contiguous().view(batch_size, -1) if noise and self.noise_radius > 0: gauss_noise = gVar( torch.normal(means=torch.zeros(enc.size()), std=self.noise_radius), self.use_cuda) enc = enc + gauss_noise return enc, (context_encoder_output, new_hidden, context_mask)