DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") tf.keras.backend.clear_session() parser = argparse.ArgumentParser(prog="Vocab Training with SentencePiece") parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") args = parser.parse_args() strategy = setup_strategy(args.devices) from tensorflow_asr.configs.config import Config from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer config = Config(args.config) logger.info("Generating subwords ...") text_featurizer = SentencePieceFeaturizer.build_from_corpus( config.decoder_config)
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") parser.add_argument("--pretrained", type=str, default=None, help="Path to pretrained model") args = parser.parse_args() tf.config.optimizer.set_experimental_options( {"auto_mixed_precision": args.mxp}) strategy = env_util.setup_strategy(args.devices) from tensorflow_asr.configs.config import Config from tensorflow_asr.datasets import asr_dataset from tensorflow_asr.featurizers import speech_featurizers, text_featurizers from tensorflow_asr.models.transducer.conformer import Conformer from tensorflow_asr.optimizers.schedules import TransformerSchedule config = Config(args.config) speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config) if args.sentence_piece: logger.info("Loading SentencePiece model ...") text_featurizer = text_featurizers.SentencePieceFeaturizer( config.decoder_config) elif args.subwords:
def main( config: str = DEFAULT_YAML, tfrecords: bool = False, sentence_piece: bool = False, subwords: bool = False, bs: int = None, spx: int = 1, metadata: str = None, static_length: bool = False, devices: list = [0], mxp: bool = False, pretrained: str = None, ): tf.keras.backend.clear_session() tf.config.optimizer.set_experimental_options({"auto_mixed_precision": mxp}) strategy = env_util.setup_strategy(devices) config = Config(config) speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers( config=config, subwords=subwords, sentence_piece=sentence_piece, ) train_dataset, eval_dataset = dataset_helpers.prepare_training_datasets( config=config, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, tfrecords=tfrecords, metadata=metadata, ) if not static_length: speech_featurizer.reset_length() text_featurizer.reset_length() train_data_loader, eval_data_loader, global_batch_size = dataset_helpers.prepare_training_data_loaders( config=config, train_dataset=train_dataset, eval_dataset=eval_dataset, strategy=strategy, batch_size=bs, ) with strategy.scope(): deepspeech2 = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes) deepspeech2.make(speech_featurizer.shape, batch_size=global_batch_size) if pretrained: deepspeech2.load_weights(pretrained, by_name=True, skip_mismatch=True) deepspeech2.summary(line_length=100) deepspeech2.compile( optimizer=config.learning_config.optimizer_config, experimental_steps_per_execution=spx, global_batch_size=global_batch_size, blank=text_featurizer.blank, ) callbacks = [ tf.keras.callbacks.ModelCheckpoint( **config.learning_config.running_config.checkpoint), tf.keras.callbacks.experimental.BackupAndRestore( config.learning_config.running_config.states_dir), tf.keras.callbacks.TensorBoard( **config.learning_config.running_config.tensorboard), ] deepspeech2.fit( train_data_loader, epochs=config.learning_config.running_config.num_epochs, validation_data=eval_data_loader, callbacks=callbacks, steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps if eval_data_loader else None, )