class EncDecClassificationConfig(model_cfg.ModelConfig): # Model global arguments sample_rate: int = 16000 repeat: int = 1 dropout: float = 0.0 separable: bool = True kernel_size_factor: float = 1.0 labels: List[str] = MISSING timesteps: int = MISSING # Dataset configs train_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=True, trim_silence=False) validation_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=False) test_ds: EncDecClassificationDatasetConfig = EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=False) # Optimizer / Scheduler config optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig( sched=model_cfg.SchedConfig()) # Model component configs preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig( ) spec_augment: Optional[ SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig() crop_or_pad_augment: Optional[ CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig( audio_length=timesteps) encoder: ConvASREncoderConfig = ConvASREncoderConfig() decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig( )
def __init__(self, name: str = 'matchboxnet_3x1x64', encoder_cfg_func: Optional[Callable[[], List[Any]]] = None): if name not in EncDecClassificationModelConfigBuilder.VALID_CONFIGS: raise ValueError( "`name` must be one of : \n" f"{EncDecClassificationModelConfigBuilder.VALID_CONFIGS}") self.name = name if 'matchboxnet_3x1x64_vad' in name: if encoder_cfg_func is None: encoder_cfg_func = matchboxnet_3x1x64_vad model_cfg = MatchboxNetVADModelConfig( repeat=1, separable=True, encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(), activation="relu"), decoder=ConvASRDecoderClassificationConfig(), ) elif 'matchboxnet_3x1x64' in name: if encoder_cfg_func is None: encoder_cfg_func = matchboxnet_3x1x64 model_cfg = MatchboxNetModelConfig( repeat=1, separable=False, spec_augment=SpectrogramAugmentationConfig(rect_masks=5, rect_freq=50, rect_time=120), encoder=ConvASREncoderConfig(jasper=encoder_cfg_func(), activation="relu"), decoder=ConvASRDecoderClassificationConfig(), ) else: raise ValueError( f"Invalid config name submitted to {self.__class__.__name__}") super(EncDecClassificationModelConfigBuilder, self).__init__(model_cfg) self.model_cfg: clf_cfg.EncDecClassificationConfig = model_cfg # enable type hinting
class MatchboxNetModelConfig(clf_cfg.EncDecClassificationConfig): # Model global arguments sample_rate: int = 16000 repeat: int = 1 dropout: float = 0.0 separable: bool = True kernel_size_factor: float = 1.0 timesteps: int = 128 labels: List[str] = MISSING # Dataset configs train_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=True, trim_silence=False) validation_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=False) test_ds: clf_cfg.EncDecClassificationDatasetConfig = clf_cfg.EncDecClassificationDatasetConfig( manifest_filepath=None, shuffle=False) # Optimizer / Scheduler config optim: Optional[model_cfg.OptimConfig] = model_cfg.OptimConfig( sched=model_cfg.SchedConfig()) # Model general component configs preprocessor: AudioToMFCCPreprocessorConfig = AudioToMFCCPreprocessorConfig( window_size=0.025) spec_augment: Optional[ SpectrogramAugmentationConfig] = SpectrogramAugmentationConfig( freq_masks=2, time_masks=2, freq_width=15, time_width=25, rect_masks=5, rect_time=25, rect_freq=15) crop_or_pad_augment: Optional[ CropOrPadSpectrogramAugmentationConfig] = CropOrPadSpectrogramAugmentationConfig( audio_length=128) encoder: ConvASREncoderConfig = ConvASREncoderConfig(activation="relu") decoder: ConvASRDecoderClassificationConfig = ConvASRDecoderClassificationConfig( )