Ejemplo n.º 1
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        # log.info(self.audio_paths)
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()

            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))

            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                p = []
                probs = []
                for k in range(int(seq_length)):
                    p.append(
                        np.argmax(pred[
                            k * int(int(self.batch_size) / int(self.num_gpu)) +
                            i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                p = pred_best(p)

                l_distance = levenshtein_distance(l, p)
                # l_distance = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), res)
                self.total_n_label += len(l)
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging and this_cer > 0.4:
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, labelUtil.convert_num_to_word(p),
                           this_cer, l_distance, len(l)))
                    # log.info("ctc_loss: %.2f" % ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)))
                self.num_inst += 1
                self.sum_metric += this_cer
                # if self.is_epoch_end:
                #    loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))
                #    self.batch_loss += loss
                #    if self.is_logging:
                #        log.info("loss: %f " % loss)
        self.total_ctc_loss += 0  # self.batch_loss
Ejemplo n.º 2
0
    def update(self, labels, preds):
        check_label_shapes(labels, preds)
        if self.is_logging:
            log = LogUtil().getlogger()
            labelUtil = LabelUtil()
        self.batch_loss = 0.
        shouldPrint = True
        host_name = socket.gethostname()
        for label, pred in zip(labels, preds):
            label = label.asnumpy()
            pred = pred.asnumpy()
            seq_length = len(pred) / int(
                int(self.batch_size) / int(self.num_gpu))
            # sess = tf.Session()
            for i in range(int(int(self.batch_size) / int(self.num_gpu))):
                l = remove_blank(label[i])
                # p = []
                probs = []
                for k in range(int(seq_length)):
                    # p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]))
                    probs.append(
                        pred[k * int(int(self.batch_size) / int(self.num_gpu))
                             + i])
                # p = pred_best(p)
                probs = np.array(probs)
                st = time.time()
                beam_size = 20
                results = ctc_beam_decode(self.scorer, beam_size,
                                          labelUtil.byList, probs)
                log.info("decode by ctc_beam cost %.2f result: %s" %
                         (time.time() - st, "\n".join(results)))

                res_str1 = ctc_greedy_decode(probs, labelUtil.byList)
                log.info("decode by pred_best: %s" % res_str1)

                # max_time_steps = int(seq_length)
                # input_log_prob_matrix_0 = np.log(probs)  # + 2.0
                #
                # # len max_time_steps array of batch_size x depth matrices
                # inputs = ([
                #   input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)]
                # )
                #
                # inputs_t = [ops.convert_to_tensor(x) for x in inputs]
                # inputs_t = array_ops.stack(inputs_t)
                #
                # st = time.time()
                # # run CTC beam search decoder in tensorflow
                # decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t,
                #                                                            [max_time_steps],
                #                                                            beam_width=10,
                #                                                            top_paths=3,
                #                                                            merge_repeated=False)
                # tf_decoded, tf_log_probs = sess.run([decoded, log_probabilities])
                # st1 = time.time() - st
                # for index in range(3):
                #   tf_result = ''.join([labelUtil.byIndex.get(i + 1, ' ') for i in tf_decoded[index].values])
                #   print("%.2f elpse %.2f, %s" % (tf_log_probs[0][index], st1, tf_result))
                l_distance = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), res_str1)
                # l_distance_beam = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), beam_result[0][1])
                l_distance_beam_cpp = editdistance.eval(
                    labelUtil.convert_num_to_word(l).split(" "), results[0])
                self.total_n_label += len(l)
                # self.total_l_dist_beam += l_distance_beam
                self.total_l_dist_beam_cpp += l_distance_beam_cpp
                self.total_l_dist += l_distance
                this_cer = float(l_distance) / float(len(l))
                if self.is_logging:
                    # log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l)))
                    # log.info("%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l)))
                    log.info("%s label: %s " %
                             (host_name, labelUtil.convert_num_to_word(l)))
                    log.info(
                        "%s pred : %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, res_str1, this_cer, l_distance, len(l)))
                    # log.info("%s predb: %s , cer: %f (distance: %d/ label length: %d)" % (
                    #     host_name, " ".join(beam_result[0][1]), float(l_distance_beam) / len(l), l_distance_beam,
                    #     len(l)))
                    log.info(
                        "%s predc: %s , cer: %f (distance: %d/ label length: %d)"
                        % (host_name, " ".join(
                            results[0]), float(l_distance_beam_cpp) / len(l),
                           l_distance_beam_cpp, len(l)))
                self.total_ctc_loss += self.batch_loss
                self.placeholder = res_str1 + "\n" + "\n".join(results)