예제 #1
0
def tensor_sequences(draw) -> st.SearchStrategy[List[torch.Tensor]]:
    """Returns a search strategy for lists of Tensors with different seq len.

    The Tensors will have the same dtype and all dimensions except the last
    are guaranteed to be the same.
    """
    dtype = draw(torch_np_dtypes())

    # create at least one tensor, all other tensors in sequences will have same
    # dtype and size apart from the seq length
    base_tensor = draw(tensors(min_n_dims=1, dtype=dtype))
    max_seq_len = base_tensor.size(-1)

    sequences = [base_tensor]
    for _ in range(draw(st.integers(min_value=1, max_value=32))):
        # update size to have different sequence length
        size = list(base_tensor.size())
        size[-1] = draw(st.integers(1, max_seq_len))

        # create new tensor and append to sequences
        seq = base_tensor.new().resize_(size)
        # fill with random values
        if dtype == np.float16:
            seq = seq.to(torch.float32).random_().to(torch.float16)
        else:
            seq.random_()

        sequences.append(seq)

    # shuffle to ensure tensor with maximum sequence length is not always first
    random.shuffle(sequences)

    return sequences
예제 #2
0
def ctc_loss_arguments(draw) -> st.SearchStrategy[Dict]:
    """Generates args for class constructor and forward."""
    ret_args = {}

    # generate input tensor
    ret_args["inputs"] = draw(
        tensors(min_n_dims=3, max_n_dims=3, min_dim_size=2, max_dim_size=32))
    # get shapes, convert to Python ints for Hypothesis
    max_seq_len, batch, features = ret_args["inputs"].size()
    max_seq_len = int(max_seq_len)
    batch = int(batch)
    features = int(features)

    # generate CTCLoss arguments
    ret_args["blank"] = draw(st.integers(min_value=0, max_value=features - 1))
    ret_args["reduction"] = draw(st.sampled_from(["none", "mean", "sum"]))
    ret_args["zero_infinity"] = draw(st.booleans())

    # generate remaining CTCLoss.forward arguments
    ret_args["targets"] = torch.tensor(
        draw(
            arrays(
                shape=(batch, max_seq_len),
                dtype=np.int32,
                elements=st.integers(
                    min_value=0, max_value=features -
                    1).filter(lambda x: x != ret_args["blank"]),
            )),
        requires_grad=False,
    )

    ret_args["input_lengths"] = torch.tensor(
        draw(
            arrays(
                shape=(batch, ),
                dtype=np.int32,
                elements=st.integers(min_value=1, max_value=max_seq_len),
            )),
        requires_grad=False,
    )

    target_lengths = []
    for length in ret_args["input_lengths"]:
        # ensure CTC requirement that target length <= input length
        target_lengths.append(draw(st.integers(1, int(length))))
    ret_args["target_lengths"] = torch.tensor(target_lengths,
                                              dtype=torch.int32,
                                              requires_grad=False)

    return ret_args
예제 #3
0
def rnns_and_valid_inputs(draw) -> st.SearchStrategy[Tuple]:
    """Returns a SearchStrategy + inputs + kwargs for an RNN."""

    inp = draw(tensors(min_n_dims=3, max_n_dims=3))
    max_seq_len, batch_size, input_size = inp.size()

    rnn, kwargs = draw(rnns(return_kwargs=True, input_size=input_size))
    hard_lstm = isinstance(rnn, HardLSTM)
    if kwargs["batch_first"]:
        inp = inp.transpose(1, 0)

    hidden_state_setting = draw(rnn_hidden_settings())
    if hidden_state_setting == RNNHidStatus.HID_NONE:
        hid = None
    elif hidden_state_setting == RNNHidStatus.HID_NOT_NONE:
        num_directions = 1 + int(kwargs["bidirectional"])
        hidden_size = kwargs["hidden_size"]
        num_layers = kwargs["num_layers"]
        h_0 = torch.empty(
            [num_layers * num_directions, batch_size, hidden_size],
            requires_grad=False,
        ).normal_()
        if kwargs.get("rnn_type") in [RNNType.BASIC_RNN, RNNType.GRU]:
            hid = h_0
        elif kwargs.get("rnn_type") == RNNType.LSTM or hard_lstm:
            c_0 = h_0  # i.e. same dimensions
            hid = h_0, c_0
    else:
        raise ValueError(
            f"hidden_state_setting == {RNNHidStatus.HID_NOT_NONE} "
            f"not recognized."
        )

    seq_lens = torch.randint(
        low=1,
        high=max_seq_len + 1,
        size=[batch_size],
        dtype=torch.int32,
        requires_grad=False,
    )

    # sort lengths since we require enforce_sorted=True
    seq_lens = seq_lens.sort(descending=True)[0]

    # hidden state
    return rnn, inp, seq_lens, hid, kwargs
)
def test_fully_connected_raises_value_error_hidden_activation_fn_not_none(
    fully_connected_kwargs: Tuple[FullyConnected, Dict],
    hidden_activation_fn: torch.nn.Module,
) -> None:
    """Ensures ValueError raised when no hidden layers and no act fn."""
    _, kwargs = fully_connected_kwargs
    kwargs["num_hidden_layers"] = 0
    kwargs["hidden_activation_fn"] = hidden_activation_fn
    with pytest.raises(ValueError):
        FullyConnected(**kwargs)


@given(
    fully_connected_kwargs=fully_connecteds(return_kwargs=True),
    tensor=tensors(min_n_dims=3, max_n_dims=3),
)
def test_fully_connected_forward_returns_correct_size(
    fully_connected_kwargs: Tuple[FullyConnected, Dict], tensor: torch.Tensor
) -> None:
    # create new FullyConnected that accepts in_features sized input
    _, kwargs = fully_connected_kwargs
    kwargs["in_features"] = tensor.size()[-1]
    fully_connected = FullyConnected(**kwargs)

    max_seq_len, batch_size, *_ = tensor.size()
    in_seq_lens = torch.randint(
        low=1,
        high=max_seq_len + 1,
        size=[batch_size],
        dtype=torch.int32,
@st.composite
def spec_augments(draw) -> st.SearchStrategy[SpecAugment]:
    """Returns a SearchStrategy for SpecAugment."""
    kwargs: Dict = {}
    kwargs["feature_mask"] = draw(st.integers(0, 30))
    kwargs["time_mask"] = draw(st.integers(0, 30))
    kwargs["n_feature_masks"] = draw(st.integers(0, 3))
    kwargs["n_time_masks"] = draw(st.integers(0, 3))
    spec_augment = SpecAugment(**kwargs)
    return spec_augment


# Tests -----------------------------------------------------------------------


@given(data=st.data(), tensor=tensors(min_n_dims=1))
def test_add_sequence_length_returns_correct_seq_len(
        data, tensor: torch.Tensor) -> None:
    """Ensures AddSequenceLength returns correct sequence length."""
    length_dim = data.draw(
        st.integers(min_value=0, max_value=len(tensor.size()) - 1))

    add_seq_len = AddSequenceLength(length_dim=length_dim)

    out, seq_len = add_seq_len(tensor)

    assert torch.all(out == tensor)
    assert seq_len == torch.tensor([tensor.size(length_dim)])


# SpecAugment ---------------------------
예제 #6
0
import hypothesis.strategies as st
import torch
from hypothesis import given
from myrtlespeech.data.preprocess import AddSequenceLength

from tests.utils.utils import tensors

# Fixtures and Strategies -----------------------------------------------------

# Tests -----------------------------------------------------------------------


@given(data=st.data(), tensor=tensors(min_n_dims=1))
def test_add_sequence_length_returns_correct_seq_len(
        data, tensor: torch.Tensor) -> None:
    """Ensures AddSequenceLength returns correct sequence length."""
    length_dim = data.draw(
        st.integers(min_value=0, max_value=len(tensor.size()) - 1))

    add_seq_len = AddSequenceLength(length_dim=length_dim)

    out, seq_len = add_seq_len(tensor)

    assert torch.all(out == tensor)
    assert seq_len == torch.tensor([tensor.size(length_dim)])
예제 #7
0
from typing import Callable
from typing import Tuple

import pytest
import torch
from hypothesis import given
from myrtlespeech.model.utils import Lambda

from tests.utils.utils import tensors

# Fixtures and Strategies -----------------------------------------------------


@pytest.fixture(params=[lambda x: x * 2, lambda x: x + 1])
def lambda_fn(request) -> Tuple[str, Callable]:
    return request.param


# Tests -----------------------------------------------------------------------


@given(tensor=tensors())
def test_lambda_module_applies_lambda_fn(lambda_fn: Callable,
                                         tensor: torch.Tensor) -> None:
    """Ensures Lambda Module applies given lambda_fn to input."""
    lambda_module = Lambda(lambda_fn)
    assert torch.all(lambda_module(tensor) == lambda_fn(tensor))