예제 #1
0
    def __init__(self, execute_type: str, batch_size: int, embedding_dim: int,
                 units: int, dropout: float, checkpoint_dir: str,
                 beam_size: int, vocab_size: int, dict_fn: str,
                 max_length: int, start_sign: str, end_sign: str):
        """
        Seq2Seq聊天器初始化,用于加载模型
        """
        super().__init__(checkpoint_dir, beam_size, max_length)
        self.dict_fn = dict_fn
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.epoch_iterator = 0
        self.start_sign = start_sign
        self.end_sign = end_sign
        self.encoder = seq2seq.Encoder(vocab_size, embedding_dim, units, units,
                                       dropout)
        attention = seq2seq.BahdanauAttention(units)
        self.decoder = seq2seq.Decoder(vocab_size, embedding_dim, units, units,
                                       dropout, attention)
        self.optimizer = optim.Adam([{
            'params': self.encoder.parameters(),
            'lr': 1e-3
        }, {
            'params': self.decoder.parameters(),
            'lr': 1e-3
        }])
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(dict_fn))
            self.token = data_utils.load_token_dict(dict_fn=dict_fn)
        print('正在检查是否存在检查点...')
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        checkpoint_path = checkpoint_dir + '\\checkpoint.txt'

        if not os.path.exists(checkpoint_path):
            with open(checkpoint_path, 'w', encoding='utf-8') as file:
                print("没有检查到检查点,已创建checkpoint记录文本")
            if execute_type == "train":
                print('正在train模式...')
            else:
                print('请先执行train模式,再进入chat模式')
                exit(0)
        else:
            with open(checkpoint_path, 'r', encoding='utf-8') as file:
                lines = file.read().strip().split('\n')
                version = lines[0]
                if version is not "":
                    version = int(version)
                    checkpoint = torch.load(checkpoint_dir +
                                            '\\seq2seq-{}.pth'.format(version))
                    self.encoder.load_state_dict(
                        checkpoint['encoder_state_dict'])
                    self.decoder.load_state_dict(
                        checkpoint['decoder_state_dict'])
                    self.optimizer.load_state_dict(
                        checkpoint['optimizer_state_dict'])
                    self.epoch_iterator = checkpoint['epoch']
                    print("检测到检查点,已成功加载检查点")
예제 #2
0
    def respond(self, req, input_dict_fn, target_dict_fn):
        # 对req进行初步处理
        input_token, target_token = _data.load_token_dict(
            input_dict_fn=input_dict_fn, target_dict_fn=target_dict_fn)
        inputs, dec_input = self._pre_treat_inputs(req, input_token,
                                                   target_token)
        self.beam_search_container.init_variables(inputs=inputs,
                                                  dec_input=dec_input)
        inputs, dec_input = self.beam_search_container.get_variables()
        for t in range(_config.max_length_tar):
            predictions = self._create_predictions(inputs, dec_input, t)
            self.beam_search_container.add(predictions=predictions,
                                           end_sign=target_token.get('end'))
            if self.beam_search_container.beam_size == 0:
                break

            inputs, dec_input = self.beam_search_container.get_variables()
        beam_search_result = self.beam_search_container.get_result()
        result = ''
        # 从容器中抽取序列,生成最终结果
        for i in range(len(beam_search_result)):
            temp = beam_search_result[i].numpy()
            text = _data.sequences_to_texts(temp, target_token)
            text[0] = text[0].replace('start',
                                      '').replace('end', '').replace(' ', '')
            result = '<' + text[0] + '>' + result
        return result
예제 #3
0
    def __init__(self, execute_type, checkpoint_dir, beam_size, vocab_size,
                 dict_fn, max_length):
        """
        Seq2Seq聊天器初始化,用于加载模型
        """
        super().__init__(checkpoint_dir, beam_size, max_length)
        self.encoder = seq2seq.Encoder(vocab_size,
                                       _config.seq2seq_embedding_dim,
                                       _config.seq2seq_units,
                                       _config.BATCH_SIZE)
        self.decoder = seq2seq.Decoder(vocab_size,
                                       _config.seq2seq_embedding_dim,
                                       _config.seq2seq_units,
                                       _config.BATCH_SIZE)
        self.optimizer = tf.keras.optimizers.Adam()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                              encoder=self.encoder,
                                              decoder=self.decoder)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(dict_fn))
            self.token = _data.load_token_dict(dict_fn=dict_fn)
        print('正在检查是否存在检查点...')
        if self.ckpt:
            print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir))
            self.checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
        else:
            if execute_type == "train":
                print('不存在检查点,正在train模式...')
            else:
                print('不存在检查点,请先执行train模式,再进入chat模式')
                exit(0)
예제 #4
0
 def __init__(self, model, checkpoint_dir, beam_size, dict_fn):
     """
     Transformer聊天器初始化,用于加载模型
     """
     self.checkpoint_dir = checkpoint_dir
     if model == "chat":
         print('正在从“{}”处加载字典...'.format(dict_fn))
         self.token = _data.load_token_dict(dict_fn=dict_fn)
     self.beam_search_container = BeamSearch(
         beam_size=beam_size,
         max_length=_config.max_length_tar,
         worst_score=0)
     is_exist = Path(checkpoint_dir)
     if not is_exist.exists():
         os.makedirs(checkpoint_dir, exist_ok=True)
     self.ckpt = tf.io.gfile.listdir(checkpoint_dir)
예제 #5
0
    def evaluate(self,
                 valid_fn: str,
                 dict_fn: str = "",
                 tokenizer: tf.keras.preprocessing.text.Tokenizer = None,
                 max_turn_utterances_num: int = 10,
                 max_valid_data_size: int = 0):
        """
        验证功能,注意了dict_fn和tokenizer两个比传其中一个
        Args:
            valid_fn: 验证数据集路径
            dict_fn: 字典路径
            tokenizer: 分词器
            max_turn_utterances_num: 最大训练数据量
            max_valid_data_size: 最大验证数据量
        Returns:
            r2_1, r10_1
        """
        token_dict = None
        step = max_valid_data_size // max_turn_utterances_num
        if max_valid_data_size == 0:
            return None
        if dict_fn is not "":
            token_dict = data_utils.load_token_dict(dict_fn)
        # 处理并加载评价数据,注意,如果max_valid_data_size传
        # 入0,就直接跳过加载评价数据,也就是说只训练不评价
        valid_dataset = data_utils.load_smn_valid_data(
            data_fn=valid_fn,
            max_sentence=self.max_sentence,
            max_utterance=self.max_utterance,
            token_dict=token_dict,
            tokenizer=tokenizer,
            max_turn_utterances_num=max_turn_utterances_num,
            max_valid_data_size=max_valid_data_size)

        scores = tf.constant([], dtype=tf.float32)
        labels = tf.constant([], dtype=tf.int32)
        for (batch, (utterances, response,
                     label)) in enumerate(valid_dataset.take(step)):
            score = self.model(inputs=[utterances, response])
            score = tf.nn.softmax(score, axis=-1)
            labels = tf.concat([labels, label], axis=0)
            scores = tf.concat([scores, score[:, 1]], axis=0)

        r10_1 = self._metrics_rn_1(scores, labels, num=10)
        r2_1 = self._metrics_rn_1(scores, labels, num=2)
        return r2_1, r10_1
예제 #6
0
    def __init__(self, execute_type, checkpoint_dir, beam_size, vocab_size,
                 dict_fn, max_length):
        """
        Transformer聊天器初始化,用于加载模型
        """
        super().__init__(checkpoint_dir, beam_size, max_length)

        self.model = transformer.transformer(
            vocab_size=vocab_size,
            num_layers=_config.transformer_num_layers,
            units=_config.transformer_units,
            d_model=_config.transformer_d_model,
            num_heads=_config.transformer_num_heads,
            dropout=_config.transformer_dropout)

        self.learning_rate = CustomSchedule(_config.transformer_d_model)
        self.optimizer = tf.keras.optimizers.Adam(self.learning_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.98,
                                                  epsilon=1e-9)
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')

        self.checkpoint = tf.train.Checkpoint(transformer=self.model,
                                              optimizer=self.optimizer)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(dict_fn))
            self.token = _data.load_token_dict(dict_fn=dict_fn)
        print('正在检查是否存在检查点...')
        if self.ckpt:
            print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir))
            self.checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
        else:
            if execute_type == "train":
                print('不存在检查点,正在train模式...')
            else:
                print('不存在检查点,请先执行train模式,再进入chat模式')
                exit(0)
예제 #7
0
 def response(self, req):
     print('正在从“{}”处加载字典...'.format(self.dict_fn))
     token = _data.load_token_dict(dict_fn=self.dict_fn)
     print('功能待完善...')
예제 #8
0
    def __init__(self, units: int, vocab_size: int, execute_type: str,
                 dict_fn: str, embedding_dim: int, checkpoint_dir: int,
                 max_utterance: int, max_sentence: int, learning_rate: float,
                 database_fn: str, solr_server: str):
        """
        SMN聊天器初始化,用于加载模型
        Args:
            units: 单元数
            vocab_size: 词汇量大小
            execute_type: 对话执行模式
            dict_fn: 保存字典路径
            embedding_dim: 嵌入层维度
            checkpoint_dir: 检查点保存目录路径
            max_utterance: 每轮句子数量
            max_sentence: 单个句子最大长度
            learning_rate: 学习率
            database_fn: 候选数据库路径
        Returns:
        """
        self.dict_fn = dict_fn
        self.checkpoint_dir = checkpoint_dir
        self.max_utterance = max_utterance
        self.max_sentence = max_sentence
        self.database_fn = database_fn
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.solr = pysolr.Solr(url=solr_server,
                                always_commit=True,
                                timeout=10)
        self.train_loss = tf.keras.metrics.Mean()

        self.model = smn.smn(units=units,
                             vocab_size=vocab_size,
                             embedding_dim=embedding_dim,
                             max_utterance=self.max_utterance,
                             max_sentence=self.max_sentence)

        self.checkpoint = tf.train.Checkpoint(
            model=self.model,
            optimizer=self.optimizer,
        )

        ckpt = os.path.exists(checkpoint_dir)
        if not ckpt:
            os.makedirs(checkpoint_dir)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(self.dict_fn))
            self.token = data_utils.load_token_dict(dict_fn=self.dict_fn)
        print('正在检查是否存在检查点...')
        if ckpt:
            print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir))
            self.checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
        else:
            if execute_type == "train":
                print('不存在检查点,正在train模式...')
            else:
                print('不存在检查点,请先执行train模式,再进入chat模式')
                exit(0)

        logger = utils.log_operator(level=10)
        logger.info("启动SMN聊天器,执行类别为:{},模型参数配置为:embedding_dim:{},"
                    "max_sentence:{},max_utterance:{},units:{},vocab_size:{},"
                    "learning_rate:{}".format(execute_type, embedding_dim,
                                              max_sentence, max_utterance,
                                              units, vocab_size,
                                              learning_rate))
예제 #9
0
    def __init__(self,
                 execute_type: str,
                 checkpoint_dir: str,
                 units: int,
                 embedding_dim: int,
                 batch_size: int,
                 start_sign: str,
                 end_sign: str,
                 beam_size: int,
                 vocab_size: int,
                 dict_fn: str,
                 max_length: int,
                 encoder_layers: int,
                 decoder_layers: int,
                 cell_type: str,
                 if_bidirectional: bool = True):
        """
        Seq2Seq聊天器初始化,用于加载模型
        Args:
            execute_type: 对话执行模式
            checkpoint_dir: 检查点保存目录路径
            units: 单元数
            embedding_dim: 嵌入层维度
            batch_size: batch大小
            start_sign: 开始标记
            end_sign: 结束标记
            beam_size: batch大小
            vocab_size: 词汇量大小
            dict_fn: 保存字典路径
            max_length: 单个句子最大长度
            encoder_layers: encoder中内部RNN层数
            decoder_layers: decoder中内部RNN层数
            cell_type: cell类型,lstm/gru, 默认lstm
            if_bidirectional: 是否双向
        Returns:
        """
        super().__init__(checkpoint_dir, beam_size, max_length)
        self.units = units
        self.start_sign = start_sign
        self.end_sign = end_sign
        self.batch_size = batch_size
        self.enc_units = units

        self.encoder = seq2seq.encoder(vocab_size=vocab_size,
                                       embedding_dim=embedding_dim,
                                       enc_units=int(units / 2),
                                       layer_size=encoder_layers,
                                       cell_type=cell_type,
                                       if_bidirectional=if_bidirectional)
        self.decoder = seq2seq.decoder(vocab_size=vocab_size,
                                       embedding_dim=embedding_dim,
                                       enc_units=units,
                                       dec_units=units,
                                       layer_size=decoder_layers,
                                       cell_type=cell_type)

        self.optimizer = tf.keras.optimizers.Adam()
        self.train_loss = tf.keras.metrics.Mean()
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')
        self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer,
                                              encoder=self.encoder,
                                              decoder=self.decoder)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(dict_fn))
            self.token = data_utils.load_token_dict(dict_fn=dict_fn)
        print('正在检查是否存在检查点...')
        if self.ckpt:
            print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir))
            self.checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
        else:
            if execute_type == "train":
                print('不存在检查点,正在train模式...')
            else:
                print('不存在检查点,请先执行train模式,再进入chat模式')
                exit(0)

        utils.log_operator(level=10).info(
            "启动SMN聊天器,执行类别为:{},模型参数配置为:vocab_size:{},"
            "embedding_dim:{},units:{},max_length:{}".format(
                execute_type, vocab_size, embedding_dim, units, max_length))
예제 #10
0
    def __init__(self, execute_type: str, checkpoint_dir: str, num_layers: int,
                 units: int, d_model: int, num_heads: int, dropout: float,
                 start_sign: str, end_sign: str, beam_size: int,
                 vocab_size: int, dict_fn: str, max_length: int):
        """
        Transformer聊天器初始化,用于加载模型
        Args:
            execute_type: 对话执行模式
            checkpoint_dir: 检查点保存目录路径
            num_layers: transformer内部层数
            units: 单元数
            d_model: 嵌入层维度
            num_heads: 注意力头数
            dropout: 采样率
            start_sign: 开始标记
            end_sign: 结束标记
            beam_size: batch大小
            vocab_size: 词汇量大小
            dict_fn: 保存字典路径
            max_length: 单个句子最大长度
        Returns:
        """
        super().__init__(checkpoint_dir, beam_size, max_length)
        self.start_sign = start_sign
        self.end_sign = end_sign

        self.model = transformer.transformer(vocab_size=vocab_size,
                                             num_layers=num_layers,
                                             units=units,
                                             d_model=d_model,
                                             num_heads=num_heads,
                                             dropout=dropout)

        self.learning_rate = optimizers.CustomSchedule(d_model)
        self.optimizer = tf.keras.optimizers.Adam(self.learning_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.98,
                                                  epsilon=1e-9)
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name='train_accuracy')

        self.checkpoint = tf.train.Checkpoint(transformer=self.model,
                                              optimizer=self.optimizer)

        if execute_type == "chat":
            print('正在从“{}”处加载字典...'.format(dict_fn))
            self.token = data_utils.load_token_dict(dict_fn=dict_fn)
        print('正在检查是否存在检查点...')
        if self.ckpt:
            print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir))
            self.checkpoint.restore(
                tf.train.latest_checkpoint(checkpoint_dir)).expect_partial()
        else:
            if execute_type == "train":
                print('不存在检查点,正在train模式...')
            else:
                print('不存在检查点,请先执行train模式,再进入chat模式')
                exit(0)

        utils.log_operator(level=10).info(
            "启动SMN聊天器,执行类别为:{},模型参数配置为:num_layers:{},"
            "d_model:{},num_heads:{},units:{},dropout:{},vocab_size:{},"
            "max_length:{}".format(execute_type, num_layers, d_model,
                                   num_heads, units, dropout, vocab_size,
                                   max_length))