def prediction(args, number, data_dir='../samples/cha/stimuli3/0', model_path="", mode='train', data_suffix='npy'): """ Get the ground truth and predicted sequence number: number of data in the batch wants to be visualized mode: train or validation or test, which will indicate the folder to be opened dir: the directory of the mfcc data model_path: the model checkpoint, HDF5 file data_suffix: the data suffix of those audio file inside the data_dir """ # get data by index data, label, total_path_name = data_specification(mode, data_dir, data_suffix) max_sequence_length = 0 for i in data: max_sequence_length = max(max_sequence_length, i.shape[1]) # Training Phase model = brnn_keras(args, max_sequence_length) #model.summary() model.model_1.load_weights(model_path) #model.model_2.load_weights(model_path) batch = create_batch(args, data, max_sequence_length, label) for i in range(number): inputs, outputs = next(batch) x = inputs['the_input'] y = inputs['the_labels'][0] prediction = model.model_1.predict(x, steps=1) pred_ints = decode_ctc(args, prediction) print(np.shape(prediction)) model_path = model_path.replace("../CHECKPOINT/char/save/", "") print('model ' + model_path) print('-' * 80) y = y.astype(int) if "test" not in model_path: print('GroundTruth speech:\n' + '\n' + int_sequence_to_text(y)) else: print('GroundTruth speech:\n' + '\n' + int_sequence_to_text_test(y)) print('-' * 80) print(pred_ints) if "test" not in model_path: print('Predicted speech:\n' + '\n' + ''.join(int_sequence_to_text(pred_ints))) else: print('Predicted speech:\n' + '\n' + ''.join(int_sequence_to_text_test(pred_ints))) output_matrix(prediction[0])
def run_session(self, args): # get data ################################ # get data and model handler ################################ # training data data, label, _ = data_specification(args.mode, args.data_dir, 'npy') # 暂时设置为test 数据变化之后会进行修改 dev_data, dev_label, _ = data_specification('test', args.data_dir, 'npy') max_sequence_length = 0 batch_num = len(data) // args.batch_size dev_batch_num = len(dev_data) // args.batch_size for i in data: max_sequence_length = max(max_sequence_length, i.shape[1]) for i in dev_data: max_sequence_length = max(max_sequence_length, i.shape[1]) # Checkpointer checkpointer = ModelCheckpoint(filepath=savedir + '/' + args.model_name, verbose=1) # Training Phase model = Deepspeech2(args, max_sequence_length) history = model.model_2.fit_generator( generator=create_batch(args, data, max_sequence_length, label), validation_data=create_batch(args, dev_data, max_sequence_length, dev_label), steps_per_epoch=batch_num, validation_steps=dev_batch_num, epochs=args.epochs, verbose=1, callbacks=[checkpointer]) # save model loss with open(savedir + '/' + args.pickle_name, 'wb') as f: pickle.dump(history.history, f)
def get_data(mode='train', data_dir='./PreferredStimuliData', data_suffix='npy'): """Return data that is used for this visualization test reuse the function inside the data helper, only return data with batchsize with 1 """ data, label, total_path_name = data_specification(mode, data_dir, data_suffix) max_sequence_length = 0 for i in data: max_sequence_length = max(max_sequence_length, i.shape[1]) batch = create_batch(args, data, max_sequence_length, label) inputs, outputs = next(batch) #x = inputs['the_input'] #y = inputs['the_labels'][0] return inputs, outputs, max_sequence_length