def run(args):
        assert args.mode in modes, f"Mode must in {modes}"

        config = UserConfig(DEFAULT_YAML, args.config, learning=True)
        speech_featurizer = SpeechFeaturizer(config["speech_config"])
        text_featurizer = TextFeaturizer(config["decoder_config"])

        if args.mode == "train":
            tf.random.set_seed(2020)

            if args.mixed_precision:
                policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
                tf.keras.mixed_precision.experimental.set_policy(policy)
                print("Enabled mixed precision training")

            ctc_trainer = CTCTrainer(speech_featurizer, text_featurizer,
                                     config["learning_config"]["running_config"],
                                     args.mixed_precision)

            if args.tfrecords:
                train_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["train_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "train",
                    augmentations=config["learning_config"]["augmentations"], shuffle=True,
                )
                eval_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["eval_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "eval", shuffle=False
                )
            else:
                train_dataset = ASRSliceDataset(
                    stage="train", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["train_paths"],
                    augmentations=config["learning_config"]["augmentations"], shuffle=True,
                )
                eval_dataset = ASRSliceDataset(
                    stage="eval", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
                    shuffle=False
                )

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            with ctc_trainer.strategy.scope():
                satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                                  arch_config=config["model_config"],
                                                  num_classes=text_featurizer.num_classes)
                satt_ds2_model._build([1, 50, f, c])
                optimizer = create_optimizer(
                    name=config["learning_config"]["optimizer_config"]["name"],
                    d_model=config["model_config"]["att"]["head_size"],
                    **config["learning_config"]["optimizer_config"]["config"]
                )
            # Compile
            ctc_trainer.compile(satt_ds2_model, optimizer, max_to_keep=args.max_ckpts)

            ctc_trainer.fit(train_dataset, eval_dataset, args.eval_train_ratio)

            if args.export:
                if args.from_weights:
                    ctc_trainer.model.save_weights(args.export)
                else:
                    ctc_trainer.model.save(args.export)

        elif args.mode == "test":
            tf.random.set_seed(0)
            assert args.export

            text_featurizer.add_scorer(
                Scorer(**text_featurizer.decoder_config["lm_config"],
                       vocabulary=text_featurizer.vocab_array))

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                              arch_config=config["model_config"],
                                              num_classes=text_featurizer.num_classes)
            satt_ds2_model._build([1, 50, f, c])
            satt_ds2_model.summary(line_length=100)
            optimizer = create_optimizer(
                name=config["learning_config"]["optimizer_config"]["name"],
                d_model=config["model_config"]["att"]["head_size"],
                **config["learning_config"]["optimizer_config"]["config"]
            )

            batch_size = config["learning_config"]["running_config"]["batch_size"]
            if args.tfrecords:
                test_dataset = ASRTFRecordDataset(
                    config["learning_config"]["dataset_config"]["test_paths"],
                    config["learning_config"]["dataset_config"]["tfrecords_dir"],
                    speech_featurizer, text_featurizer, "test",
                    augmentations=config["learning_config"]["augmentations"], shuffle=False
                ).create(batch_size * args.eval_train_ratio)
            else:
                test_dataset = ASRSliceDataset(
                    stage="test", speech_featurizer=speech_featurizer,
                    text_featurizer=text_featurizer,
                    data_paths=config["learning_config"]["dataset_config"]["test_paths"],
                    augmentations=config["learning_config"]["augmentations"], shuffle=False
                ).create(batch_size * args.eval_train_ratio)

            ctc_tester = BaseTester(
                config=config["learning_config"]["running_config"],
                saved_path=args.export, from_weights=args.from_weights
            )
            ctc_tester.compile(satt_ds2_model, speech_featurizer, text_featurizer)
            ctc_tester.run(test_dataset)

        else:
            assert args.export

            # Build DS2 model
            f, c = speech_featurizer.compute_feature_dim()
            satt_ds2_model = SelfAttentionDS2(input_shape=[None, f, c],
                                              arch_config=config["model_config"],
                                              num_classes=text_featurizer.num_classes)
            satt_ds2_model._build([1, 50, f, c])
            optimizer = create_optimizer(
                name=config["learning_config"]["optimizer_config"]["name"],
                d_model=config["model_config"]["att"]["head_size"],
                **config["learning_config"]["optimizer_config"]["config"]
            )

            def save_func(**kwargs):
                if args.from_weights:
                    kwargs["model"].save_weights(args.export)
                else:
                    kwargs["model"].save(args.export)

            save_from_checkpoint(func=save_func,
                                 outdir=config["learning_config"]["running_config"]["outdir"],
                                 model=satt_ds2_model, optimizer=optimizer)
Exemple #2
0
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        data_paths=config["learning_config"]["dataset_config"]["train_paths"],
        augmentations=config["learning_config"]["augmentations"],
        stage="train",
        cache=args.cache,
        shuffle=True)
    eval_dataset = ASRSliceDataset(
        speech_featurizer=speech_featurizer,
        text_featurizer=text_featurizer,
        data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
        stage="eval",
        cache=args.cache,
        shuffle=True)

ctc_trainer = CTCTrainer(text_featurizer,
                         config["learning_config"]["running_config"])
# Build DS2 model
with ctc_trainer.strategy.scope():
    ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
                            arch_config=config["model_config"],
                            num_classes=text_featurizer.num_classes,
                            name="deepspeech2")
    ds2_model._build(speech_featurizer.shape)
    ds2_model.summary(line_length=150)
# Compile
ctc_trainer.compile(ds2_model,
                    config["learning_config"]["optimizer_config"],
                    max_to_keep=args.max_ckpts)

ctc_trainer.fit(train_dataset,
                eval_dataset,
Exemple #3
0
def main():
    tf.keras.backend.clear_session()

    parser = argparse.ArgumentParser(prog="Deep Speech 2 Training")

    parser.add_argument("--config",
                        "-c",
                        type=str,
                        default=DEFAULT_YAML,
                        help="The file path of model configuration file")

    parser.add_argument("--export",
                        "-e",
                        type=str,
                        default=None,
                        help="Path to the model file to be exported")

    parser.add_argument("--mixed_precision",
                        type=bool,
                        default=False,
                        help="Whether to use mixed precision training")

    parser.add_argument("--save_weights",
                        type=bool,
                        default=False,
                        help="Whether to save or load only weights")

    parser.add_argument("--max_ckpts",
                        type=int,
                        default=10,
                        help="Max number of checkpoints to keep")

    parser.add_argument(
        "--eval_train_ratio",
        type=int,
        default=1,
        help="ratio between train batch size and eval batch size")

    parser.add_argument("--tfrecords",
                        type=bool,
                        default=False,
                        help="Whether to use tfrecords dataset")

    args = parser.parse_args()

    config = UserConfig(DEFAULT_YAML, args.config, learning=True)
    speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
    text_featurizer = TextFeaturizer(config["decoder_config"])

    tf.random.set_seed(2020)

    if args.mixed_precision:
        policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16")
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print("Enabled mixed precision training")

    if args.tfrecords:
        train_dataset = ASRTFRecordDataset(
            config["learning_config"]["dataset_config"]["train_paths"],
            config["learning_config"]["dataset_config"]["tfrecords_dir"],
            speech_featurizer,
            text_featurizer,
            "train",
            augmentations=config["learning_config"]["augmentations"],
            shuffle=True,
        )
        eval_dataset = ASRTFRecordDataset(
            config["learning_config"]["dataset_config"]["eval_paths"],
            config["learning_config"]["dataset_config"]["tfrecords_dir"],
            speech_featurizer,
            text_featurizer,
            "eval",
            shuffle=False)
    else:
        train_dataset = ASRSliceDataset(
            stage="train",
            speech_featurizer=speech_featurizer,
            text_featurizer=text_featurizer,
            data_paths=config["learning_config"]["dataset_config"]
            ["eval_paths"],
            augmentations=config["learning_config"]["augmentations"],
            shuffle=True)
        eval_dataset = ASRSliceDataset(stage="train",
                                       speech_featurizer=speech_featurizer,
                                       text_featurizer=text_featurizer,
                                       data_paths=config["learning_config"]
                                       ["dataset_config"]["eval_paths"],
                                       shuffle=True)

    ctc_trainer = CTCTrainer(speech_featurizer, text_featurizer,
                             config["learning_config"]["running_config"],
                             args.mixed_precision)
    # Build DS2 model
    f, c = speech_featurizer.compute_feature_dim()
    with ctc_trainer.strategy.scope():
        ds2_model = DeepSpeech2(input_shape=[None, f, c],
                                arch_config=config["model_config"],
                                num_classes=text_featurizer.num_classes,
                                name="deepspeech2")
        ds2_model._build([1, 50, f, c])
    # Compile
    ctc_trainer.compile(ds2_model,
                        config["learning_config"]["optimizer_config"],
                        max_to_keep=args.max_ckpts)

    ctc_trainer.fit(train_dataset, eval_dataset, args.eval_train_ratio)

    if args.export:
        if args.save_weights:
            ctc_trainer.model.save_weights(args.export)
        else:
            ctc_trainer.model.save(args.export)