def test_non_contiguous_initial_states_handled(self): # Check that the encoder is robust to non-contiguous initial states. # Case 1: Encoder is not stateful # A transposition will make the tensors non-contiguous, start them off at the wrong shape # and transpose them into the right shape. encoder_base = _EncoderBase(stateful=False) initial_states = ( torch.randn(5, 6, 7).permute(1, 0, 2), torch.randn(5, 6, 7).permute(1, 0, 2), ) assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous() assert initial_states[0].size() == torch.Size([6, 5, 7]) assert initial_states[1].size() == torch.Size([6, 5, 7]) # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works # whether the initial states are a tuple of tensors or just a single tensor. encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states) encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0]) # Case 2: Encoder is stateful # For stateful encoders, the initial state may be non-contiguous if its state was # previously updated with non-contiguous tensors. As in the non-stateful tests, we check # that the encoder still works on initial states for RNNs and LSTMs. final_states = initial_states # Check LSTM encoder_base = _EncoderBase(stateful=True) encoder_base._update_states(final_states, self.restoration_indices) encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask) # Check RNN encoder_base.reset_states() encoder_base._update_states([final_states[0]], self.restoration_indices) encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask)
def test_non_stateful_states_are_sorted_correctly(self): encoder_base = _EncoderBase(stateful=False) initial_states = (Variable(torch.randn(6, 5, 7)), Variable(torch.randn(6, 5, 7))) # Check that we sort the state for non-stateful encoders. To test # we'll just use a "pass through" encoder, as we aren't actually testing # the functionality of the encoder here anyway. _, states, restoration_indices = encoder_base.sort_and_run_forward( lambda *x: x, self.tensor, self.mask, initial_states) # Our input tensor had 2 zero length sequences, so we need # to concat a tensor of shape # (num_layers * num_directions, batch_size - num_valid, hidden_dim), # to the output before unsorting it. zeros = Variable(torch.zeros([6, 2, 7])) # sort_and_run_forward strips fully-padded instances from the batch; # in order to use the restoration_indices we need to add back the two # that got stripped. What we get back should match what we started with. for state, original in zip(states, initial_states): assert list(state.size()) == [6, 3, 7] state_with_zeros = torch.cat([state, zeros], 1) unsorted_state = state_with_zeros.index_select( 1, restoration_indices) for index in [0, 1, 3]: numpy.testing.assert_array_equal( unsorted_state[:, index, :].data.numpy(), original[:, index, :].data.numpy())
def setUp(self): super(TestEncoderBase, self).setUp() self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True) self.encoder_base = _EncoderBase(stateful=True) tensor = Variable(torch.rand([5, 7, 3])) tensor[1, 6:, :] = 0 tensor[3, 2:, :] = 0 self.tensor = tensor mask = Variable(torch.ones(5, 7)) mask[1, 6:] = 0 mask[2, :] = 0 # <= completely masked mask[3, 2:] = 0 mask[4, :] = 0 # <= completely masked self.mask = mask self.batch_size = 5 self.num_valid = 3 sequence_lengths = get_lengths_from_binary_sequence_mask(mask) _, _, restoration_indices, sorting_indices = sort_batch_by_length( tensor, sequence_lengths) self.sorting_indices = sorting_indices self.restoration_indices = restoration_indices
def test_non_stateful_states_are_sorted_correctly(self): encoder_base = _EncoderBase(stateful=False) initial_states = (torch.randn(6, 5, 7), torch.randn(6, 5, 7)) # Check that we sort the state for non-stateful encoders. To test # we'll just use a "pass through" encoder, as we aren't actually testing # the functionality of the encoder here anyway. _, states, restoration_indices = encoder_base.sort_and_run_forward(lambda *x: x, self.tensor, self.mask, initial_states) # Our input tensor had 2 zero length sequences, so we need # to concat a tensor of shape # (num_layers * num_directions, batch_size - num_valid, hidden_dim), # to the output before unsorting it. zeros = torch.zeros([6, 2, 7]) # sort_and_run_forward strips fully-padded instances from the batch; # in order to use the restoration_indices we need to add back the two # that got stripped. What we get back should match what we started with. for state, original in zip(states, initial_states): assert list(state.size()) == [6, 3, 7] state_with_zeros = torch.cat([state, zeros], 1) unsorted_state = state_with_zeros.index_select(1, restoration_indices) for index in [0, 1, 3]: numpy.testing.assert_array_equal(unsorted_state[:, index, :].data.numpy(), original[:, index, :].data.numpy())
def test_non_contiguous_initial_states_handled_on_gpu(self): # Some PyTorch operations which produce contiguous tensors on the CPU produce # non-contiguous tensors on the GPU (e.g. forward pass of an RNN when batch_first=True). # Accordingly, we perform the same checks from previous test on the GPU to ensure the # encoder is not affected by which device it is on. # Case 1: Encoder is not stateful # A transposition will make the tensors non-contiguous, start them off at the wrong shape # and transpose them into the right shape. encoder_base = _EncoderBase(stateful=False).cuda() initial_states = ( torch.randn(5, 6, 7).cuda().permute(1, 0, 2), torch.randn(5, 6, 7).cuda().permute(1, 0, 2), ) assert not initial_states[0].is_contiguous( ) and not initial_states[1].is_contiguous() assert initial_states[0].size() == torch.Size([6, 5, 7]) assert initial_states[1].size() == torch.Size([6, 5, 7]) # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works # whether the initial states are a tuple of tensors or just a single tensor. encoder_base.sort_and_run_forward(self.lstm.cuda(), self.tensor.cuda(), self.mask.cuda(), initial_states) encoder_base.sort_and_run_forward(self.rnn.cuda(), self.tensor.cuda(), self.mask.cuda(), initial_states[0]) # Case 2: Encoder is stateful # For stateful encoders, the initial state may be non-contiguous if its state was # previously updated with non-contiguous tensors. As in the non-stateful tests, we check # that the encoder still works on initial states for RNNs and LSTMs. final_states = initial_states # Check LSTM encoder_base = _EncoderBase(stateful=True).cuda() encoder_base._update_states(final_states, self.restoration_indices.cuda()) encoder_base.sort_and_run_forward(self.lstm.cuda(), self.tensor.cuda(), self.mask.cuda()) # Check RNN encoder_base.reset_states() encoder_base._update_states([final_states[0]], self.restoration_indices.cuda()) encoder_base.sort_and_run_forward(self.rnn.cuda(), self.tensor.cuda(), self.mask.cuda())
def test_non_contiguous_input_states_handled(self): # Check that the encoder is robust to non-contiguous input states. # A transposition will make the tensors non-contiguous, start them off at the wrong shape # and transpose them into the right shape. encoder_base = _EncoderBase(stateful=False) initial_states = (torch.randn(5, 6, 7).t(), torch.randn(5, 6, 7).t()) assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous() assert initial_states[0].size() == torch.Size([6, 5, 7]) assert initial_states[1].size() == torch.Size([6, 5, 7]) # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works # whether the initial states are a tuple of tensors or just a single tensor. encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states) encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0])
def test_non_contiguous_input_states_handled(self): # Check that the encoder is robust to non-contiguous input states. # A transposition will make the tensors non-contiguous, start them off at the wrong shape # and transpose them into the right shape. encoder_base = _EncoderBase(stateful=False) initial_states = (torch.randn(5, 6, 7).permute(1, 0, 2), torch.randn(5, 6, 7).permute(1, 0, 2)) assert not initial_states[0].is_contiguous() and not initial_states[1].is_contiguous() assert initial_states[0].size() == torch.Size([6, 5, 7]) assert initial_states[1].size() == torch.Size([6, 5, 7]) # We'll pass them through an LSTM encoder and a vanilla RNN encoder to make sure it works # whether the initial states are a tuple of tensors or just a single tensor. encoder_base.sort_and_run_forward(self.lstm, self.tensor, self.mask, initial_states) encoder_base.sort_and_run_forward(self.rnn, self.tensor, self.mask, initial_states[0])
def setUp(self): super(TestEncoderBase, self).setUp() self.lstm = LSTM(bidirectional=True, num_layers=3, input_size=3, hidden_size=7, batch_first=True) self.encoder_base = _EncoderBase(stateful=True) tensor = Variable(torch.rand([5, 7, 3])) tensor[1, 6:, :] = 0 tensor[3, 2:, :] = 0 self.tensor = tensor mask = Variable(torch.ones(5, 7)) mask[1, 6:] = 0 mask[2, :] = 0 # <= completely masked mask[3, 2:] = 0 mask[4, :] = 0 # <= completely masked self.mask = mask self.batch_size = 5 self.num_valid = 3 sequence_lengths = get_lengths_from_binary_sequence_mask(mask) _, _, restoration_indices, sorting_indices = sort_batch_by_length(tensor, sequence_lengths) self.sorting_indices = sorting_indices self.restoration_indices = restoration_indices