Ejemplo n.º 1
0
    def __init__(self,
                 name: str = 'quartznet_15x5',
                 encoder_cfg_func: Optional[Callable[[], List[Any]]] = None):
        if name not in EncDecCTCModelConfigBuilder.VALID_CONFIGS:
            raise ValueError("`name` must be one of : \n"
                             f"{EncDecCTCModelConfigBuilder.VALID_CONFIGS}")

        self.name = name

        if 'quartznet_15x5' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = qn_15x5

            model_cfg = QuartzNetModelConfig(
                repeat=5,
                separable=True,
                spec_augment=SpectrogramAugmentationConfig(rect_masks=5,
                                                           rect_freq=50,
                                                           rect_time=120),
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderConfig(),
            )

        elif 'jasper_10x5' in name:
            if encoder_cfg_func is None:
                encoder_cfg_func = jasper_10x5_dr

            model_cfg = JasperModelConfig(
                repeat=5,
                separable=False,
                spec_augment=SpectrogramAugmentationConfig(rect_masks=5,
                                                           rect_freq=50,
                                                           rect_time=120),
                encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(),
                                             activation="relu"),
                decoder=ConvASRDecoderConfig(),
            )

        else:
            raise ValueError(
                f"Invalid config name submitted to {self.__class__.__name__}")

        super(EncDecCTCModelConfigBuilder, self).__init__(model_cfg)
        self.model_cfg: ctc_cfg.EncDecCTCConfig = model_cfg  # enable type hinting

        if 'zh' in name:
            self.set_dataset_normalize(normalize=False)
Ejemplo n.º 2
0
class JasperModelConfig(ctc_cfg.EncDecCTCConfig):
    # Model global arguments
    sample_rate: int = 16000
    repeat: int = 1
    dropout: float = 0.0
    separable: bool = False
    labels: List[str] = MISSING

    # Dataset configs
    train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=True, trim_silence=True)
    validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=False)
    test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig(
        manifest_filepath=None, shuffle=False)

    # Optimizer / Scheduler config
    optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig(
        sched=model_cfg.SchedConfig())

    # Model general component configs
    preprocessor: AudioToMelSpectrogramPreprocessorConfig = AudioToMelSpectrogramPreprocessorConfig(
    )
    spec_augment: Optional[
        SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig()
    encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu")
    decoder: ConvASRDecoderConfig = ConvASRDecoderConfig()