Пример #1
0
    def __init__(self,
                 name: str = 'quartznet_15x5',
                 encoder_cfg_func: Optional[Callable[[], List[Any]]] = None):
        if name not in EncDecCTCModelConfigBuilder.VALID_CONFIGS:
            raise ValueError("`name` must be one of : \n"
                             f"{EncDecCTCModelConfigBuilder.VALID_CONFIGS}")

        self.name = name

        if 'quartznet_15x5' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = qn_15x5

            model_cfg = QuartzNetModelConfig(
                repeat=5,
                separable=True,
                spec_augment=SpectrogramAugmentationConfig(rect_masks=5,
                                                           rect_freq=50,
                                                           rect_time=120),
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderConfig(),
            )

        elif 'jasper_10x5' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = jasper_10x5_dr

            model_cfg = JasperModelConfig(
                repeat=5,
                separable=False,
                spec_augment=SpectrogramAugmentationConfig(rect_masks=5,
                                                           rect_freq=50,
                                                           rect_time=120),
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderConfig(),
            )

        else:
            raise ValueError(
                f"Invalid config name submitted to {self.__class__.__name__}")

        super(EncDecCTCModelConfigBuilder, self).__init__(model_cfg)
        self.model_cfg: ctc_cfg.EncDecCTCConfig = model_cfg  # enable type hinting

        if 'zh' in name:
            self.set_dataset_normalize(normalize=False)
Пример #2
0
class JasperModelConfig(ctc_cfg.EncDecCTCConfig):
    # Model global arguments
    sample_rate: int = 16000
    repeat: int = 1
    dropout: float = 0.0
    separable: bool = False
    labels: List[str] = MISSING

    # Dataset configs
    train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=True, trim_silence=True)
    validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=False)
    test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=False)

    # Optimizer / Scheduler config
    optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(
        sched=model_cfg.SchedConfig())

    # Model general component configs
    preprocessor: AudioToMelSpectrogramPreprocessorConfig = AudioToMelSpectrogramPreprocessorConfig(
    )
    spec_augment: Optional[
        SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig()
    encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu")
    decoder: ConvASRDecoderConfig = ConvASRDecoderConfig()
class EncDecClassificationConfig(model_cfg.ModelConfig):
    # Model global arguments
    sample_rate: int = 16000
    repeat: int = 1
    dropout: float = 0.0
    separable: bool = True
    kernel_size_factor: float = 1.0
    labels: List[str] = MISSING
    timesteps: int = MISSING

    # Dataset configs
    train_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=True, trim_silence=False)
    validation_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=False)
    test_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=False)

    # Optimizer / Scheduler config
    optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(
        sched=model_cfg.SchedConfig())

    # Model component configs
    preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig(
    )
    spec_augment: Optional[
        SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig()
    crop_or_pad_augment: Optional[
        CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig(
            audio_length=timesteps)

    encoder: ConvASREncoderConfig = ConvASREncoderConfig()
    decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig(
    )
Пример #4
0
    def __init__(self,
                 name: str = 'matchboxnet_3x1x64',
                 encoder_cfg_func: Optional[Callable[[], List[Any]]] = None):
        if name not in EncDecClassificationModelConfigBuilder.VALID_CONFIGS:
            raise ValueError(
                "`name` must be one of : \n"
                f"{EncDecClassificationModelConfigBuilder.VALID_CONFIGS}")

        self.name = name

        if 'matchboxnet_3x1x64_vad' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = matchboxnet_3x1x64_vad

            model_cfg = MatchboxNetVADModelConfig(
                repeat=1,
                separable=True,
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderClassificationConfig(),
            )

        elif 'matchboxnet_3x1x64' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = matchboxnet_3x1x64

            model_cfg = MatchboxNetModelConfig(
                repeat=1,
                separable=False,
                spec_augment=SpectrogramAugmentationConfig(rect_masks=5,
                                                           rect_freq=50,
                                                           rect_time=120),
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderClassificationConfig(),
            )

        else:
            raise ValueError(
                f"Invalid config name submitted to {self.__class__.__name__}")

        super(EncDecClassificationModelConfigBuilder, self).__init__(model_cfg)
        self.model_cfg: clf_cfg.EncDecClassificationConfig = model_cfg  # enable type hinting
Пример #5
0
class MatchboxNetModelConfig(clf_cfg.EncDecClassificationConfig):
    # Model global arguments
    sample_rate: int = 16000
    repeat: int = 1
    dropout: float = 0.0
    separable: bool = True
    kernel_size_factor: float = 1.0
    timesteps: int = 128
    labels: List[str] = MISSING

    # Dataset configs
    train_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=True, trim_silence=False)
    validation_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=False)
    test_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig(
        manifest_filepath=None, shuffle=False)

    # Optimizer / Scheduler config
    optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(
        sched=model_cfg.SchedConfig())

    # Model general component configs
    preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig(
        window_size=0.025)
    spec_augment: Optional[
        SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig(
            freq_masks=2,
            time_masks=2,
            freq_width=15,
            time_width=25,
            rect_masks=5,
            rect_time=25,
            rect_freq=15)
    crop_or_pad_augment: Optional[
        CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig(
            audio_length=128)

    encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu")
    decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig(
    )