def pro_data(): '''读取验证集中数据 :return dev_data 验证集数据列表 :return dev_label 验证集列表中数据对应label ''' pro = process.process_data() pro.split_data_file(normal_param.dev_path) dev_data, dev_label, dev_label_not_deal = pro.deal_data(part=len(pro.all_text_path), n_part=0, need_label= True) return dev_data, dev_label, dev_label_not_deal
def load_corpus(): process_data_init = process_data.process_data() process_data_init.split_data_file(normal_param.train_path) # process_data_init.deal_data(part= 0, n_part= process_data_init.all_text_path) x_texts_labels = process_data_init.build_datas_and_labels( process_data_init.all_text_path) # 是取文本中对应的词语在该文本中的对应下标 x_texts = [x.split(" ") for x, label in x_texts_labels] # list_arr_text = self.deal_text(x_texts) labels = [label for x, label in x_texts_labels] return x_texts, labels
def pro_data(texts_path): '''读取测试路径中数据(texts_path)里面的待测试txt文件 :return contents 所有待测试文件内容列表 ''' pro = process.process_data() # pro.split_data_file(normal_param.dev_path) contents_and_label = [] contents = [] for text_path in texts_path: content = pro.load_data(text_path) contents += [content] # contents += [contents_and_label] # dev_data, dev_label, _= pro.deal_data(part=len(pro.all_text_path), n_part=0) return contents
def train(): print("loading data……") process_data_init = process_data() process_data_init.split_data_file(normal_param.train_path) # train_data, train_label, dev_data, dev_label, vocal_size_train = process_data_init.deal_data( # normal_param.train_path, n_part=0) with tf.Graph().as_default(): session_conf = tf.ConfigProto(allow_soft_placement=normal_param.allow_soft_placement, log_device_placement=normal_param.log_device_placement) # session_conf.gpu_options.allow_growth = True sess = tf.Session(config=session_conf) # KTF.set_session(sess) with sess.as_default(): cnn_init = cnn.model_cnn(sequence_length=normal_param.sequence_length, num_classes=normal_param.num_class, vocab_size=normal_param.vocal_size_train, embedding_size=normal_param.embedding_dim, filters_size=list(map(int, normal_param.filter_sizes.split(","))), num_filters=normal_param.num_filters) global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.AdamOptimizer(1e-3) grads_and_vars = optimizer.compute_gradients(cnn_init.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) time_step = str(int(time.time())) print(time.time()) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", normal_param.model_name)) print("Writing to {}\n".format(out_dir)) loss_summary = tf.summary.scalar("loss", cnn_init.loss) acc_summary = tf.summary.scalar("accuracy", cnn_init.accuracy) train_summary_op = tf.summary.merge([loss_summary, acc_summary]) train_summary_dir = os.path.join(out_dir, "summary", "train") train_summary_write = tf.summary.FileWriter(train_summary_dir, sess.graph_def) dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) dev_summary_dir = os.path.join(out_dir, "summary", "dev") dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph_def) sess.run(tf.global_variables_initializer()) # sess.run(tf.local_variables_initializer()) max_acc = 0 checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoint")) checkpoint_prefix = os.path.join(checkpoint_dir, "model") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) saver = tf.train.Saver(tf.global_variables()) ckpt = tf.train.latest_checkpoint(checkpoint_dir) if ckpt: saver.restore(sess, ckpt) print("CNN restore from the checkpoint {0}".format(ckpt)) # current_step = int(ckpt.split('-')[-1]) # sess.run(tf.initializers) def train_step(x_bratch, y_bratch, writer): # print(y_bratch.shape()) feed_dic = { cnn_init.input_x: x_bratch, cnn_init.input_y: y_bratch, cnn_init.dropout_keep_prob: normal_param.dropout_keep_prob } _, step, summaries, loss, accuracy = sess.run( [train_op, global_step, train_summary_op, cnn_init.loss, cnn_init.accuracy], feed_dic) # print("scores: ", score) time_str = datetime.datetime.now().isoformat() print('{}: step {}, loss {:g}, acc {:g}'.format( time_str, step, loss, accuracy)) if writer: writer.add_summary(summaries, step) def dev_step(x_bratch, y_bratch, writer=None): '''在开发集上验证数据集''' feed_dic = { cnn_init.input_x: x_bratch, cnn_init.input_y: y_bratch, cnn_init.dropout_keep_prob: 1.0 } step, summaries, loss, accuracy = sess.run( [global_step, dev_summary_op, cnn_init.loss, cnn_init.accuracy], feed_dic) time_str = datetime.datetime.now().isoformat() print('{}: step {}, loss {:g}, acc {:g}'.format( time_str, step, loss, accuracy)) if writer: writer.add_summary(summaries, step) # batches = process_data_init.batch_iter(list(zip(train_data, train_label)), normal_param.batch_size, # normal_param.num_epochs) batches = process_data_init.batch_iter(normal_param.batch_size, normal_param.num_epochs) for batch, dev_data, dev_label, is_save in batches: # for batch, is_save in batches: x_batch, y_batch = zip(*batch) # print("y_batch", y_batch) train_step(x_batch, y_batch, train_summary_write) current_step = tf.train.global_step(sess, global_step) if current_step % normal_param.evaluate_every == 0: print("\nEvaluation:") dev_step(dev_data, dev_label, writer=dev_summary_writer) if current_step % normal_param.checkpoint_every == 0: path = saver.save(sess, checkpoint_prefix, global_step=current_step) print("Saved model checkpoint to {}\n".format(path)) if is_save: path = saver.save(sess, checkpoint_prefix, global_step=current_step) print("Saved model checkpoint to {}\n".format(path))