def create_RNN_CTC_model(model_db: TranscriptionModel, corpus_storage_path: Path, models_storage_path: Path) -> rnn_ctc.Model: """Create a persephone RNN CTC model :model: The database entry contaning the information about the model attempting to be created here. :corpus_storage_path: The path the corpuses are stored at. :models_storage_path: The path the models are stored at. """ model_path = models_storage_path / model_db.filesystem_path exp_dir = experiment.prep_exp_dir(directory=str(model_path)) corpus_db_entry = model_db.corpus pickled_corpus_path = corpus_storage_path / corpus_db_entry.filesystem_path / "corpus.p" with pickled_corpus_path.open('rb') as pickle_file: corpus = pickle.load(pickle_file) corpus_reader = CorpusReader(corpus, batch_size=decide_batch_size( len(corpus.train_prefixes))) return rnn_ctc.Model( exp_dir, corpus_reader, num_layers=model_db.num_layers, hidden_size=model_db.hidden_size, beam_width=model_db.beam_width, decoding_merge_repeated=model_db.decoding_merge_repeated)
def test_full_na(): """ A full Na integration test. """ # Pulls Na wavs from cloudstor. NA_WAVS_LINK = "https://cloudstor.aarnet.edu.au/plus/s/LnNyNa20GQ8qsPC/download" download_example_data(NA_WAVS_LINK) na_dir = join(DATA_BASE_DIR, "na/") os.rm_dir(na_dir) os.makedirs(na_dir) org_wav_dir = join(na_dir, "org_wav/") os.rename(join(DATA_BASE_DIR, "na_wav/"), org_wav_dir) tgt_wav_dir = join(na_dir, "wav/") NA_REPO_URL = "https://github.com/alexis-michaud/na-data.git" with cd(DATA_BASE_DIR): subprocess.run(["git", "clone", NA_REPO_URL, "na/xml/"], check=True) # Note also that this subdirectory only containts TEXTs, so this integration # test will include only Na narratives, not wordlists. na_xml_dir = join(DATA_BASE_DIR, "na/xml/TEXT/F4") label_dir = join(DATA_BASE_DIR, "na/label") label_type = "phonemes_and_tones" na.prepare_labels(label_type, org_xml_dir=na_xml_dir, label_dir=label_dir) tgt_feat_dir = join(DATA_BASE_DIR, "na/feat") # TODO Make this fbank_and_pitch, but then I need to install kaldi on ray # or run the tests on GPUs on slug or doe. feat_type = "fbank" na.prepare_feats(feat_type, org_wav_dir=org_wav_dir, tgt_wav_dir=tgt_wav_dir, feat_dir=tgt_feat_dir, org_xml_dir=na_xml_dir, label_dir=label_dir) from shutil import copyfile copyfile("persephone/tests/test_sets/valid_prefixes.txt", join(na_dir, "valid_prefixes.txt")) copyfile("persephone/tests/test_sets/test_prefixes.txt", join(na_dir, "test_prefixes.txt")) na.make_data_splits(label_type, train_rec_type="text", tgt_dir=na_dir) # Training with texts exp_dir = experiment.prep_exp_dir(directory=EXP_BASE_DIR) na_corpus = na.Corpus(feat_type, label_type, train_rec_type="text", tgt_dir=na_dir) na_corpus_reader = corpus_reader.CorpusReader(na_corpus) model = rnn_ctc.Model(exp_dir, na_corpus_reader, num_layers=3, hidden_size=400) model.train(min_epochs=30) # Ensure LER < 0.20 ler = get_test_ler(exp_dir) assert ler < 0.2
def test_multispeaker(self, preprocessed_corpus): """ Trains a multispeaker BKW system using default settings. """ exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH) # TODO bkw.Corpus and elan.Corpus should take an org_dir argument. corp = preprocessed_corpus cr = CorpusReader(corp) model = rnn_ctc.Model(exp_dir, cr, num_layers=2, hidden_size=250) model.train(min_epochs=30)
def train_bkw(num_layers: int) -> None: exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH) corp = bkw.create_corpus(tgt_dir=Path(config.TEST_DATA_PATH) / "bkw") cr = CorpusReader(corp) model = rnn_ctc.Model(exp_dir, cr, num_layers=num_layers, hidden_size=250) model.train(min_epochs=40)
def test_reuse_model(preprocess_na): tgt_dir = Path(config.TEST_DATA_PATH) / "na" na_corpus = na.Corpus("fbank_and_pitch", "phonemes_and_tones", tgt_dir=tgt_dir) na_reader = corpus_reader.CorpusReader(na_corpus) logging.info("na_corpus {}".format(na_corpus)) logging.info("na_corpus.get_untranscribed_fns():") logging.info(pprint.pformat(na_corpus.get_untranscribed_fns())) # TODO Currently assumes we're on slug. Need to package up the model and # put it on cloudstor, then create a fixture to download it. exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH) model = rnn_ctc.Model(exp_dir, na_reader, num_layers=3, hidden_size=400) model.transcribe( restore_model_path="/home/oadams/code/mam/exp/252/model/model_best.ckpt" )