Пример #1
0
    def predict_ctc(self, inputX):
        # with self.sess as sess, self.graph.as_default():
        st = time.time()
        softmax1, softmax2, nnoutput = self.sess.run(
            ['model/softmax1:0', 'model/softmax2:0', 'model/nn_outputs:0'],
            feed_dict={'model/inputX:0': inputX})
        end = time.time()
        print('processing time', end - st)

        colors = ['r', 'b', 'g', 'm', 'y', 'k', 'r', 'b']
        y = softmax2
        x = range(len(y))
        print(y.shape)
        plt.figure(figsize=(10, 4))  # 创建绘图对象
        for i in range(1, y.shape[1]):
            plt.plot(x, y[:, i], colors[i - 1], linewidth=1, label=str(i))
        plt.legend(loc='upper right')
        plt.savefig('./fig.png')
        output1 = ctc_decode_strict(softmax1, config.num_classes)
        output2 = ctc_decode(softmax2, config.num_classes)
        np.set_printoptions(precision=4, threshold=np.inf,
                            suppress=True)
        # print(prob)
        # result = True if ctc_predict(output[0]) else False
        result = ctc_predict(output1, config.label_seqs) | ctc_predict(output2,
                                                                       config.label_seqs)
        return result, str(output1)+'---' + str(output2)
Пример #2
0
    def test(self, name, detected_callback=play_audio_file):
        y, sr = librosa.load(name, 16000)
        linearspec = np.abs(
            librosa.stft(y,
                         config.fft_size,
                         config.hop_size,
                         config.fft_size,
                         center=False)).T
        mel = np.dot(linearspec, self.mel_basis)
        softmax, state = self.sess.run(
            ['model/softmax:0', 'model/rnn_states:0'],
            feed_dict={
                'model/inputX:0': mel,
                'model/rnn_initial_states:0': self.state
            })
        result = ctc_decode(softmax)
        print(result)
        if ctc_predict(result):
            detected_callback()
        colors = ['r', 'b', 'g', 'm', 'y', 'k']
        y = softmax
        x = range(len(y))
        plt.figure(figsize=(10, 4))  # 创建绘图对象

        for i in range(1, y.shape[1]):
            plt.plot(x, y[:, i], colors[i], linewidth=1, label=str(i))
        plt.legend(loc='upper right')
        plt.savefig('./test.png')
Пример #3
0
    def test(self, name, detected_callback=play_audio_file):
        y, sr = librosa.load(name, 16000)
        linearspec = np.abs(
            librosa.stft(y,
                         config.fft_size,
                         config.hop_size,
                         config.fft_size,
                         center=False)).T
        mel = np.dot(linearspec, self.mel_basis)
        softmax, logits, state = self.sess.run(
            ['model/softmax:0', 'model/logit:0', 'model/rnn_states:0'],
            feed_dict={
                'model/inputX:0': y,
                'model/rnn_initial_states:0': self.state
            })
        result = ctc_decode(softmax)
        print(result)
        if ctc_predict(result, '1233'):
            detected_callback()
        colors = ['r', 'b', 'g', 'm', 'y', 'k']

        y = softmax
        np.set_printoptions(precision=4, threshold=np.inf, suppress=True)
        test = y.flatten()
        print(softmax.shape)
        print(test.tolist())
        x = range(len(y))
        plt.figure(figsize=(20, 15), dpi=120)  # 创建绘图对象
        plt.xticks(fontsize=40)
        plt.yticks(fontsize=40)
        p1 = plt.subplot(211)
        p2 = plt.subplot(212)

        for i in range(0, y.shape[1]):
            p1.plot(x, y[:, i], colors[i], linewidth=2, label=str(i))
        p1.legend(loc='upper right', fontsize=20)
        # p1.figure(figsize=(20, 4))  # 创建绘图对象

        for i in range(0, logits.shape[2]):
            p2.plot(x, logits[0][:, i], colors[i], linewidth=2, label=str(i))
        p2.legend(loc='upper right', fontsize=20)
        plt.savefig('./test.png')
Пример #4
0
    def run(self, TrainingModel):

        graph = tf.Graph()
        with graph.as_default(), tf.Session() as sess:

            self.data = read_dataset(self.config)

            if config.mode == 'train':
                print('building training model....')
                with tf.variable_scope("model"):
                    self.train_model = TrainingModel(self.config,
                                                     self.data.batch_input_queue(),
                                                     is_train=True)
                    self.train_model.config.show()
                print('building valid model....')
                with tf.variable_scope("model", reuse=True):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            else:
                with tf.variable_scope("model", reuse=False):
                    self.valid_model = TrainingModel(self.config,
                                                     self.data.valid_queue(),
                                                     is_train=False)
            saver = tf.train.Saver()

            # restore from stored models
            files = glob(path_join(self.config.model_path, '*.ckpt.*'))

            if len(files) > 0:

                saver.restore(sess, path_join(self.config.model_path,
                                              self.config.model_name))
                print(('Model restored from:' + self.config.model_path))
            else:
                print("Model doesn't exist.\nInitializing........")
                sess.run(tf.global_variables_initializer())

            sess.run(tf.local_variables_initializer())
            tf.Graph.finalize(graph)

            st_time = time.time()
            if os.path.exists(path_join(self.config.model_path, 'best.pkl')):
                with open(path_join(self.config.model_path, 'best.pkl'),
                          'rb') as f:
                    best_miss, best_false = pickle.load(f)
                    print('best miss', best_miss, 'best false', best_false)
            else:
                print('best not exist')

            check_dir(self.config.model_path)

            if self.config.mode == 'train':
                best_miss = 1
                best_false = 1
                accu_loss = 0
                epoch_step = config.tfrecord_size * self.data.train_file_size // config.batch_size

                if self.config.reset_global:
                    sess.run(self.train_model.reset_global_step)

                def handler_stop_signals(signum, frame):
                    global run
                    run = False
                    if not DEBUG:
                        print(
                            'training shut down, total setp %s, the model will be save in %s' % (
                                step, self.config.model_path))
                        saver.save(sess, save_path=(
                            path_join(self.config.model_path, 'latest.ckpt')))
                        print('best miss rate:%f\tbest false rate %f' % (
                            best_miss, best_false))
                    sys.exit(0)

                signal.signal(signal.SIGINT, handler_stop_signals)
                signal.signal(signal.SIGTERM, handler_stop_signals)

                best_list = []
                best_threshold = 0.08
                best_count = 0
                # (miss,false,step,best_count)

                last_time = time.time()

                try:
                    sess.run([self.data.noise_stage_op,
                              self.data.noise_filequeue_enqueue_op,
                              self.train_model.stage_op,
                              self.train_model.input_filequeue_enqueue_op,
                              self.valid_model.stage_op,
                              self.valid_model.input_filequeue_enqueue_op])

                    va = tf.trainable_variables()
                    for i in va:
                        print(i.name)
                    while self.epoch < self.config.max_epoch:

                        _, _, _, _, _, l, lr, step, grads = sess.run(
                            [self.train_model.train_op,
                             self.data.noise_stage_op,
                             self.data.noise_filequeue_enqueue_op,
                             self.train_model.stage_op,
                             self.train_model.input_filequeue_enqueue_op,
                             self.train_model.loss,
                             self.train_model.learning_rate,
                             self.train_model.global_step,
                             self.train_model.grads
                             ])
                        epoch = step // epoch_step
                        accu_loss += l
                        if epoch > self.epoch:
                            self.epoch = epoch
                            print('accumulated loss', accu_loss)
                            saver.save(sess, save_path=(
                                path_join(self.config.model_path,
                                          'latest.ckpt')))
                            print('latest.ckpt save in %s' % (
                                path_join(self.config.model_path,
                                          'latest.ckpt')))
                            accu_loss = 0
                        if step % config.valid_step == 0:
                            print('epoch time ', (time.time() - last_time) / 60)
                            last_time = time.time()

                            miss_count = 0
                            false_count = 0
                            target_count = 0
                            wer = 0
                            valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size
                            text = ""
                            for i in range(valid_batch):
                                softmax, correctness, labels, _, _ = sess.run(
                                    [self.valid_model.softmax,
                                     self.valid_model.correctness,
                                     self.valid_model.labels,
                                     self.valid_model.stage_op,
                                     self.valid_model.input_filequeue_enqueue_op])
                                np.set_printoptions(precision=4,
                                                    threshold=np.inf,
                                                    suppress=True)

                                decode_output = [ctc_decode(s) for s in softmax]
                                for i in decode_output:
                                    text += str(i) + '\n'
                                    text += str(labels) + '\n'
                                    text += '=' * 20 + '\n'
                                result = [ctc_predict(seq, config.label_seqs)
                                          for seq in
                                          decode_output]
                                miss, target, false_accept = evaluate(
                                    result, correctness.tolist())

                                miss_count += miss
                                target_count += target
                                false_count += false_accept

                                wer += self.wer_cal.cal_batch_wer(labels,
                                                                  decode_output).sum()
                                # print(miss_count, false_count)
                            with open('./valid.txt', 'w') as f:
                                f.write(text)

                            miss_rate = miss_count / target_count
                            false_accept_rate = false_count / (
                                self.data.validation_size - target_count)
                            print('--------------------------------')
                            print('epoch %d' % self.epoch)
                            print('training loss:' + str(l))
                            print('learning rate:', lr, 'global step', step)
                            print('miss rate:' + str(miss_rate))
                            print('flase_accept_rate:' + str(false_accept_rate))
                            print(miss_count, '/', target_count)
                            print('wer', wer / self.data.validation_size)

                            if miss_rate + false_accept_rate < best_miss + best_false:
                                best_miss = miss_rate
                                best_false = false_accept_rate
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.model_path,
                                               'best.ckpt')))
                                with open(path_join(
                                        self.config.model_path, 'best.pkl'),
                                        'wb') as f:
                                    best_tuple = (best_miss, best_false)
                                    pickle.dump(best_tuple, f)
                            if miss_rate + false_accept_rate < best_threshold:
                                best_count += 1
                                print('best_count', best_count)
                                best_list.append((miss_rate,
                                                  false_accept_rate, step,
                                                  best_count))
                                saver.save(sess,
                                           save_path=(path_join(
                                               self.config.model_path,
                                               'best' + str(
                                                   best_count) + '.ckpt')))

                    print(
                        'training finished, total epoch %d, the model will be save in %s' % (
                            self.epoch, self.config.model_path))
                    saver.save(sess, save_path=(
                        path_join(self.config.model_path, 'latest.ckpt')))
                    print('best miss rate:%f\tbest false rate"%f' % (
                        best_miss, best_false))

                except tf.errors.OutOfRangeError:
                    print('Done training -- epoch limit reached')
                except Exception as e:
                    print(e)
                    traceback.print_exc()
                finally:
                    with open('best_list.pkl', 'wb') as f:
                        pickle.dump(best_list, f)
                    print('total time:%f hours' % (
                        (time.time() - st_time) / 3600))
                    # When done, ask the threads to stop.

            else:
                with open(
                                        config.rawdata_path + 'valid/' + "ctc_valid.pkl.sorted",
                                        'rb') as f:
                    pkl = pickle.load(f)
                miss_count = 0
                false_count = 0
                target_count = 0

                valid_batch = self.data.valid_file_size * config.tfrecord_size // config.batch_size

                for i in range(valid_batch):
                    # if i > 7:
                    #     break
                    ind = 14
                    softmax, ctc_input, correctness, labels, _, _ = sess.run(
                        [self.valid_model.softmax,
                         self.valid_model.nn_outputs,
                         self.valid_model.correctness,
                         self.valid_model.labels,
                         self.valid_model.stage_op,
                         self.valid_model.input_filequeue_enqueue_op])
                    np.set_printoptions(precision=4,
                                        threshold=np.inf,
                                        suppress=True)

                    correctness = correctness.tolist()
                    decode_output = [ctc_decode(s) for s in softmax]
                    result = [ctc_predict(seq, config.label_seqs)
                              for seq in
                              decode_output]
                    for k, r in enumerate(result):
                        if r != correctness[k]:
                            name = pkl[i * config.batch_size + k][0]
                            print("scp [email protected]:/ssd/keyword_raw/valid/%s ./"%name)
                            # print(pkl[i * config.batch_size + k])
                            # print(decode_output[k])
                            # print(labels[k])
                            with open('logits.txt', 'w') as f:
                                f.write(str(ctc_input[k]))

                    miss, target, false_accept = evaluate(
                        result, correctness)

                    miss_count += miss
                    target_count += target
                    false_count += false_accept

                print('--------------------------------')
                print('miss rate: %d/%d' % (miss_count, target_count))
                print('flase_accept_rate: %d/%d' % (
                    false_count, self.data.validation_size - target_count))
Пример #5
0
    def start(self,
              detected_callback=play_audio_file,
              interrupt_check=lambda: False,
              sleep_time=0.3):
        """
        Start the voice detector. For every `sleep_time` second it checks the
        audio buffer for triggering keywords. If detected, then call
        corresponding function in `detected_callback`, which can be a single
        function (single model) or a list of callback functions (multiple
        models). Every loop it also calls `interrupt_check` -- if it returns
        True, then breaks from the loop and return.

        :param detected_callback: a function or list of functions. The number of
                                  items must match the number of models in
                                  `decoder_model`.
        :param interrupt_check: a function that returns True if the main loop
                                needs to stop.
        :param float sleep_time: how much time in second every loop waits.
        :return: None
        """
        if interrupt_check():
            logger.debug("detect voice return")
            return

        logger.debug("detecting...")
        with self.sess as sess:
            while True:
                if interrupt_check():
                    logger.debug("detect voice break")
                    break
                data = self.ring_buffer.get()
                self.npdata.append(data)
                if len(data) == 0:
                    time.sleep(sleep_time)
                    continue

                if vad(data, 20):
                    pass
                    self.prev_speech = True
                else:
                    if self.prev_speech:
                        self.prev_speech = False
                        self.clean_state()
                        self.prob_queue.clear()
                    continue

                data = np.concatenate((self.res, data), 0)

                res = (len(data) - config.fft_size) % config.hop_size + (
                    config.fft_size - config.hop_size)
                self.res = data[-res:]

                linearspec = np.abs(
                    librosa.stft(data,
                                 config.fft_size,
                                 config.hop_size,
                                 config.fft_size,
                                 center=False)).T
                mel = np.dot(linearspec, self.mel_basis)

                softmax, state = sess.run(
                    ['model/softmax:0', 'model/rnn_states:0'],
                    feed_dict={
                        'model/inputX:0': mel,
                        'model/rnn_initial_states:0': self.state
                    })

                self.prob_queue.add(softmax)
                self.state = state
                concated_soft = np.concatenate(self.prob_queue.get_all(), 0)
                print(concated_soft.shape)

                result = ctc_decode(concated_soft)
                if ctc_predict(result):
                    detected_callback()
                    self.prob_queue.clear()
                    librosa.output.write_wav('./trigger.wav',
                                             np.concatenate(self.npdata, 0),
                                             16000)
                    self.npdata = []
                    self.clean_state()
                    self.plot(concated_soft, 'trigger.png')

        logger.debug("finished.")
        self.terminate()
Пример #6
0
    def start(self,
              detected_callback=play_audio_file,
              interrupt_check=lambda: False,
              sleep_time=0.3):

        if interrupt_check():
            logger.debug("detect voice return")
            return

        logger.debug("detecting...")
        with self.sess as sess:
            while True:
                if interrupt_check():
                    logger.debug("detect voice break")
                    break
                data = self.ring_buffer.get()
                self.npdata.append(data)
                if len(data) == 0:
                    time.sleep(sleep_time)
                    continue

                if vad(data, 30):
                    pass
                    self.prev_speech = True
                else:
                    # if self.prev_speech:
                    #     self.prev_speech = False
                    #     self.clean_state()
                    #     self.prob_queue.clear()
                    self.clean_state()
                    self.prob_queue.clear()

                data = np.concatenate((self.res, data), 0)

                res = (len(data) - config.fft_size) % config.hop_size + (
                    config.fft_size - config.hop_size)
                self.res = data[-res:]
                #
                # linearspec = np.abs(
                #     librosa.stft(data, config.fft_size, config.hop_size,
                #                  config.fft_size, center=False)).T
                # mel = np.dot(linearspec, self.mel_basis)

                softmax, state = sess.run(
                    ['model/softmax:0', 'model/rnn_states:0'],
                    feed_dict={
                        'model/inputX:0': data,
                        'model/rnn_initial_states:0': self.state
                    })

                self.prob_queue.add(softmax)
                self.state = state
                concated_soft = np.concatenate(self.prob_queue.get_all(), 0)
                print(concated_soft.shape)

                result = ctc_decode2(concated_soft, config.num_classes)
                if ctc_predict(result, '1233'):
                    detected_callback()
                    self.prob_queue.clear()
                    librosa.output.write_wav('./trigger.wav',
                                             np.concatenate(self.npdata, 0),
                                             16000)
                    self.npdata = []
                    self.clean_state()
                    self.plot(concated_soft, 'trigger.png')

        logger.debug("finished.")
        self.terminate()