def test_forward_pulls_out_correct_tensor_with_unsorted_batches(self):
        lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        encoder = PytorchSeq2VecWrapper(lstm)

        input_tensor = torch.rand([5, 7, 3])
        input_tensor[0, 3:, :] = 0
        input_tensor[1, 4:, :] = 0
        input_tensor[2, 2:, :] = 0
        input_tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, _ = sort_batch_by_length(input_tensor,
                                                                                              sequence_lengths)
        packed_sequence = pack_padded_sequence(sorted_inputs,
                                               sorted_sequence_lengths.tolist(),
                                               batch_first=True)
        _, state = lstm(packed_sequence)
        # Transpose output state, extract the last forward and backward states and
        # reshape to be of dimension (batch_size, 2 * hidden_size).
        sorted_transposed_state = state[0].transpose(0, 1).index_select(0, restoration_indices)
        reshaped_state = sorted_transposed_state[:, -2:, :].contiguous()
        explicitly_concatenated_state = torch.cat([reshaped_state[:, 0, :].squeeze(1),
                                                   reshaped_state[:, 1, :].squeeze(1)], -1)
        encoder_output = encoder(input_tensor, mask)
        assert_almost_equal(encoder_output.data.numpy(), explicitly_concatenated_state.data.numpy())
Example #2
0
    def test_forward_pulls_out_correct_tensor_for_unsorted_batches(self):
        lstm = LSTM(bidirectional=True,
                    num_layers=3,
                    input_size=3,
                    hidden_size=7,
                    batch_first=True)
        encoder = PytorchSeq2SeqWrapper(lstm)
        input_tensor = torch.rand([5, 7, 3])
        input_tensor[0, 3:, :] = 0
        input_tensor[1, 4:, :] = 0
        input_tensor[2, 2:, :] = 0
        input_tensor[3, 6:, :] = 0
        mask = torch.ones(5, 7)
        mask[0, 3:] = 0
        mask[1, 4:] = 0
        mask[2, 2:] = 0
        mask[3, 6:] = 0

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, _ = sort_batch_by_length(
            input_tensor, sequence_lengths)
        packed_sequence = pack_padded_sequence(
            sorted_inputs,
            sorted_sequence_lengths.data.tolist(),
            batch_first=True)
        lstm_output, _ = lstm(packed_sequence)
        encoder_output = encoder(input_tensor, mask)
        lstm_tensor, _ = pad_packed_sequence(lstm_output, batch_first=True)
        assert_almost_equal(
            encoder_output.data.numpy(),
            lstm_tensor.index_select(0, restoration_indices).data.numpy())
Example #3
0
    def test_sort_tensor_by_length(self):
        tensor = torch.rand([5, 7, 9])
        tensor[0, 3:, :] = 0
        tensor[1, 4:, :] = 0
        tensor[2, 1:, :] = 0
        tensor[3, 5:, :] = 0

        sequence_lengths = torch.LongTensor([3, 4, 1, 5, 7])
        sorted_tensor, sorted_lengths, reverse_indices, _ = util.sort_batch_by_length(
            tensor, sequence_lengths)

        # Test sorted indices are padded correctly.
        numpy.testing.assert_array_equal(sorted_tensor[1, 5:, :].data.numpy(),
                                         0.0)
        numpy.testing.assert_array_equal(sorted_tensor[2, 4:, :].data.numpy(),
                                         0.0)
        numpy.testing.assert_array_equal(sorted_tensor[3, 3:, :].data.numpy(),
                                         0.0)
        numpy.testing.assert_array_equal(sorted_tensor[4, 1:, :].data.numpy(),
                                         0.0)

        assert sorted_lengths.data.equal(torch.LongTensor([7, 5, 4, 3, 1]))

        # Test restoration indices correctly recover the original tensor.
        assert sorted_tensor.index_select(0, reverse_indices).data.equal(
            tensor.data)
Example #4
0
 def test_augmented_lstm_works_with_highway_connections(self):
     augmented_lstm = AugmentedLstm(10, 11, use_highway=True)
     sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(
         self.random_tensor, self.sequence_lengths)
     lstm_input = pack_padded_sequence(sorted_tensor,
                                       sorted_sequence.data.tolist(),
                                       batch_first=True)
     augmented_lstm(lstm_input)
Example #5
0
    def test_variable_length_sequences_return_correctly_padded_outputs(self):
        sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(
            self.random_tensor, self.sequence_lengths)
        tensor = pack_padded_sequence(sorted_tensor,
                                      sorted_sequence.data.tolist(),
                                      batch_first=True)
        lstm = AugmentedLstm(10, 11)
        output, _ = lstm(tensor)
        output_sequence, _ = pad_packed_sequence(output, batch_first=True)

        numpy.testing.assert_array_equal(
            output_sequence.data[1, 6:, :].numpy(), 0.0)
        numpy.testing.assert_array_equal(
            output_sequence.data[2, 4:, :].numpy(), 0.0)
        numpy.testing.assert_array_equal(
            output_sequence.data[3, 3:, :].numpy(), 0.0)
        numpy.testing.assert_array_equal(
            output_sequence.data[4, 2:, :].numpy(), 0.0)
Example #6
0
    def test_augmented_lstm_computes_same_function_as_pytorch_lstm(self):
        augmented_lstm = AugmentedLstm(10, 11)
        pytorch_lstm = LSTM(10, 11, num_layers=1, batch_first=True)
        # Initialize all weights to be == 1.
        initializer = InitializerApplicator([
            (".*", lambda tensor: torch.nn.init.constant_(tensor, 1.))
        ])
        initializer(augmented_lstm)
        initializer(pytorch_lstm)

        initial_state = torch.zeros([1, 5, 11])
        initial_memory = torch.zeros([1, 5, 11])

        # Use bigger numbers to avoid floating point instability.
        sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(
            self.random_tensor * 5., self.sequence_lengths)
        lstm_input = pack_padded_sequence(sorted_tensor,
                                          sorted_sequence.data.tolist(),
                                          batch_first=True)

        augmented_output, augmented_state = augmented_lstm(
            lstm_input, (initial_state, initial_memory))
        pytorch_output, pytorch_state = pytorch_lstm(
            lstm_input, (initial_state, initial_memory))
        pytorch_output_sequence, _ = pad_packed_sequence(pytorch_output,
                                                         batch_first=True)
        augmented_output_sequence, _ = pad_packed_sequence(augmented_output,
                                                           batch_first=True)

        numpy.testing.assert_array_almost_equal(
            pytorch_output_sequence.data.numpy(),
            augmented_output_sequence.data.numpy(),
            decimal=4)
        numpy.testing.assert_array_almost_equal(
            pytorch_state[0].data.numpy(),
            augmented_state[0].data.numpy(),
            decimal=4)
        numpy.testing.assert_array_almost_equal(
            pytorch_state[1].data.numpy(),
            augmented_state[1].data.numpy(),
            decimal=4)
    def setUp(self):
        super(TestEncoderBase, self).setUp()
        self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        self.rnn = RNN(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        self.encoder_base = _EncoderBase(stateful=True)

        tensor = torch.rand([5, 7, 3])
        tensor[1, 6:, :] = 0
        tensor[3, 2:, :] = 0
        self.tensor = tensor
        mask = torch.ones(5, 7)
        mask[1, 6:] = 0
        mask[2, :] = 0  # <= completely masked
        mask[3, 2:] = 0
        mask[4, :] = 0  # <= completely masked
        self.mask = mask

        self.batch_size = 5
        self.num_valid = 3
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        _, _, restoration_indices, sorting_indices = sort_batch_by_length(tensor, sequence_lengths)
        self.sorting_indices = sorting_indices
        self.restoration_indices = restoration_indices
Example #8
0
    def sort_and_run_forward(self,
                             module: Callable[[PackedSequence, Optional[RnnState]],
                                              Tuple[Union[PackedSequence, torch.Tensor], RnnState]],
                             inputs: torch.Tensor,
                             mask: torch.Tensor,
                             hidden_state: Optional[RnnState] = None):
        """
        This function exists because Pytorch RNNs require that their inputs be sorted
        before being passed as input. As all of our Seq2xxxEncoders use this functionality,
        it is provided in a base class. This method can be called on any module which
        takes as input a ``PackedSequence`` and some ``hidden_state``, which can either be a
        tuple of tensors or a tensor.

        As all of our Seq2xxxEncoders have different return types, we return `sorted`
        outputs from the module, which is called directly. Additionally, we return the
        indices into the batch dimension required to restore the tensor to it's correct,
        unsorted order and the number of valid batch elements (i.e the number of elements
        in the batch which are not completely masked). This un-sorting and re-padding
        of the module outputs is left to the subclasses because their outputs have different
        types and handling them smoothly here is difficult.

        Parameters
        ----------
        module : ``Callable[[PackedSequence, Optional[RnnState]],
                            Tuple[Union[PackedSequence, torch.Tensor], RnnState]]``, required.
            A function to run on the inputs. In most cases, this is a ``torch.nn.Module``.
        inputs : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length, embedding_size)`` representing
            the inputs to the Encoder.
        mask : ``torch.Tensor``, required.
            A tensor of shape ``(batch_size, sequence_length)``, representing masked and
            non-masked elements of the sequence for each element in the batch.
        hidden_state : ``Optional[RnnState]``, (default = None).
            A single tensor of shape (num_layers, batch_size, hidden_size) representing the
            state of an RNN with or a tuple of
            tensors of shapes (num_layers, batch_size, hidden_size) and
            (num_layers, batch_size, memory_size), representing the hidden state and memory
            state of an LSTM-like RNN.

        Returns
        -------
        module_output : ``Union[torch.Tensor, PackedSequence]``.
            A Tensor or PackedSequence representing the output of the Pytorch Module.
            The batch size dimension will be equal to ``num_valid``, as sequences of zero
            length are clipped off before the module is called, as Pytorch cannot handle
            zero length sequences.
        final_states : ``Optional[RnnState]``
            A Tensor representing the hidden state of the Pytorch Module. This can either
            be a single tensor of shape (num_layers, num_valid, hidden_size), for instance in
            the case of a GRU, or a tuple of tensors, such as those required for an LSTM.
        restoration_indices : ``torch.LongTensor``
            A tensor of shape ``(batch_size,)``, describing the re-indexing required to transform
            the outputs back to their original batch order.
        """
        # In some circumstances you may have sequences of zero length. ``pack_padded_sequence``
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices =\
            sort_batch_by_length(inputs, sequence_lengths)

        # Now create a PackedSequence with only the non-empty, sorted sequences.
        packed_sequence_input = pack_padded_sequence(sorted_inputs[:num_valid, :, :],
                                                     sorted_sequence_lengths[:num_valid].data.tolist(),
                                                     batch_first=True)
        # Prepare the initial states.
        if not self.stateful:
            if hidden_state is None:
                initial_states = hidden_state
            elif isinstance(hidden_state, tuple):
                initial_states = [state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()
                                  for state in hidden_state]
            else:
                initial_states = hidden_state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()

        else:
            initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)

        # Actually call the module on the sorted PackedSequence.
        module_output, final_states = module(packed_sequence_input, initial_states)

        return module_output, final_states, restoration_indices