예제 #1
0
parser.add_argument("transcripts",
                    nargs="+",
                    type=str,
                    default=None,
                    help="Paths to transcript files")

args = parser.parse_args()

transcripts = preprocess_paths(args.transcripts)
tfrecords_dir = preprocess_paths(args.tfrecords_dir)

config = Config(args.config)

if args.sentence_piece:
    print("Loading SentencePiece model ...")
    text_featurizer = SentencePieceFeaturizer.load_from_file(
        config.decoder_config, args.subwords)
elif args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config,
                                                       args.subwords)

ASRTFRecordDataset(data_paths=transcripts,
                   tfrecords_dir=tfrecords_dir,
                   speech_featurizer=None,
                   text_featurizer=text_featurizer,
                   stage=args.mode,
                   shuffle=args.shuffle,
                   tfrecords_shards=args.tfrecords_shards).create_tfrecords()
if args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config,
                                                       args.subwords)
else:
    raise ValueError("subwords must be set")

tf.random.set_seed(0)
assert args.saved

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)
else:
    test_dataset = ASRSliceDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
예제 #3
0
if args.sentence_piece:
    print("Use SentencePiece ...")
    text_featurizer = SentencePieceFeaturizer(config.decoder_config)
elif args.subwords:
    print("Use subwords ...")
    text_featurizer = SubwordFeaturizer(config.decoder_config)
else:
    print("Use characters ...")
    text_featurizer = CharFeaturizer(config.decoder_config)

tf.random.set_seed(0)

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.test_dataset_config))
else:
    test_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.test_dataset_config))

# build model
jasper = Jasper(**config.model_config,
                vocabulary_size=text_featurizer.num_classes)
jasper.make(speech_featurizer.shape)
jasper.load_weights(args.saved)
jasper.summary(line_length=100)
jasper.add_featurizers(speech_featurizer, text_featurizer)
예제 #4
0
strategy = setup_strategy(args.devices)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.ctc_runners import CTCTrainer
from tensorflow_asr.models.deepspeech2 import DeepSpeech2

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

if args.tfrecords:
    train_dataset = ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRTFRecordDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))
else:
    train_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))
예제 #5
0
text_featurizer = CharFeaturizer(config["decoder_config"])
# Build DS2 model
ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
                        arch_config=config["model_config"],
                        num_classes=text_featurizer.num_classes,
                        name="deepspeech2")
ds2_model._build(speech_featurizer.shape)
ds2_model.load_weights(args.saved, by_name=True)
ds2_model.summary(line_length=150)
ds2_model.add_featurizers(speech_featurizer, text_featurizer)

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        data_paths=config["learning_config"]["dataset_config"]["test_paths"],
        tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test", shuffle=False
    )
else:
    test_dataset = ASRSliceDataset(
        data_paths=config["learning_config"]["dataset_config"]["test_paths"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test", shuffle=False
    )

ctc_tester = BaseTester(
    config=config["learning_config"]["running_config"],
    output_name=args.output_name
)
예제 #6
0
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

if args.tfrecords:
    train_dataset = ASRTFRecordDataset(
        data_paths=config.learning_config.dataset_config.train_paths,
        tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config.learning_config.augmentations,
        stage="train",
        cache=args.cache,
        shuffle=True)
    eval_dataset = ASRTFRecordDataset(
        data_paths=config.learning_config.dataset_config.eval_paths,
        tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="eval",
        cache=args.cache,
        shuffle=True)
else:
    train_dataset = ASRSliceDataset(
        data_paths=config.learning_config.dataset_config.train_paths,
예제 #7
0
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])

if args.tfrecords:
    train_dataset = ASRTFRecordDataset(
        data_paths=config["learning_config"]["dataset_config"]["train_paths"],
        tfrecords_dir=config["learning_config"]["dataset_config"]
        ["tfrecords_dir"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config["learning_config"]["augmentations"],
        stage="train",
        cache=args.cache,
        shuffle=True)
    eval_dataset = ASRTFRecordDataset(
        data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
        tfrecords_dir=config["learning_config"]["dataset_config"]
        ["tfrecords_dir"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="eval",
        cache=args.cache,
        shuffle=True)
else:
    train_dataset = ASRSliceDataset(
예제 #8
0
                    default=None,
                    help="Directory to tfrecords")

parser.add_argument("transcripts",
                    nargs="+",
                    type=str,
                    default=None,
                    help="Paths to transcript files")

args = parser.parse_args()

assert args.mode in modes, f"Mode must in {modes}"

transcripts = preprocess_paths(args.transcripts)
tfrecords_dir = preprocess_paths(args.tfrecords_dir)

if args.mode == "train":
    ASRTFRecordDataset(transcripts,
                       tfrecords_dir,
                       None,
                       None,
                       args.mode,
                       shuffle=True).create_tfrecords()
else:
    ASRTFRecordDataset(transcripts,
                       tfrecords_dir,
                       None,
                       None,
                       args.mode,
                       shuffle=False).create_tfrecords()