예제 #1
0
    def test_EncDecCTCDatasetConfig_for_AudioToBPEDataset(self):
        # ignore some additional arguments as dataclass is generic
        IGNORE_ARGS = [
            'is_tarred',
            'num_workers',
            'batch_size',
            'tarred_audio_filepaths',
            'shuffle',
            'pin_memory',
            'drop_last',
            'tarred_shard_strategy',
            'shuffle_n',
            'parser',
            'normalize',
            'unk_index',
            'pad_id',
            'bos_id',
            'eos_id',
            'blank_index',
        ]

        REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}

        result = assert_dataclass_signature_match(
            audio_to_text.AudioToBPEDataset,
            configs.EncDecCTCDatasetConfig,
            ignore_args=IGNORE_ARGS,
            remap_args=REMAP_ARGS,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #2
0
    def test_EncDecCTCDatasetConfig_for_TarredAudioToCharDataset(self):
        # ignore some additional arguments as dataclass is generic
        IGNORE_ARGS = [
            'is_tarred',
            'num_workers',
            'batch_size',
            'shuffle',
            'pin_memory',
            'drop_last',
            'global_rank',
            'world_size',
            'use_start_end_token',
        ]

        REMAP_ARGS = {
            'trim_silence': 'trim',
            'tarred_audio_filepaths': 'audio_tar_filepaths',
            'tarred_shard_strategy': 'shard_strategy',
            'shuffle_n': 'shuffle',
        }

        result = assert_dataclass_signature_match(
            audio_to_text.TarredAudioToCharDataset,
            configs.EncDecCTCDatasetConfig,
            ignore_args=IGNORE_ARGS,
            remap_args=REMAP_ARGS,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #3
0
    def test_ptl_config(self):
        PTL_DEPRECATED = ['distributed_backend', 'automatic_optimization', 'gpus', 'num_processes']

        result = config_utils.assert_dataclass_signature_match(ptl.Trainer, TrainerConfig, ignore_args=PTL_DEPRECATED)
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #4
0
    def test_MaskedPatchAugmentation_config(self):
        # Test that dataclass matches signature of module
        result = config_utils.assert_dataclass_signature_match(
            modules.MaskedPatchAugmentation, modules.audio_preprocessing.MaskedPatchAugmentationConfig,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #5
0
    def test_AudioToMelSpectrogramPreprocessor_config(self):
        # Test that dataclass matches signature of module
        result = config_utils.assert_dataclass_signature_match(
            modules.AudioToMelSpectrogramPreprocessor,
            modules.audio_preprocessing.AudioToMelSpectrogramPreprocessorConfig,
        )
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #6
0
    def test_linear_adapter_config(self):
        IGNORED_ARGS = ['_target_']

        result = config_utils.assert_dataclass_signature_match(
            adapter_modules.LinearAdapter, adapter_modules.LinearAdapterConfig, ignore_args=IGNORED_ARGS
        )

        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #7
0
    def test_BeamRNNTInferConfig(self):
        IGNORE_ARGS = ['decoder_model', 'joint_model', 'blank_index']

        result = assert_dataclass_signature_match(
            beam_decode.BeamRNNTInfer, beam_decode.BeamRNNTInferConfig, ignore_args=IGNORE_ARGS
        )

        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #8
0
    def test_all_args_dont_exist(self, cls):
        @dataclass
        class DummyDataClass:
            a: int = -1
            b: int = 5
            c: int = 0

        result = config_utils.assert_dataclass_signature_match(cls, DummyDataClass)
        signatures_match, cls_subset, dataclass_subset = result

        assert not signatures_match
        assert len(cls_subset) > 0
        assert len(dataclass_subset) == 0
예제 #9
0
    def test_all_args_exist(self, cls):
        @dataclass
        class DummyDataClass:
            a: int = -1
            b: int = 5
            c: int = 0
            d: Any = None

        result = config_utils.assert_dataclass_signature_match(cls, DummyDataClass)
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #10
0
    def test_args_exist_but_is_remapped(self, cls):
        @dataclass
        class DummyDataClass:
            a: int = -1
            b: int = 5
            c: int = 0
            e: Any = None  # Assume remapped

        result = config_utils.assert_dataclass_signature_match(cls, DummyDataClass, remap_args={'e': 'd'})
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #11
0
    def test_extra_args_exist_but_is_ignored(self, cls):
        @dataclass
        class DummyDataClass:
            a: int = -1
            b: int = 5
            c: int = 0
            d: Any = None
            e: float = 0.0  # Assume ignored

        result = config_utils.assert_dataclass_signature_match(cls, DummyDataClass, ignore_args=['e'])
        signatures_match, cls_subset, dataclass_subset = result

        assert signatures_match
        assert cls_subset is None
        assert dataclass_subset is None
예제 #12
0
    def test_extra_args_exist(self, cls):
        @dataclass
        class DummyDataClass:
            a: int = -1
            b: int = 5
            c: int = 0
            d: Any = None
            e: float = 0.0

        result = config_utils.assert_dataclass_signature_match(cls, DummyDataClass)
        signatures_match, cls_subset, dataclass_subset = result

        assert not signatures_match
        assert len(cls_subset) == 0
        assert len(dataclass_subset) > 0