예제 #1
0
def test_error_raised_when_label_len_lower_greater_than_upper(data,
                                                              kwargs) -> None:
    """Ensures ``ValueError`` raised when ``label_len[0] > label_leb[1]``."""
    upper = kwargs["label_len"][1]
    invalid_lower = data.draw(st.integers(min_value=upper + 1))
    kwargs["label_len"] = (invalid_lower, upper)
    with pytest.raises(ValueError):
        speech_to_text(**kwargs)
예제 #2
0
def random_speech_to_text(
    draw, ) -> st.SearchStrategy[Tuple[SpeechToTextGen, Dict]]:
    """Generates different speech_to_text functions."""
    kwargs = draw(random_speech_to_text_kwargs())
    return speech_to_text(**kwargs), kwargs
예제 #3
0
def test_error_raised_when_dtype_invalid(data, kwargs) -> None:
    """Ensures ``ValueError`` raised when ``audio_dtype`` invalid."""
    invalid_dtypes = [torch.float16, torch.uint8, torch.int8]
    kwargs["audio_dtype"] = data.draw(st.sampled_from(invalid_dtypes))
    with pytest.raises(ValueError):
        speech_to_text(**kwargs)
예제 #4
0
def test_error_raised_when_audio_channels_less_than_one(data, kwargs) -> None:
    """Ensures ``ValueError`` raised when ``audio_channels < 1``."""
    kwargs["audio_channels"] = data.draw(
        st.integers(min_value=-1000, max_value=0))
    with pytest.raises(ValueError):
        speech_to_text(**kwargs)
예제 #5
0
def test_error_raised_when_label_len_less_than_zero(data, kwargs) -> None:
    """Ensures ``ValueError`` raised when ``label_len[0] < 0``."""
    invalid_lower = data.draw(st.integers(min_value=-1000, max_value=-1))
    kwargs["label_len"] = (invalid_lower, kwargs["label_len"][1])
    with pytest.raises(ValueError):
        speech_to_text(**kwargs)
예제 #6
0
def test_error_raised_when_audio_ms_less_than_one(data, kwargs) -> None:
    """Ensures ``ValueError`` raised when ``audio_ms[0] <= 0``."""
    invalid_lower = data.draw(st.integers(min_value=-1000, max_value=0))
    kwargs["audio_ms"] = (invalid_lower, kwargs["audio_ms"][1])
    with pytest.raises(ValueError):
        speech_to_text(**kwargs)
예제 #7
0
def build(
    dataset: dataset_pb2.Dataset,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    add_seq_len_to_transforms: bool = False,
    download: bool = False,
) -> torch.utils.data.Dataset:
    """Returns a :py:class:`torch.utils.data.Dataset` based on the config.

    Args:
        dataset: A :py:class:`myrtlespeech.protos.dataset_pb2.Dataset` protobuf
            object containing the config for the desired
            :py:class:`torch.utils.data.Dataset`.

        transform: Transform to pass to the
            :py:class:`torch.utils.data.Dataset`.

        target_transform: Target transform to pass to the
            :py:class:`torch.utils.data.Dataset`.

        add_seq_len_to_transforms: If :py:data:`True`, an additional function
            is applied after ``transform`` and ``target_transform`` that takes
            a value and returns a tuple of ``(value,
            torch.tensor(len(value)))``.

        download: If :py:data:`True` and dataset does not exist, download it
            if possible.

    Returns:
        A :py:class:`torch.utils.data.Dataset` based on the config.

    Example:
        >>> from google.protobuf import text_format
        >>> dataset_cfg = text_format.Merge('''
        ... fake_speech_to_text {
        ...   dataset_len: 2;
        ...   audio_ms {
        ...     lower: 10;
        ...     upper: 100;
        ...   }
        ...   label_symbols: "abcde";
        ...   label_len {
        ...     lower: 1;
        ...     upper: 10;
        ...   }
        ... }
        ... ''', dataset_pb2.Dataset())
        >>> dataset = build(dataset_cfg, add_seq_len_to_transforms=True)
        >>> len(dataset)
        2
        >>> (audio, audio_len), (label, label_len) = dataset[0]
        >>> type(audio)
        <class 'torch.Tensor'>
        >>> bool(audio.size(-1) == audio_len)
        True
        >>> type(label)
        <class 'str'>
        >>> bool(len(label) == label_len)
        True
    """
    supported_dataset = dataset.WhichOneof("supported_datasets")

    if add_seq_len_to_transforms:
        transform = _add_seq_len(transform, len_fn=lambda x: x.size(-1))
        target_transform = _add_seq_len(target_transform, len_fn=len)

    if supported_dataset == "fake_speech_to_text":
        cfg = dataset.fake_speech_to_text
        dataset = FakeDataset(
            generator=speech_to_text(
                audio_ms=(cfg.audio_ms.lower, cfg.audio_ms.upper),
                label_symbols=cfg.label_symbols,
                label_len=(cfg.label_len.lower, cfg.label_len.upper),
                audio_transform=transform,
                label_transform=target_transform,
            ),
            dataset_len=cfg.dataset_len,
        )
    elif supported_dataset == "librispeech":
        cfg = dataset.librispeech
        max_duration = cfg.max_secs.value if cfg.HasField("max_secs") else None
        dataset = LibriSpeech(
            root=cfg.root,
            subsets=[
                cfg.SUBSET.DESCRIPTOR.values_by_number[subset_idx].name.lower(
                ).replace("_", "-") for subset_idx in cfg.subset
            ],
            audio_transform=transform,
            label_transform=target_transform,
            download=download,
            max_duration=max_duration,
        )
    else:
        raise ValueError(f"{supported_dataset} not supported")

    return dataset