# Fix random seeds and number of threads np.random.seed(42) tf.random.set_seed(42) tf.config.threading.set_inter_op_parallelism_threads(args.threads) tf.config.threading.set_intra_op_parallelism_threads(args.threads) # Create logdir name args.logdir = os.path.join( "logs", "{}-{}-{}".format( os.path.basename(__file__), datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"), ",".join( ("{}={}".format(re.sub("(.)[^_]*_?", r"\1", key), value) for key, value in sorted(vars(args).items()))))) # Load the data timit = TimitMFCC() # Create the network and train network = Network(args) for epoch in range(args.epochs): network.train_epoch(timit.train, args) network.evaluate(timit.dev, "dev", args) print(f"Epoch {epoch+1} done.") # Generate test set annotations, but to allow parallel execution, create it # in in args.logdir if it exists. out_path = "speech_recognition_test.txt" if os.path.isdir(args.logdir): out_path = os.path.join(args.logdir, out_path) with open(out_path, "w", encoding="utf-8") as out_file: for sentence in network.predict(timit.test, args):
a[i - 1][j - 1] + (x[i - 1] != y[j - 1])) return a[-1][-1] if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser() parser.add_argument("predictions", type=str, help="Path to predicted output.") parser.add_argument("dataset", type=str, help="Which dataset to evaluate ('dev', 'test').") args = parser.parse_args([] if "__file__" not in globals() else None) gold = getattr(TimitMFCC(), args.dataset).data["letters"] with open(args.predictions, "r", encoding="utf-8") as predictions_file: predictions = [line.rstrip("\n") for line in predictions_file] if len(predictions) < len(gold): raise RuntimeError( "The predictions are shorter than gold data: {} vs {}.".format( len(predictions), len(gold))) score = 0 for i in range(len(gold)): gold_sentence = [TimitMFCC.LETTERS[letter] for letter in gold[i]] predicted_sentence = predictions[i].split(" ") score += edit_distance(gold_sentence, predicted_sentence) / len(gold_sentence)