Esempio n. 1
0
File: utils.py Progetto: boyshen/ASR
    def shuffle(*args):
        """
        洗牌
        :param args:(list, tuple) 需要洗牌的数据集
        :return:
        """
        # 检查输入的参数是不是 list 或 tuple 类型
        for arg in args:
            if not isinstance(arg, (list, tuple)):
                raise ParameterError(
                    "args must be is list or tuple, but actually get {}".
                    format(type(arg)))

        # 检查输入的每个参数大小是不是相等
        if len(args) > 1:
            for args_a, args_b in zip(args[:-1], args[1:]):
                if len(args_a) != len(args_b):
                    raise ParameterError(
                        "args size must be equal. expect size:{}, actually size{}"
                        .format(len(args_a), len(args_b)))

        # 随机排序索引
        random_index = np.random.permutation(len(args[0]))

        # 洗牌
        result = list()
        for arg in args:
            arg_result = list()
            for index in random_index:
                arg_result.append(arg[index])
            result.append(arg_result)

        return result
Esempio n. 2
0
    def get_n_word_freq(self, words):
        """
        获取 n 阶词频
        :param words: (list or tuple, mandatory) 输入的类型必须是list或tuple,同时词数量必须等于 n(n元语法,n=2或n=3).
        同时规定列表中第一个单词为当前词,第二个单词为当前词的前一个,依次类推。例如"A"、"B"、"C" 这三个词。当前词为 'B','B' 的上一个词为"A"
        :return: (int) 词频
        """
        if not isinstance(words, (list, tuple)):
            raise ParameterError(
                "input words type must be is [list, tuple], actually get {}".
                format(type(words)))
        if len(words) != self.__n:
            raise ParameterError(
                "input number of words equal to {}, actually get {}".format(
                    self.__n, len(words)))

        n_word_freq = 0
        if self.__n == TWO_TAG:
            n_word_freq = self.__n_word_freq[self.word_to_token(
                words[0])][self.word_to_token(words[1])]

        elif self.__n == THREE_TAG:
            w_k = self.word_to_token(words[0])
            one_gram_w_k = self.word_to_token(words[1])
            two_gram_w_k = self.word_to_token(words[2])
            n_word_freq = self.__n_word_freq[w_k][one_gram_w_k][two_gram_w_k]

        return n_word_freq
Esempio n. 3
0
    def __check(self, texts, disambiguate_type):
        """
        检查函数。检查输入的参数是否符合要求
        :param texts: (list or tuple, mandatory) 文本
        :param disambiguate_type: (str, mandatory) 消歧类型
        :return:
        """
        if disambiguate_type not in self.__disambiguate_type:
            raise ParameterError(
                "class:{}, func:{}".format(Disambiguate.__name__,
                                           self.__check.__name__),
                "disambiguate_type must be is {}, but actually get {}".format(
                    self.__disambiguate_type, disambiguate_type))

        if not isinstance(texts, (list, tuple)):
            raise ParameterError(
                "class:{}, func:{}".format(Disambiguate.__name__,
                                           self.__check.__name__),
                "texts parameter must be is [list,tuple], bug actually get{}".
                format(type(texts)))

        if len(texts) == 0:
            raise ParameterError(
                "class:{}, func:{}".format(Disambiguate.__name__,
                                           self.__check.__name__),
                "texts cannot be empty ")

        for text in texts:
            if not isinstance(text, (list, tuple)):
                raise ParameterError(
                    "class:{}, func:{}".format(Disambiguate.__name__,
                                               self.__check.__name__),
                    "texts element type must be is [list,tuple], but actually get {}"
                    .format(type(text)))
Esempio n. 4
0
    def disambiguate(self, texts, print_prob=False, need_score=False):
        """
        N_gram 消歧义.
        :param texts: (list or tuple, mandatory) 文本。文本需要是list或tuple两种类型。 同一个文本,多种分词形式。
        其中元素可以(list,tuple,str) 。如果元素是 str 类型,需要提供 split_labs。如 ["hello word",...] 或 [["hello", "world"]]
        :param print_prob: (bool, optional, default=False) 输出概率。即每个词频,n阶词频的计算概率
        :param need_score: (bool, optional, default=False)  是否返回评分。
        :return: (str and float) 。如果 need_score = True 返回文本和评分。否则返回文本
        """

        if not isinstance(texts, (list, tuple)):
            raise ParameterError(
                "Input texts parameter type must be is list or tuple, but actually get {}".format(type(texts)))
        for text in texts:
            if not isinstance(text, (list, tuple)):
                raise ParameterError(
                    "Input texts elements type must is [str、list、tuple], but actually get {}".format(type(text)))

        score = self.__get_score(texts, print_prob)
        # 选择最大评分的文本
        max_score_arg = np.array(score).argmax()
        if need_score:
            return texts[max_score_arg], score[max_score_arg]

        return texts[max_score_arg]
Esempio n. 5
0
    def add_word(self,
                 words,
                 is_save=True,
                 model_file='DictSegmentation.pickle'):
        """
        添加词汇到训练词典中
        :param words: (str or list or tuple, mandatory) 词汇,可以是字符、列表、元祖
        :param is_save: (bool, optional, default=True) 对于新加的词汇,是否保存词典
        :param model_file: (str, optional, default=DictSegmentation.pickle) 需要保存的模型文件名
        :return:
        """
        if not isinstance(words, (str, list, tuple)):
            raise ParameterError(
                "words parameter type must be is [str, list, tuple]")
        if isinstance(words, (list, tuple)):
            if not isinstance(words[0], str):
                raise ParameterError("words elements type must be is str")

        if isinstance(words, str):
            self.word_dictionary.add_words(words)
        elif isinstance(words, (list, tuple)):
            for word in words:
                self.word_dictionary.add_words(word)

        words = self.word_dictionary.get_dictionary()
        self.matching.update_words(words)

        if is_save:
            self.word_dictionary.save(model_file)
Esempio n. 6
0
    def writer_file(file, results, mode='a', encoding='utf-8'):
        """
        输入字符串、列表字符,写入文件
        :param file: (str, mandatory) 文件名 或 路径 + 文件名
        :param results: (str or list or tuple, mandatory) 需要写入文件的字符集或单个字符串。例如"hello word" 或['hello world']
        :param mode: (str, optional, default='a') 模式。默认为 'a' ,追加模式
        :param encoding: (str, optional, default='utf-8') 编码。默认为 UTF-8 编码
        :return:
        """
        if not isinstance(results, (str, list, tuple)):
            raise ParameterError(
                "result parameter must be is {}, but actually get {}".format((str, list, tuple), type(results)))
        if isinstance(results, (list, tuple)):
            for result in results:
                if not isinstance(result, str):
                    raise ParameterError(
                        "results parameter elements must be is str, but actually get {}, elements:{}".format(
                            type(result), result))

        Writer.check_path(file)

        with open(file, mode, encoding=encoding) as f_write:
            if isinstance(results, str):
                f_write.writelines(results + '\n')
            else:
                for result in results:
                    f_write.writelines(result + '\n')

        print("\n" + "over!File: {}, encoding: {}".format(Color.red(file), Color.red(encoding)))
Esempio n. 7
0
    def __update(self, text, matching_length, matching_type):
        """
        匹配算法中迭代更新文本和匹配长度
        :param text: (str, mandatory) 当前文本
        :param matching_length: (int, mandatory) 当前匹配长度
        :param matching_type: (str, mandatory) 匹配算法的类型,用于判断截取新文本
        :return: (str and int) 新的文本和匹配长度
        """
        if matching_type not in self.__matching_type:
            raise ParameterError("matching type {}".format(
                self.__matching_type))

        new_text = None
        # 截取新的文本,将剩余的文本再次进行匹配
        if matching_type == self.__forward_type:
            new_text = text[matching_length:]

        elif matching_type == self.__reverse_type:
            new_text = text[:-matching_length]

        # 如果新的文本存在空格字符,则去除
        new_text = new_text.strip()

        # 更新最大匹配长度
        # 如果剩余文本的字符长度小于默认设置,则以剩余文本字符长度为准
        # 否则以最大匹配字符长度为准
        new_matching_length = len(new_text) if len(
            new_text) < self.__max_matching else self.__max_matching

        return new_text, new_matching_length
Esempio n. 8
0
    def __init__(self,
                 f_type=SPECTROGRAM,
                 frame_length=256,
                 frame_shift=128,
                 mfcc_dim=13):
        super(AudioFeatures, self).__init__()
        # 帧长度。每个时序包含多个数据帧
        self.frame_length = frame_length

        # 帧移。每个时序移动多少数据帧。
        self.frame_shift = frame_shift

        # spectrogram 特征维度
        self.spectrogram_dim = self.frame_length // 2

        assert f_type in AudioFeatures.feature_type, \
            ParameterError("{} not in {}".format(f_type, AudioFeatures.feature_type))
        self.f_type = f_type

        # mfcc 特征维度
        self.mfcc_dim = mfcc_dim

        if self.f_type == AudioFeatures.SPECTROGRAM:
            self.mean = np.zeros(self.spectrogram_dim)
            self.std = np.ones(self.spectrogram_dim)
        elif self.f_type == AudioFeatures.MFCC:
            self.mean = np.zeros(self.mfcc_dim)
            self.std = np.ones(self.mfcc_dim)

        # 保存拟合语料中最大的序列长度
        self.max_length = 0
Esempio n. 9
0
    def decoding(text, label):
        """
        解码. 将文本列表 ['小','明','是','中','国','人'] 和 ['B','E','S','B','M','E'] 转换成 ["小明","是","中国人"]
        :param text: (list, mandatory) 文本列表。
        :param label: (list, mandatory) 文本列表对应的标签。要求 text 和 label 的长度必须一致
        :return: (list) 解码后的文本列表
        """

        if len(text) != len(label):
            raise ParameterError(
                "Parameter text={} and label={} length must be equal".format(
                    len(text), len(label)))

        words = list()
        string = ''
        for word, lab in zip(text, label):
            if lab == HmmDictionary.STATE_S:
                words.append(word)
            elif lab == HmmDictionary.STATE_B and string == "":
                string = word
            elif lab == HmmDictionary.STATE_M and string != "":
                string += word
            elif lab == HmmDictionary.STATE_E and string != "":
                string += word
                words.append(string)
                string = ""

        return words
Esempio n. 10
0
    def get_state_transition_prob(self, state, is_last_state=False):
        """
        获取状态转移概率。从当前状态到下一个状态的概率值
        :param state: (str, mandatory) 当前状态值
        :param is_last_state: (bool, optional, default=False) 是否是最后一个状态值。如果为 True 则返回当前状态到 (End) 概率
        :return: (dict) 下一个状态和概率值。例如:{'B':0.12, 'E':0.28, 'S':'0.6', 'M':0.0}
        """
        if state == self.END_TAG:
            raise ParameterError(
                "Current state value cannot is {}".format(state),
                level=ParameterError.warning)

        result = dict()

        if is_last_state:
            value = self.__state_transition_matrix[self.state_token[state]][
                self.state_token[self.END_TAG]]
            if value != 0.0:
                result[self.END_TAG] = value
        else:
            state_seq = HmmDictionary.STATE_SEQUENCE
            for next_state in state_seq:
                value = self.__state_transition_matrix[
                    self.state_token[state]][self.state_token[next_state]]
                if value == 0.0:
                    continue
                result[next_state] = value

        return result
Esempio n. 11
0
    def decoding(sent, label):
        """
        解码。将采用 "BMES" 分词的文本或句子转换成字符列表
        例如:
            输入:['小','明','是','中','国','人'] 和 ['B','E','S','B','M','E']
            输出:["小明","是","中国人"]
        :param sent: (list, mandatory) 字符列表
        :param label: (list, mandatory) 标签列表
        :return: (list) 字符列表
        """
        if len(sent) != len(label):
            raise ParameterError(
                "sent length:{} and label length:{} must be equal".format(
                    len(sent), len(label)))

        words = []
        string = ""
        for word, lab in zip(sent, label):
            if lab == CRFSegmentation.S_TAG:
                words.append(word)
            elif lab == CRFSegmentation.B_TAG and string == "":
                string += word
            elif lab == CRFSegmentation.M_TAG and string != "":
                string += word
            elif lab == CRFSegmentation.E_TAG and string != "":
                string += word
                words.append(string)
                string = ""
        return words
Esempio n. 12
0
    def __matching(self, text, matching_type):
        """
        匹配算法
        :param text: (str, mandatory) 匹配文本
        :param matching_type: (str, mandatory) 类型,forward 和 reverse 两种类型
        :return: (tuple) 分词的元祖数据
        """
        # 检查输出的参数是否有误
        if matching_type not in self.__matching_type:
            raise ParameterError("matching type {}".format(
                self.__matching_type))

        # 检查输出的文本是否为空,为空则抛出异常
        if len(text) == 0 or text == '' or text == " ":
            raise NullCharacterException("input text cannot be empty")

        new_text = text
        new_matching_length = len(
            text) if len(text) < self.__max_matching else self.__max_matching

        words = list()

        while True:
            # 如果文本的长度等于 0 ,或者为空字符,则结束
            if new_text == '' or new_text == " " or len(new_text) == 0:
                break

            # 根据定义的最大匹配长度获取单词
            word = ''
            if matching_type is self.__forward_type:
                word = new_text[:new_matching_length]

            elif matching_type is self.__reverse_type:
                word = new_text[-new_matching_length:]

            # 单词匹配,判断单词是否在词表中
            if word in self.__dictionary:
                # 将匹配上的词加入词表
                words.append(word)
                # 更新文本和匹配字符的长度
                new_text, new_matching_length = self.__update(
                    new_text, new_matching_length, matching_type)

            # 如果只有某个字符,则将该字符作为词汇
            elif len(word) == 1:
                words.append(word)
                new_text, new_matching_length = self.__update(
                    new_text, new_matching_length, matching_type)

            # 缩短匹配字符的长度
            else:
                new_matching_length = new_matching_length - 1

        return tuple(words)
Esempio n. 13
0
    def append(self,
               mid,
               state,
               word=None,
               launch_prob=None,
               transition_prob=None,
               prob_product=None):
        """添加节点"""
        if mid in self.markov_link.keys():
            raise ParameterError("mid={} already exists !".format(mid))

        self.markov_link[mid] = Markov.StateNode(state, word, launch_prob,
                                                 transition_prob, prob_product)
Esempio n. 14
0
    def get_word_error_rate(self, y_true, y_pred, label_length, black_index=0):
        """
        词错误率。计算每个样本的词错误率。
        :param y_true: (tensor, mandatory) 样本正确标签。 shape: [label, ]
        :param y_pred: (tensor, mandatory) 样本预测标签。shape: [pred, ]
        :param label_length: (int, mandatory) 标签长度。
        :param black_index: (int, optional, default=0) 空白字符。用于填充
        :return:
        """
        assert len(y_true) >= label_length, \
            ParameterError("The actual label sequence length:{} less than {}".format(len(y_true), label_length))

        y_true = tf.convert_to_tensor(y_true, dtype=tf.int32)
        y_pred = tf.convert_to_tensor(y_pred, dtype=tf.int32)

        # 获取标签字符的序列长度
        y_true = y_true[:label_length]

        # 如果序列长度=0,则设置错误率为 len(y_true) / len(y_true) 即错误率为 100%
        if len(y_pred) == 0:
            error_word_num = len(y_true)
        else:
            # 截断。如果预测字符标签大于正确的标签,则截断。
            if len(y_pred) > len(y_true):
                new_y_p = y_pred[:label_length]
                error_word_num = len(y_pred[label_length:])
            # 填充。如果预测字符标签小于正确的标签,则填充。
            elif len(y_pred) < len(y_true):
                black_count = len(y_true) - len(y_pred)
                new_y_p = tf.concat([y_pred, [black_index] * black_count],
                                    axis=0)
                error_word_num = 0
            else:
                new_y_p = y_pred
                error_word_num = 0

            # 统计正确的词数量
            true_count = tf.math.reduce_sum(
                tf.cast(tf.equal(tf.cast(new_y_p, dtype=tf.int32),
                                 tf.cast(y_true, dtype=tf.int32)),
                        dtype=tf.int32))
            error_word_num += len(new_y_p) - true_count

        value = tf.squeeze(
            tf.cast(error_word_num / len(y_true), dtype=tf.float32))
        # print(value)

        self.word_error_rate.assign_add(value)
        self.count += 1
Esempio n. 15
0
    def __init__(self, n=2):
        super(NgramDictionary, self).__init__()
        if n not in N_GRAM:
            raise ParameterError(
                "parameter n must be within {}".format(N_GRAM))
        self.__n = n

        # 词表。统计文本中的所有出现的单词
        self.__words = [self.UNK_TAG]
        # 词典。统计文本中的所有出现的单词,并给予 token 。token 值可以看作是单词的唯一标识符。
        self.__word_token = {self.UNK_TAG: self.UNK}
        # 单词数量。记录文本中出现的单词。不包括重复单词
        self.__word_num = len(self.__words)
        # 总词数。记录文本中所有单词。包括重复单词
        self.__total_words = len(self.__words)
        # 词频。统计文本中单词频率
        self.__word_freq = Counter()
        # n阶词频。
        self.__n_word_freq = None
Esempio n. 16
0
File: utils.py Progetto: boyshen/ASR
    def split_valid_dataset(*args, ratio=0.2):
        """
        划分验证数据集。
        :param args: (list, mandatory) 数据集。列表格式。
        :param ratio: (int, optional, default=0.2) 验证数据集比例。0~1范围内。
        :return: (list) 训练数据集和验证数据集
        """
        if ratio > 1 or ratio < 0:
            raise ParameterError(
                "dataset ratio must be is 0 ~ 1 range. actually get: {}".
                format(ratio))

        dataset = Generator.shuffle(*args)

        sample_num = int(len(dataset[0]) * ratio)
        if sample_num == 0:
            sample_num = 1

        train_dataset, valid_dataset = list(), list()
        for data in dataset:
            train_dataset.append(data[sample_num:])
            valid_dataset.append(data[:sample_num])

        return train_dataset, valid_dataset
Esempio n. 17
0
    def get(self, mid):
        """获取节点对象"""
        if mid not in self.markov_link.keys():
            raise ParameterError("mid={} not found !".format(mid))

        return self.markov_link[mid]
Esempio n. 18
0
    def delete(self, mid):
        """删除节点"""
        if mid not in self.markov_link.keys():
            raise ParameterError("mid={} not found !".format(mid))

        self.markov_link.pop(mid)