def train_requestable(data_tmp_path): """ 用于训练模型,先训练完存好了才能用 训练用 early stopping :param data_tmp_path: data tmp 文件夹位置 """ print('载入数据管理器...') data_manager = DataManager(data_tmp_path) print('载入训练数据...') informable_slot_datasets, requestable_slot_datasets = generate_dataset( data_manager.DialogData) print('载入 requestable slot detector...') init_learning_rate = 0.005 graph = tf.Graph() with graph.as_default(): requestable_slots_models = {} for k, v in All_requestable_slots_order.items(): requestable_slots_models[k] = RequestableSlotDector( str(v), learning_rate=init_learning_rate) with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(tf.group(tf.global_variables_initializer())) saver = tf.train.Saver() # saver.restore(sess, "./ckpt/requestable/model.ckpt") # 训练 requestable slots requestable_slots_F1s = {} for slot, model in requestable_slots_models.items(): average_loss = 0 learning_rate = init_learning_rate best_F1 = 0 tolerance = 30 tolerance_count = 0 display_step = 10 for step in range(5000): step += 1 batch_data, batch_output = requestable_slot_datasets[ slot].next_batch() char_emb_matrix, word_emb_matrix, seqlen = data_manager.sent2num( batch_data) _, training_loss = sess.run( [model.train_op, model.final_loss], feed_dict={ model.char_emb_matrix: char_emb_matrix, model.word_emb_matrix: word_emb_matrix, model.output: batch_output }) average_loss += training_loss / display_step if step % display_step == 0: batch_data, batch_output = requestable_slot_datasets[ slot].get_testset() char_emb_matrix, word_emb_matrix, seqlen = data_manager.sent2num( batch_data) pred, accu = sess.run( [model.predict, model.accuracy], feed_dict={ model.char_emb_matrix: char_emb_matrix, model.word_emb_matrix: word_emb_matrix, model.output: batch_output }) F1 = get_F1score(batch_output, pred.tolist()) if best_F1 < F1: best_F1 = F1 tolerance_count = 0 if not os.path.exists("./ckpt/requestable/"): os.makedirs("./ckpt/requestable/") saver.save(sess, "./ckpt/requestable/model.ckpt") if tolerance_count == tolerance: break print("%s, step % 4d, loss %0.4f, F1 %0.4f, accu %0.4f" % (slot, step, average_loss, F1, accu)) average_loss = 0 tolerance_count += 1 learning_rate = max(learning_rate * 0.98, 0.001) sess.run(model.update_lr, feed_dict={model.new_lr: learning_rate}) print("requestable slot: %s, best F1 %0.4f" % (slot, best_F1)) requestable_slots_F1s[slot] = best_F1 print(requestable_slots_F1s) print( sum(requestable_slots_F1s.values()) / len(requestable_slots_F1s.values()))
def train_informable(data_tmp_path): """ 用于训练模型,先训练完存好了才能用 训练用 early stopping :param data_tmp_path: data tmp 文件夹位置 """ print('载入数据管理器...') data_manager = DataManager(data_tmp_path) print('载入训练数据...') informable_slot_datasets, requestable_slot_datasets = generate_dataset( data_manager.DialogData) print('载入 informable slot detector ...') init_learning_rate = 0.005 informable_batch_ratios = { # 不同slot 的minibatch ratio "通话时长": [2, 8, 8, 8], "流量": [4, 8, 8, 8], "功能费": [4, 8, 8, 8] } graph = tf.Graph() with graph.as_default(): informable_slots_models = { "功能费": InformableSlotDector('cost', learning_rate=init_learning_rate), "流量": InformableSlotDector('data', learning_rate=init_learning_rate), "通话时长": InformableSlotDector('time', learning_rate=init_learning_rate), } with tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(tf.group(tf.global_variables_initializer())) saver = tf.train.Saver() # saver.restore(sess, "./ckpt/informable/model.ckpt") # 训练 informable slots informable_slots_accus = [] for slot, model in informable_slots_models.items(): learning_rate = init_learning_rate average_loss = 0 best_accu = 0 tolerance = 20 tolerance_count = 0 display_step = 10 for step in range(5000): step += 1 batch_data, batch_output = informable_slot_datasets[ slot].next_batch(informable_batch_ratios[slot]) char_emb_matrix, word_emb_matrix, seqlen = data_manager.sent2num( batch_data) _, training_loss = sess.run( [model.train_op, model.final_loss], feed_dict={ model.char_emb_matrix: char_emb_matrix, model.word_emb_matrix: word_emb_matrix, model.output: batch_output }) average_loss += training_loss / display_step if step % display_step == 0: batch_data, batch_output = informable_slot_datasets[ slot].get_testset() char_emb_matrix, word_emb_matrix, seqlen = data_manager.sent2num( batch_data) pred, accu = sess.run( [model.predict, model.accuracy], feed_dict={ model.char_emb_matrix: char_emb_matrix, model.word_emb_matrix: word_emb_matrix, model.output: batch_output }) if best_accu < accu: best_accu = accu tolerance_count = 0 if not os.path.exists("./ckpt/informable/"): os.makedirs("./ckpt/informable/") saver.save(sess, "./ckpt/informable/model.ckpt") if tolerance_count == tolerance: break print("%s, step % 4d, loss %0.4f, accu %0.4f" % (slot, step, average_loss, accu)) average_loss = 0 tolerance_count += 1 learning_rate = max(learning_rate * 0.95, 0.0001) sess.run(model.update_lr, feed_dict={model.new_lr: learning_rate}) print("informable slot: %s, best accu %0.4f" % (slot, best_accu)) informable_slots_accus.append(best_accu)