def prepare_training_datasets(
    config: Config,
    speech_featurizer: SpeechFeaturizer,
    text_featurizer: TextFeaturizer,
    tfrecords: bool = False,
    metadata: str = None,
):
    if tfrecords:
        train_dataset = asr_dataset.ASRTFRecordDataset(
            speech_featurizer=speech_featurizer,
            text_featurizer=text_featurizer,
            **vars(config.learning_config.train_dataset_config),
            indefinite=True
        )
        eval_dataset = asr_dataset.ASRTFRecordDataset(
            speech_featurizer=speech_featurizer,
            text_featurizer=text_featurizer,
            **vars(config.learning_config.eval_dataset_config),
            indefinite=True
        )
    else:
        train_dataset = asr_dataset.ASRSliceDataset(
            speech_featurizer=speech_featurizer,
            text_featurizer=text_featurizer,
            **vars(config.learning_config.train_dataset_config),
            indefinite=True
        )
        eval_dataset = asr_dataset.ASRSliceDataset(
            speech_featurizer=speech_featurizer,
            text_featurizer=text_featurizer,
            **vars(config.learning_config.eval_dataset_config),
            indefinite=True
        )
    train_dataset.load_metadata(metadata)
    eval_dataset.load_metadata(metadata)
    return train_dataset, eval_dataset
Esempio n. 2
0
if args.sentence_piece:
    logger.info("Loading SentencePiece model ...")
    text_featurizer = text_featurizers.SentencePieceFeaturizer(
        config.decoder_config)
elif args.subwords:
    logger.info("Loading subwords ...")
    text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config)
else:
    logger.info("Use characters ...")
    text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)

if args.tfrecords:
    train_dataset = asr_dataset.ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config),
        indefinite=True)
    eval_dataset = asr_dataset.ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config),
        indefinite=True)
else:
    train_dataset = asr_dataset.ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config),
        indefinite=True)
    eval_dataset = asr_dataset.ASRSliceDataset(
        speech_featurizer=speech_featurizer,