def __init__(self, name: str = 'quartznet_15x5', encoder_cfg_func: Optional[Callable[[], List[Any]]] = None): if name not in EncDecCTCModelConfigBuilder.VALID_CONFIGS: raise ValueError("`name` must be one of : \n" f"{EncDecCTCModelConfigBuilder.VALID_CONFIGS}") self.name = name if 'quartznet_15x5' in name: if encoder_cfg_func is None: encoder_cfg_func = qn_15x5 model_cfg = QuartzNetModelConfig( repeat=5, separable=True, spec_augment=SpectrogramAugmentationConfig(rect_masks=5, rect_freq=50, rect_time=120), encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(), activation="relu"), decoder=ConvASRDecoderConfig(), ) elif 'jasper_10x5' in name: if encoder_cfg_func is None: encoder_cfg_func = jasper_10x5_dr model_cfg = JasperModelConfig( repeat=5, separable=False, spec_augment=SpectrogramAugmentationConfig(rect_masks=5, rect_freq=50, rect_time=120), encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(), activation="relu"), decoder=ConvASRDecoderConfig(), ) else: raise ValueError( f"Invalid config name submitted to {self.__class__.__name__}") super(EncDecCTCModelConfigBuilder, self).__init__(model_cfg) self.model_cfg: ctc_cfg.EncDecCTCConfig = model_cfg # enable type hinting if 'zh' in name: self.set_dataset_normalize(normalize=False)
class JasperModelConfig(ctc_cfg.EncDecCTCConfig): # Model global arguments sample_rate: int = 16000 repeat: int = 1 dropout: float = 0.0 separable: bool = False labels: List[str] = MISSING # Dataset configs train_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig( manifest_filepath=None, shuffle=True, trim_silence=True) validation_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig( manifest_filepath=None, shuffle=False) test_ds: ctc_cfg.ASRDatasetConfig = ctc_cfg.ASRDatasetConfig( manifest_filepath=None, shuffle=False) # Optimizer / Scheduler config optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig( sched=model_cfg.SchedConfig()) # Model general component configs preprocessor: AudioToMelSpectrogramPreprocessorConfig = AudioToMelSpectrogramPreprocessorConfig( ) spec_augment: Optional[ SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig() encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu") decoder: ConvASRDecoderConfig = ConvASRDecoderConfig()