Exemple #1
0
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))
else:
    train_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))

ctc_trainer = CTCTrainerGA(text_featurizer,
                           config.learning_config.running_config)
# Build DS2 model
with ctc_trainer.strategy.scope():
    jasper = Jasper(**config.model_config,
                    vocabulary_size=text_featurizer.num_classes)
    jasper._build(speech_featurizer.shape)
    jasper.summary(line_length=120)
# Compile
ctc_trainer.compile(jasper,
                    config.learning_config.optimizer_config,
                    max_to_keep=args.max_ckpts)

ctc_trainer.fit(train_dataset,
                eval_dataset,
                train_bs=args.tbs,
                eval_bs=args.ebs,
                train_acs=args.acs)
Exemple #2
0
from tensorflow_asr.configs.user_config import UserConfig
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.base_runners import BaseTester
from tensorflow_asr.models.jasper import Jasper

tf.random.set_seed(0)
assert args.export

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])
# Build DS2 model
jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
jasper._build(speech_featurizer.shape)
jasper.load_weights(args.saved, by_name=True)
jasper.summary(line_length=120)
jasper.add_featurizers(speech_featurizer, text_featurizer)

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        data_paths=config["learning_config"]["dataset_config"]["test_paths"],
        tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test", shuffle=False
    )
else:
    test_dataset = ASRSliceDataset(