def main(cfg): trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer) trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: gpu = 1 if cfg.trainer.gpus != 0 else 0 trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision) if asr_model.prepare_test(trainer): trainer.test(asr_model)
def test_constructor(self, asr_model): asr_model.train() # TODO: make proper config and assert correct number of weights # Check to/from config_dict: confdict = asr_model.to_config_dict() instance2 = EncDecRNNTModel.from_config_dict(confdict) assert isinstance(instance2, EncDecRNNTModel)
def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer) # Initialize the weights of the model from another model, if provided via config asr_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: if asr_model.prepare_test(trainer): trainer.test(asr_model)
def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer) # Initialize the weights of the model from another model, if provided via config asr_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(asr_model) if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: gpu = 1 if cfg.trainer.gpus != 0 else 0 test_trainer = pl.Trainer( gpus=gpu, precision=trainer.precision, amp_level=trainer.accelerator_connector.amp_level, amp_backend=cfg.trainer.get("amp_backend", "native"), ) if asr_model.prepare_test(test_trainer): test_trainer.test(asr_model)
def citrinet_rnnt_model(): labels = list(chr(i % 28) for i in range(0, 1024)) model_defaults = {'enc_hidden': 640, 'pred_hidden': 256, 'joint_hidden': 320} preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} encoder = { '_target_': 'nemo.collections.asr.modules.ConvASREncoder', 'feat_in': 80, 'activation': 'relu', 'conv_mask': True, 'jasper': [ { 'filters': 512, 'repeat': 1, 'kernel': [5], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1, }, { 'filters': 512, 'repeat': 5, 'kernel': [11], 'stride': [2], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1, 'stride_last': True, 'residual_mode': 'stride_add', }, { 'filters': 512, 'repeat': 5, 'kernel': [13], 'stride': [1], 'dilation': [1], 'dropout': 0.1, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1, }, { 'filters': 640, 'repeat': 1, 'kernel': [41], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': True, 'separable': True, 'se': True, 'se_context_size': -1, }, ], } decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', 'prednet': {'pred_hidden': 256, 'pred_rnn_layers': 1, 'dropout': 0.0}, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', 'fuse_loss_wer': False, 'jointnet': {'joint_hidden': 320, 'activation': 'relu', 'dropout': 0.0}, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 5}} modelConfig = DictConfig( { 'preprocessor': DictConfig(preprocessor), 'labels': labels, 'model_defaults': DictConfig(model_defaults), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder), 'joint': DictConfig(joint), 'decoding': DictConfig(decoding), } ) citri_model = EncDecRNNTModel(cfg=modelConfig) return citri_model
def asr_model(): preprocessor = { 'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({}) } # fmt: off labels = [ ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", ] # fmt: on model_defaults = {'enc_hidden': 1024, 'pred_hidden': 64} encoder = { 'cls': 'nemo.collections.asr.modules.ConvASREncoder', 'params': { 'feat_in': 64, 'activation': 'relu', 'conv_mask': True, 'jasper': [{ 'filters': model_defaults['enc_hidden'], 'repeat': 1, 'kernel': [1], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1, }], }, } decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', 'prednet': { 'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1 }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', 'jointnet': { 'joint_hidden': 32, 'activation': 'relu' }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} modelConfig = DictConfig({ 'labels': ListConfig(labels), 'preprocessor': DictConfig(preprocessor), 'model_defaults': DictConfig(model_defaults), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder), 'joint': DictConfig(joint), 'decoding': DictConfig(decoding), }) model_instance = EncDecRNNTModel(cfg=modelConfig) return model_instance