def test_error_raised_when_label_len_lower_greater_than_upper(data, kwargs) -> None: """Ensures ``ValueError`` raised when ``label_len[0] > label_leb[1]``.""" upper = kwargs["label_len"][1] invalid_lower = data.draw(st.integers(min_value=upper + 1)) kwargs["label_len"] = (invalid_lower, upper) with pytest.raises(ValueError): speech_to_text(**kwargs)
def random_speech_to_text( draw, ) -> st.SearchStrategy[Tuple[SpeechToTextGen, Dict]]: """Generates different speech_to_text functions.""" kwargs = draw(random_speech_to_text_kwargs()) return speech_to_text(**kwargs), kwargs
def test_error_raised_when_dtype_invalid(data, kwargs) -> None: """Ensures ``ValueError`` raised when ``audio_dtype`` invalid.""" invalid_dtypes = [torch.float16, torch.uint8, torch.int8] kwargs["audio_dtype"] = data.draw(st.sampled_from(invalid_dtypes)) with pytest.raises(ValueError): speech_to_text(**kwargs)
def test_error_raised_when_audio_channels_less_than_one(data, kwargs) -> None: """Ensures ``ValueError`` raised when ``audio_channels < 1``.""" kwargs["audio_channels"] = data.draw( st.integers(min_value=-1000, max_value=0)) with pytest.raises(ValueError): speech_to_text(**kwargs)
def test_error_raised_when_label_len_less_than_zero(data, kwargs) -> None: """Ensures ``ValueError`` raised when ``label_len[0] < 0``.""" invalid_lower = data.draw(st.integers(min_value=-1000, max_value=-1)) kwargs["label_len"] = (invalid_lower, kwargs["label_len"][1]) with pytest.raises(ValueError): speech_to_text(**kwargs)
def test_error_raised_when_audio_ms_less_than_one(data, kwargs) -> None: """Ensures ``ValueError`` raised when ``audio_ms[0] <= 0``.""" invalid_lower = data.draw(st.integers(min_value=-1000, max_value=0)) kwargs["audio_ms"] = (invalid_lower, kwargs["audio_ms"][1]) with pytest.raises(ValueError): speech_to_text(**kwargs)
def build( dataset: dataset_pb2.Dataset, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, add_seq_len_to_transforms: bool = False, download: bool = False, ) -> torch.utils.data.Dataset: """Returns a :py:class:`torch.utils.data.Dataset` based on the config. Args: dataset: A :py:class:`myrtlespeech.protos.dataset_pb2.Dataset` protobuf object containing the config for the desired :py:class:`torch.utils.data.Dataset`. transform: Transform to pass to the :py:class:`torch.utils.data.Dataset`. target_transform: Target transform to pass to the :py:class:`torch.utils.data.Dataset`. add_seq_len_to_transforms: If :py:data:`True`, an additional function is applied after ``transform`` and ``target_transform`` that takes a value and returns a tuple of ``(value, torch.tensor(len(value)))``. download: If :py:data:`True` and dataset does not exist, download it if possible. Returns: A :py:class:`torch.utils.data.Dataset` based on the config. Example: >>> from google.protobuf import text_format >>> dataset_cfg = text_format.Merge(''' ... fake_speech_to_text { ... dataset_len: 2; ... audio_ms { ... lower: 10; ... upper: 100; ... } ... label_symbols: "abcde"; ... label_len { ... lower: 1; ... upper: 10; ... } ... } ... ''', dataset_pb2.Dataset()) >>> dataset = build(dataset_cfg, add_seq_len_to_transforms=True) >>> len(dataset) 2 >>> (audio, audio_len), (label, label_len) = dataset[0] >>> type(audio) <class 'torch.Tensor'> >>> bool(audio.size(-1) == audio_len) True >>> type(label) <class 'str'> >>> bool(len(label) == label_len) True """ supported_dataset = dataset.WhichOneof("supported_datasets") if add_seq_len_to_transforms: transform = _add_seq_len(transform, len_fn=lambda x: x.size(-1)) target_transform = _add_seq_len(target_transform, len_fn=len) if supported_dataset == "fake_speech_to_text": cfg = dataset.fake_speech_to_text dataset = FakeDataset( generator=speech_to_text( audio_ms=(cfg.audio_ms.lower, cfg.audio_ms.upper), label_symbols=cfg.label_symbols, label_len=(cfg.label_len.lower, cfg.label_len.upper), audio_transform=transform, label_transform=target_transform, ), dataset_len=cfg.dataset_len, ) elif supported_dataset == "librispeech": cfg = dataset.librispeech max_duration = cfg.max_secs.value if cfg.HasField("max_secs") else None dataset = LibriSpeech( root=cfg.root, subsets=[ cfg.SUBSET.DESCRIPTOR.values_by_number[subset_idx].name.lower( ).replace("_", "-") for subset_idx in cfg.subset ], audio_transform=transform, label_transform=target_transform, download=download, max_duration=max_duration, ) else: raise ValueError(f"{supported_dataset} not supported") return dataset