예제 #1
0
def test_feed_batch(preprocessed_corpus):
    logging.debug("test_feed_batch()")
    model_path_prefix = "testing/exp/41/0/model/model_best.ckpt"
    logging.debug("model_path_prefix: {}".format(model_path_prefix))
    bkw_reader = CorpusReader(preprocessed_corpus, batch_size=1)
    #print(preprocessed_corpus.get_test_fns())
    batch = bkw_reader.test_batch()
    batch = next(bkw_reader.train_batch_gen())
    import tensorflow as tf
    with tf.device("/cpu:0"):
        dense_decoded = model.decode(model_path_prefix, batch)
        print(dense_decoded)
    hyps = bkw_reader.human_readable(dense_decoded)
    print(hyps)
예제 #2
0
def create_RNN_CTC_model(model_db: TranscriptionModel,
                         corpus_storage_path: Path,
                         models_storage_path: Path) -> rnn_ctc.Model:
    """Create a persephone RNN CTC model

    :model: The database entry contaning the information about the model attempting
            to be created here.
    :corpus_storage_path: The path the corpuses are stored at.
    :models_storage_path: The path the models are stored at.
    """
    model_path = models_storage_path / model_db.filesystem_path
    exp_dir = experiment.prep_exp_dir(directory=str(model_path))
    corpus_db_entry = model_db.corpus
    pickled_corpus_path = corpus_storage_path / corpus_db_entry.filesystem_path / "corpus.p"
    with pickled_corpus_path.open('rb') as pickle_file:
        corpus = pickle.load(pickle_file)

    corpus_reader = CorpusReader(corpus,
                                 batch_size=decide_batch_size(
                                     len(corpus.train_prefixes)))
    return rnn_ctc.Model(
        exp_dir,
        corpus_reader,
        num_layers=model_db.num_layers,
        hidden_size=model_db.hidden_size,
        beam_width=model_db.beam_width,
        decoding_merge_repeated=model_db.decoding_merge_repeated)
예제 #3
0
    def test_multispeaker(self, preprocessed_corpus):
        """ Trains a multispeaker BKW system using default settings. """

        exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH)
        # TODO bkw.Corpus and elan.Corpus should take an org_dir argument.
        corp = preprocessed_corpus
        cr = CorpusReader(corp)
        model = rnn_ctc.Model(exp_dir, cr, num_layers=2, hidden_size=250)
        model.train(min_epochs=30)
예제 #4
0
 def train_bkw(num_layers: int) -> None:
     exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH)
     corp = bkw.create_corpus(tgt_dir=Path(config.TEST_DATA_PATH) / "bkw")
     cr = CorpusReader(corp)
     model = rnn_ctc.Model(exp_dir,
                           cr,
                           num_layers=num_layers,
                           hidden_size=250)
     model.train(min_epochs=40)
예제 #5
0
def test_model_creation(create_test_corpus):
    """Test that we can create a model"""
    from persephone.corpus_reader import CorpusReader
    from persephone.rnn_ctc import Model
    corpus = create_test_corpus()
    corpus_r = CorpusReader(corpus, num_train=1, batch_size=1)
    assert corpus_r

    model = Model(
        corpus.tgt_dir,
        corpus_r,
    )
    assert model
예제 #6
0
def test_model_train_and_decode(tmpdir, create_sine, make_wav, create_test_corpus):
    """Test that we can create a model, train it then decode something with it"""
    from persephone.corpus_reader import CorpusReader
    from persephone.rnn_ctc import Model
    from pathlib import Path
    corpus = create_test_corpus()

    # If it turns out that `tgt_dir` is not in the public interface of the Corpus
    # this test should change and get the base directory from the fixture that created it.
    base_directory = corpus.tgt_dir
    print("base_directory", base_directory)

    corpus_r = CorpusReader(
        corpus,
        batch_size=1
    )
    assert corpus_r

    test_model = Model(
        base_directory,
        corpus_r,
        num_layers=3,
        hidden_size=100
    )
    assert test_model

    test_model.train(
        early_stopping_steps=1,
        min_epochs=1,
        max_epochs=10
    )

    from persephone.model import decode

    wav_dir = tmpdir.join("wav")
    wav_to_decode_path = str(wav_dir.join("to_decode.wav"))
    sine_to_decode = create_sine(note="C")

    make_wav(sine_to_decode, wav_to_decode_path)

    model_checkpoint_path = base_directory / "model" / "model_best.ckpt"
    decode(
        model_checkpoint_path,
        [Path(wav_to_decode_path)],
        label_set = {"A", "B", "C"},
        feature_type = "fbank",
        batch_x_name = test_model.batch_x.name,
        batch_x_lens_name = test_model.batch_x_lens.name,
        output_name = test_model.dense_decoded.name
    )
예제 #7
0
def test_decode(preprocessed_corpus):
    model_path_prefix = "testing/exp/41/0/model/model_best.ckpt"
    bkw_reader = CorpusReader(preprocessed_corpus, batch_size=1)
    labels = bkw_reader.corpus.labels
    logging.debug("labels: {}".format(labels))
    test_prefixes = bkw_reader.corpus.test_prefixes
    logging.debug("test_fns: {}".format(test_prefixes[:10]))
    test_feat_paths = ["testing/data/bkw/feat/{}.fbank.npy".format(prefix)
                      for prefix in test_prefixes]
    logging.debug("test_fns: {}".format(test_feat_paths[:10]))
    transcripts = model.decode(model_path_prefix,
                               test_feat_paths,
                               labels)
    logging.debug("transcripts: {}".format(pprint.pformat(
        [" ".join(transcript) for transcript in transcripts])))
예제 #8
0
def test_model_train_callback(tmpdir, create_sine, make_wav, create_test_corpus):
    """Test that we can create a model, train it then get our callback called on each epoch of training"""
    from persephone.corpus_reader import CorpusReader
    from persephone.rnn_ctc import Model
    from pathlib import Path
    corpus = create_test_corpus()

    # If it turns out that `tgt_dir` is not in the public interface of the Corpus
    # this test should change and get the base directory from the fixture that created it.
    base_directory = corpus.tgt_dir
    print("base_directory", base_directory)

    corpus_r = CorpusReader(
        corpus,
        batch_size=1
    )
    assert corpus_r

    test_model = Model(
        base_directory,
        corpus_r,
        num_layers=3,
        hidden_size=50
    )
    assert test_model

    from unittest.mock import Mock

    mock_callback = Mock(return_value=None)

    test_model.train(
        early_stopping_steps=1,
        min_epochs=1,
        max_epochs=10,
        epoch_callback=mock_callback
    )

    assert mock_callback.call_count == 10
예제 #9
0
 def test_corpus_duration(self, preprocessed_corpus):
     corp = preprocessed_corpus
     cr = CorpusReader(corp, batch_size=1)
     cr.calc_time()
     print("Number of corpus utterances: {}".format(len(corp.get_train_fns()[0])))
예제 #10
0
def test_corpus_reader(create_test_corpus):
    """Test that we can create a CorpusReader object"""
    from persephone.corpus_reader import CorpusReader
    corpus = create_test_corpus()
    corpus_r = CorpusReader(corpus, num_train=2, batch_size=1)
    assert corpus_r