コード例 #1
0
    # Fix random seeds and number of threads
    np.random.seed(42)
    tf.random.set_seed(42)
    tf.config.threading.set_inter_op_parallelism_threads(args.threads)
    tf.config.threading.set_intra_op_parallelism_threads(args.threads)

    # Create logdir name
    args.logdir = os.path.join(
        "logs", "{}-{}-{}".format(
            os.path.basename(__file__),
            datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"), ",".join(
                ("{}={}".format(re.sub("(.)[^_]*_?", r"\1", key), value)
                 for key, value in sorted(vars(args).items())))))

    # Load the data
    timit = TimitMFCC()

    # Create the network and train
    network = Network(args)
    for epoch in range(args.epochs):
        network.train_epoch(timit.train, args)
        network.evaluate(timit.dev, "dev", args)
        print(f"Epoch {epoch+1} done.")

    # Generate test set annotations, but to allow parallel execution, create it
    # in in args.logdir if it exists.
    out_path = "speech_recognition_test.txt"
    if os.path.isdir(args.logdir):
        out_path = os.path.join(args.logdir, out_path)
    with open(out_path, "w", encoding="utf-8") as out_file:
        for sentence in network.predict(timit.test, args):
コード例 #2
0
                          a[i - 1][j - 1] + (x[i - 1] != y[j - 1]))
    return a[-1][-1]


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("predictions",
                        type=str,
                        help="Path to predicted output.")
    parser.add_argument("dataset",
                        type=str,
                        help="Which dataset to evaluate ('dev', 'test').")
    args = parser.parse_args([] if "__file__" not in globals() else None)

    gold = getattr(TimitMFCC(), args.dataset).data["letters"]

    with open(args.predictions, "r", encoding="utf-8") as predictions_file:
        predictions = [line.rstrip("\n") for line in predictions_file]

    if len(predictions) < len(gold):
        raise RuntimeError(
            "The predictions are shorter than gold data: {} vs {}.".format(
                len(predictions), len(gold)))

    score = 0
    for i in range(len(gold)):
        gold_sentence = [TimitMFCC.LETTERS[letter] for letter in gold[i]]
        predicted_sentence = predictions[i].split(" ")
        score += edit_distance(gold_sentence,
                               predicted_sentence) / len(gold_sentence)