Exemplo n.º 1
0
def test_full_na():
    """ A full Na integration test. """

    # Pulls Na wavs from cloudstor.
    NA_WAVS_LINK = "https://cloudstor.aarnet.edu.au/plus/s/LnNyNa20GQ8qsPC/download"
    download_example_data(NA_WAVS_LINK)

    na_dir = join(DATA_BASE_DIR, "na/")
    os.rm_dir(na_dir)
    os.makedirs(na_dir)
    org_wav_dir = join(na_dir, "org_wav/")
    os.rename(join(DATA_BASE_DIR, "na_wav/"), org_wav_dir)
    tgt_wav_dir = join(na_dir, "wav/")

    NA_REPO_URL = "https://github.com/alexis-michaud/na-data.git"
    with cd(DATA_BASE_DIR):
        subprocess.run(["git", "clone", NA_REPO_URL, "na/xml/"], check=True)
    # Note also that this subdirectory only containts TEXTs, so this integration
    # test will include only Na narratives, not wordlists.
    na_xml_dir = join(DATA_BASE_DIR, "na/xml/TEXT/F4")

    label_dir = join(DATA_BASE_DIR, "na/label")
    label_type = "phonemes_and_tones"
    na.prepare_labels(label_type, org_xml_dir=na_xml_dir, label_dir=label_dir)

    tgt_feat_dir = join(DATA_BASE_DIR, "na/feat")
    # TODO Make this fbank_and_pitch, but then I need to install kaldi on ray
    # or run the tests on GPUs on slug or doe.
    feat_type = "fbank"
    na.prepare_feats(feat_type,
                     org_wav_dir=org_wav_dir,
                     tgt_wav_dir=tgt_wav_dir,
                     feat_dir=tgt_feat_dir,
                     org_xml_dir=na_xml_dir,
                     label_dir=label_dir)

    from shutil import copyfile
    copyfile("persephone/tests/test_sets/valid_prefixes.txt",
             join(na_dir, "valid_prefixes.txt"))
    copyfile("persephone/tests/test_sets/test_prefixes.txt",
             join(na_dir, "test_prefixes.txt"))
    na.make_data_splits(label_type, train_rec_type="text", tgt_dir=na_dir)

    # Training with texts
    exp_dir = experiment.prep_exp_dir(directory=EXP_BASE_DIR)
    na_corpus = na.Corpus(feat_type,
                          label_type,
                          train_rec_type="text",
                          tgt_dir=na_dir)
    na_corpus_reader = corpus_reader.CorpusReader(na_corpus)
    model = rnn_ctc.Model(exp_dir,
                          na_corpus_reader,
                          num_layers=3,
                          hidden_size=400)

    model.train(min_epochs=30)

    # Ensure LER < 0.20
    ler = get_test_ler(exp_dir)
    assert ler < 0.2
Exemplo n.º 2
0
def test_load_saver():
    tgt_dir = Path(config.TEST_DATA_PATH) / "na"
    na_corpus = na.Corpus("fbank_and_pitch",
                          "phonemes_and_tones",
                          tgt_dir=tgt_dir)
    na_reader = corpus_reader.CorpusReader(na_corpus)
    model_prefix_path = "/home/oadams/code/mam/exp/252/model/model_best.ckpt"
    saver = tf.train.Saver()
Exemplo n.º 3
0
def test_feed_batch():
    tgt_dir = Path(config.TEST_DATA_PATH) / "na"
    na_corpus = na.Corpus("fbank_and_pitch",
                          "phonemes_and_tones",
                          tgt_dir=tgt_dir)
    na_reader = corpus_reader.CorpusReader(na_corpus)
    model_path = "/home/oadams/code/mam/exp/252/model/model_best.ckpt"
    graph = model.load_graph(model_path)
    batch = next(na_reader.untranscribed_batch_gen())
    print(model.decode(graph, batch))
Exemplo n.º 4
0
def test_reuse_model(preprocess_na):
    tgt_dir = Path(config.TEST_DATA_PATH) / "na"
    na_corpus = na.Corpus("fbank_and_pitch",
                          "phonemes_and_tones",
                          tgt_dir=tgt_dir)
    na_reader = corpus_reader.CorpusReader(na_corpus)
    logging.info("na_corpus {}".format(na_corpus))
    logging.info("na_corpus.get_untranscribed_fns():")
    logging.info(pprint.pformat(na_corpus.get_untranscribed_fns()))
    # TODO Currently assumes we're on slug. Need to package up the model and
    # put it on cloudstor, then create a fixture to download it.
    exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH)
    model = rnn_ctc.Model(exp_dir, na_reader, num_layers=3, hidden_size=400)
    model.transcribe(
        restore_model_path="/home/oadams/code/mam/exp/252/model/model_best.ckpt"
    )
Exemplo n.º 5
0
def test_load_meta():
    tgt_dir = Path(config.TEST_DATA_PATH) / "na"
    na_corpus = na.Corpus("fbank_and_pitch",
                          "phonemes_and_tones",
                          tgt_dir=tgt_dir)
    na_reader = corpus_reader.CorpusReader(na_corpus)

    tf.reset_default_graph()
    model_prefix_path = "/home/oadams/code/mam/exp/252/model/model_best.ckpt"

    #model_prefix_path = "/home/oadams/code/persephone/testing/exp/39/model/model_best.ckpt"

    #loaded_graph = model.load_graph(model_prefix_path)

    metagraph = model.load_metagraph(model_prefix_path)
    #imported_meta = tf.train.import_meta_graph(model_prefix_path + ".meta")
    #print(type(imported_meta))
    #print(dir(imported_meta))
    #print(dir(imported_meta))
    #print(dir(imported_meta.restore))

    #exp_dir = prep_exp_dir(directory=config.TEST_EXP_PATH)
    #new_mod = rnn_ctc.Model(exp_dir, na_reader)
    #with tf.Session() as sess:
    #    sess.run(tf.global_variables_initializer())
    #    for v in tf.get_default_graph().get_collection("train_op"):
    #        print(v)
    #    return
    #with tf.Session(graph=loaded_graph) as sess:
    with tf.Session() as sess:
        with tf.device("/cpu:0"):
            metagraph.restore(sess, model_prefix_path)
        #imported_meta.restore(sess, tf.train.latest_checkpoint('./'))
        #imported_meta.restore(sess, model_prefix_path)
        #print([x for x in tf.get_default_graph().get_operations() if "Placeholder" in x.type])
        print([
            x for x in sess.graph.get_operations() if "Placeholder" in x.type
        ])
        print(dir(sess))
        print(sess.graph)
        print(dir(sess.graph))
        for v in tf.get_default_graph().get_collection("variables"):
            print(v)
        for v in tf.get_default_graph().get_collection("trainable_variables"):
            print(v)
        for v in tf.get_default_graph().get_collection("train_op"):
            print(v)
        print(tf.get_default_graph().get_all_collection_keys())
        pprint.pprint([
            repr(op) for op in tf.get_default_graph().get_operations()
            if "hyp" in op.name
        ])
        pprint.pprint([
            repr(op) for op in tf.get_default_graph().get_operations()
            if "Placeholder" in op.type
        ])
        pprint.pprint([
            repr(op) for op in tf.get_default_graph().get_operations()
            if "SparseToDense" in op.type
        ])

        all_prefixes = []
        all_hyps = []
        for batch_i, batch in enumerate(na_reader.untranscribed_batch_gen()):

            batch_x, batch_x_lens, feat_fn_batch = batch
            prefixes = [
                fn.split("/")[-1].split(".")[:2] for fn in feat_fn_batch
            ]

            #ph_batch_x = tf.placeholder(
            #   tf.float32, [None, None, na_reader.corpus.num_feats])
            #ph_batch_x_lens = tf.placeholder(tf.int32, [None])
            #ph_batch_y = tf.sparse_placeholder(tf.int32)

            #feed_dict = {"batch_x:0": batch_x,
            #             "batch_x_lens:0": batch_x_lens}
            feed_dict = {
                "Placeholder:0": batch_x,
                "Placeholder_1:0": batch_x_lens
            }

            #[dense_decoded] = sess.run(["hyp_dense_decoded:0"],
            #                         feed_dict=feed_dict)
            [dense_decoded] = sess.run(["SparseToDense:0"],
                                       feed_dict=feed_dict)
            print(dense_decoded)
            hyps = na_reader.human_readable(dense_decoded)
            print(hyps)
            print(na_reader.corpus.INDEX_TO_LABEL)
            all_hyps.extend(["".join(hyp) for hyp in hyps])
            all_prefixes.extend([".".join(prefix) for prefix in prefixes])
        print(
            results.fmt_latex_untranscribed(
                all_hyps, all_prefixes,
                Path("benevolence_and_funeral_custom.tex")))