Example #1
0
def main(cfg):
    # Generate default asr model config
    asr_model_config = configs.EncDecCTCModelConfig()

    # Merge hydra updates with model config
    # `drop_missing_subconfig=True` is necessary here. Without it, while the data class will instantiate and be added
    # to the config, it contains test_ds.sample_rate = MISSING and test_ds.labels = MISSING.
    # This will raise a OmegaConf MissingMandatoryValue error when processing the dataloaders inside
    # model_utils.resolve_test_dataloaders(model=self) (used for multi data loader support).
    # In general, any operation that tries to use a DictConfig with MISSING in it will fail,
    # other than explicit update operations to change MISSING to some actual value.
    asr_model_config = update_model_config(asr_model_config,
                                           cfg,
                                           drop_missing_subconfigs=True)

    # From here on out, its a general OmegaConf DictConfig, directly usable by our code.
    trainer = pl.Trainer(**asr_model_config.trainer)
    exp_manager(trainer, asr_model_config.get("exp_manager", None))
    asr_model = EncDecCTCModel(cfg=asr_model_config.model, trainer=trainer)

    trainer.fit(asr_model)
Example #2
0
    def test_dataclass_instantiation(self, asr_model):
        model_cfg = configs.EncDecCTCModelConfig()

        # Update mandatory values
        vocabulary = asr_model.decoder.vocabulary
        model_cfg.model.labels = vocabulary

        # Update encoder
        model_cfg.model.encoder.activation = 'relu'
        model_cfg.model.encoder.feat_in = 64
        model_cfg.model.encoder.jasper = [
            nemo_asr.modules.conv_asr.JasperEncoderConfig(
                filters=1024,
                repeat=1,
                kernel=[1],
                stride=[1],
                dilation=[1],
                dropout=0.0,
                residual=False,
                se=True,
                se_context_size=-1,
            )
        ]

        # Update decoder
        model_cfg.model.decoder.feat_in = 1024
        model_cfg.model.decoder.num_classes = 28
        model_cfg.model.decoder.vocabulary = vocabulary

        # Construct the model
        asr_cfg = OmegaConf.create({'model': asr_model.cfg})
        model_cfg_v1 = update_model_config(model_cfg, asr_cfg)
        new_model = EncDecCTCModel(cfg=model_cfg_v1.model)

        assert new_model.num_weights == asr_model.num_weights
        # trainer and exp manager should be there
        # assert 'trainer' in model_cfg_v1
        # assert 'exp_manager' in model_cfg_v1
        # datasets and optim/sched should not be there after ModelPT.update_model_dataclass()
        assert 'train_ds' not in model_cfg_v1.model
        assert 'validation_ds' not in model_cfg_v1.model
        assert 'test_ds' not in model_cfg_v1.model
        assert 'optim' not in model_cfg_v1.model

        # Construct the model, without dropping additional keys
        asr_cfg = OmegaConf.create({'model': asr_model.cfg})
        model_cfg_v2 = update_model_config(model_cfg,
                                           asr_cfg,
                                           drop_missing_subconfigs=False)

        # Assert all components are in config
        # assert 'trainer' in model_cfg_v2
        # assert 'exp_manager' in model_cfg_v2
        assert 'train_ds' in model_cfg_v2.model
        assert 'validation_ds' in model_cfg_v2.model
        assert 'test_ds' in model_cfg_v2.model
        assert 'optim' in model_cfg_v2.model

        # Remove extra components (optim and sched can be kept without issue)
        with open_dict(model_cfg_v2.model):
            model_cfg_v2.model.pop('train_ds')
            model_cfg_v2.model.pop('validation_ds')
            model_cfg_v2.model.pop('test_ds')

        new_model = EncDecCTCModel(cfg=model_cfg_v2.model)

        assert new_model.num_weights == asr_model.num_weights
Example #3
0
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import asdict

import pytorch_lightning as pl

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models import EncDecCTCModel, configs
from nemo.utils.exp_manager import exp_manager
"""
python speech_to_text_structured.py
"""

# Generate default asr model config
cfg = configs.EncDecCTCModelConfig()

# set global values
cfg.model.repeat = 5
cfg.model.separable = True

# fmt: off
LABELS = [
    " ",
    "a",
    "b",
    "c",
    "d",
    "e",
    "f",
    "g",