def forward(self, utterances): """Embeds a batch of utterances. Args: utterances (list[list[unicode]]): list[unicode] is a list of tokens forming a sentence. list[list[unicode]] is batch of sentences. Returns: Variable[FloatTensor]: batch x lstm_dim (concatenated first and last hidden states) """ # Cut to max_words + look up indices utterances = [ utterance[:self._max_words] + [EOS] for utterance in utterances ] token_indices = SequenceBatch.from_sequences( utterances, self._token_embedder.vocab) # batch x seq_len x token_embed_dim token_embeds = self._token_embedder.embed_seq_batch(token_indices) # print('token_embeds', token_embeds) bi_hidden_states = self._bilstm(token_embeds.split()) final_states = torch.cat(bi_hidden_states.final_states, 1) hidden_states = SequenceBatch.cat(bi_hidden_states.combined_states) return self._attention(hidden_states, final_states).context
def encoder_generate_edits(self, encoder_input): """ Draw uniform random vectors with given norm, and use as edit vector """ source_words = encoder_input.source_words source_word_embeds = self.editor.encoder.token_embedder.embed_seq_batch(source_words) insert_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_words) delete_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_words) insert_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_exact_words) delete_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_exact_words) source_encoder_output = self.editor.encoder.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) edit_encoded = self.editor.encoder.edit_encoder(insert_embeds, insert_embeds_exact, delete_embeds, delete_embeds_exact) # the random vector is computed as in rand_p_noise (see in edit_encoder) torch.manual_seed(7) batch_size, edit_dim = edit_encoded.size() rand_draw = GPUVariable(torch.randn(batch_size, edit_dim)) rand_draw = rand_draw / torch.norm(rand_draw, p=2, dim=1).expand(batch_size, edit_dim) rand_norms = (torch.rand(batch_size, 1) * self.editor.encoder.edit_encoder.norm_max).expand(batch_size, edit_dim) edit_embed = rand_draw * GPUVariable(rand_norms) agenda = self.editor.encoder.agenda_maker(source_embeds_final, edit_embed) return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, agenda)
def generate_edits(self, encoder_input, norm): """ Draw uniform random vectors with given norm, and use as edit vector """ source_words = encoder_input.source_words source_word_embeds = self.token_embedder.embed_seq_batch(source_words) insert_embeds = self.token_embedder.embed_seq_batch( encoder_input.insert_words) delete_embeds = self.token_embedder.embed_seq_batch( encoder_input.delete_words) insert_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.insert_exact_words) delete_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.delete_exact_words) source_encoder_output = self.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) edit_encoded = self.edit_encoder(insert_embeds, delete_embeds) rand_vec = torch.randn(edit_encoded.shape()) edit_embed = GPUVariable( rand_vec / torch.norm(rand_vec, 2, dim=1).expand_as(rand_vec) * norm) agenda = self.agenda_maker(source_embeds_final, edit_embed) return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, agenda)
def make_embedding(self, encoder_input, words_list, encoder): """Encoder for a single `channel' """ channel_word_embeds = encoder_input.token_embedder.embed_seq_batch(words_list) source_encoder_output = encoder(channel_word_embeds.split()) channel_embeds_list = source_encoder_output.combined_states channel_embeds = SequenceBatch.cat(channel_embeds_list) # the final hidden states in both the forward and backward direction, concatenated channel_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) return channel_embeds, channel_embeds_final
def forward(self, encoder_input, draw_samples=False, draw_p=False): """Encode. Args: encoder_input (EncoderInput) draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time. Returns: EncoderOutput """ source_words = encoder_input.source_words source_word_embeds = self.token_embedder.embed_seq_batch(source_words) source_encoder_output = self.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) insert_embeds = self.token_embedder.embed_seq_batch( encoder_input.insert_words) delete_embeds = self.token_embedder.embed_seq_batch( encoder_input.delete_words) insert_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.insert_exact_words) delete_embeds_exact = self.token_embedder.embed_seq_batch( encoder_input.delete_exact_words) insert_noisy_exact = self.edit_encoder.seq_batch_noise( insert_embeds_exact, draw_samples) delete_noisy_exact = self.edit_encoder.seq_batch_noise( delete_embeds_exact, draw_samples) batch_size, _ = source_embeds_final.size() if self.kill_edit: edit_embed = GPUVariable(torch.zeros(batch_size, self.edit_dim)) else: if encoder_input.edit_embed is None: edit_embed = self.edit_encoder(insert_embeds, insert_embeds_exact, delete_embeds, delete_embeds_exact, draw_samples, draw_p) else: # bypass the edit_encoder edit_embed = encoder_input.edit_embed agenda = self.agenda_maker(source_embeds_final, edit_embed) return EncoderOutput(source_embeds, insert_noisy_exact, delete_noisy_exact, agenda)
def forward(self, encoder_output, train_decoder_input): """ Args: encoder_output (EncoderOutput) train_decoder_input (TrainDecoderInput) Returns: rnn_states (list[RNNState]) total_loss (Variable): a scalar loss """ batch_size, _ = train_decoder_input.input_words.mask.size() rnn_state = self.decoder_cell.initialize(batch_size) input_word_embeds = encoder_output.token_embedder.embed_seq_batch( train_decoder_input.input_words) input_embed_list = input_word_embeds.split() target_word_list = train_decoder_input.target_words.split() loss_list = [] rnn_states = [] vocab_probs = [] for t, (x, target_word) in enumerate( izip(input_embed_list, target_word_list)): # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable # update rnn state rnn_input = self.rnn_context_combiner(encoder_output, x.values) decoder_cell_output = self.decoder_cell(rnn_state, rnn_input, x.mask) rnn_state = decoder_cell_output.rnn_state rnn_states.append(rnn_state) vocab_pr = decoder_cell_output.vocab_probs vocab_probs.append(vocab_pr) # compute loss loss = decoder_cell_output.loss( target_word.values) # (batch_size,) loss_list.append(SequenceBatchElement(loss, x.mask)) losses = SequenceBatch.cat( loss_list) # (batch_size, target_seq_length) return vocab_probs, rnn_states, losses
def test_cat(self): x1 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[1, 2], [3, 4]], [[8, 2], [9, 0]]])), GPUVariable(torch.FloatTensor([ [1], [1] ]))) x2 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[-1, 20], [3, 40]], [[-8, 2], [9, 10]]])), GPUVariable(torch.FloatTensor([ [1], [0] ]))) x3 = SequenceBatchElement( GPUVariable(torch.FloatTensor([ [[-1, 20], [3, 40]], [[-8, 2], [9, 10]]])), GPUVariable(torch.FloatTensor([ [0], [0] ]))) result = SequenceBatch.cat([x1, x2, x3]) assert_tensor_equal(result.values, [ [[[1, 2], [3, 4]], [[-1, 20], [3, 40]], [[-1, 20], [3, 40]]], [[[8, 2], [9, 0]], [[-8, 2], [9, 10]], [[-8, 2], [9, 10]]], ]) assert_tensor_equal(result.mask, [ [1, 1, 0], [1, 0, 0] ])
def warp_edit_vec(self, edit_embed, encoder_input): """ Wrap a given edit vector and generate encoder outputs """ source_words = encoder_input.source_words source_word_embeds = self.token_embedder.embed_seq_batch(source_words) insert_embeds = self.token_embedder.embed_seq_batch(encoder_input.insert_words) delete_embeds = self.token_embedder.embed_seq_batch(encoder_input.delete_words) insert_embeds_exact = self.token_embedder.embed_seq_batch(encoder_input.insert_exact_words) delete_embeds_exact = self.token_embedder.embed_seq_batch(encoder_input.delete_exact_words) source_encoder_output = self.source_encoder(source_word_embeds.split()) source_embeds_list = source_encoder_output.combined_states source_embeds = SequenceBatch.cat(source_embeds_list) # the final hidden states in both the forward and backward direction, concatenated source_embeds_final = torch.cat(source_encoder_output.final_states, 1) # (batch_size, hidden_dim) agenda = self.agenda_maker(source_embeds_final, edit_embed) # agenda run thorugh 2 different linear transformations to get lambda and v agenda_l = self.agenda_lin1(agenda) agenda_v = self.agenda_lin1(agenda) return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, (agenda_l, agenda_v))