def on_train_end(self, logs={}): """ Method used by fit_generator() at the end of training to calculate the test WER and output prediction samples """ try: test_wer = calc_wer(self.test_func, self.test_gen) print "\n - Training ended, test wer: ", test_wer[1], " -" except (Exception, StandardError) as e: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(e).__name__, e.args) print message # Print a sample of predictions, for visualisation print "\nPrediction samples:\n" predictions = predict_on_batch(self.validation_gen, self.test_func, 6) for i in predictions: print "Original: ", i[0] print "Predicted: ", i[1], "\n" self.save_log()
def main(args): try: if not args.model_load: raise ValueError("Error within model arguments") audio_dir = args.audio_dir print("\nReading test data: ") _, df = combine_all_wavs_and_trans_from_csvs(audio_dir) batch_size = args.batch_size batch_index = args.batch_index mfcc_features = args.mfccs n_mels = args.mels # Sampling rate of data in khz (LibriSpeech is 16khz) frequency = 16 # Training data_params: model_load = args.model_load load_multi = args.load_multi # Sets the full dataset in audio_dir to be available through data_generator # The data_generator doesn't actually load the audio files until they are requested through __get_item__() epoch_length = 0 # Load trained model # When loading custom objects, Keras needs to know where to find them. # The CTC lambda is a dummy function custom_objects = { 'clipped_relu': models.clipped_relu, '<lambda>': lambda y_true, y_pred: y_pred } # When loading a parallel model saved *while* running on GPU, use load_multi if load_multi: model = models.load_model(model_load, custom_objects=custom_objects) model = model.layers[-2] print("\nLoaded existing model: ", model_load) # Load single GPU/CPU model or model saved *after* finished training else: model = models.load_model(model_load, custom_objects=custom_objects) print("\nLoaded existing model: ", model_load) # Dummy loss-function to compile model, actual CTC loss-function defined as a lambda layer in model loss = {'ctc': lambda y_true, y_pred: y_pred} model.compile(loss=loss, optimizer='Adam') feature_shape = model.input_shape[0][2] # Model feature type if not args.feature_type: if feature_shape == 26: feature_type = 'mfcc' else: feature_type = 'spectrogram' else: feature_type = args.feature_type print("Feature type: ", feature_type) # Data generation parameters data_params = { 'feature_type': feature_type, 'batch_size': batch_size, 'frame_length': 20 * frequency, 'hop_length': 10 * frequency, 'mfcc_features': mfcc_features, 'n_mels': n_mels, 'epoch_length': epoch_length, 'shuffle': False } # Data generators for training, validation and testing data data_generator = DataGenerator(df, **data_params) # Print(model summary) model.summary() # Creates a test function that takes preprocessed sound input and outputs predictions # Used to calculate WER while training the network input_data = model.get_layer('the_input').input y_pred = model.get_layer('ctc').input[0] test_func = K.function([input_data], [y_pred]) if args.calc_wer: print("\n - Calculation WER on ", audio_dir) wer = calc_wer(test_func, data_generator) print("Average WER: ", wer[1]) predictions = predict_on_batch(data_generator, test_func, batch_index) print("\n - Predictions from batch index: ", batch_index, "\nFrom: ", audio_dir, "\n") for i in predictions: print("Original: ", i[0]) print("Predicted: ", i[1], "\n") except (Exception, BaseException, GeneratorExit, SystemExit) as e: template = "An exception of type {0} occurred. Arguments:\n{1!r}" message = template.format(type(e).__name__, e.args) print("e.args: ", e.args) print(message) finally: # Clear memory K.clear_session()