Пример #1
0
    def test_non_contiguous_initial_states_handled(self):
        # Check that the encoder is robust to non-contiguous initial states.

        # Case 1: Encoder is not stateful

        # A transposition will make the tensors non-contiguous, start them off at the wrong shape
        # and transpose them into the right shape.
        encoder_base = _EncoderBase(stateful=False)
        initial_states = (
            torch.randn(5, 6, 7).permute(1, 0, 2),
            torch.randn(5, 6, 7).permute(1, 0, 2),
        )
        assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous()
        assert initial_states[0].size() == torch.Size([6, 5, 7])
        assert initial_states[1].size() == torch.Size([6, 5, 7])

        # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works
        # whether the initial states are a tuple of tensors or just a single tensor.
        encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states)
        encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0])

        # Case 2: Encoder is stateful

        # For stateful encoders, the initial state may be non-contiguous if its state was
        # previously updated with non-contiguous tensors. As in the non-stateful tests, we check
        # that the encoder still works on initial states for RNNs and LSTMs.
        final_states = initial_states
        # Check LSTM
        encoder_base = _EncoderBase(stateful=True)
        encoder_base._update_states(final_states, self.restoration_indices)
        encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask)
        # Check RNN
        encoder_base.reset_states()
        encoder_base._update_states([final_states[0]], self.restoration_indices)
        encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask)
    def test_non_stateful_states_are_sorted_correctly(self):
        encoder_base = _EncoderBase(stateful=False)
        initial_states = (Variable(torch.randn(6, 5, 7)),
                          Variable(torch.randn(6, 5, 7)))
        # Check that we sort the state for non-stateful encoders. To test
        # we'll just use a "pass through" encoder, as we aren't actually testing
        # the functionality of the encoder here anyway.
        _, states, restoration_indices = encoder_base.sort_and_run_forward(
            lambda *x: x, self.tensor, self.mask, initial_states)
        # Our input tensor had 2 zero length sequences, so we need
        # to concat a tensor of shape
        # (num_layers * num_directions, batch_size - num_valid, hidden_dim),
        # to the output before unsorting it.
        zeros = Variable(torch.zeros([6, 2, 7]))

        # sort_and_run_forward strips fully-padded instances from the batch;
        # in order to use the restoration_indices we need to add back the two
        #  that got stripped. What we get back should match what we started with.
        for state, original in zip(states, initial_states):
            assert list(state.size()) == [6, 3, 7]
            state_with_zeros = torch.cat([state, zeros], 1)
            unsorted_state = state_with_zeros.index_select(
                1, restoration_indices)
            for index in [0, 1, 3]:
                numpy.testing.assert_array_equal(
                    unsorted_state[:, index, :].data.numpy(),
                    original[:, index, :].data.numpy())
    def setUp(self):
        super(TestEncoderBase, self).setUp()
        self.lstm = LSTM(bidirectional=True,
                         num_layers=3,
                         input_size=3,
                         hidden_size=7,
                         batch_first=True)
        self.encoder_base = _EncoderBase(stateful=True)

        tensor = Variable(torch.rand([5, 7, 3]))
        tensor[1, 6:, :] = 0
        tensor[3, 2:, :] = 0
        self.tensor = tensor
        mask = Variable(torch.ones(5, 7))
        mask[1, 6:] = 0
        mask[2, :] = 0  # <= completely masked
        mask[3, 2:] = 0
        mask[4, :] = 0  # <= completely masked
        self.mask = mask

        self.batch_size = 5
        self.num_valid = 3
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        _, _, restoration_indices, sorting_indices = sort_batch_by_length(
            tensor, sequence_lengths)
        self.sorting_indices = sorting_indices
        self.restoration_indices = restoration_indices
Пример #4
0
    def test_non_stateful_states_are_sorted_correctly(self):
        encoder_base = _EncoderBase(stateful=False)
        initial_states = (torch.randn(6, 5, 7), torch.randn(6, 5, 7))
        # Check that we sort the state for non-stateful encoders. To test
        # we'll just use a "pass through" encoder, as we aren't actually testing
        # the functionality of the encoder here anyway.
        _, states, restoration_indices = encoder_base.sort_and_run_forward(lambda *x: x,
                                                                           self.tensor,
                                                                           self.mask,
                                                                           initial_states)
        # Our input tensor had 2 zero length sequences, so we need
        # to concat a tensor of shape
        # (num_layers * num_directions, batch_size - num_valid, hidden_dim),
        # to the output before unsorting it.
        zeros = torch.zeros([6, 2, 7])

        # sort_and_run_forward strips fully-padded instances from the batch;
        # in order to use the restoration_indices we need to add back the two
        #  that got stripped. What we get back should match what we started with.
        for state, original in zip(states, initial_states):
            assert list(state.size()) == [6, 3, 7]
            state_with_zeros = torch.cat([state, zeros], 1)
            unsorted_state = state_with_zeros.index_select(1, restoration_indices)
            for index in [0, 1, 3]:
                numpy.testing.assert_array_equal(unsorted_state[:, index, :].data.numpy(),
                                                 original[:, index, :].data.numpy())
Пример #5
0
    def test_non_contiguous_initial_states_handled_on_gpu(self):
        # Some PyTorch operations which produce contiguous tensors on the CPU produce
        # non-contiguous tensors on the GPU (e.g. forward pass of an RNN when batch_first=True).
        # Accordingly, we perform the same checks from previous test on the GPU to ensure the
        # encoder is not affected by which device it is on.

        # Case 1: Encoder is not stateful

        # A transposition will make the tensors non-contiguous, start them off at the wrong shape
        # and transpose them into the right shape.
        encoder_base = _EncoderBase(stateful=False).cuda()
        initial_states = (
            torch.randn(5, 6, 7).cuda().permute(1, 0, 2),
            torch.randn(5, 6, 7).cuda().permute(1, 0, 2),
        )
        assert not initial_states[0].is_contiguous(
        ) and not initial_states[1].is_contiguous()
        assert initial_states[0].size() == torch.Size([6, 5, 7])
        assert initial_states[1].size() == torch.Size([6, 5, 7])

        # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works
        # whether the initial states are a tuple of tensors or just a single tensor.
        encoder_base.sort_and_run_forward(self.lstm.cuda(), self.tensor.cuda(),
                                          self.mask.cuda(), initial_states)
        encoder_base.sort_and_run_forward(self.rnn.cuda(), self.tensor.cuda(),
                                          self.mask.cuda(), initial_states[0])

        # Case 2: Encoder is stateful

        # For stateful encoders, the initial state may be non-contiguous if its state was
        # previously updated with non-contiguous tensors. As in the non-stateful tests, we check
        # that the encoder still works on initial states for RNNs and LSTMs.
        final_states = initial_states
        # Check LSTM
        encoder_base = _EncoderBase(stateful=True).cuda()
        encoder_base._update_states(final_states,
                                    self.restoration_indices.cuda())
        encoder_base.sort_and_run_forward(self.lstm.cuda(), self.tensor.cuda(),
                                          self.mask.cuda())
        # Check RNN
        encoder_base.reset_states()
        encoder_base._update_states([final_states[0]],
                                    self.restoration_indices.cuda())
        encoder_base.sort_and_run_forward(self.rnn.cuda(), self.tensor.cuda(),
                                          self.mask.cuda())
Пример #6
0
    def test_non_contiguous_input_states_handled(self):
        # Check that the encoder is robust to non-contiguous input states.

        # A transposition will make the tensors non-contiguous, start them off at the wrong shape
        # and transpose them into the right shape.
        encoder_base = _EncoderBase(stateful=False)
        initial_states = (torch.randn(5, 6, 7).t(), torch.randn(5, 6, 7).t())
        assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous()
        assert initial_states[0].size() == torch.Size([6, 5, 7])
        assert initial_states[1].size() == torch.Size([6, 5, 7])

        # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works
        # whether the initial states are a tuple of tensors or just a single tensor.
        encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states)
        encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0])
Пример #7
0
    def test_non_contiguous_input_states_handled(self):
        # Check that the encoder is robust to non-contiguous input states.

        # A transposition will make the tensors non-contiguous, start them off at the wrong shape
        # and transpose them into the right shape.
        encoder_base = _EncoderBase(stateful=False)
        initial_states = (torch.randn(5, 6, 7).permute(1, 0, 2),
                          torch.randn(5, 6, 7).permute(1, 0, 2))
        assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous()
        assert initial_states[0].size() == torch.Size([6, 5, 7])
        assert initial_states[1].size() == torch.Size([6, 5, 7])

        # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works
        # whether the initial states are a tuple of tensors or just a single tensor.
        encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states)
        encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0])
Пример #8
0
    def setUp(self):
        super(TestEncoderBase, self).setUp()
        self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True)
        self.encoder_base = _EncoderBase(stateful=True)

        tensor = Variable(torch.rand([5, 7, 3]))
        tensor[1, 6:, :] = 0
        tensor[3, 2:, :] = 0
        self.tensor = tensor
        mask = Variable(torch.ones(5, 7))
        mask[1, 6:] = 0
        mask[2, :] = 0  # <= completely masked
        mask[3, 2:] = 0
        mask[4, :] = 0  # <= completely masked
        self.mask = mask

        self.batch_size = 5
        self.num_valid = 3
        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        _, _, restoration_indices, sorting_indices = sort_batch_by_length(tensor, sequence_lengths)
        self.sorting_indices = sorting_indices
        self.restoration_indices = restoration_indices