Пример #1
0
def test_conformer():
    config = Config(DEFAULT_YAML, learning=False)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = Conformer(vocabulary_size=text_featurizer.num_classes, **config.model_config)

    model._build(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.convert()

    print("Converted successfully with timestamp")
Пример #2
0
    def __init__(self, path='ConformerS.h5'):
        # fetch and load the config of the model
        config = Config('tamil_tech/configs/conformer_new_config.yml', learning=True)

        # load speech and text featurizers
        speech_featurizer = TFSpeechFeaturizer(config.speech_config)
        text_featurizer = CharFeaturizer(config.decoder_config)

        # check if model already exists in given path, else download the model in the given path
        if os.path.exists(path):
          pass
        else:
          print("Downloading Model...")
          file_id = config.file_id
          download_file_from_google_drive(file_id, path)
          print("Downloaded Model Successfully...")
        
        # load model using config
        self.model = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
        # set shape of the featurizer and build the model
        self.model._build(speech_featurizer.shape)
        # load weights of the model
        self.model.load_weights(path, by_name=True)
        # display model summary
        self.model.summary(line_length=120)
        # set featurizers for the model
        self.model.add_featurizers(speech_featurizer, text_featurizer)

        print("Loaded Model...!")
Пример #3
0
 def build_am(self, config_path, model_path):
     config = Config(config_path, learning=False)
     conformer = Conformer(**config.model_config, vocabulary_size=1031)
     conformer._build(self.speech_featurizer.shape)
     print('loading am...')
     conformer.load_weights(model_path, by_name=True)
     return conformer
def main():
    parser = argparse.ArgumentParser(prog="Conformer Training")

    parser.add_argument("--config",
                        type=str,
                        default=DEFAULT_YAML,
                        help="The file path of model configuration file")

    parser.add_argument("--max_ckpts",
                        type=int,
                        default=10,
                        help="Max number of checkpoints to keep")

    parser.add_argument("--tbs",
                        type=int,
                        default=None,
                        help="Train batch size per replica")

    parser.add_argument("--ebs",
                        type=int,
                        default=None,
                        help="Evaluation batch size per replica")

    parser.add_argument("--acs",
                        type=int,
                        default=None,
                        help="Train accumulation steps")

    parser.add_argument("--devices",
                        type=int,
                        nargs="*",
                        default=[0],
                        help="Devices' ids to apply distributed training")

    parser.add_argument("--mxp",
                        default=False,
                        action="store_true",
                        help="Enable mixed precision")

    parser.add_argument("--subwords",
                        type=str,
                        default=None,
                        help="Path to file that stores generated subwords")

    parser.add_argument("--subwords_corpus",
                        nargs="*",
                        type=str,
                        default=[],
                        help="Transcript files for generating subwords")

    parser.add_argument(
        "--train-dir",
        '-td',
        nargs='*',
        default=["en_ng_male_train.tsv", "en_ng_female_train.tsv"])
    parser.add_argument("--train-reg-dir",
                        '-trd',
                        nargs='*',
                        default=[
                            "libritts_train-clean-100.tsv",
                            "libritts_train-clean-360.tsv",
                            "libritts_train-other-500.tsv"
                        ])
    parser.add_argument(
        "--dev-dir",
        '-dd',
        nargs='*',
        default=["en_ng_male_eval.tsv", "en_ng_female_eval.tsv"])
    parser.add_argument("--dev-reg-dir",
                        '-drd',
                        nargs='*',
                        default=["libritts_test-other.tsv"])

    args = parser.parse_args()

    tf.config.optimizer.set_experimental_options(
        {"auto_mixed_precision": args.mxp})

    strategy = setup_strategy(args.devices)

    config = Config(args.config, learning=True)
    config.train_dir = args.train_dir
    config.dev_dir = args.dev_dir
    config.train_reg_dir = args.train_reg_dir
    config.dev_reg_dir = args.dev_reg_dir
    with open(config.speech_config) as f:
        speech_config = yaml.load(f, Loader=yaml.Loader)
    speech_featurizer = TFSpeechFeaturizer(speech_config)

    if args.subwords and os.path.exists(args.subwords):
        print("Loading subwords ...")
        text_featurizer = SubwordFeaturizer.load_from_file(
            config.decoder_config, args.subwords)
    else:
        print("Generating subwords ...")
        text_featurizer = SubwordFeaturizer.build_from_corpus(
            config.decoder_config, corpus_files=args.subwords_corpus)
        text_featurizer.save_to_file(args.subwords)

    train_dataset = Dataset(data_paths=config.train_dir,
                            speech_featurizer=speech_featurizer,
                            text_featurizer=text_featurizer,
                            augmentations=config.learning_config.augmentations,
                            stage="train",
                            cache=False,
                            shuffle=False)
    train_reg_dataset = DatasetInf(
        data_paths=config.train_reg_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config.learning_config.augmentations,
        stage="train",
        cache=False,
        shuffle=False)
    eval_dataset = Dataset(data_paths=config.dev_dir,
                           speech_featurizer=speech_featurizer,
                           text_featurizer=text_featurizer,
                           stage="eval",
                           cache=False,
                           shuffle=False)
    eval_reg_dataset = DatasetInf(
        data_paths=config.dev_reg_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config.learning_config.augmentations,
        stage="eval",
        cache=False,
        shuffle=False)

    conformer_trainer = MultiReaderTransducerTrainer(
        config=config.learning_config.running_config,
        text_featurizer=text_featurizer,
        strategy=strategy)

    with conformer_trainer.strategy.scope():
        # build model
        conformer = Conformer(**config.model_config,
                              vocabulary_size=text_featurizer.num_classes)
        conformer._build(speech_featurizer.shape)
        conformer.summary(line_length=120)

        optimizer = tf.keras.optimizers.Adam(
            TransformerSchedule(d_model=conformer.dmodel,
                                warmup_steps=config.learning_config.
                                optimizer_config["warmup_steps"],
                                max_lr=(0.05 / math.sqrt(conformer.dmodel))),
            beta_1=config.learning_config.optimizer_config["beta1"],
            beta_2=config.learning_config.optimizer_config["beta2"],
            epsilon=config.learning_config.optimizer_config["epsilon"])

    conformer_trainer.compile(model=conformer,
                              optimizer=optimizer,
                              max_to_keep=args.max_ckpts)
    conformer_trainer.fit(
        train_dataset,
        train_reg_dataset,
        # alpha for regularising dataset; alpha = 1 for training dataset
        1.,
        eval_dataset,
        eval_reg_dataset,
        train_bs=args.tbs,
        eval_bs=args.ebs,
        train_acs=args.acs)
Пример #5
0
    "feature_type": "log_mel_spectrogram",
    "preemphasis": 0.97,
    "normalize_signal": True,
    "normalize_feature": True,
    "normalize_per_feature": False
})

# i = tf.keras.Input(shape=[None, 80, 1])
# o = Conv2dSubsampling(144)(i)

# encoder = tf.keras.Model(inputs=i, outputs=o)
# model = Transducer(encoder=encoder, vocabulary_size=text_featurizer.num_classes)

model = Conformer(
    subsampling={"type": "conv2d", "filters": 144, "kernel_size": 3,
                 "strides": 2},
    num_blocks=1,
    vocabulary_size=text_featurizer.num_classes)

model._build(speech_featurizer.shape)
model.summary(line_length=150)

model.save_weights("/tmp/transducer.h5")

model.add_featurizers(
    speech_featurizer=speech_featurizer,
    text_featurizer=text_featurizer
)

# features = tf.zeros(shape=[5, 50, 80, 1], dtype=tf.float32)
# pred = model.recognize(features)
Пример #6
0
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config, learning=False)
with open(config.speech_config) as f:
  speech_config = yaml.load(f, Loader=yaml.Loader)

speech_featurizer = TFSpeechFeaturizer(speech_config)
if args.subwords and os.path.exists(args.subwords):
  print("Loading subwords ...")
  text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
else:
  text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

import numpy as np
np.random.seed(0)
tf.random.set_seed(0)
if args.filename.endswith('.wav'):
  signal = read_raw_audio(args.filename)
  # features = speech_featurizer.tf_extract(signal)
  features = speech_featurizer.extract(signal)
  features = tf.constant(features)
else:
  features = np.load(args.filename).reshape([-1, 80, 1])
Пример #7
0
args = parser.parse_args()

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config, learning=False)
with open(config.speech_config) as f:
    speech_config = yaml.load(f, Loader=yaml.Loader)

with tf.device('/cpu:0'):
    speech_featurizer = TFSpeechFeaturizer(speech_config)
    # build model
    conformer = Conformer(**config.model_config, vocabulary_size=1031)
    conformer._build(speech_featurizer.shape)
    conformer.load_weights(args.saved, by_name=True)
    encoder = conformer.encoder
# encoder.summary(line_length=120)


@tf.function(input_signature=[
    tf.TensorSpec(shape=[None], dtype=tf.float32, name="signal")
])
def extract_from_audio(signal):
    with tf.device('/cpu:0'):
        features = speech_featurizer.tf_extract(signal)
    return extract_from_mel(features)

Пример #8
0
args = parser.parse_args()

assert args.saved and args.output

config = Config(args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config,
                                                       args.subwords)
else:
    raise ValueError("subwords must be set")

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved)
conformer.summary(line_length=150)
conformer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = conformer.make_tflite_function(
    greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
#converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
Пример #9
0
                    help="Whether to only use cpu")

args = parser.parse_args()

setup_devices([args.device], cpu=args.cpu)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config, learning=False)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)
predicted = tf.constant(args.blank, dtype=tf.int32)
states = tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)

hyp, _, _ = conformer.recognize_tflite(signal, predicted, states)

print("".join([chr(u) for u in hyp]))
Пример #10
0
class ConformerTamilASR(object):
    """
    Conformer S based ASR model
    """
    def __init__(self, path='ConformerS.h5'):
        # fetch and load the config of the model
        config = Config('tamil_tech/configs/conformer_new_config.yml', learning=True)

        # load speech and text featurizers
        speech_featurizer = TFSpeechFeaturizer(config.speech_config)
        text_featurizer = CharFeaturizer(config.decoder_config)

        # check if model already exists in given path, else download the model in the given path
        if os.path.exists(path):
          pass
        else:
          print("Downloading Model...")
          file_id = config.file_id
          download_file_from_google_drive(file_id, path)
          print("Downloaded Model Successfully...")
        
        # load model using config
        self.model = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
        # set shape of the featurizer and build the model
        self.model._build(speech_featurizer.shape)
        # load weights of the model
        self.model.load_weights(path, by_name=True)
        # display model summary
        self.model.summary(line_length=120)
        # set featurizers for the model
        self.model.add_featurizers(speech_featurizer, text_featurizer)

        print("Loaded Model...!")
    
    def read_raw_audio(self, audio, sample_rate=16000):
        # if audio path is given, load audio using librosa
        if isinstance(audio, str):
            wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate)
        
        # if audio file is in bytes, use soundfile to read audio
        elif isinstance(audio, bytes):
            wave, sr = sf.read(io.BytesIO(audio))
            
            # if audio is stereo, convert it to mono
            try:
                if wave.shape[1] >= 2:
                  wave = np.transpose(wave)[0][:]
            except:
              pass
            
            # get loaded audio as numpy array
            wave = np.asfortranarray(wave)

            # resampel to 16000 kHz
            if sr != sample_rate:
                wave = librosa.resample(wave, sr, sample_rate)
        
        # if numpy array, return audio
        elif isinstance(audio, np.ndarray):
            return audio
        
        else:
            raise ValueError("input audio must be either a path or bytes")
        return wave

    def bytes_to_string(self, array: np.ndarray, encoding: str = "utf-8"):
        # decode text array with utf-8 encoding
        return [transcript.decode(encoding) for transcript in array]

    def infer(self, path, greedy=True, return_text=False):
        # read the audio 
        signal = self.read_raw_audio(path)
        # expand dims to process for a single prediction
        signal = tf.expand_dims(self.model.speech_featurizer.tf_extract(signal), axis=0)
        # predict greedy
        if greedy:
          pred = self.model.recognize(features=signal)
        else:
          # preidct using beam search and language model
          pred = self.model.recognize_beam(features=signal, lm=True)

        if return_text:
          # return predicted transcription
          return self.bytes_to_string(pred.numpy())[0]
        
        # return predicted transcription
        print(self.bytes_to_string(pred.numpy())[0], end=' ')
                    help="Output to save whole model")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

setup_devices([args.device], cpu=args.cpu)

from tensorflow_asr.configs.user_config import UserConfig
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])

tf.random.set_seed(0)
assert args.saved

# build model
conformer = Conformer(vocabulary_size=text_featurizer.num_classes,
                      **config["model_config"])
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=150)
conformer.save(args.output)

print(f"Saved whole model to {args.output}")
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))

conformer_trainer = TransducerTrainer(
    config=config.learning_config.running_config,
    text_featurizer=text_featurizer,
    strategy=strategy)

with conformer_trainer.strategy.scope():
    # build model
    if args.pretrained_model is None:
        print("Training from scratch...")
        conformer = Conformer(**config.model_config,
                              vocabulary_size=text_featurizer.num_classes)
        conformer._build(speech_featurizer.shape)
        conformer.summary(line_length=120)
    else:
        print("Training from provided checkpoint...")
        conformer = Conformer(**config.model_config,
                              vocabulary_size=text_featurizer.num_classes)
        conformer._build(speech_featurizer.shape)
        conformer.load_weights(args.pretrained_model)
        conformer.summary(line_length=120)
        conformer.add_featurizers(speech_featurizer,
                                  text_featurizer)  # TODO: Do we need this?

    optimizer = tf.keras.optimizers.Adam(
        TransformerSchedule(d_model=conformer.dmodel,
                            warmup_steps=config.learning_config.
Пример #13
0
def test_conformer():
    config = Config(DEFAULT_YAML)

    text_featurizer = CharFeaturizer(config.decoder_config)

    speech_featurizer = TFSpeechFeaturizer(config.speech_config)

    model = Conformer(vocabulary_size=text_featurizer.num_classes,
                      **config.model_config)

    model._build(speech_featurizer.shape)
    model.summary(line_length=150)

    model.add_featurizers(speech_featurizer=speech_featurizer,
                          text_featurizer=text_featurizer)

    concrete_func = model.make_tflite_function(
        timestamp=False).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    tflite = converter.convert()

    print("Converted successfully with no timestamp")

    concrete_func = model.make_tflite_function(
        timestamp=True).get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [concrete_func])
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.experimental_new_converter = True
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter.convert()

    print("Converted successfully with timestamp")

    tflitemodel = tf.lite.Interpreter(model_content=tflite)
    signal = tf.random.normal([4000])

    input_details = tflitemodel.get_input_details()
    output_details = tflitemodel.get_output_details()
    tflitemodel.resize_tensor_input(input_details[0]["index"], [4000])
    tflitemodel.allocate_tensors()
    tflitemodel.set_tensor(input_details[0]["index"], signal)
    tflitemodel.set_tensor(input_details[1]["index"],
                           tf.constant(text_featurizer.blank, dtype=tf.int32))
    tflitemodel.set_tensor(
        input_details[2]["index"],
        tf.zeros([
            config.model_config["prediction_num_rnns"], 2, 1,
            config.model_config["prediction_rnn_units"]
        ],
                 dtype=tf.float32))
    tflitemodel.invoke()
    hyp = tflitemodel.get_tensor(output_details[0]["index"])

    print(hyp)
Пример #14
0
from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config, learning=False)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
if args.subwords and os.path.exists(args.subwords):
    print("Loading subwords ...")
    text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
else:
    text_featurizer = CharFeaturizer(config.decoder_config)
text_featurizer.decoder_config.beam_width = args.beam_width

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

signal = read_raw_audio(args.filename)

if (args.beam_width):
    transcript = conformer.recognize_beam(signal[None, ...])
else:
    transcript = conformer.recognize(signal[None, ...])

tf.print("Transcript:", transcript[0])
Пример #15
0
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.train_dataset_config))
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        **vars(config.learning_config.eval_dataset_config))

conformer_trainer = TransducerTrainer(
    config=config.learning_config.running_config,
    text_featurizer=text_featurizer,
    strategy=strategy)

with conformer_trainer.strategy.scope():
    # build model
    conformer = Conformer(**config.model_config,
                          vocabulary_size=text_featurizer.num_classes)
    conformer._build(speech_featurizer.shape)
    conformer.summary(line_length=120)

    optimizer_config = config.learning_config.optimizer_config
    optimizer = tf.keras.optimizers.Adam(TransformerSchedule(
        d_model=conformer.dmodel,
        warmup_steps=optimizer_config["warmup_steps"],
        max_lr=(0.05 / math.sqrt(conformer.dmodel))),
                                         beta_1=optimizer_config["beta1"],
                                         beta_2=optimizer_config["beta2"],
                                         epsilon=optimizer_config["epsilon"])

conformer_trainer.compile(model=conformer,
                          optimizer=optimizer,
                          max_to_keep=args.max_ckpts)
assert args.saved

if args.tfrecords:
    test_dataset = ASRTFRecordDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)
else:
    test_dataset = ASRSliceDataset(
        data_paths=config.learning_config.dataset_config.test_paths,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="test",
        shuffle=False)

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved, by_name=True)
conformer.summary(line_length=120)
conformer.add_featurizers(speech_featurizer, text_featurizer)

conformer_tester = BaseTester(config=config.learning_config.running_config,
                              output_name=args.output_name)
conformer_tester.compile(conformer)
conformer_tester.run(test_dataset)
Пример #17
0
                    help="Output to save whole model")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options(
    {"auto_mixed_precision": args.mxp})

setup_devices([args.device], cpu=args.cpu)

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.models.conformer import Conformer

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
text_featurizer = CharFeaturizer(config.decoder_config)

tf.random.set_seed(0)
assert args.saved

# build model
conformer = Conformer(**config.model_config,
                      vocabulary_size=text_featurizer.num_classes)
conformer._build(speech_featurizer.shape)
conformer.load_weights(args.saved)
conformer.summary(line_length=150)
conformer.save(args.output)

print(f"Saved whole model to {args.output}")