Exemplo n.º 1
0
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config.learning_config.augmentations,
        stage="train",
        cache=args.cache,
        shuffle=True)
    eval_dataset = ASRSliceDataset(
        data_paths=config.learning_config.dataset_config.eval_paths,
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="eval",
        cache=args.cache,
        shuffle=True)

conformer_trainer = TransducerTrainerGA(
    config=config.learning_config.running_config,
    text_featurizer=text_featurizer,
    strategy=strategy)

with conformer_trainer.strategy.scope():
    # build model
    conformer = Conformer(**config.model_config,
                          vocabulary_size=text_featurizer.num_classes)
    conformer._build(speech_featurizer.shape)
    conformer.summary(line_length=120)

    optimizer_config = config.learning_config.optimizer_config
    optimizer = tf.keras.optimizers.Adam(TransformerSchedule(
        d_model=config.model_config["dmodel"],
        warmup_steps=optimizer_config["warmup_steps"],
        max_lr=(0.05 / math.sqrt(config.model_config["dmodel"]))),
                                         beta_1=optimizer_config["beta1"],
Exemplo n.º 2
0
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        augmentations=config["learning_config"]["augmentations"],
        stage="train",
        cache=args.cache,
        shuffle=True)
    eval_dataset = ASRSliceDataset(
        data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        stage="eval",
        cache=args.cache,
        shuffle=True)

streaming_transducer_trainer = TransducerTrainerGA(
    config=config["learning_config"]["running_config"],
    text_featurizer=text_featurizer,
    strategy=strategy)

with streaming_transducer_trainer.strategy.scope():
    # build model
    streaming_transducer = StreamingTransducer(
        **config["model_config"], vocabulary_size=text_featurizer.num_classes)
    streaming_transducer._build(speech_featurizer.shape)
    streaming_transducer.summary(line_length=150)

    optimizer = tf.keras.optimizers.Adam()

streaming_transducer_trainer.compile(model=streaming_transducer,
                                     optimizer=optimizer,
                                     max_to_keep=args.max_ckpts)