def predict_ctc(self, inputX): # with self.sess as sess, self.graph.as_default(): st = time.time() softmax1, softmax2, nnoutput = self.sess.run( ['model/softmax1:0', 'model/softmax2:0', 'model/nn_outputs:0'], feed_dict={'model/inputX:0': inputX}) end = time.time() print('processing time', end - st) colors = ['r', 'b', 'g', 'm', 'y', 'k', 'r', 'b'] y = softmax2 x = range(len(y)) print(y.shape) plt.figure(figsize=(10, 4)) # 创建绘图对象 for i in range(1, y.shape[1]): plt.plot(x, y[:, i], colors[i - 1], linewidth=1, label=str(i)) plt.legend(loc='upper right') plt.savefig('./fig.png') output1 = ctc_decode_strict(softmax1, config.num_classes) output2 = ctc_decode(softmax2, config.num_classes) np.set_printoptions(precision=4, threshold=np.inf, suppress=True) # print(prob) # result = True if ctc_predict(output[0]) else False result = ctc_predict(output1, config.label_seqs) | ctc_predict(output2, config.label_seqs) return result, str(output1)+'---' + str(output2)
def test(self, name, detected_callback=play_audio_file): y, sr = librosa.load(name, 16000) linearspec = np.abs( librosa.stft(y, config.fft_size, config.hop_size, config.fft_size, center=False)).T mel = np.dot(linearspec, self.mel_basis) softmax, state = self.sess.run( ['model/softmax:0', 'model/rnn_states:0'], feed_dict={ 'model/inputX:0': mel, 'model/rnn_initial_states:0': self.state }) result = ctc_decode(softmax) print(result) if ctc_predict(result): detected_callback() colors = ['r', 'b', 'g', 'm', 'y', 'k'] y = softmax x = range(len(y)) plt.figure(figsize=(10, 4)) # 创建绘图对象 for i in range(1, y.shape[1]): plt.plot(x, y[:, i], colors[i], linewidth=1, label=str(i)) plt.legend(loc='upper right') plt.savefig('./test.png')
def test(self, name, detected_callback=play_audio_file): y, sr = librosa.load(name, 16000) linearspec = np.abs( librosa.stft(y, config.fft_size, config.hop_size, config.fft_size, center=False)).T mel = np.dot(linearspec, self.mel_basis) softmax, logits, state = self.sess.run( ['model/softmax:0', 'model/logit:0', 'model/rnn_states:0'], feed_dict={ 'model/inputX:0': y, 'model/rnn_initial_states:0': self.state }) result = ctc_decode(softmax) print(result) if ctc_predict(result, '1233'): detected_callback() colors = ['r', 'b', 'g', 'm', 'y', 'k'] y = softmax np.set_printoptions(precision=4, threshold=np.inf, suppress=True) test = y.flatten() print(softmax.shape) print(test.tolist()) x = range(len(y)) plt.figure(figsize=(20, 15), dpi=120) # 创建绘图对象 plt.xticks(fontsize=40) plt.yticks(fontsize=40) p1 = plt.subplot(211) p2 = plt.subplot(212) for i in range(0, y.shape[1]): p1.plot(x, y[:, i], colors[i], linewidth=2, label=str(i)) p1.legend(loc='upper right', fontsize=20) # p1.figure(figsize=(20, 4)) # 创建绘图对象 for i in range(0, logits.shape[2]): p2.plot(x, logits[0][:, i], colors[i], linewidth=2, label=str(i)) p2.legend(loc='upper right', fontsize=20) plt.savefig('./test.png')
def run(self, TrainingModel): graph = tf.Graph() with graph.as_default(), tf.Session() as sess: self.data = read_dataset(self.config) if config.mode == 'train': print('building training model....') with tf.variable_scope("model"): self.train_model = TrainingModel(self.config, self.data.batch_input_queue(), is_train=True) self.train_model.config.show() print('building valid model....') with tf.variable_scope("model", reuse=True): self.valid_model = TrainingModel(self.config, self.data.valid_queue(), is_train=False) else: with tf.variable_scope("model", reuse=False): self.valid_model = TrainingModel(self.config, self.data.valid_queue(), is_train=False) saver = tf.train.Saver() # restore from stored models files = glob(path_join(self.config.model_path, '*.ckpt.*')) if len(files) > 0: saver.restore(sess, path_join(self.config.model_path, self.config.model_name)) print(('Model restored from:' + self.config.model_path)) else: print("Model doesn't exist.\nInitializing........") sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) tf.Graph.finalize(graph) st_time = time.time() if os.path.exists(path_join(self.config.model_path, 'best.pkl')): with open(path_join(self.config.model_path, 'best.pkl'), 'rb') as f: best_miss, best_false = pickle.load(f) print('best miss', best_miss, 'best false', best_false) else: print('best not exist') check_dir(self.config.model_path) if self.config.mode == 'train': best_miss = 1 best_false = 1 accu_loss = 0 epoch_step = config.tfrecord_size * self.data.train_file_size // config.batch_size if self.config.reset_global: sess.run(self.train_model.reset_global_step) def handler_stop_signals(signum, frame): global run run = False if not DEBUG: print( 'training shut down, total setp %s, the model will be save in %s' % ( step, self.config.model_path)) saver.save(sess, save_path=( path_join(self.config.model_path, 'latest.ckpt'))) print('best miss rate:%f\tbest false rate %f' % ( best_miss, best_false)) sys.exit(0) signal.signal(signal.SIGINT, handler_stop_signals) signal.signal(signal.SIGTERM, handler_stop_signals) best_list = [] best_threshold = 0.08 best_count = 0 # (miss,false,step,best_count) last_time = time.time() try: sess.run([self.data.noise_stage_op, self.data.noise_filequeue_enqueue_op, self.train_model.stage_op, self.train_model.input_filequeue_enqueue_op, self.valid_model.stage_op, self.valid_model.input_filequeue_enqueue_op]) va = tf.trainable_variables() for i in va: print(i.name) while self.epoch < self.config.max_epoch: _, _, _, _, _, l, lr, step, grads = sess.run( [self.train_model.train_op, self.data.noise_stage_op, self.data.noise_filequeue_enqueue_op, self.train_model.stage_op, self.train_model.input_filequeue_enqueue_op, self.train_model.loss, self.train_model.learning_rate, self.train_model.global_step, self.train_model.grads ]) epoch = step // epoch_step accu_loss += l if epoch > self.epoch: self.epoch = epoch print('accumulated loss', accu_loss) saver.save(sess, save_path=( path_join(self.config.model_path, 'latest.ckpt'))) print('latest.ckpt save in %s' % ( path_join(self.config.model_path, 'latest.ckpt'))) accu_loss = 0 if step % config.valid_step == 0: print('epoch time ', (time.time() - last_time) / 60) last_time = time.time() miss_count = 0 false_count = 0 target_count = 0 wer = 0 valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size text = "" for i in range(valid_batch): softmax, correctness, labels, _, _ = sess.run( [self.valid_model.softmax, self.valid_model.correctness, self.valid_model.labels, self.valid_model.stage_op, self.valid_model.input_filequeue_enqueue_op]) np.set_printoptions(precision=4, threshold=np.inf, suppress=True) decode_output = [ctc_decode(s) for s in softmax] for i in decode_output: text += str(i) + '\n' text += str(labels) + '\n' text += '=' * 20 + '\n' result = [ctc_predict(seq, config.label_seqs) for seq in decode_output] miss, target, false_accept = evaluate( result, correctness.tolist()) miss_count += miss target_count += target false_count += false_accept wer += self.wer_cal.cal_batch_wer(labels, decode_output).sum() # print(miss_count, false_count) with open('./valid.txt', 'w') as f: f.write(text) miss_rate = miss_count / target_count false_accept_rate = false_count / ( self.data.validation_size - target_count) print('--------------------------------') print('epoch %d' % self.epoch) print('training loss:' + str(l)) print('learning rate:', lr, 'global step', step) print('miss rate:' + str(miss_rate)) print('flase_accept_rate:' + str(false_accept_rate)) print(miss_count, '/', target_count) print('wer', wer / self.data.validation_size) if miss_rate + false_accept_rate < best_miss + best_false: best_miss = miss_rate best_false = false_accept_rate saver.save(sess, save_path=(path_join( self.config.model_path, 'best.ckpt'))) with open(path_join( self.config.model_path, 'best.pkl'), 'wb') as f: best_tuple = (best_miss, best_false) pickle.dump(best_tuple, f) if miss_rate + false_accept_rate < best_threshold: best_count += 1 print('best_count', best_count) best_list.append((miss_rate, false_accept_rate, step, best_count)) saver.save(sess, save_path=(path_join( self.config.model_path, 'best' + str( best_count) + '.ckpt'))) print( 'training finished, total epoch %d, the model will be save in %s' % ( self.epoch, self.config.model_path)) saver.save(sess, save_path=( path_join(self.config.model_path, 'latest.ckpt'))) print('best miss rate:%f\tbest false rate"%f' % ( best_miss, best_false)) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') except Exception as e: print(e) traceback.print_exc() finally: with open('best_list.pkl', 'wb') as f: pickle.dump(best_list, f) print('total time:%f hours' % ( (time.time() - st_time) / 3600)) # When done, ask the threads to stop. else: with open( config.rawdata_path + 'valid/' + "ctc_valid.pkl.sorted", 'rb') as f: pkl = pickle.load(f) miss_count = 0 false_count = 0 target_count = 0 valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size for i in range(valid_batch): # if i > 7: # break ind = 14 softmax, ctc_input, correctness, labels, _, _ = sess.run( [self.valid_model.softmax, self.valid_model.nn_outputs, self.valid_model.correctness, self.valid_model.labels, self.valid_model.stage_op, self.valid_model.input_filequeue_enqueue_op]) np.set_printoptions(precision=4, threshold=np.inf, suppress=True) correctness = correctness.tolist() decode_output = [ctc_decode(s) for s in softmax] result = [ctc_predict(seq, config.label_seqs) for seq in decode_output] for k, r in enumerate(result): if r != correctness[k]: name = pkl[i * config.batch_size + k][0] print("scp [email protected]:/ssd/keyword_raw/valid/%s ./"%name) # print(pkl[i * config.batch_size + k]) # print(decode_output[k]) # print(labels[k]) with open('logits.txt', 'w') as f: f.write(str(ctc_input[k])) miss, target, false_accept = evaluate( result, correctness) miss_count += miss target_count += target false_count += false_accept print('--------------------------------') print('miss rate: %d/%d' % (miss_count, target_count)) print('flase_accept_rate: %d/%d' % ( false_count, self.data.validation_size - target_count))
def start(self, detected_callback=play_audio_file, interrupt_check=lambda: False, sleep_time=0.3): """ Start the voice detector. For every `sleep_time` second it checks the audio buffer for triggering keywords. If detected, then call corresponding function in `detected_callback`, which can be a single function (single model) or a list of callback functions (multiple models). Every loop it also calls `interrupt_check` -- if it returns True, then breaks from the loop and return. :param detected_callback: a function or list of functions. The number of items must match the number of models in `decoder_model`. :param interrupt_check: a function that returns True if the main loop needs to stop. :param float sleep_time: how much time in second every loop waits. :return: None """ if interrupt_check(): logger.debug("detect voice return") return logger.debug("detecting...") with self.sess as sess: while True: if interrupt_check(): logger.debug("detect voice break") break data = self.ring_buffer.get() self.npdata.append(data) if len(data) == 0: time.sleep(sleep_time) continue if vad(data, 20): pass self.prev_speech = True else: if self.prev_speech: self.prev_speech = False self.clean_state() self.prob_queue.clear() continue data = np.concatenate((self.res, data), 0) res = (len(data) - config.fft_size) % config.hop_size + ( config.fft_size - config.hop_size) self.res = data[-res:] linearspec = np.abs( librosa.stft(data, config.fft_size, config.hop_size, config.fft_size, center=False)).T mel = np.dot(linearspec, self.mel_basis) softmax, state = sess.run( ['model/softmax:0', 'model/rnn_states:0'], feed_dict={ 'model/inputX:0': mel, 'model/rnn_initial_states:0': self.state }) self.prob_queue.add(softmax) self.state = state concated_soft = np.concatenate(self.prob_queue.get_all(), 0) print(concated_soft.shape) result = ctc_decode(concated_soft) if ctc_predict(result): detected_callback() self.prob_queue.clear() librosa.output.write_wav('./trigger.wav', np.concatenate(self.npdata, 0), 16000) self.npdata = [] self.clean_state() self.plot(concated_soft, 'trigger.png') logger.debug("finished.") self.terminate()
def start(self, detected_callback=play_audio_file, interrupt_check=lambda: False, sleep_time=0.3): if interrupt_check(): logger.debug("detect voice return") return logger.debug("detecting...") with self.sess as sess: while True: if interrupt_check(): logger.debug("detect voice break") break data = self.ring_buffer.get() self.npdata.append(data) if len(data) == 0: time.sleep(sleep_time) continue if vad(data, 30): pass self.prev_speech = True else: # if self.prev_speech: # self.prev_speech = False # self.clean_state() # self.prob_queue.clear() self.clean_state() self.prob_queue.clear() data = np.concatenate((self.res, data), 0) res = (len(data) - config.fft_size) % config.hop_size + ( config.fft_size - config.hop_size) self.res = data[-res:] # # linearspec = np.abs( # librosa.stft(data, config.fft_size, config.hop_size, # config.fft_size, center=False)).T # mel = np.dot(linearspec, self.mel_basis) softmax, state = sess.run( ['model/softmax:0', 'model/rnn_states:0'], feed_dict={ 'model/inputX:0': data, 'model/rnn_initial_states:0': self.state }) self.prob_queue.add(softmax) self.state = state concated_soft = np.concatenate(self.prob_queue.get_all(), 0) print(concated_soft.shape) result = ctc_decode2(concated_soft, config.num_classes) if ctc_predict(result, '1233'): detected_callback() self.prob_queue.clear() librosa.output.write_wav('./trigger.wav', np.concatenate(self.npdata, 0), 16000) self.npdata = [] self.clean_state() self.plot(concated_soft, 'trigger.png') logger.debug("finished.") self.terminate()