예제 #1
0
def main(cfg):
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer)

    trainer.fit(asr_model)

    if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        trainer = pl.Trainer(gpus=gpu, precision=cfg.trainer.precision)
        if asr_model.prepare_test(trainer):
            trainer.test(asr_model)
예제 #2
0
 def test_constructor(self, asr_model):
     asr_model.train()
     # TODO: make proper config and assert correct number of weights
     # Check to/from config_dict:
     confdict = asr_model.to_config_dict()
     instance2 = EncDecRNNTModel.from_config_dict(confdict)
     assert isinstance(instance2, EncDecRNNTModel)
예제 #3
0
def main(cfg):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer)

    # Initialize the weights of the model from another model, if provided via config
    asr_model.maybe_init_from_pretrained_checkpoint(cfg)

    trainer.fit(asr_model)

    if hasattr(cfg.model,
               'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        if asr_model.prepare_test(trainer):
            trainer.test(asr_model)
예제 #4
0
def main(cfg):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    asr_model = EncDecRNNTModel(cfg=cfg.model, trainer=trainer)

    # Initialize the weights of the model from another model, if provided via config
    asr_model.maybe_init_from_pretrained_checkpoint(cfg)

    trainer.fit(asr_model)

    if hasattr(cfg.model,
               'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        test_trainer = pl.Trainer(
            gpus=gpu,
            precision=trainer.precision,
            amp_level=trainer.accelerator_connector.amp_level,
            amp_backend=cfg.trainer.get("amp_backend", "native"),
        )
        if asr_model.prepare_test(test_trainer):
            test_trainer.test(asr_model)
예제 #5
0
def citrinet_rnnt_model():
    labels = list(chr(i % 28) for i in range(0, 1024))
    model_defaults = {'enc_hidden': 640, 'pred_hidden': 256, 'joint_hidden': 320}

    preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
    encoder = {
        '_target_': 'nemo.collections.asr.modules.ConvASREncoder',
        'feat_in': 80,
        'activation': 'relu',
        'conv_mask': True,
        'jasper': [
            {
                'filters': 512,
                'repeat': 1,
                'kernel': [5],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': False,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
            {
                'filters': 512,
                'repeat': 5,
                'kernel': [11],
                'stride': [2],
                'dilation': [1],
                'dropout': 0.1,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
                'stride_last': True,
                'residual_mode': 'stride_add',
            },
            {
                'filters': 512,
                'repeat': 5,
                'kernel': [13],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.1,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
            {
                'filters': 640,
                'repeat': 1,
                'kernel': [41],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': True,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            },
        ],
    }

    decoder = {
        '_target_': 'nemo.collections.asr.modules.RNNTDecoder',
        'prednet': {'pred_hidden': 256, 'pred_rnn_layers': 1, 'dropout': 0.0},
    }

    joint = {
        '_target_': 'nemo.collections.asr.modules.RNNTJoint',
        'fuse_loss_wer': False,
        'jointnet': {'joint_hidden': 320, 'activation': 'relu', 'dropout': 0.0},
    }

    decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 5}}

    modelConfig = DictConfig(
        {
            'preprocessor': DictConfig(preprocessor),
            'labels': labels,
            'model_defaults': DictConfig(model_defaults),
            'encoder': DictConfig(encoder),
            'decoder': DictConfig(decoder),
            'joint': DictConfig(joint),
            'decoding': DictConfig(decoding),
        }
    )
    citri_model = EncDecRNNTModel(cfg=modelConfig)
    return citri_model
예제 #6
0
def asr_model():
    preprocessor = {
        'cls':
        'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
        'params': dict({})
    }

    # fmt: off
    labels = [
        ' ',
        'a',
        'b',
        'c',
        'd',
        'e',
        'f',
        'g',
        'h',
        'i',
        'j',
        'k',
        'l',
        'm',
        'n',
        'o',
        'p',
        'q',
        'r',
        's',
        't',
        'u',
        'v',
        'w',
        'x',
        'y',
        'z',
        "'",
    ]
    # fmt: on

    model_defaults = {'enc_hidden': 1024, 'pred_hidden': 64}

    encoder = {
        'cls': 'nemo.collections.asr.modules.ConvASREncoder',
        'params': {
            'feat_in':
            64,
            'activation':
            'relu',
            'conv_mask':
            True,
            'jasper': [{
                'filters': model_defaults['enc_hidden'],
                'repeat': 1,
                'kernel': [1],
                'stride': [1],
                'dilation': [1],
                'dropout': 0.0,
                'residual': False,
                'separable': True,
                'se': True,
                'se_context_size': -1,
            }],
        },
    }

    decoder = {
        '_target_': 'nemo.collections.asr.modules.RNNTDecoder',
        'prednet': {
            'pred_hidden': model_defaults['pred_hidden'],
            'pred_rnn_layers': 1
        },
    }

    joint = {
        '_target_': 'nemo.collections.asr.modules.RNNTJoint',
        'jointnet': {
            'joint_hidden': 32,
            'activation': 'relu'
        },
    }

    decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}}

    modelConfig = DictConfig({
        'labels': ListConfig(labels),
        'preprocessor': DictConfig(preprocessor),
        'model_defaults': DictConfig(model_defaults),
        'encoder': DictConfig(encoder),
        'decoder': DictConfig(decoder),
        'joint': DictConfig(joint),
        'decoding': DictConfig(decoding),
    })

    model_instance = EncDecRNNTModel(cfg=modelConfig)
    return model_instance