def forward(self, context): batch_size, _ = context.size() context = self.fc(context) pi = self.pi_net(context) pi = F.gumbel_softmax(pi, tau=self.gumbel_temp, hard=True, eps=1e-10) pi = pi.unsqueeze(1) mus = self.context_to_mu(context) logsigmas = self.context_to_logsigma(context) # mus = torch.clamp(mus, -30, 30) logsigmas = torch.clamp(logsigmas, -20, 20) stds = torch.exp(0.5 * logsigmas) epsilons = gVar( torch.randn([batch_size, self.n_components * self.z_size]), self.use_cua) zi = (epsilons * stds + mus).view(batch_size, self.n_components, self.z_size) z = torch.bmm(pi, zi).squeeze(1) # [batch_sz x z_sz] mu = torch.bmm(pi, mus.view(batch_size, self.n_components, self.z_size)) logsigma = torch.bmm( pi, logsigmas.view(batch_size, self.n_components, self.z_size)) return z, mu, logsigma
def forward(self, context): batch_size, _ = context.size() context = self.fc(context) mu = self.context_to_mu(context) logsigma = self.context_to_logsigma(context) # mu = torch.clamp(mu, -30, 30) logsigma = torch.clamp(logsigma, -20, 20) std = torch.exp(0.5 * logsigma) epsilon = gVar(torch.randn([batch_size, self.z_size]), use_cuda=self.use_cua) z = epsilon * std + mu return z, mu, logsigma
def forward(self, inputs, input_lens=None, noise=False): inputs = self.input_dropout(inputs) attn_mask = inputs.ne(0) if self.embedding is not None: inputs = self.embedding(inputs) batch_size, seq_len, emb_size = inputs.size() inputs = self.dropout(inputs) self.rnn.flatten_parameters() encoder_output, hidden = self.rnn(inputs) h_n = hidden[0] if self.rnn_class == 'lstm' else hidden h_n = h_n.view(self.n_layers, self.dirs, batch_size, self.hidden_size) enc = h_n[-1].transpose(1, 0).contiguous().view( batch_size, -1) # bsz, num_dirs*hidden_size if isinstance(self.rnn, nn.LSTM): hidden = ( hidden[0].view(-1, self.dirs, batch_size, self.hidden_size).sum(1), hidden[1].view(-1, self.dirs, batch_size, self.hidden_size).sum(1), ) else: hidden = hidden.view(-1, self.dirs, batch_size, self.hidden_size).sum(1) hidden = _transpose_hidden_state(hidden) if noise and self.noise_radius > 0: gauss_noise = gVar( torch.normal(means=torch.zeros(enc.size()), std=self.noise_radius), self.use_cuda) enc = enc + gauss_noise utt_encoder_states = (encoder_output, hidden, attn_mask) return enc, utt_encoder_states
def sampling(self, init_hidden, context, maxlen, SOS_tok, EOS_tok, mode='greedy', context_encoder_states=None): batch_size = init_hidden.size(0) decoded_words = np.zeros((batch_size, maxlen), dtype=np.int) sample_lens = np.zeros(batch_size, dtype=np.int) # noinspection PyArgumentList decoder_input = gVar( torch.LongTensor([[SOS_tok] * batch_size]).view(batch_size, 1), self.use_cuda) decoder_input = self.embedding( decoder_input) if self.embedding is not None else decoder_input decoder_input = torch.cat( [decoder_input, context.unsqueeze(1)], 2) if context is not None else decoder_input if context_encoder_states is not None: context_enc_state, context_enc_hidden, context_attn_mask = context_encoder_states context_attn_params = (context_enc_state, context_attn_mask) context_hidden = _transpose_hidden_state(context_enc_hidden) if isinstance(context_hidden, tuple): context_hidden = tuple(x.contiguous() for x in context_hidden) else: context_hidden = context_hidden.contiguous() decoder_hidden = context_hidden else: decoder_hidden = init_hidden.view(batch_size, self.n_layers, self.hidden_size) decoder_hidden = decoder_hidden.transpose(0, 1).contiguous() if self.rnn_class == 'lstm': decoder_hidden = (decoder_hidden, decoder_hidden) for di in range(maxlen): decoder_output, decoder_hidden = self.rnn(decoder_input, decoder_hidden) if context_encoder_states is not None: # apply attention decoder_output, _ = self.context_attention( decoder_output, decoder_hidden, context_attn_params) decoder_output = self.out(decoder_output) if mode == 'greedy': topi = decoder_output[:, -1].max(1, keepdim=True)[1] elif mode == 'nucleus': # Nucelus, aka top-p sampling (Holtzman et al., 2019). logprobs = decoder_output[:, -1] probs = torch.softmax(logprobs, dim=-1) sprobs, sinds = probs.sort(dim=-1, descending=True) mask = (sprobs.cumsum(dim=-1) - sprobs[:, :1]) >= self.topp sprobs[mask] = 0 sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1)) choices = torch.multinomial(sprobs, 1)[:, 0] hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device) topi = sinds[hyp_ids, choices].unsqueeze(dim=1) else: raise RuntimeError('inference method: {} not supported yet!') decoder_input = self.embedding( topi) if self.embedding is not None else topi decoder_input = torch.cat( [decoder_input, context.unsqueeze(1)], 2) if context is not None else decoder_input ni = topi.squeeze().data.cpu().numpy() decoded_words[:, di] = ni for i in range(batch_size): for word in decoded_words[i]: if word == EOS_tok: break sample_lens[i] += 1 return decoded_words, sample_lens
def forward(self, context, context_lens, utt_lens, floors, noise=False): batch_size, max_context_len, max_utt_len = context.size() utts = context.view(-1, max_utt_len) batch_max_lens = torch.arange(max_context_len).expand( batch_size, max_context_len) if self.use_cuda: batch_max_lens = batch_max_lens.cuda() context_mask = batch_max_lens < context_lens.unsqueeze(1) utt_lens = utt_lens.view(-1) utt_encs, utt_encoder_states = self.utt_encoder(utts, utt_lens) utt_encs = utt_encs.view(batch_size, max_context_len, -1) utt_encoder_output, utt_hidden, utt_attn_mask = utt_encoder_states utt_encoder_output = utt_encoder_output.view( batch_size, max_context_len, max_utt_len, self.utt_encoder.dirs * self.utt_encoder.hidden_size) utt_hidden = _transpose_hidden_state(utt_hidden) if isinstance(utt_hidden, tuple): utt_hidden = tuple( x.view(self.utt_encoder.n_layers, batch_size, max_context_len, self.utt_encoder.hidden_size).contiguous() for x in utt_hidden) else: utt_hidden = utt_hidden.view( self.utt_encoder.n_layers, batch_size, max_context_len, self.utt_encoder.hidden_size).contiguous() utt_attn_mask = utt_attn_mask.view(batch_size, max_context_len, max_utt_len) floor_one_hot = gVar(torch.zeros(floors.numel(), 2), self.use_cuda) floor_one_hot.data.scatter_(1, floors.view(-1, 1), 1) floor_one_hot = floor_one_hot.view(-1, max_context_len, 2) utt_floor_encs = torch.cat([utt_encs, floor_one_hot], 2) utt_floor_encs = self.dropout(utt_floor_encs) self.rnn.flatten_parameters() if self.rnn_class == 'lstm': new_hidden = tuple(x[:, :, -1, :].contiguous() for x in utt_hidden) else: new_hidden = utt_hidden[:, :, -1, :].contiguous() if self.attn_type != 'none': output = [] for i in range(max_context_len): o, new_hidden = self.rnn(utt_floor_encs[:, i, :].unsqueeze(1), new_hidden) o, _ = self.word_attention( o, new_hidden, (utt_encoder_output[:, i, :, :], utt_attn_mask[:, i, :])) output.append(o) context_encoder_output = torch.cat(output, dim=1).to(utt_floor_encs.device) else: utt_floor_encs = pack_padded_sequence(utt_floor_encs, context_lens, batch_first=True, enforce_sorted=False) context_encoder_output, new_hidden = self.rnn( utt_floor_encs, new_hidden) context_encoder_output, _ = pad_packed_sequence( context_encoder_output, batch_first=True, total_length=max_context_len) new_hidden = _transpose_hidden_state(new_hidden) if self.rnn_class == 'lstm': enc = new_hidden[0] else: enc = new_hidden enc = enc.contiguous().view(batch_size, -1) if noise and self.noise_radius > 0: gauss_noise = gVar( torch.normal(means=torch.zeros(enc.size()), std=self.noise_radius), self.use_cuda) enc = enc + gauss_noise return enc, (context_encoder_output, new_hidden, context_mask)