def forward(self, input_embeds_list):
        """

        Args:
            input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim)

        Returns:
            hidden_states_list (list[SequenceBatchElement]) where each element is (batch_size, hidden_dim)
        """
        batch_size = input_embeds_list[0].values.size()[0]

        h = tile_state(self.h0, batch_size)  # (batch_size, hidden_dim)
        c = tile_state(self.c0, batch_size)  # (batch_size, hidden_dim)

        hidden_states_list = []

        for t, x in enumerate(input_embeds_list):
            # x.values has shape (batch_size, input_dim)
            # x.mask has shape (batch_size, 1)
            h_new, c_new = self.rnn_cell(x.values, (h, c))
            h = gated_update(h, h_new, x.mask)
            c = gated_update(c, c_new, x.mask)
            hidden_states_list.append(
                SequenceBatchElement(self.dropout(h), x.mask))

        return hidden_states_list
Ejemplo n.º 2
0
 def clear_cache(self):
     # Keep empty tuple cached, for SequenceBatch
     self._cache.clear()
     self._cache.cache(
         [tuple()], [
             (GPUVariable(torch.zeros(self._embed_dim)),
              SequenceBatchElement(
                  GPUVariable(torch.zeros(1, self._embed_dim)),
                  GPUVariable(torch.zeros(1)))
              )])
 def combined_states(self):
     """Concatenates forward and backward hidden states: [forward; backward].
     
     Returns:
         combined_states (list[SequenceBatchElement]): ordered left to right
     """
     combined_states = [
         SequenceBatchElement(torch.cat([f.values, b.values], 1), f.mask)
         for f, b in izip(self.forward_states, self.backward_states)
     ]
     return combined_states
Ejemplo n.º 4
0
    def forward(self, utterance):
        """Embeds a batch of utterances.

        Args:
            utterance (list[list[unicode]]): list[unicode] is a list of tokens
            forming a sentence. list[list[unicode]] is batch of sentences.

        Returns:
            Variable[FloatTensor]: batch x lstm_dim
                (concatenated first and last hidden states)
            list[SequenceBatchElement]: list of length batch, where each
                element's values is seq_len x embed_dim and mask is seq_len,
                representing the hidden states of each token.
        """
        # Make keys hashable
        utterance = [tuple(utt) for utt in utterance]

        uncached_utterances = self._cache.uncached_keys(utterance)

        # Cache the uncached utterances
        if len(uncached_utterances) > 0:
            token_indices = SequenceBatch.from_sequences(
                    uncached_utterances, self._token_embedder.vocab)
            # batch x seq_len x token_embed_dim
            token_embeds = self._token_embedder.embed_seq_batch(token_indices)

            bi_hidden_states = self._bilstm(token_embeds.split())
            final_states = torch.cat(bi_hidden_states.final_states, 1)

            # Store the combined states in batch x stuff order for caching.
            combined_states = bi_hidden_states.combined_states
            # batch x seq_len x embed_dim
            combined_values = torch.stack(
                    [state.values for state in combined_states], 1)
            # batch x seq_len
            combined_masks = torch.stack(
                    [state.mask for state in combined_states], 1)
            assert len(combined_values) == len(combined_masks)
            combined_states_by_batch = [SequenceBatchElement(
                value, mask) for value, mask in zip(
                    combined_values, combined_masks)]

            assert len(final_states) == len(combined_states_by_batch)
            # self._cache.cache(
            #     uncached_utterances,
            #     zip(final_states, combined_states_by_batch))
            self._cache.cache(
                list(uncached_utterances),
                list(zip(final_states, combined_states_by_batch)))

        final_states, combined_states = zip(*self._cache.get(utterance))
        return torch.stack(final_states, 0), combined_states
Ejemplo n.º 5
0
    def test_cat(self):
        x1 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[1, 2], [3, 4]],
                [[8, 2], [9, 0]]])),
            GPUVariable(torch.FloatTensor([
                [1],
                [1]
            ])))
        x2 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[-1, 20], [3, 40]],
                [[-8, 2], [9, 10]]])),
            GPUVariable(torch.FloatTensor([
                [1],
                [0]
            ])))
        x3 = SequenceBatchElement(
            GPUVariable(torch.FloatTensor([
                [[-1, 20], [3, 40]],
                [[-8, 2], [9, 10]]])),
            GPUVariable(torch.FloatTensor([
                [0],
                [0]
            ])))

        result = SequenceBatch.cat([x1, x2, x3])

        assert_tensor_equal(result.values,
                            [
                                [[[1, 2], [3, 4]], [[-1, 20], [3, 40]], [[-1, 20], [3, 40]]],
                                [[[8, 2], [9, 0]], [[-8, 2], [9, 10]], [[-8, 2], [9, 10]]],
                            ])

        assert_tensor_equal(result.mask,
                            [
                                [1, 1, 0],
                                [1, 0, 0]
                            ])
    def forward(self, input_embeds_list):
        """

        Args:
            input_embeds_list (list[SequenceBatchElement]): where each element is of shape (batch_size, input_dim)

        Returns:
            BidirectionalEncoderOutput
        """
        for i, layer in enumerate(self.layers):
            if i == 0:
                prev_hidden_states = input_embeds_list
            else:
                prev_hidden_states = [
                    SequenceBatchElement(torch.cat([f.values, b.values], 1),
                                         f.mask)
                    for f, b in izip(forward_states, backward_states)
                ]

            new_forward_states, new_backward_states = layer(prev_hidden_states)

            if i == 0:
                # no skip connections here, because dimensions don't match
                forward_states, backward_states = new_forward_states, new_backward_states
            else:
                # add residuals to previous hidden states
                add_residuals = lambda a_list, b_list: [
                    SequenceBatchElement(a.values + b.values, a.mask)
                    for a, b in izip(a_list, b_list)
                ]

                forward_states = add_residuals(forward_states,
                                               new_forward_states)
                backward_states = add_residuals(backward_states,
                                                new_backward_states)

        return BidirectionalEncoderOutput(forward_states, backward_states)
Ejemplo n.º 7
0
    def forward(self, encoder_output, train_decoder_input):
        """

        Args:
            encoder_output (EncoderOutput)
            train_decoder_input (TrainDecoderInput)

        Returns:
            rnn_states (list[RNNState])
            total_loss (Variable): a scalar loss
        """
        batch_size, _ = train_decoder_input.input_words.mask.size()
        rnn_state = self.decoder_cell.initialize(batch_size)

        input_word_embeds = encoder_output.token_embedder.embed_seq_batch(
            train_decoder_input.input_words)

        input_embed_list = input_word_embeds.split()
        target_word_list = train_decoder_input.target_words.split()

        loss_list = []
        rnn_states = []
        vocab_probs = []
        for t, (x, target_word) in enumerate(
                izip(input_embed_list, target_word_list)):
            # x is a (batch_size, word_dim) SequenceBatchElement, target_word is a (batch_size,) Variable

            # update rnn state
            rnn_input = self.rnn_context_combiner(encoder_output, x.values)
            decoder_cell_output = self.decoder_cell(rnn_state, rnn_input,
                                                    x.mask)
            rnn_state = decoder_cell_output.rnn_state
            rnn_states.append(rnn_state)
            vocab_pr = decoder_cell_output.vocab_probs
            vocab_probs.append(vocab_pr)

            # compute loss
            loss = decoder_cell_output.loss(
                target_word.values)  # (batch_size,)
            loss_list.append(SequenceBatchElement(loss, x.mask))

        losses = SequenceBatch.cat(
            loss_list)  # (batch_size, target_seq_length)

        return vocab_probs, rnn_states, losses