def __init__(self, execute_type: str, batch_size: int, embedding_dim: int, units: int, dropout: float, checkpoint_dir: str, beam_size: int, vocab_size: int, dict_fn: str, max_length: int, start_sign: str, end_sign: str): """ Seq2Seq聊天器初始化,用于加载模型 """ super().__init__(checkpoint_dir, beam_size, max_length) self.dict_fn = dict_fn self.batch_size = batch_size self.vocab_size = vocab_size self.epoch_iterator = 0 self.start_sign = start_sign self.end_sign = end_sign self.encoder = seq2seq.Encoder(vocab_size, embedding_dim, units, units, dropout) attention = seq2seq.BahdanauAttention(units) self.decoder = seq2seq.Decoder(vocab_size, embedding_dim, units, units, dropout, attention) self.optimizer = optim.Adam([{ 'params': self.encoder.parameters(), 'lr': 1e-3 }, { 'params': self.decoder.parameters(), 'lr': 1e-3 }]) self.criterion = nn.CrossEntropyLoss(ignore_index=0) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = data_utils.load_token_dict(dict_fn=dict_fn) print('正在检查是否存在检查点...') if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint_path = checkpoint_dir + '\\checkpoint.txt' if not os.path.exists(checkpoint_path): with open(checkpoint_path, 'w', encoding='utf-8') as file: print("没有检查到检查点,已创建checkpoint记录文本") if execute_type == "train": print('正在train模式...') else: print('请先执行train模式,再进入chat模式') exit(0) else: with open(checkpoint_path, 'r', encoding='utf-8') as file: lines = file.read().strip().split('\n') version = lines[0] if version is not "": version = int(version) checkpoint = torch.load(checkpoint_dir + '\\seq2seq-{}.pth'.format(version)) self.encoder.load_state_dict( checkpoint['encoder_state_dict']) self.decoder.load_state_dict( checkpoint['decoder_state_dict']) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict']) self.epoch_iterator = checkpoint['epoch'] print("检测到检查点,已成功加载检查点")
def respond(self, req, input_dict_fn, target_dict_fn): # 对req进行初步处理 input_token, target_token = _data.load_token_dict( input_dict_fn=input_dict_fn, target_dict_fn=target_dict_fn) inputs, dec_input = self._pre_treat_inputs(req, input_token, target_token) self.beam_search_container.init_variables(inputs=inputs, dec_input=dec_input) inputs, dec_input = self.beam_search_container.get_variables() for t in range(_config.max_length_tar): predictions = self._create_predictions(inputs, dec_input, t) self.beam_search_container.add(predictions=predictions, end_sign=target_token.get('end')) if self.beam_search_container.beam_size == 0: break inputs, dec_input = self.beam_search_container.get_variables() beam_search_result = self.beam_search_container.get_result() result = '' # 从容器中抽取序列,生成最终结果 for i in range(len(beam_search_result)): temp = beam_search_result[i].numpy() text = _data.sequences_to_texts(temp, target_token) text[0] = text[0].replace('start', '').replace('end', '').replace(' ', '') result = '<' + text[0] + '>' + result return result
def __init__(self, execute_type, checkpoint_dir, beam_size, vocab_size, dict_fn, max_length): """ Seq2Seq聊天器初始化,用于加载模型 """ super().__init__(checkpoint_dir, beam_size, max_length) self.encoder = seq2seq.Encoder(vocab_size, _config.seq2seq_embedding_dim, _config.seq2seq_units, _config.BATCH_SIZE) self.decoder = seq2seq.Decoder(vocab_size, _config.seq2seq_embedding_dim, _config.seq2seq_units, _config.BATCH_SIZE) self.optimizer = tf.keras.optimizers.Adam() self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, encoder=self.encoder, decoder=self.decoder) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = _data.load_token_dict(dict_fn=dict_fn) print('正在检查是否存在检查点...') if self.ckpt: print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir)) self.checkpoint.restore( tf.train.latest_checkpoint(checkpoint_dir)).expect_partial() else: if execute_type == "train": print('不存在检查点,正在train模式...') else: print('不存在检查点,请先执行train模式,再进入chat模式') exit(0)
def __init__(self, model, checkpoint_dir, beam_size, dict_fn): """ Transformer聊天器初始化,用于加载模型 """ self.checkpoint_dir = checkpoint_dir if model == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = _data.load_token_dict(dict_fn=dict_fn) self.beam_search_container = BeamSearch( beam_size=beam_size, max_length=_config.max_length_tar, worst_score=0) is_exist = Path(checkpoint_dir) if not is_exist.exists(): os.makedirs(checkpoint_dir, exist_ok=True) self.ckpt = tf.io.gfile.listdir(checkpoint_dir)
def evaluate(self, valid_fn: str, dict_fn: str = "", tokenizer: tf.keras.preprocessing.text.Tokenizer = None, max_turn_utterances_num: int = 10, max_valid_data_size: int = 0): """ 验证功能,注意了dict_fn和tokenizer两个比传其中一个 Args: valid_fn: 验证数据集路径 dict_fn: 字典路径 tokenizer: 分词器 max_turn_utterances_num: 最大训练数据量 max_valid_data_size: 最大验证数据量 Returns: r2_1, r10_1 """ token_dict = None step = max_valid_data_size // max_turn_utterances_num if max_valid_data_size == 0: return None if dict_fn is not "": token_dict = data_utils.load_token_dict(dict_fn) # 处理并加载评价数据,注意,如果max_valid_data_size传 # 入0,就直接跳过加载评价数据,也就是说只训练不评价 valid_dataset = data_utils.load_smn_valid_data( data_fn=valid_fn, max_sentence=self.max_sentence, max_utterance=self.max_utterance, token_dict=token_dict, tokenizer=tokenizer, max_turn_utterances_num=max_turn_utterances_num, max_valid_data_size=max_valid_data_size) scores = tf.constant([], dtype=tf.float32) labels = tf.constant([], dtype=tf.int32) for (batch, (utterances, response, label)) in enumerate(valid_dataset.take(step)): score = self.model(inputs=[utterances, response]) score = tf.nn.softmax(score, axis=-1) labels = tf.concat([labels, label], axis=0) scores = tf.concat([scores, score[:, 1]], axis=0) r10_1 = self._metrics_rn_1(scores, labels, num=10) r2_1 = self._metrics_rn_1(scores, labels, num=2) return r2_1, r10_1
def __init__(self, execute_type, checkpoint_dir, beam_size, vocab_size, dict_fn, max_length): """ Transformer聊天器初始化,用于加载模型 """ super().__init__(checkpoint_dir, beam_size, max_length) self.model = transformer.transformer( vocab_size=vocab_size, num_layers=_config.transformer_num_layers, units=_config.transformer_units, d_model=_config.transformer_d_model, num_heads=_config.transformer_num_heads, dropout=_config.transformer_dropout) self.learning_rate = CustomSchedule(_config.transformer_d_model) self.optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_accuracy') self.checkpoint = tf.train.Checkpoint(transformer=self.model, optimizer=self.optimizer) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = _data.load_token_dict(dict_fn=dict_fn) print('正在检查是否存在检查点...') if self.ckpt: print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir)) self.checkpoint.restore( tf.train.latest_checkpoint(checkpoint_dir)).expect_partial() else: if execute_type == "train": print('不存在检查点,正在train模式...') else: print('不存在检查点,请先执行train模式,再进入chat模式') exit(0)
def response(self, req): print('正在从“{}”处加载字典...'.format(self.dict_fn)) token = _data.load_token_dict(dict_fn=self.dict_fn) print('功能待完善...')
def __init__(self, units: int, vocab_size: int, execute_type: str, dict_fn: str, embedding_dim: int, checkpoint_dir: int, max_utterance: int, max_sentence: int, learning_rate: float, database_fn: str, solr_server: str): """ SMN聊天器初始化,用于加载模型 Args: units: 单元数 vocab_size: 词汇量大小 execute_type: 对话执行模式 dict_fn: 保存字典路径 embedding_dim: 嵌入层维度 checkpoint_dir: 检查点保存目录路径 max_utterance: 每轮句子数量 max_sentence: 单个句子最大长度 learning_rate: 学习率 database_fn: 候选数据库路径 Returns: """ self.dict_fn = dict_fn self.checkpoint_dir = checkpoint_dir self.max_utterance = max_utterance self.max_sentence = max_sentence self.database_fn = database_fn self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) self.solr = pysolr.Solr(url=solr_server, always_commit=True, timeout=10) self.train_loss = tf.keras.metrics.Mean() self.model = smn.smn(units=units, vocab_size=vocab_size, embedding_dim=embedding_dim, max_utterance=self.max_utterance, max_sentence=self.max_sentence) self.checkpoint = tf.train.Checkpoint( model=self.model, optimizer=self.optimizer, ) ckpt = os.path.exists(checkpoint_dir) if not ckpt: os.makedirs(checkpoint_dir) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(self.dict_fn)) self.token = data_utils.load_token_dict(dict_fn=self.dict_fn) print('正在检查是否存在检查点...') if ckpt: print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir)) self.checkpoint.restore( tf.train.latest_checkpoint(checkpoint_dir)).expect_partial() else: if execute_type == "train": print('不存在检查点,正在train模式...') else: print('不存在检查点,请先执行train模式,再进入chat模式') exit(0) logger = utils.log_operator(level=10) logger.info("启动SMN聊天器,执行类别为:{},模型参数配置为:embedding_dim:{}," "max_sentence:{},max_utterance:{},units:{},vocab_size:{}," "learning_rate:{}".format(execute_type, embedding_dim, max_sentence, max_utterance, units, vocab_size, learning_rate))
def __init__(self, execute_type: str, checkpoint_dir: str, units: int, embedding_dim: int, batch_size: int, start_sign: str, end_sign: str, beam_size: int, vocab_size: int, dict_fn: str, max_length: int, encoder_layers: int, decoder_layers: int, cell_type: str, if_bidirectional: bool = True): """ Seq2Seq聊天器初始化,用于加载模型 Args: execute_type: 对话执行模式 checkpoint_dir: 检查点保存目录路径 units: 单元数 embedding_dim: 嵌入层维度 batch_size: batch大小 start_sign: 开始标记 end_sign: 结束标记 beam_size: batch大小 vocab_size: 词汇量大小 dict_fn: 保存字典路径 max_length: 单个句子最大长度 encoder_layers: encoder中内部RNN层数 decoder_layers: decoder中内部RNN层数 cell_type: cell类型,lstm/gru, 默认lstm if_bidirectional: 是否双向 Returns: """ super().__init__(checkpoint_dir, beam_size, max_length) self.units = units self.start_sign = start_sign self.end_sign = end_sign self.batch_size = batch_size self.enc_units = units self.encoder = seq2seq.encoder(vocab_size=vocab_size, embedding_dim=embedding_dim, enc_units=int(units / 2), layer_size=encoder_layers, cell_type=cell_type, if_bidirectional=if_bidirectional) self.decoder = seq2seq.decoder(vocab_size=vocab_size, embedding_dim=embedding_dim, enc_units=units, dec_units=units, layer_size=decoder_layers, cell_type=cell_type) self.optimizer = tf.keras.optimizers.Adam() self.train_loss = tf.keras.metrics.Mean() self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, encoder=self.encoder, decoder=self.decoder) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = data_utils.load_token_dict(dict_fn=dict_fn) print('正在检查是否存在检查点...') if self.ckpt: print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir)) self.checkpoint.restore( tf.train.latest_checkpoint(checkpoint_dir)).expect_partial() else: if execute_type == "train": print('不存在检查点,正在train模式...') else: print('不存在检查点,请先执行train模式,再进入chat模式') exit(0) utils.log_operator(level=10).info( "启动SMN聊天器,执行类别为:{},模型参数配置为:vocab_size:{}," "embedding_dim:{},units:{},max_length:{}".format( execute_type, vocab_size, embedding_dim, units, max_length))
def __init__(self, execute_type: str, checkpoint_dir: str, num_layers: int, units: int, d_model: int, num_heads: int, dropout: float, start_sign: str, end_sign: str, beam_size: int, vocab_size: int, dict_fn: str, max_length: int): """ Transformer聊天器初始化,用于加载模型 Args: execute_type: 对话执行模式 checkpoint_dir: 检查点保存目录路径 num_layers: transformer内部层数 units: 单元数 d_model: 嵌入层维度 num_heads: 注意力头数 dropout: 采样率 start_sign: 开始标记 end_sign: 结束标记 beam_size: batch大小 vocab_size: 词汇量大小 dict_fn: 保存字典路径 max_length: 单个句子最大长度 Returns: """ super().__init__(checkpoint_dir, beam_size, max_length) self.start_sign = start_sign self.end_sign = end_sign self.model = transformer.transformer(vocab_size=vocab_size, num_layers=num_layers, units=units, d_model=d_model, num_heads=num_heads, dropout=dropout) self.learning_rate = optimizers.CustomSchedule(d_model) self.optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) self.train_loss = tf.keras.metrics.Mean(name='train_loss') self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_accuracy') self.checkpoint = tf.train.Checkpoint(transformer=self.model, optimizer=self.optimizer) if execute_type == "chat": print('正在从“{}”处加载字典...'.format(dict_fn)) self.token = data_utils.load_token_dict(dict_fn=dict_fn) print('正在检查是否存在检查点...') if self.ckpt: print('存在检查点,正在从“{}”中加载检查点...'.format(checkpoint_dir)) self.checkpoint.restore( tf.train.latest_checkpoint(checkpoint_dir)).expect_partial() else: if execute_type == "train": print('不存在检查点,正在train模式...') else: print('不存在检查点,请先执行train模式,再进入chat模式') exit(0) utils.log_operator(level=10).info( "启动SMN聊天器,执行类别为:{},模型参数配置为:num_layers:{}," "d_model:{},num_heads:{},units:{},dropout:{},vocab_size:{}," "max_length:{}".format(execute_type, num_layers, d_model, num_heads, units, dropout, vocab_size, max_length))