Beispiel #1
0
    def forward(self, utterances):
        """Embeds a batch of utterances.

        Args:
            utterances (list[list[unicode]]): list[unicode] is a list of tokens
            forming a sentence. list[list[unicode]] is batch of sentences.

        Returns:
            Variable[FloatTensor]: batch x lstm_dim
                (concatenated first and last hidden states)
        """
        # Cut to max_words + look up indices
        utterances = [
            utterance[:self._max_words] + [EOS] for utterance in utterances
        ]
        token_indices = SequenceBatch.from_sequences(
            utterances, self._token_embedder.vocab)
        # batch x seq_len x token_embed_dim
        token_embeds = self._token_embedder.embed_seq_batch(token_indices)
        # print('token_embeds', token_embeds)
        bi_hidden_states = self._bilstm(token_embeds.split())
        final_states = torch.cat(bi_hidden_states.final_states, 1)

        hidden_states = SequenceBatch.cat(bi_hidden_states.combined_states)
        return self._attention(hidden_states, final_states).context
Beispiel #2
0
    def encoder_generate_edits(self, encoder_input):
        """ Draw uniform random vectors with given norm, and use as edit vector """
        source_words = encoder_input.source_words
        source_word_embeds = self.editor.encoder.token_embedder.embed_seq_batch(source_words)
        insert_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_words)
        delete_embeds = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_words)

        insert_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.insert_exact_words)
        delete_embeds_exact = self.editor.encoder.token_embedder.embed_seq_batch(encoder_input.delete_exact_words)

        source_encoder_output = self.editor.encoder.source_encoder(source_word_embeds.split())
        source_embeds_list = source_encoder_output.combined_states
        source_embeds = SequenceBatch.cat(source_embeds_list)
        # the final hidden states in both the forward and backward direction, concatenated
        source_embeds_final = torch.cat(source_encoder_output.final_states, 1)  # (batch_size, hidden_dim)

        edit_encoded = self.editor.encoder.edit_encoder(insert_embeds, insert_embeds_exact, delete_embeds,
                                                        delete_embeds_exact)

        # the random vector is computed as in rand_p_noise (see in edit_encoder)
        torch.manual_seed(7)
        batch_size, edit_dim = edit_encoded.size()
        rand_draw = GPUVariable(torch.randn(batch_size, edit_dim))
        rand_draw = rand_draw / torch.norm(rand_draw, p=2, dim=1).expand(batch_size, edit_dim)
        rand_norms = (torch.rand(batch_size, 1) * self.editor.encoder.edit_encoder.norm_max).expand(batch_size,
                                                                                                    edit_dim)
        edit_embed = rand_draw * GPUVariable(rand_norms)

        agenda = self.editor.encoder.agenda_maker(source_embeds_final, edit_embed)
        return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, agenda)
Beispiel #3
0
    def generate_edits(self, encoder_input, norm):
        """ Draw uniform random vectors with given norm, and use as edit vector """
        source_words = encoder_input.source_words
        source_word_embeds = self.token_embedder.embed_seq_batch(source_words)
        insert_embeds = self.token_embedder.embed_seq_batch(
            encoder_input.insert_words)
        delete_embeds = self.token_embedder.embed_seq_batch(
            encoder_input.delete_words)

        insert_embeds_exact = self.token_embedder.embed_seq_batch(
            encoder_input.insert_exact_words)
        delete_embeds_exact = self.token_embedder.embed_seq_batch(
            encoder_input.delete_exact_words)

        source_encoder_output = self.source_encoder(source_word_embeds.split())
        source_embeds_list = source_encoder_output.combined_states
        source_embeds = SequenceBatch.cat(source_embeds_list)
        # the final hidden states in both the forward and backward direction, concatenated
        source_embeds_final = torch.cat(source_encoder_output.final_states,
                                        1)  # (batch_size, hidden_dim)

        edit_encoded = self.edit_encoder(insert_embeds, delete_embeds)

        rand_vec = torch.randn(edit_encoded.shape())
        edit_embed = GPUVariable(
            rand_vec / torch.norm(rand_vec, 2, dim=1).expand_as(rand_vec) *
            norm)
        agenda = self.agenda_maker(source_embeds_final, edit_embed)
        return EncoderOutput(source_embeds, insert_embeds_exact,
                             delete_embeds_exact, agenda)
Beispiel #4
0
    def make_embedding(self, encoder_input, words_list, encoder):
        """Encoder for a single `channel'
        """
        channel_word_embeds = encoder_input.token_embedder.embed_seq_batch(words_list)
        source_encoder_output = encoder(channel_word_embeds.split())

        channel_embeds_list = source_encoder_output.combined_states
        channel_embeds = SequenceBatch.cat(channel_embeds_list)

        # the final hidden states in both the forward and backward direction, concatenated
        channel_embeds_final = torch.cat(source_encoder_output.final_states, 1)  # (batch_size, hidden_dim)
        return channel_embeds, channel_embeds_final
Beispiel #5
0
    def forward(self, encoder_input, draw_samples=False, draw_p=False):
        """Encode.

        Args:
            encoder_input (EncoderInput)
            draw_samples (bool) : flag for whether to add noise for variational approx. disable at test time.

        Returns:
            EncoderOutput
        """
        source_words = encoder_input.source_words
        source_word_embeds = self.token_embedder.embed_seq_batch(source_words)
        source_encoder_output = self.source_encoder(source_word_embeds.split())
        source_embeds_list = source_encoder_output.combined_states
        source_embeds = SequenceBatch.cat(source_embeds_list)
        # the final hidden states in both the forward and backward direction, concatenated
        source_embeds_final = torch.cat(source_encoder_output.final_states,
                                        1)  # (batch_size, hidden_dim)

        insert_embeds = self.token_embedder.embed_seq_batch(
            encoder_input.insert_words)
        delete_embeds = self.token_embedder.embed_seq_batch(
            encoder_input.delete_words)

        insert_embeds_exact = self.token_embedder.embed_seq_batch(
            encoder_input.insert_exact_words)
        delete_embeds_exact = self.token_embedder.embed_seq_batch(
            encoder_input.delete_exact_words)

        insert_noisy_exact = self.edit_encoder.seq_batch_noise(
            insert_embeds_exact, draw_samples)
        delete_noisy_exact = self.edit_encoder.seq_batch_noise(
            delete_embeds_exact, draw_samples)

        batch_size, _ = source_embeds_final.size()

        if self.kill_edit:
            edit_embed = GPUVariable(torch.zeros(batch_size, self.edit_dim))
        else:
            if encoder_input.edit_embed is None:
                edit_embed = self.edit_encoder(insert_embeds,
                                               insert_embeds_exact,
                                               delete_embeds,
                                               delete_embeds_exact,
                                               draw_samples, draw_p)
            else:
                # bypass the edit_encoder
                edit_embed = encoder_input.edit_embed

        agenda = self.agenda_maker(source_embeds_final, edit_embed)
        return EncoderOutput(source_embeds, insert_noisy_exact,
                             delete_noisy_exact, agenda)
    def forward(self, encoder_output, train_decoder_input):
        """

        Args:
            encoder_output (EncoderOutput)
            train_decoder_input (TrainDecoderInput)

        Returns:
            rnn_states (list[RNNState])
            total_loss (Variable): a scalar loss
        """
        batch_size, _ = train_decoder_input.input_words.mask.size()
        rnn_state = self.decoder_cell.initialize(batch_size)

        input_word_embeds = encoder_output.token_embedder.embed_seq_batch(
            train_decoder_input.input_words)

        input_embed_list = input_word_embeds.split()
        target_word_list = train_decoder_input.target_words.split()

        loss_list = []
        rnn_states = []
        vocab_probs = []
        for t, (x, target_word) in enumerate(
                izip(input_embed_list, target_word_list)):
            # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable

            # update rnn state
            rnn_input = self.rnn_context_combiner(encoder_output, x.values)
            decoder_cell_output = self.decoder_cell(rnn_state, rnn_input,
                                                    x.mask)
            rnn_state = decoder_cell_output.rnn_state
            rnn_states.append(rnn_state)
            vocab_pr = decoder_cell_output.vocab_probs
            vocab_probs.append(vocab_pr)

            # compute loss
            loss = decoder_cell_output.loss(
                target_word.values)  # (batch_size,)
            loss_list.append(SequenceBatchElement(loss, x.mask))

        losses = SequenceBatch.cat(
            loss_list)  # (batch_size, target_seq_length)

        return vocab_probs, rnn_states, losses
    def test_cat(self):
        x1 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[1, 2], [3, 4]],
                [[8, 2], [9, 0]]])),
            GPUVariable(torch.FloatTensor([
                [1],
                [1]
            ])))
        x2 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[-1, 20], [3, 40]],
                [[-8, 2], [9, 10]]])),
            GPUVariable(torch.FloatTensor([
                [1],
                [0]
            ])))
        x3 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[-1, 20], [3, 40]],
                [[-8, 2], [9, 10]]])),
            GPUVariable(torch.FloatTensor([
                [0],
                [0]
            ])))

        result = SequenceBatch.cat([x1, x2, x3])

        assert_tensor_equal(result.values,
                            [
                                [[[1, 2], [3, 4]], [[-1, 20], [3, 40]], [[-1, 20], [3, 40]]],
                                [[[8, 2], [9, 0]], [[-8, 2], [9, 10]], [[-8, 2], [9, 10]]],
                            ])

        assert_tensor_equal(result.mask,
                            [
                                [1, 1, 0],
                                [1, 0, 0]
                            ])
Beispiel #8
0
    def warp_edit_vec(self, edit_embed, encoder_input):
        """ Wrap a given edit vector and generate encoder outputs """
        source_words = encoder_input.source_words
        source_word_embeds = self.token_embedder.embed_seq_batch(source_words)
        insert_embeds = self.token_embedder.embed_seq_batch(encoder_input.insert_words)
        delete_embeds = self.token_embedder.embed_seq_batch(encoder_input.delete_words)

        insert_embeds_exact = self.token_embedder.embed_seq_batch(encoder_input.insert_exact_words)
        delete_embeds_exact = self.token_embedder.embed_seq_batch(encoder_input.delete_exact_words)

        source_encoder_output = self.source_encoder(source_word_embeds.split())
        source_embeds_list = source_encoder_output.combined_states
        source_embeds = SequenceBatch.cat(source_embeds_list)
        # the final hidden states in both the forward and backward direction, concatenated
        source_embeds_final = torch.cat(source_encoder_output.final_states, 1)  # (batch_size, hidden_dim)

        agenda = self.agenda_maker(source_embeds_final, edit_embed)

        # agenda run thorugh 2 different linear transformations to get lambda and v
        agenda_l = self.agenda_lin1(agenda)
        agenda_v = self.agenda_lin1(agenda)

        return EncoderOutput(source_embeds, insert_embeds_exact, delete_embeds_exact, (agenda_l, agenda_v))