nargs="*",
                    type=str,
                    default=[],
                    help="Transcript files for generating subwords")

parser.add_argument("--saved",
                    type=str,
                    default=None,
                    help="Path to saved model")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

strategy = setup_tpu(args.tpu_address)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
from tensorflow_asr.models.keras.contextnet import ContextNet
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.sentence_piece:
    print("Loading SentencePiece model ...")
    text_featurizer = SentencePieceFeaturizer.load_from_file(
        config.decoder_config, args.subwords)
예제 #2
0
                    type=int,
                    nargs="*",
                    default=[0],
                    help="Devices' ids to apply distributed training")

parser.add_argument("--mxp",
                    default=False,
                    action="store_true",
                    help="Enable mixed precision")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

strategy = setup_tpu(None)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

if args.tfrecords:
    train_dataset = ASRTFRecordDataset(