def test_reduce_catted_sequences(data, batch_sizes, in_dim, hidden_dim, device): sequences = [[ torch.randn((token_size, in_dim), requires_grad=True, device=device) for token_size in data.draw( token_size_lists(max_token_size=TINY_TOKEN_SIZE, max_batch_size=TINY_BATCH_SIZE)) ] for _ in batch_sizes] inputs = [token for sequence in sequences for token in sequence] catted_sequences = [ cat_sequence(sequence, device=device) for sequence in sequences ] packed_sequences = [ pack_sequence(sequence, device=device) for sequence in sequences ] rnn = nn.LSTM( input_size=in_dim, hidden_size=hidden_dim, bidirectional=True, bias=True, ).to(device=device) reduction_pack = reduce_catted_sequences(catted_sequences, device=device) _, (actual, _) = rnn(reduction_pack) actual = rearrange(actual, 'd n x -> n (d x)') excepted = [] for pack in packed_sequences: _, (t, _) = rnn(pack) excepted.append(rearrange(t, 'd n x -> n (d x)')) excepted = pack_sequence(excepted).data assert_close(actual, excepted, check_stride=False) assert_grad_close(actual, excepted, inputs=inputs)
def test_pack_sequence(data, token_sizes, dim, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] actual = rua.pack_sequence(inputs, device=device) excepted = tgt.pack_sequence(inputs, enforce_sorted=False) assert_packed_sequence_close(actual, excepted) assert_grad_close(actual.data, excepted.data, inputs=inputs)
def test_pad_sequence(data, token_sizes, dim, batch_first, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] actual = rua.pad_sequence(inputs, batch_first=batch_first) excepted = tgt.pad_sequence(inputs, batch_first=batch_first) assert_close(actual, excepted) assert_grad_close(actual, excepted, inputs=inputs)
def test_reverse_packed_sequence(data, token_sizes, dim, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = pack_sequence(inputs, enforce_sorted=False) actual = reverse_packed_sequence(sequence=packed_sequence) expected = pack_sequence([sequence.flip(dims=[0]) for sequence in inputs], enforce_sorted=False) assert_packed_sequence_close(actual, expected) assert_grad_close(actual.data, expected.data, inputs=inputs)
def test_cat_packed_sequence(data, token_sizes, dim, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = tgt.pack_sequence(inputs, enforce_sorted=False) actual_data, actual_token_sizes = rua.cat_sequence(inputs, device=device) expected_data, expected_token_sizes = rua.cat_packed_sequence( packed_sequence, device=device) assert_close(actual_data, expected_data) assert_equal(actual_token_sizes, expected_token_sizes) assert_grad_close(actual_data, expected_data, inputs=inputs)
def test_cat_padded_sequence(data, token_sizes, dim, batch_first, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] padded_sequence = tgt.pad_sequence(inputs, batch_first=batch_first) token_sizes = torch.tensor(token_sizes, device=device) actual_data, actual_token_sizes = rua.cat_sequence(inputs, device=device) expected_data, expected_token_sizes = rua.cat_padded_sequence( padded_sequence, token_sizes, batch_first=batch_first, device=device) assert_close(actual_data, expected_data) assert_equal(actual_token_sizes, expected_token_sizes) assert_grad_close(actual_data, expected_data, inputs=inputs)
def test_select_last(data, token_sizes, dim, unsort, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = pack_sequence(inputs, enforce_sorted=False) actual = select_last(sequence=packed_sequence, unsort=unsort) if not unsort: actual = actual[packed_sequence.unsorted_indices] expected = torch.stack([sequence[-1] for sequence in inputs], dim=0) assert_close(actual, expected) assert_grad_close(actual, expected, inputs=inputs)
def test_pad_packed_sequence(data, token_sizes, dim, batch_first, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = tgt.pack_sequence(inputs, enforce_sorted=False) excepted_token_sizes = torch.tensor(token_sizes, device=torch.device('cpu')) excepted = tgt.pad_sequence(inputs, batch_first=batch_first) actual, actual_token_sizes = rua.pad_packed_sequence( packed_sequence, batch_first=batch_first) assert_close(actual, excepted) assert_grad_close(actual, excepted, inputs=inputs) assert_equal(actual_token_sizes, excepted_token_sizes)
def test_select_tail(data, token_sizes, dim, device): drop_first_n = data.draw( st.integers(min_value=1, max_value=min(token_sizes))) inputs = [ torch.randn((token_size + 1, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = pack_sequence(inputs, enforce_sorted=False) actual = select_tail(sequence=packed_sequence, drop_first_n=drop_first_n) expected = pack_sequence([sequence[drop_first_n:] for sequence in inputs], enforce_sorted=False) assert_packed_sequence_close(actual, expected) assert_grad_close(actual.data, expected.data, inputs=inputs)
def test_tree_reduce_packed_sequence(data, token_sizes, dim, device): inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] excepted = pad_sequence(inputs, device=device).sum(dim=0) packed_sequence = pack_sequence(inputs, device=device) indices = tree_reduce_packed_indices( batch_sizes=packed_sequence.batch_sizes) actual = tree_reduce_sequence(torch.add)(packed_sequence.data, indices) actual = actual[packed_sequence.unsorted_indices] assert_close(actual, excepted) assert_grad_close(actual, excepted, inputs=inputs)
def test_roll_packed_sequence(data, token_sizes, dim, device): offset = data.draw( st.integers(min_value=-max(token_sizes), max_value=+max(token_sizes))) inputs = [ torch.randn((token_size, dim), device=device, requires_grad=True) for token_size in token_sizes ] packed_sequence = pack_sequence(inputs, enforce_sorted=False) actual = roll_packed_sequence(sequence=packed_sequence, shifts=offset) expected = pack_sequence( [sequence.roll(offset, dims=[0]) for sequence in inputs], enforce_sorted=False) assert_packed_sequence_close(actual, expected) assert_grad_close(actual.data, expected.data, inputs=inputs)