예제 #1
0
def test_reduce_catted_sequences(data, batch_sizes, in_dim, hidden_dim,
                                 device):
    sequences = [[
        torch.randn((token_size, in_dim), requires_grad=True, device=device)
        for token_size in data.draw(
            token_size_lists(max_token_size=TINY_TOKEN_SIZE,
                             max_batch_size=TINY_BATCH_SIZE))
    ] for _ in batch_sizes]
    inputs = [token for sequence in sequences for token in sequence]
    catted_sequences = [
        cat_sequence(sequence, device=device) for sequence in sequences
    ]
    packed_sequences = [
        pack_sequence(sequence, device=device) for sequence in sequences
    ]

    rnn = nn.LSTM(
        input_size=in_dim,
        hidden_size=hidden_dim,
        bidirectional=True,
        bias=True,
    ).to(device=device)

    reduction_pack = reduce_catted_sequences(catted_sequences, device=device)
    _, (actual, _) = rnn(reduction_pack)
    actual = rearrange(actual, 'd n x -> n (d x)')

    excepted = []
    for pack in packed_sequences:
        _, (t, _) = rnn(pack)
        excepted.append(rearrange(t, 'd n x -> n (d x)'))
    excepted = pack_sequence(excepted).data

    assert_close(actual, excepted, check_stride=False)
    assert_grad_close(actual, excepted, inputs=inputs)
예제 #2
0
def test_pack_sequence(data, token_sizes, dim, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]

    actual = rua.pack_sequence(inputs, device=device)
    excepted = tgt.pack_sequence(inputs, enforce_sorted=False)

    assert_packed_sequence_close(actual, excepted)
    assert_grad_close(actual.data, excepted.data, inputs=inputs)
예제 #3
0
def test_pad_sequence(data, token_sizes, dim, batch_first, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]

    actual = rua.pad_sequence(inputs, batch_first=batch_first)
    excepted = tgt.pad_sequence(inputs, batch_first=batch_first)

    assert_close(actual, excepted)
    assert_grad_close(actual, excepted, inputs=inputs)
예제 #4
0
def test_reverse_packed_sequence(data, token_sizes, dim, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = pack_sequence(inputs, enforce_sorted=False)

    actual = reverse_packed_sequence(sequence=packed_sequence)
    expected = pack_sequence([sequence.flip(dims=[0]) for sequence in inputs],
                             enforce_sorted=False)

    assert_packed_sequence_close(actual, expected)
    assert_grad_close(actual.data, expected.data, inputs=inputs)
예제 #5
0
def test_cat_packed_sequence(data, token_sizes, dim, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = tgt.pack_sequence(inputs, enforce_sorted=False)

    actual_data, actual_token_sizes = rua.cat_sequence(inputs, device=device)
    expected_data, expected_token_sizes = rua.cat_packed_sequence(
        packed_sequence, device=device)

    assert_close(actual_data, expected_data)
    assert_equal(actual_token_sizes, expected_token_sizes)
    assert_grad_close(actual_data, expected_data, inputs=inputs)
예제 #6
0
def test_cat_padded_sequence(data, token_sizes, dim, batch_first, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    padded_sequence = tgt.pad_sequence(inputs, batch_first=batch_first)
    token_sizes = torch.tensor(token_sizes, device=device)

    actual_data, actual_token_sizes = rua.cat_sequence(inputs, device=device)
    expected_data, expected_token_sizes = rua.cat_padded_sequence(
        padded_sequence, token_sizes, batch_first=batch_first, device=device)

    assert_close(actual_data, expected_data)
    assert_equal(actual_token_sizes, expected_token_sizes)
    assert_grad_close(actual_data, expected_data, inputs=inputs)
예제 #7
0
def test_select_last(data, token_sizes, dim, unsort, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = pack_sequence(inputs, enforce_sorted=False)

    actual = select_last(sequence=packed_sequence, unsort=unsort)
    if not unsort:
        actual = actual[packed_sequence.unsorted_indices]

    expected = torch.stack([sequence[-1] for sequence in inputs], dim=0)

    assert_close(actual, expected)
    assert_grad_close(actual, expected, inputs=inputs)
예제 #8
0
def test_pad_packed_sequence(data, token_sizes, dim, batch_first, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = tgt.pack_sequence(inputs, enforce_sorted=False)
    excepted_token_sizes = torch.tensor(token_sizes,
                                        device=torch.device('cpu'))

    excepted = tgt.pad_sequence(inputs, batch_first=batch_first)
    actual, actual_token_sizes = rua.pad_packed_sequence(
        packed_sequence, batch_first=batch_first)

    assert_close(actual, excepted)
    assert_grad_close(actual, excepted, inputs=inputs)
    assert_equal(actual_token_sizes, excepted_token_sizes)
예제 #9
0
def test_select_tail(data, token_sizes, dim, device):
    drop_first_n = data.draw(
        st.integers(min_value=1, max_value=min(token_sizes)))

    inputs = [
        torch.randn((token_size + 1, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = pack_sequence(inputs, enforce_sorted=False)

    actual = select_tail(sequence=packed_sequence, drop_first_n=drop_first_n)
    expected = pack_sequence([sequence[drop_first_n:] for sequence in inputs],
                             enforce_sorted=False)

    assert_packed_sequence_close(actual, expected)
    assert_grad_close(actual.data, expected.data, inputs=inputs)
예제 #10
0
def test_tree_reduce_packed_sequence(data, token_sizes, dim, device):
    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]

    excepted = pad_sequence(inputs, device=device).sum(dim=0)

    packed_sequence = pack_sequence(inputs, device=device)
    indices = tree_reduce_packed_indices(
        batch_sizes=packed_sequence.batch_sizes)

    actual = tree_reduce_sequence(torch.add)(packed_sequence.data, indices)
    actual = actual[packed_sequence.unsorted_indices]

    assert_close(actual, excepted)
    assert_grad_close(actual, excepted, inputs=inputs)
예제 #11
0
def test_roll_packed_sequence(data, token_sizes, dim, device):
    offset = data.draw(
        st.integers(min_value=-max(token_sizes), max_value=+max(token_sizes)))

    inputs = [
        torch.randn((token_size, dim), device=device, requires_grad=True)
        for token_size in token_sizes
    ]
    packed_sequence = pack_sequence(inputs, enforce_sorted=False)

    actual = roll_packed_sequence(sequence=packed_sequence, shifts=offset)
    expected = pack_sequence(
        [sequence.roll(offset, dims=[0]) for sequence in inputs],
        enforce_sorted=False)

    assert_packed_sequence_close(actual, expected)
    assert_grad_close(actual.data, expected.data, inputs=inputs)