예제 #1
0
파일: module.py 프로젝트: DengBoCong/hlp
def evaluate(model: tf.keras.Model,
             data_path: str,
             batch_size: int,
             buffer_size: int,
             dict_path: str = "",
             length_path: str = "",
             max_data_size: int = 0):
    """
    评估模块
    :param model: 模型
    :param data_path: 文本数据路径
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param dict_path: 字典路径,若使用phoneme则不用传
    :param max_data_size: 最大训练数据量
    :param length_path: 训练样本长度保存路径
    :return: 返回历史指标数据
    """
    valid_dataset, _, valid_steps_per_epoch, _ = \
        load_data(train_data_path=data_path, batch_size=batch_size, buffer_size=buffer_size,
                  valid_data_split=0.0, valid_data_path="", train_length_path=length_path,
                  valid_length_path="", max_train_data_size=max_data_size, max_valid_data_size=0)

    tokenizer = load_tokenizer(dict_path=dict_path)

    _, _, _ = _valid_step(model=model,
                          dataset=valid_dataset,
                          steps_per_epoch=valid_steps_per_epoch,
                          tokenizer=tokenizer)
예제 #2
0
파일: module.py 프로젝트: DengBoCong/hlp
def evaluate(model: tf.keras.Model,
             data_path: str,
             batch_size: int,
             buffer_size: int,
             dict_path: str = "",
             length_path: str = "",
             max_data_size: int = 0):
    """
    评估模块
    :param model: 模型
    :param data_path: 文本数据路径
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param dict_path: 字典路径,若使用phoneme则不用传
    :param max_data_size: 最大训练数据量
    :param length_path: 训练样本长度保存路径
    :return: 返回历史指标数据
    """
    valid_dataset, _, valid_steps_per_epoch, _ = \
        load_data(train_data_path=data_path, batch_size=batch_size, buffer_size=buffer_size,
                  valid_data_split=0.0, valid_data_path="", train_length_path=length_path,
                  valid_length_path="", max_train_data_size=max_data_size, max_valid_data_size=0)

    tokenizer = load_tokenizer(dict_path=dict_path)
    enc_hidden = model.initialize_hidden_state()
    dec_input = tf.cast(tf.expand_dims([tokenizer.word_index.get('<start>')] *
                                       batch_size, 1),
                        dtype=tf.int64)

    _, _, _ = _valid_step(model=model,
                          dataset=valid_dataset,
                          steps_per_epoch=valid_steps_per_epoch,
                          tokenizer=tokenizer,
                          enc_hidden=enc_hidden,
                          dec_input=dec_input)
예제 #3
0
파일: chatter.py 프로젝트: DengBoCong/hlp
    def respond(self, req: str):
        """
        对外部聊天请求进行回复
        子类需要利用模型进行推断和搜索以产生回复。
        :param req: 输入的语句
        :return: 系统回复字符串
        """
        # 对req进行初步处理
        tokenizer = load_tokenizer(self.dict_fn)
        inputs, dec_input = data_utils.preprocess_request(sentence=req, tokenizer=tokenizer,
                                                          max_length=self.max_length, start_sign=self.start_sign)
        self.beam_search_container.reset(inputs=inputs, dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.get_search_inputs()

        for t in range(self.max_length):
            predictions = self._create_predictions(inputs, dec_input, t)
            self.beam_search_container.expand(predictions=predictions, end_sign=tokenizer.word_index.get(self.end_sign))
            # 注意了,如果BeamSearch容器里的beam_size为0了,说明已经找到了相应数量的结果,直接跳出循环
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.get_search_inputs()

        beam_search_result = self.beam_search_container.get_result(top_k=3)
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = tokenizer.sequences_to_texts(temp)
            text[0] = text[0].replace(self.start_sign, '').replace(self.end_sign, '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result
예제 #4
0
파일: module.py 프로젝트: DengBoCong/hlp
def recognize(model: tf.keras.Model, audio_feature_type: str, start_sign: str,
              unk_sign: str, end_sign: str, record_path: str, max_length: int,
              dict_path: str):
    """
    语音识别模块
    :param model: 模型
    :param audio_feature_type: 特征类型
    :param start_sign: 开始标记
    :param end_sign: 结束标记
    :param unk_sign: 未登录词
    :param record_path: 录音保存路径
    :param max_length: 最大音频补齐长度
    :param dict_path: 字典保存路径
    :return: 无返回值
    """
    while True:
        try:
            record_duration = int(input("请设定录音时长(秒, 负数结束,0则继续输入音频路径):"))
        except BaseException:
            print("录音时长只能为int数值")
        else:
            if record_duration < 0:
                break
            if not os.path.exists(record_path):
                os.makedirs(record_path)
            # 录音
            if record_duration == 0:
                record_path = input("请输入音频路径:")
            else:
                record_path = record_path + time.strftime(
                    "%Y_%m_%d_%H_%M_%S_", time.localtime(time.time())) + ".wav"
                record(record_path, record_duration)

            # 加载录音数据并预测
            audio_feature = wav_to_feature(record_path, audio_feature_type)
            audio_feature = audio_feature[:max_length, :]
            input_tensor = tf.keras.preprocessing.sequence.pad_sequences(
                [audio_feature],
                padding='post',
                maxlen=max_length,
                dtype='float32')
            predictions = model(input_tensor)
            ctc_input_length = compute_ctc_input_length(
                input_tensor.shape[1], predictions.shape[1],
                tf.convert_to_tensor([[len(audio_feature)]]))

            output = tf.keras.backend.ctc_decode(
                y_pred=predictions,
                input_length=tf.reshape(ctc_input_length,
                                        [ctc_input_length.shape[0]]),
                greedy=True)

            tokenizer = load_tokenizer(dict_path=dict_path)

            sentence = tokenizer.sequences_to_texts(output[0][0].numpy())
            sentence = sentence[0].replace(start_sign,
                                           '').replace(end_sign,
                                                       '').replace(' ', '')
            print("Output:", sentence)
예제 #5
0
    def respond(self, req: str):
        """
        对外部聊天请求进行回复
        子类需要利用模型进行推断和搜索以产生回复。
        :param req: 输入的语句
        :return: 系统回复字符串
        """
        self.solr.ping()
        history = req[-self.max_utterance:]
        pad_sequences = [0] * self.max_sentence
        tokenizer = load_tokenizer(self.dict_fn)
        utterance = tokenizer.texts_to_sequences(history)
        utterance_len = len(utterance)

        # 如果当前轮次中的历史语句不足max_utterances数量,需要在尾部进行填充
        if utterance_len != self.max_utterance:
            utterance += [pad_sequences] * (self.max_utterance - utterance_len)
        utterance = tf.keras.preprocessing.sequence.pad_sequences(
            utterance, maxlen=self.max_sentence, padding="post").tolist()

        tf_idf = data_utils.get_tf_idf_top_k(history)
        query = "{!func}sum("
        for key in tf_idf:
            query += "product(idf(utterance," + key + "),tf(utterance," + key + ")),"
        query += ")"
        candidates = self.solr.search(q=query, start=0, rows=10).docs
        candidates = [candidate['utterance'][0] for candidate in candidates]

        if candidates is None:
            return "Sorry! I didn't hear clearly, can you say it again?"
        else:
            utterances = [utterance] * len(candidates)
            responses = tokenizer.texts_to_sequences(candidates)
            responses = tf.keras.preprocessing.sequence.pad_sequences(
                responses, maxlen=self.max_sentence, padding="post")
            utterances = tf.convert_to_tensor(utterances)
            responses = tf.convert_to_tensor(responses)
            scores = self.model(inputs=[utterances, responses])
            index = tf.argmax(scores[:, 0])

            return candidates[index]
예제 #6
0
    def evaluate(self,
                 valid_fn: str,
                 dict_fn: str = "",
                 max_turn_utterances_num: int = 10,
                 max_valid_data_size: int = 0):
        """
        验证功能,注意了dict_fn和tokenizer两个比传其中一个
        :param valid_fn: 验证数据集路径
        :param dict_fn: 字典路径
        :param max_turn_utterances_num: 最大训练数据量
        :param max_valid_data_size: 最大验证数据量
        :return: r2_1, r10_1指标
        """
        step = max_valid_data_size // max_turn_utterances_num
        if max_valid_data_size == 0:
            return None
        # 处理并加载评价数据,注意,如果max_valid_data_size传
        # 入0,就直接跳过加载评价数据,也就是说只训练不评价
        tokenizer = load_tokenizer(dict_path=dict_fn)
        valid_dataset = data_utils.load_smn_valid_data(
            data_fn=valid_fn,
            max_sentence=self.max_sentence,
            max_utterance=self.max_utterance,
            tokenizer=tokenizer,
            max_turn_utterances_num=max_turn_utterances_num,
            max_valid_data_size=max_valid_data_size)

        scores = tf.constant([], dtype=tf.float32)
        labels = tf.constant([], dtype=tf.int32)
        for (batch, (utterances, response,
                     label)) in enumerate(valid_dataset.take(step)):
            score = self.model(inputs=[utterances, response])
            score = tf.nn.softmax(score, axis=-1)
            labels = tf.concat([labels, label], axis=0)
            scores = tf.concat([scores, score[:, 1]], axis=0)

        r10_1 = self._metrics_rn_1(scores, labels, num=10)
        r2_1 = self._metrics_rn_1(scores, labels, num=2)
        return r2_1, r10_1
예제 #7
0
def recognize(encoder: tf.keras.Model, decoder: tf.keras.Model, beam_size: int,
              start_sign: str, unk_sign: str, end_sign: str,
              audio_feature_type: str, max_length: int,
              max_sentence_length: int, dict_path: str):
    """
    语音识别模块
    :param encoder: 模型的encoder
    :param decoder: 模型的decoder
    :param beam_size: beam_size
    :param start_sign: 开始标记
    :param end_sign: 结束标记
    :param unk_sign: 未登录词
    :param audio_feature_type: 特征类型
    :param max_length: 最大音频补齐长度
    :param max_sentence_length: 最大音频补齐长度
    :param dict_path: 字典保存路径
    """
    beam_search_container = BeamSearch(beam_size=beam_size,
                                       max_length=max_sentence_length,
                                       worst_score=0)

    print("Agent: 你好!结束识别请输入ESC。")
    while True:
        path = input("Path: ")
        if path == "ESC":
            print("Agent: 再见!")
            exit(0)

        if not os.path.exists(path):
            print("音频文件不存在,请重新输入")
            continue

        audio_feature = wav_to_feature(path, audio_feature_type)
        audio_feature = tf.keras.preprocessing.sequence.pad_sequences(
            [audio_feature],
            maxlen=max_length,
            dtype="float32",
            padding="post")

        tokenizer = load_tokenizer(dict_path=dict_path)
        dec_input = tf.expand_dims(
            [tokenizer.word_index.get(start_sign, unk_sign)], 0)

        beam_search_container.reset(inputs=audio_feature, dec_input=dec_input)

        for i in range(max_sentence_length):
            enc_outputs, padding_mask = encoder(audio_feature)
            sentence_predictions = decoder(
                inputs=[dec_input, enc_outputs, padding_mask])
            sentence_predictions = tf.nn.softmax(sentence_predictions)
            sentence_predictions = sentence_predictions[:, -1, :]

            beam_search_container.expand(
                predictions=sentence_predictions,
                end_sign=tokenizer.word_index.get(end_sign))
            if beam_search_container.beam_size == 0:
                break

            audio_feature, dec_input = beam_search_container.get_search_inputs(
            )

        beam_search_result = beam_search_container.get_result(top_k=3)
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = tokenizer.sequences_to_texts(temp)[0]
            text = text.replace(start_sign, '').replace(end_sign,
                                                        '').replace(' ', '')
            result = '<' + text + '>' + result

        print("识别句子为:{}".format(result))

    print("识别结束")
예제 #8
0
파일: module.py 프로젝트: DengBoCong/hlp
def recognize(model: tf.keras.Model, audio_feature_type: str, start_sign: str,
              unk_sign: str, end_sign: str, w: int, beam_size: int,
              record_path: str, max_length: int, max_sentence_length: int,
              dict_path: str):
    """
    语音识别模块
    :param model: 模型
    :param audio_feature_type: 特征类型
    :param start_sign: 开始标记
    :param end_sign: 结束标记
    :param unk_sign: 未登录词
    :param w: BiLSTM单元数
    :param beam_size: Beam Size
    :param record_path: 录音保存路径
    :param max_length: 最大音频补齐长度
    :param max_sentence_length: 最大句子长度
    :param dict_path: 字典保存路径
    :return: 无返回值
    """
    tokenizer = load_tokenizer(dict_path=dict_path)
    enc_hidden = tf.zeros((1, w))
    dec_input = tf.expand_dims([tokenizer.word_index.get('<start>')], 1)
    beam_search = BeamSearch(beam_size=beam_size,
                             max_length=max_sentence_length,
                             worst_score=0)

    while True:
        try:
            record_duration = int(input("请设定录音时长(秒, 负数结束,0则继续输入音频路径):"))
        except BaseException:
            print("录音时长只能为int数值")
        else:
            if record_duration < 0:
                break
            if not os.path.exists(record_path):
                os.makedirs(record_path)
            # 录音
            if record_duration == 0:
                record_path = input("请输入音频路径:")
            else:
                record_path = record_path + time.strftime(
                    "%Y_%m_%d_%H_%M_%S_", time.localtime(time.time())) + ".wav"
                record(record_path, record_duration)

            # 加载录音数据并预测
            audio_feature = wav_to_feature(record_path, audio_feature_type)
            audio_feature = audio_feature[:max_length, :]
            input_tensor = tf.keras.preprocessing.sequence.pad_sequences(
                [audio_feature],
                padding='post',
                maxlen=max_length,
                dtype='float32')

            beam_search.reset(inputs=input_tensor, dec_input=dec_input)
            decoder_input = dec_input
            for t in range(1, max_sentence_length):
                decoder_input = decoder_input[:, -1:]
                predictions, _ = model(input_tensor, enc_hidden, decoder_input)
                predictions = tf.nn.softmax(predictions)

                beam_search.expand(predictions=predictions,
                                   end_sign=tokenizer.word_index.get(end_sign))
                if beam_search.beam_size == 0:
                    break

                input_tensor, decoder_input = beam_search.get_search_inputs()

            beam_search_result = beam_search.get_result(top_k=3)
            result = ''
            # 从容器中抽取序列,生成最终结果
            for i in range(len(beam_search_result)):
                temp = beam_search_result[i].numpy()
                text = tokenizer.sequences_to_texts(temp)[0]
                text = text.replace(start_sign,
                                    '').replace(end_sign, '').replace(' ', '')
                result = '<' + text + '>' + result

            print("识别句子为:{}".format(result))
예제 #9
0
파일: module.py 프로젝트: DengBoCong/hlp
def train(epochs: int,
          train_data_path: str,
          batch_size: int,
          buffer_size: int,
          checkpoint_save_freq: int,
          checkpoint: tf.train.CheckpointManager,
          model: tf.keras.Model,
          optimizer: tf.keras.optimizers.Adam,
          dict_path: str = "",
          valid_data_split: float = 0.0,
          valid_data_path: str = "",
          train_length_path: str = "",
          valid_length_path: str = "",
          max_train_data_size: int = 0,
          max_valid_data_size: int = 0,
          history_img_path: str = ""):
    """
    训练模块
    :param epochs: 训练周期
    :param train_data_path: 文本数据路径
    :param dict_path: 字典路径,若使用phoneme则不用传
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param checkpoint: 检查点管理器
    :param model: 模型
    :param optimizer: 优化器
    :param valid_data_split: 用于从训练数据中划分验证数据
    :param valid_data_path: 验证数据文本路径
    :param max_train_data_size: 最大训练数据量
    :param train_length_path: 训练样本长度保存路径
    :param valid_length_path: 验证样本长度保存路径
    :param max_valid_data_size: 最大验证数据量
    :param checkpoint_save_freq: 检查点保存频率
    :param history_img_path: 历史指标数据图表保存路径
    :return:
    """
    train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \
        load_data(train_data_path=train_data_path, batch_size=batch_size, buffer_size=buffer_size,
                  valid_data_split=valid_data_split, valid_data_path=valid_data_path,
                  train_length_path=train_length_path, valid_length_path=valid_length_path,
                  max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size)

    tokenizer = load_tokenizer(dict_path=dict_path)
    history = {"loss": [], "wers": [], "norm_lers": []}

    if steps_per_epoch == 0:
        print("训练数据量过小,小于batch_size,请添加数据后重试")
        exit(0)

    for epoch in range(epochs):
        total_loss = 0
        start_time = time.time()
        enc_hidden = model.initialize_hidden_state()
        dec_input = tf.cast(tf.expand_dims(
            [tokenizer.word_index.get('<start>')] * batch_size, 1),
                            dtype=tf.int64)

        print("Epoch {}/{}".format(epoch + 1, epochs))
        for (batch,
             (audio_feature, sentence,
              length)) in enumerate(train_dataset.take(steps_per_epoch)):
            batch_start = time.time()

            batch_loss = _train_step(model, optimizer, audio_feature, sentence,
                                     enc_hidden, dec_input)
            total_loss += batch_loss

            print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format(
                (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(),
                (time.time() - batch_start)),
                  end="")

        print(' - {:.0f}s/step - loss: {:.4f}'.format(
            (time.time() - start_time) / steps_per_epoch,
            total_loss / steps_per_epoch))

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0:
                print("验证数据量过小,小于batch_size,请添加数据后重试")
                exit(0)

            valid_loss, valid_wer, valid_ler = _valid_step(
                model=model,
                dataset=valid_dataset,
                enc_hidden=enc_hidden,
                dec_input=dec_input,
                steps_per_epoch=valid_steps_per_epoch,
                tokenizer=tokenizer)
            history["wers"].append(valid_wer)
            history["norm_lers"].append(valid_ler)

    plot_history(history=history,
                 valid_epoch_freq=checkpoint_save_freq,
                 history_img_path=history_img_path)
    return history
예제 #10
0
파일: module.py 프로젝트: DengBoCong/hlp
def train(model: tf.keras.Model,
          optimizer: tf.keras.optimizers.Adam,
          epochs: int,
          checkpoint: tf.train.CheckpointManager,
          train_data_path: str,
          batch_size: int,
          buffer_size: int,
          checkpoint_save_freq: int,
          dict_path: str = "",
          valid_data_split: float = 0.0,
          valid_data_path: str = "",
          train_length_path: str = "",
          valid_length_path: str = "",
          stop_early_limits: int = 0,
          max_train_data_size: int = 0,
          max_valid_data_size: int = 0,
          history_img_path: str = ""):
    """
    训练模块
    :param model: 模型
    :param optimizer: 优化器
    :param checkpoint: 检查点管理器
    :param epochs: 训练周期
    :param train_data_path: 文本数据路径
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param dict_path: 字典路径,若使用phoneme则不用传
    :param valid_data_split: 用于从训练数据中划分验证数据
    :param valid_data_path: 验证数据文本路径
    :param max_train_data_size: 最大训练数据量
    :param train_length_path: 训练样本长度保存路径
    :param valid_length_path: 验证样本长度保存路径
    :param stop_early_limits: 不增长停止个数
    :param max_valid_data_size: 最大验证数据量
    :param checkpoint_save_freq: 检查点保存频率
    :param history_img_path: 历史指标数据图表保存路径
    :return: 返回历史指标数据
    """
    train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch = \
        load_data(train_data_path=train_data_path, batch_size=batch_size, buffer_size=buffer_size,
                  valid_data_split=valid_data_split, valid_data_path=valid_data_path,
                  train_length_path=train_length_path, valid_length_path=valid_length_path,
                  max_train_data_size=max_train_data_size, max_valid_data_size=max_valid_data_size)

    tokenizer = load_tokenizer(dict_path=dict_path)

    history = {"loss": [], "wers": [], "norm_lers": []}

    if steps_per_epoch == 0:
        print("训练数据量过小,小于batch_size,请添加数据后重试")
        exit(0)

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        start_time = time.time()
        total_loss = 0

        for (batch,
             (audio_feature, sentence,
              length)) in enumerate(train_dataset.take(steps_per_epoch)):
            batch_start = time.time()

            batch_loss = _train_step(model, optimizer, sentence, length,
                                     audio_feature)
            total_loss += batch_loss

            print('\r{}/{} [Batch {} Loss {:.4f} {:.1f}s]'.format(
                (batch + 1), steps_per_epoch, batch + 1, batch_loss.numpy(),
                (time.time() - batch_start)),
                  end="")

        print(' - {:.0f}s/step - loss: {:.4f}'.format(
            (time.time() - start_time) / steps_per_epoch,
            total_loss / steps_per_epoch))
        history["loss"].append(total_loss / steps_per_epoch)

        if (epoch + 1) % checkpoint_save_freq == 0:
            checkpoint.save()

            if valid_steps_per_epoch == 0:
                print("验证数据量过小,小于batch_size,请添加数据后重试")
                exit(0)

            valid_loss, valid_wer, valid_ler = _valid_step(
                model=model,
                dataset=valid_dataset,
                steps_per_epoch=valid_steps_per_epoch,
                tokenizer=tokenizer)
            history["wers"].append(valid_wer)
            history["norm_lers"].append(valid_ler)

            if stop_early_limits != 0 and len(
                    history["wers"]) >= stop_early_limits:
                if can_stop(history["wers"][-stop_early_limits:]) \
                        or can_stop(history["norm_lers"][-stop_early_limits:]):
                    print("指标反弹,停止训练!")
                    break
    plot_history(history=history,
                 valid_epoch_freq=checkpoint_save_freq,
                 history_img_path=history_img_path)
    return history
예제 #11
0
def load_data(dict_fn: str,
              data_fn: str,
              buffer_size: int,
              batch_size: int,
              checkpoint_dir: str,
              max_length: int,
              valid_data_split: float = 0.0,
              valid_data_fn: str = "",
              max_train_data_size: int = 0,
              max_valid_data_size: int = 0):
    """
    数据加载方法,含四个元素的元组,包括如下:
    :param dict_fn: 字典路径
    :param data_fn: 文本数据路径
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param checkpoint_dir: 检查点保存路径
    :param max_length: 单个句子最大长度
    :param valid_data_split: 用于从训练数据中划分验证数据
    :param valid_data_fn: 验证数据文本路径
    :param max_train_data_size: 最大训练数据量
    :param max_valid_data_size: 最大验证数据量
    :return: 训练Dataset、验证Dataset、训练数据总共的步数、验证数据总共的步数和检查点前缀
    """
    print("读取训练对话对...")
    tokenizer = load_tokenizer(dict_path=dict_fn)
    train_input, train_target, sample_weights = \
        _read_data(data_path=data_fn, num_examples=max_train_data_size, max_length=max_length, tokenizer=tokenizer)

    valid_flag = True  # 是否开启验证标记
    valid_steps_per_epoch = 0

    if valid_data_fn != "":
        print("读取验证对话对...")
        valid_input, valid_target, _ = _read_data(
            data_path=valid_data_fn,
            num_examples=max_valid_data_size,
            max_length=max_length,
            tokenizer=tokenizer)
    elif valid_data_split != 0.0:
        train_size = int(len(train_input) * (1.0 - valid_data_split))
        valid_input = train_input[train_size:]
        valid_target = train_target[train_size:]
        train_input = train_input[:train_size]
        train_target = train_target[:train_size]
        sample_weights = sample_weights[:train_size]
    else:
        valid_flag = False

    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_input, train_target,
         sample_weights)).cache().shuffle(buffer_size).prefetch(
             tf.data.experimental.AUTOTUNE)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)

    if valid_flag:
        valid_dataset = tf.data.Dataset.from_tensor_slices(
            (valid_input, valid_target)).cache().shuffle(buffer_size).prefetch(
                tf.data.experimental.AUTOTUNE)
        valid_dataset = valid_dataset.batch(batch_size, drop_remainder=True)
        valid_steps_per_epoch = len(valid_input) // batch_size
    else:
        valid_dataset = None

    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    steps_per_epoch = len(train_input) // batch_size

    return train_dataset, valid_dataset, steps_per_epoch, valid_steps_per_epoch, checkpoint_prefix
예제 #12
0
def smn_load_train_data(dict_fn: str,
                        data_fn: str,
                        checkpoint_dir: str,
                        buffer_size: int,
                        batch_size: int,
                        max_utterance: int,
                        max_sentence: int,
                        max_train_data_size: int = 0):
    """
    用于SMN的训练数据加载
    :param dict_fn: 字典文本路径
    :param data_fn: 数据文本路径
    :param buffer_size: Dataset加载缓存大小
    :param batch_size: Dataset加载批大小
    :param checkpoint_dir: 检查点保存路径
    :param max_utterance: 每轮对话最大对话数
    :param max_sentence: 单个句子最大长度
    :param max_train_data_size: 最大训练数据量
    :return: TensorFlow的数据处理类、分词器、检查点前缀和总的步数
    """
    if not os.path.exists(data_fn):
        print('不存在训练数据集,请添加数据集之后重试')
        exit(0)

    print('正在读取文本数据...')
    history = []  # 用于保存每轮对话历史语句
    response = []  # 用于保存每轮对话的回答
    label = []  # 用于保存每轮对话的标签
    count = 0  # 用于处理数据计数

    with open(data_fn, 'r', encoding='utf-8') as file:
        odd_flag = True
        for line in file:
            odd_flag = not odd_flag
            if odd_flag:
                continue

            count += 1
            apart = line.split('\t')
            label.append(int(apart[0]))
            response.append(apart[-1])
            del apart[0]
            del apart[-1]
            history.append(apart)

            print('\r已读取 {} 轮对话'.format(count), flush=True, end="")
            if max_train_data_size == count:
                break

    tokenizer = load_tokenizer(dict_path=dict_fn)
    response = tokenizer.texts_to_sequences(response)
    response = tf.keras.preprocessing.sequence.pad_sequences(
        response, maxlen=max_sentence, padding="post")

    count = 0
    utterances = []
    for utterance in history:
        count += 1
        pad_sequences = [0] * max_sentence
        # 注意了,这边要取每轮对话的最后max_utterances数量的语句
        utterance_padding = tokenizer.texts_to_sequences(
            utterance)[-max_utterance:]
        utterance_len = len(utterance_padding)
        # 如果当前轮次中的历史语句不足max_utterances数量,需要在尾部进行填充
        if utterance_len != 10:
            utterance_padding += [pad_sequences
                                  ] * (max_utterance - utterance_len)
        utterances.append(
            tf.keras.preprocessing.sequence.pad_sequences(
                utterance_padding, maxlen=max_sentence,
                padding="post").tolist())
        print('\r已生成 {} 轮训练数据'.format(count), flush=True, end="")

    print('数据生成完毕,正在转换为Dataset')
    dataset = tf.data.Dataset.from_tensor_slices(
        (utterances, response, label)).cache().shuffle(buffer_size).prefetch(
            tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
    steps_per_epoch = len(utterances) // batch_size
    print('训练数据处理完成,正在进行训练')

    return dataset, tokenizer, checkpoint_prefix, steps_per_epoch