示例#1
0
        prepare_librispeech,
        kwargs={
            "data_folder": hparams["data_folder"],
            "tr_splits": hparams["train_splits"],
            "dev_splits": hparams["dev_splits"],
            "te_splits": hparams["test_splits"],
            "save_folder": hparams["data_folder"],
            "merge_lst": hparams["train_splits"],
            "merge_name": hparams["train_csv"],
        },
    )

    # Create experiment directory
    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    # here we create the datasets objects as well as tokenization and encoding
    train_data, valid_data, test_datasets, tokenizer = dataio_prepare(hparams)

    # Trainer initialization
    asr_brain = ASR(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )
def main(config):
    ### create Experiment Directory ###
    # combine all hyperparameters into a single file
    hparams = load_hparams(config.exp_config)
    hparams["model_config"] = load_hparams(config.model_config)

    # create exp dir
    sb.create_experiment_directory(experiment_directory=config.output_folder,
                                   hyperparams_to_save=config.exp_config,
                                   overrides=None)

    ### Datasets and Tokenizer ###
    train_data, valid_data, test_data, tokenizer = dataio_prepare(hparams)

    # Trainer initialization
    run_opts = {
        "device": "cuda:0"
    }  # certain args from yaml file will autoamtically get picked as run_opts
    # see https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/transformer/train.py#L372
    # see https://github.com/speechbrain/speechbrain/blob/d6adc40e742107c34ae38dc63484171938b4d237/speechbrain/core.py#L124
    #print(type(hparams["model_config"]["modules"]))
    #print(type(hparams))
    #exit()
    asr_brain = ASR(
        modules=hparams["model_config"]["modules"],
        opt_class=hparams["model_config"]["Adam"],
        hparams=hparams["model_config"],
        run_opts=run_opts,
        checkpointer=hparams["model_config"]["checkpointer"],
    )

    # adding objects to trainer:
    asr_brain.tokenizer = tokenizer  # hparams["tokenizer"]

    # Training
    asr_brain.fit(
        asr_brain.hparams.epoch_counter,
        train_data,
        valid_data,
        train_loader_kwargs=hparams["model_config"]["train_dataloader_opts"],
        valid_loader_kwargs=hparams["model_config"]["valid_dataloader_opts"],
    )

    raise NotImplementedError

    ### get Train Data ###
    # list of {'audio__file': str, 'transcript_all_file': str, 'transcript_uid': str, 'filter_criteria': str}
    # meaning that <audio__file>'s transcript is the one in the <transcript_all_file> with id <transcript_uid>
    train_corpus = get_utterance_manifest_from_data_config(
        config.train_data_config)
    for x in train_corpus:
        assert os.path.exists(
            x["transcript_all_file"]
        ), "data transcript file {} does not exist! Exiting!".format(
            x["transcript_all_file"])

    ### create json file for SpeechBrain-->SentencePiece ###
    selected_transcripts_json, annotation_read = create_transcripts_json(
        train_corpus)

    ### train custom SentencePiece Tokenizer ###
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".json") as f:
        f.write(json.dumps(selected_transcripts_json))
        f.seek(0)

        SentencePiece(model_dir=config.output_folder,
                      vocab_size=config.vocab_size,
                      annotation_train=f.name,
                      annotation_read=annotation_read,
                      annotation_format="json",
                      model_type=config.model_type,
                      character_coverage=config.character_coverage,
                      annotation_list_to_check=config.annotation_list_to_check)