def encode(self, input_sequence, length): batch_size = input_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(length, descending=True) input_sequence = input_sequence[sorted_idx] # ENCODER input_embedding = self.embedding(input_sequence) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.encoder_rnn(packed_input) if self.bidirectional or self.num_layers > 1: # flatten hidden state hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor) else: hidden = hidden.squeeze() # REPARAMETERIZATION mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean return mean, std, z
def encode_condition(self, cond_sequence, cond_length): assert self.is_conditional is not None and cond_sequence is not None and cond_length is not None batch_size = cond_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(cond_length, descending=True) cond_sequence = cond_sequence[sorted_idx] _, reversed_idx = torch.sort(sorted_idx) # -------------------- CONDITIONAL ENCODER ------------------------ cond_embedding = self.cond_embedding(cond_sequence) packed_input = rnn_utils.pack_padded_sequence( cond_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.cond_encoder_rnn(packed_input) hidden = self._reshape_hidden_for_bidirection(hidden, batch_size, self.cond_hidden_size) hidden = hidden[reversed_idx] assert hidden.size(0) == batch_size, hidde.size( 1) == self.cond_hidden_size # REPARAMETERIZATION mean = self.cond_hidden2mean(hidden) logv = self.cond_hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean return hidden, mean, logv, z
def encode(self, input_sequence, length, extra_hidden=None): batch_size = input_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(length, descending=True) input_sequence = input_sequence[sorted_idx] _, reversed_idx = torch.sort(sorted_idx) # -------------------- ENCODER ------------------------ input_embedding = self.embedding(input_sequence) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.encoder_rnn(packed_input) hidden = self._reshape_hidden_for_bidirection(hidden, batch_size, self.hidden_size) hidden = hidden[reversed_idx] assert hidden.size(0) == batch_size, hidde.size(1) == self.hidden_size if extra_hidden is not None: assert self.is_conditional, 'extra_hidden を追加しているのに is_conditional が無効になっています' hidden = torch.cat([hidden, extra_hidden], dim=1) # REPARAMETERIZATION mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean return mean, logv, z
def hidden2latent(self, hidden): # --------------- REPARAMETERIZATION ------------------ batch_size = hidden.size(0) mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean return mean, logv, z
def inference(self, n=4, z=None): if z is None: batch_size = n z = to_var(torch.randn([batch_size, self.latent_size])) else: batch_size = z.size(0) hidden = self.latent2hidden(z) if self.bidirectional or self.num_layers > 1: # unflatten hidden state hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) hidden = hidden.unsqueeze(0) # required for dynamic stopping of sentence generation sequence_idx = torch.arange( 0, batch_size, out=self.tensor()).long() # all idx of batch sequence_running = torch.arange(0, batch_size, out=self.tensor()).long( ) # all idx of batch which are still generating sequence_mask = torch.ones(batch_size, out=self.tensor()).bool() running_seqs = torch.arange(0, batch_size, out=self.tensor()).long( ) # idx of still generating sequences with respect to current loop generations = self.tensor(batch_size, self.max_sequence_length).fill_( self.pad_idx).long() t = 0 while (t < self.max_sequence_length and len(running_seqs) > 0): if t == 0: input_sequence = to_var( torch.Tensor(batch_size).fill_(self.sos_idx).long()) input_sequence = input_sequence.unsqueeze(1) input_embedding = self.embedding(input_sequence) output, hidden = self.decoder_rnn(input_embedding, hidden) logits = self.outputs2vocab(output) input_sequence = self._sample(logits) # save next input generations = self._save_sample(generations, input_sequence, sequence_running, t) # update gloabl running sequence sequence_mask[sequence_running] = (input_sequence != self.eos_idx).data sequence_running = sequence_idx.masked_select(sequence_mask) # update local running sequences running_mask = (input_sequence != self.eos_idx).data running_seqs = running_seqs.masked_select(running_mask) # prune input and hidden state according to local update if len(running_seqs) > 0: input_sequence = input_sequence.view(-1)[running_seqs] hidden = hidden[:, running_seqs] running_seqs = torch.arange(0, len(running_seqs), out=self.tensor()).long() t += 1 return generations, z
def forward(self, input_sequence, length): batch_size = input_sequence.size(0) sorted_lengths, sorted_idx = torch.sort(length, descending=True) input_sequence = input_sequence[sorted_idx] # ENCODER input_embedding = self.embedding(input_sequence) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) _, hidden = self.encoder_rnn(packed_input) if self.bidirectional or self.num_layers > 1: # flatten hidden state hidden = hidden.view(batch_size, self.hidden_size * self.hidden_factor) else: hidden = hidden.squeeze() # REPARAMETERIZATION mean = self.hidden2mean(hidden) logv = self.hidden2logv(hidden) std = torch.exp(0.5 * logv) z = to_var(torch.randn([batch_size, self.latent_size])) z = z * std + mean # DECODER hidden = self.latent2hidden(z) if self.bidirectional or self.num_layers > 1: # unflatten hidden state hidden = hidden.view(self.hidden_factor, batch_size, self.hidden_size) else: hidden = hidden.unsqueeze(0) # decoder input if self.word_dropout_rate > 0: # randomly replace decoder input with <unk> prob = torch.rand(input_sequence.size()) if torch.cuda.is_available(): prob = prob.cuda() prob[(input_sequence.data - self.sos_idx) * (input_sequence.data - self.pad_idx) == 0] = 1 decoder_input_sequence = input_sequence.clone() decoder_input_sequence[ prob < self.word_dropout_rate] = self.unk_idx input_embedding = self.embedding(decoder_input_sequence) input_embedding = self.embedding_dropout(input_embedding) packed_input = rnn_utils.pack_padded_sequence( input_embedding, sorted_lengths.data.tolist(), batch_first=True) # decoder forward pass outputs, _ = self.decoder_rnn(packed_input, hidden) # process outputs padded_outputs = rnn_utils.pad_packed_sequence(outputs, batch_first=True)[0] padded_outputs = padded_outputs.contiguous() _, reversed_idx = torch.sort(sorted_idx) padded_outputs = padded_outputs[reversed_idx] b, s, _ = padded_outputs.size() # project outputs to vocab logp = nn.functional.log_softmax(self.outputs2vocab( padded_outputs.view(-1, padded_outputs.size(2))), dim=-1) logp = logp.view(b, s, self.embedding.num_embeddings) return logp, mean, logv, z