text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config)) else: train_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config)) eval_dataset = ASRSliceDataset( speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config)) ctc_trainer = CTCTrainerGA(text_featurizer, config.learning_config.running_config) # Build DS2 model with ctc_trainer.strategy.scope(): jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.summary(line_length=120) # Compile ctc_trainer.compile(jasper, config.learning_config.optimizer_config, max_to_keep=args.max_ckpts) ctc_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)
from tensorflow_asr.configs.user_config import UserConfig from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer from tensorflow_asr.runners.base_runners import BaseTester from tensorflow_asr.models.jasper import Jasper tf.random.set_seed(0) assert args.export config = UserConfig(DEFAULT_YAML, args.config, learning=True) speech_featurizer = TFSpeechFeaturizer(config["speech_config"]) text_featurizer = CharFeaturizer(config["decoder_config"]) # Build DS2 model jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes) jasper._build(speech_featurizer.shape) jasper.load_weights(args.saved, by_name=True) jasper.summary(line_length=120) jasper.add_featurizers(speech_featurizer, text_featurizer) if args.tfrecords: test_dataset = ASRTFRecordDataset( data_paths=config["learning_config"]["dataset_config"]["test_paths"], tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"], speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, stage="test", shuffle=False ) else: test_dataset = ASRSliceDataset(