def get_the_final_result(): # 参数配置 batch_size = 512 seq_length = 20 embeddings_size = 300 hidden_size = 256 num_layers = 2 num_classes = 9 learning_rate = 0.003 dropout = 0.3 # 数据文件路径 word2vec_path = './data/word2vec.bin' train_file = './data/train.json' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 定义模型 model = TextRCNN(embeddings_size, num_classes, hidden_size, num_layers, True, dropout) model.to(device) # 加载训练好的模型参数 checkpoints = torch.load('./saved_model/text_rcnn.pth') model.load_state_dict(checkpoints['model_state']) # 加载数据 data_loader = Dataloader(word2vec_path, batch_size, embeddings_size, seq_length, device) # 初始化数据迭代器 texts, labels = data_loader.load_data(train_file, shuffle=True, mode='train') # 加载数据 print('Data load completed...') # 在测试集上进行测试 test_texts = texts[int(len(texts) * 0.8):] test_labels = labels[int(len(texts) * 0.8):] steps = len(test_texts) // batch_size loader = data_loader.data_iterator(test_texts, test_labels) # 测试集上的准确率 accuracy = evaluate(model, loader, steps) print('The final result(Accuracy in Test) is %.2f' % (accuracy * 100))
def train_rcnn(): """Training RCNN model.""" # Load sentences, labels, and training parameters logger.info("✔︎ Loading data...") logger.info("✔︎ Training data processing...") train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False) logger.info("✔︎ Validation data processing...") val_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False) logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len)) logger.info("✔︎ Training data padding...") x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len) logger.info("✔︎ Validation data padding...") x_val, y_val = dh.pad_data(val_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and rcnn object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): rcnn = TextRCNN( sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, vocab_size=VOCAB_SIZE, lstm_hidden_size=FLAGS.lstm_hidden_size, fc_hidden_size=FLAGS.fc_hidden_size, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))), num_filters=FLAGS.num_filters, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, global_step=rcnn.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(rcnn.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients(zip(grads, vars), global_step=rcnn.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input("☛ Please input the checkpoints model you want to restore, " "it should be like(1490175368): ") # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input("✘ The format of your input is illegal, please re-input: ") logger.info("✔︎ The format of your input is legal, now loading to next step...") out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss loss_summary = tf.summary.scalar("loss", rcnn.loss) # Train summaries train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if FLAGS.train_or_restore == 'R': # Load rcnn model logger.info("✔︎ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(rcnn.global_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { rcnn.input_x: x_batch, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob, rcnn.is_training: True } _, step, summaries, loss = sess.run( [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_val, y_val, writer=None): """Evaluates model on a validation set""" batches_validation = dh.batch_iter(list(zip(x_val, y_val)), FLAGS.batch_size, 1) # Predict classes by threshold or topk ('ts': threshold; 'tk': topk) eval_counter, eval_loss = 0, 0.0 eval_pre_tk = [0.0] * FLAGS.top_num eval_rec_tk = [0.0] * FLAGS.top_num eval_F_tk = [0.0] * FLAGS.top_num true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(FLAGS.top_num)] for batch_validation in batches_validation: x_batch_val, y_batch_val = zip(*batch_validation) feed_dict = { rcnn.input_x: x_batch_val, rcnn.input_y: y_batch_val, rcnn.dropout_keep_prob: 1.0, rcnn.is_training: False } step, summaries, scores, cur_loss = sess.run( [rcnn.global_step, validation_summary_op, rcnn.scores, rcnn.loss], feed_dict) # Prepare for calculating metrics for i in y_batch_val: true_onehot_labels.append(i) for j in scores: predicted_onehot_scores.append(j) # Predict by threshold batch_predicted_onehot_labels_ts = \ dh.get_onehot_label_threshold(scores=scores, threshold=FLAGS.threshold) for k in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(k) # Predict by topK for top_num in range(FLAGS.top_num): batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[top_num].append(i) eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) # Calculate Precision & Recall & F1 (threshold & topK) eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') eval_F_ts = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for top_num in range(FLAGS.top_num): eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') eval_F_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') # Calculate the average AUC eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR eval_prc = average_precision_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') return eval_loss, eval_auc, eval_prc, eval_rec_ts, eval_pre_ts, eval_F_ts, \ eval_rec_tk, eval_pre_tk, eval_F_tk # Generate batches batches_train = dh.batch_iter( list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, rcnn.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_auc, eval_prc, \ eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \ validation_step(x_val, y_val, writer=validation_summary_writer) logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}" .format(eval_loss, eval_auc, eval_prc)) # Predict by threshold logger.info("☛ Predict by threshold: Precision {0:g}, Recall {1:g}, F {2:g}" .format(eval_pre_ts, eval_rec_ts, eval_F_ts)) # Predict by topK logger.info("☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F {3:g}" .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F_tk[top_num])) best_saver.handle(eval_prc, sess, current_step) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info("✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
def train_rcnn(): """Training RCNN model.""" # Print parameters used for the model dh.tab_printer(args, logger) # Load sentences, labels, and training parameters logger.info("Loading data...") logger.info("Data processing...") train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file) validation_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file) logger.info("Data padding...") x_train_front, x_train_behind, y_train = dh.pad_data( train_data, args.pad_seq_len) x_validation_front, x_validation_behind, y_validation = dh.pad_data( validation_data, args.pad_seq_len) # Build vocabulary VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix( args.word2vec_file) # Build a graph and rcnn object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=args.allow_soft_placement, log_device_placement=args.log_device_placement) session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): rcnn = TextRCNN(sequence_length=args.pad_seq_len, vocab_size=VOCAB_SIZE, embedding_type=args.embedding_type, embedding_size=EMBEDDING_SIZE, lstm_hidden_size=args.lstm_dim, filter_sizes=args.filter_sizes, num_filters=args.num_filters, fc_hidden_size=args.fc_dim, num_classes=y_train.shape[1], l2_reg_lambda=args.l2_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay( learning_rate=args.learning_rate, global_step=rcnn.global_step, decay_steps=args.decay_steps, decay_rate=args.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(rcnn.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio) train_op = optimizer.apply_gradients( zip(grads, vars), global_step=rcnn.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram( "{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar( "{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries out_dir = dh.get_out_dir(OPTION, logger) checkpoint_dir = os.path.abspath( os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath( os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss loss_summary = tf.summary.scalar("loss", rcnn.loss) # Train summaries train_summary_op = tf.summary.merge( [loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter( validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints) best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if OPTION == 'R': # Load rcnn model logger.info("Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) if OPTION == 'T': if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = args.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save( sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(rcnn.global_step) def train_step(x_batch_front, x_batch_behind, y_batch): """A single training step""" feed_dict = { rcnn.input_x_front: x_batch_front, rcnn.input_x_behind: x_batch_behind, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: args.dropout_rate, rcnn.is_training: True } _, step, summaries, loss = sess.run( [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None): """Evaluates model on a validation set""" batches_validation = dh.batch_iter( list(zip(x_batch_front, x_batch_behind, y_batch)), args.batch_size, 1) eval_counter, eval_loss = 0, 0.0 true_labels = [] predicted_scores = [] predicted_labels = [] for batch_validation in batches_validation: x_batch_val_front, x_batch_val_behind, y_batch_val = zip( *batch_validation) feed_dict = { rcnn.input_x_front: x_batch_val_front, rcnn.input_x_behind: x_batch_val_behind, rcnn.input_y: y_batch_val, rcnn.dropout_keep_prob: 1.0, rcnn.is_training: False } step, summaries, scores, predictions, cur_loss = sess.run([ rcnn.global_step, validation_summary_op, rcnn.topKPreds, rcnn.predictions, rcnn.loss ], feed_dict) # Prepare for calculating metrics for i in y_batch_val: true_labels.append(np.argmax(i)) for j in scores[0]: predicted_scores.append(j[0]) for k in predictions: predicted_labels.append(k) eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) # Calculate Precision & Recall & F1 eval_acc = accuracy_score(y_true=np.array(true_labels), y_pred=np.array(predicted_labels)) eval_pre = precision_score(y_true=np.array(true_labels), y_pred=np.array(predicted_labels), average='micro') eval_rec = recall_score(y_true=np.array(true_labels), y_pred=np.array(predicted_labels), average='micro') eval_F1 = f1_score(y_true=np.array(true_labels), y_pred=np.array(predicted_labels), average='micro') # Calculate the average AUC eval_auc = roc_auc_score(y_true=np.array(true_labels), y_score=np.array(predicted_scores), average='micro') return eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc # Generate batches batches_train = dh.batch_iter( list(zip(x_train_front, x_train_behind, y_train)), args.batch_size, args.epochs) num_batches_per_epoch = int( (len(x_train_front) - 1) / args.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_front, x_batch_behind, y_batch = zip(*batch_train) train_step(x_batch_front, x_batch_behind, y_batch) current_step = tf.train.global_step(sess, rcnn.global_step) if current_step % args.evaluate_steps == 0: logger.info("\nEvaluation:") eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc = \ validation_step(x_validation_front, x_validation_behind, y_validation, writer=validation_summary_writer) logger.info( "All Validation set: Loss {0:g} | Acc {1:g} | Precision {2:g} | " "Recall {3:g} | F1 {4:g} | AUC {5:g}".format( eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc)) best_saver.handle(eval_acc, sess, current_step) if current_step % args.checkpoint_steps == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info( "Epoch {0} has finished!".format(current_epoch)) logger.info("All Done.")
def train_rcnn(): """Training RCNN model.""" # Load sentences, labels, and training parameters logger.info('✔︎ Loading data...') logger.info('✔︎ Training data processing...') train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('✔︎ Validation data processing...') validation_data = \ dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('Recommended padding Sequence length is: {0}'.format( FLAGS.pad_seq_len)) logger.info('✔︎ Training data padding...') x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len) logger.info('✔︎ Validation data padding...') x_validation, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = dh.load_word2vec_matrix( VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and rcnn object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): rcnn = TextRCNN(sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, batch_size=FLAGS.batch_size, vocab_size=VOCAB_SIZE, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay( learning_rate=FLAGS.learning_rate, global_step=rcnn.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(rcnn.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients( zip(grads, vars), global_step=rcnn.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram( "{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar( "{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input( "☛ Please input the checkpoints model you want to restore, " "it should be like(1490175368): " ) # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input( '✘ The format of your input is illegal, please re-input: ' ) logger.info( '✔︎ The format of your input is legal, now loading to next step...' ) checkpoint_dir = 'runs/' + MODEL + '/checkpoints/' out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", rcnn.loss) # Train summaries train_summary_op = tf.summary.merge( [loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter( validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) if FLAGS.train_or_restore == 'R': # Load rcnn model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: checkpoint_dir = os.path.abspath( os.path.join(out_dir, "checkpoints")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = 'embedding' embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save( sess, os.path.join(out_dir, 'embedding', 'embedding.ckpt')) current_step = sess.run(rcnn.global_step) def batch_iter(data, batch_size, num_epochs, shuffle=True): """ The function <batch_iter> in data_helpers.py will create the data batch which has not exactly batch size since that we have to overwrite the function for rcnn Because rcnn need the all batches has the exact batch size otherwise will raise error """ data = np.array(data) data_size = len(data) # Just the diff in var num_batches_per_epoch # Do not plus one in there # Because we need to drop the last batch in case it has not exactly batch_size num_batches_per_epoch = int((data_size - 1) / batch_size) for epoch in range(num_epochs): # Shuffle the data at each epoch if shuffle: shuffle_indices = np.random.permutation( np.arange(data_size)) shuffled_data = data[shuffle_indices] else: shuffled_data = data for batch_num in range(num_batches_per_epoch): start_index = batch_num * batch_size end_index = min((batch_num + 1) * batch_size, data_size) yield shuffled_data[start_index:end_index] def train_step(x_batch, y_batch): """A single training step""" feed_dict = { rcnn.input_x: x_batch, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob, rcnn.is_training: True } _, step, summaries, loss = sess.run( [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_validation, y_validation, writer=None): """Evaluates model on a validation set""" batches_validation = batch_iter( list(zip(x_validation, y_validation)), FLAGS.batch_size, FLAGS.num_epochs) # Predict classes by threshold or topk ('ts': threshold; 'tk': topk) eval_counter, eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0 eval_rec_tk = [0.0] * FLAGS.top_num eval_acc_tk = [0.0] * FLAGS.top_num eval_F_tk = [0.0] * FLAGS.top_num for batch_validation in batches_validation: x_batch_validation, y_batch_validation = zip( *batch_validation) feed_dict = { rcnn.input_x: x_batch_validation, rcnn.input_y: y_batch_validation, rcnn.dropout_keep_prob: 1.0, rcnn.is_training: False } step, summaries, scores, cur_loss = sess.run([ rcnn.global_step, validation_summary_op, rcnn.scores, rcnn.loss ], feed_dict) # Predict by threshold predicted_labels_threshold, predicted_values_threshold = \ dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold) cur_rec_ts, cur_acc_ts, cur_F_ts = 0.0, 0.0, 0.0 for index, predicted_label_threshold in enumerate( predicted_labels_threshold): rec_inc_ts, acc_inc_ts, F_inc_ts = dh.cal_metric( predicted_label_threshold, y_batch_validation[index]) cur_rec_ts, cur_acc_ts, cur_F_ts = cur_rec_ts + rec_inc_ts, \ cur_acc_ts + acc_inc_ts, \ cur_F_ts + F_inc_ts cur_rec_ts = cur_rec_ts / len(y_batch_validation) cur_acc_ts = cur_acc_ts / len(y_batch_validation) cur_F_ts = cur_F_ts / len(y_batch_validation) eval_rec_ts, eval_acc_ts, eval_F_ts = eval_rec_ts + cur_rec_ts, \ eval_acc_ts + cur_acc_ts, \ eval_F_ts + cur_F_ts # Predict by topK topK_predicted_labels = [] for top_num in range(FLAGS.top_num): predicted_labels_topk, predicted_values_topk = \ dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1) topK_predicted_labels.append(predicted_labels_topk) cur_rec_tk = [0.0] * FLAGS.top_num cur_acc_tk = [0.0] * FLAGS.top_num cur_F_tk = [0.0] * FLAGS.top_num for top_num, predicted_labels_topK in enumerate( topK_predicted_labels): for index, predicted_label_topK in enumerate( predicted_labels_topK): rec_inc_tk, acc_inc_tk, F_inc_tk = dh.cal_metric( predicted_label_topK, y_batch_validation[index]) cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num] = \ cur_rec_tk[top_num] + rec_inc_tk, \ cur_acc_tk[top_num] + acc_inc_tk, \ cur_F_tk[top_num] + F_inc_tk cur_rec_tk[top_num] = cur_rec_tk[top_num] / len( y_batch_validation) cur_acc_tk[top_num] = cur_acc_tk[top_num] / len( y_batch_validation) cur_F_tk[top_num] = cur_F_tk[top_num] / len( y_batch_validation) eval_rec_tk[top_num], eval_acc_tk[top_num], eval_F_tk[top_num] = \ eval_rec_tk[top_num] + cur_rec_tk[top_num], \ eval_acc_tk[top_num] + cur_acc_tk[top_num], \ eval_F_tk[top_num] + cur_F_tk[top_num] eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 logger.info("✔︎ validation batch {0}: loss {1:g}".format( eval_counter, cur_loss)) logger.info( "︎☛ Predict by threshold: recall {0:g}, accuracy {1:g}, F {2:g}" .format(cur_rec_ts, cur_acc_ts, cur_F_ts)) logger.info("︎☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info( "Top{0}: recall {1:g}, accuracy {2:g}, F {3:g}". format(top_num + 1, cur_rec_tk[top_num], cur_acc_tk[top_num], cur_F_tk[top_num])) if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec_ts = float(eval_rec_ts / eval_counter) eval_acc_ts = float(eval_acc_ts / eval_counter) eval_F_ts = float(eval_F_ts / eval_counter) for top_num in range(FLAGS.top_num): eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter) eval_acc_tk[top_num] = float(eval_acc_tk[top_num] / eval_counter) eval_F_tk[top_num] = float(eval_F_tk[top_num] / eval_counter) return eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts, eval_rec_tk, eval_acc_tk, eval_F_tk # Generate batches batches_train = batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int( (len(x_train) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, rcnn.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_rec_ts, eval_acc_ts, eval_F_ts, eval_rec_tk, eval_acc_tk, eval_F_tk = \ validation_step(x_validation, y_validation, writer=validation_summary_writer) logger.info( "All Validation set: Loss {0:g}".format(eval_loss)) # Predict by threshold logger.info( "︎☛ Predict by threshold: Recall {0:g}, Accuracy {1:g}, F {2:g}" .format(eval_rec_ts, eval_acc_ts, eval_F_ts)) # Predict by topK logger.info("︎☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info( "Top{0}: Recall {1:g}, Accuracy {2:g}, F {3:g}". format(top_num + 1, eval_rec_tk[top_num], eval_acc_tk[top_num], eval_F_tk[top_num])) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info( "✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info( "✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
def main(_): print('Loading word2vec model finished:%s' % (FLAGS.word_embedding_file)) w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 256) print('Load word2vec model finished') print('Loading train/valid samples:%s' % (FLAGS.training_data)) train_x, train_y, train_len, valid_x, valid_y, valid_len = loadSamples( FLAGS.training_data, FLAGS.label_file, FLAGS.label_map, word2id, FLAGS.valid_rate, FLAGS.num_classes) print('Load train/valid samples finished') train_x = pad_sequences(train_x, maxlen=FLAGS.sample_len, value=0.) valid_x = pad_sequences(valid_x, maxlen=FLAGS.sample_len, value=0.) train_sample_size = len(train_x) dev_sample_size = len(valid_x) print('Training sample size:%d' % (train_sample_size)) print('Valid sample size:%d' % (dev_sample_size)) timestamp = str(int(time.time())) runs_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs')) if not os.path.exists(runs_dir): os.makedirs(runs_dir) out_dir = os.path.abspath(os.path.join(runs_dir, timestamp)) if not os.path.exists(out_dir): os.makedirs(out_dir) checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints')) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint_prefix = os.path.join(checkpoint_dir, 'model') with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = True sess = tf.Session(config=session_conf) with sess.as_default(), tf.device('/gpu:1'): text_rcnn = TextRCNN(embedding_size=FLAGS.embedding_size, sequence_length=FLAGS.sample_len, num_classes=FLAGS.num_classes, w2v_model=w2v_model, rnn_hidden_size=FLAGS.rnn_hidden_size, learning_rate=FLAGS.learning_rate, decay_rate=FLAGS.decay_rate, decay_steps=FLAGS.decay_steps, l2_reg_lambda=FLAGS.l2_reg_lambda) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) train_summary_dir = os.path.join(out_dir, 'summaries', 'train') dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev') loss_summary = tf.summary.scalar('loss', text_rcnn.loss_val) acc_summary = tf.summary.scalar('accuracy', text_rcnn.accuracy) train_summary_op = tf.summary.merge([loss_summary, acc_summary]) train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) sess.run(tf.global_variables_initializer()) total_loss = 0. total_acc = 0. total_step = 0. best_valid_acc = 0. best_valid_loss = 1000. best_valid_zhihu_score = 0. this_step_valid_acc = 0. this_step_valid_loss = 0. this_step_zhihu_score = 0. valid_loss_summary = tf.summary.scalar('loss', this_step_valid_loss) valid_acc_summary = tf.summary.scalar('accuracy', this_step_valid_acc) valid_zhihu_score_summary = tf.summary.scalar( 'zhihu_score', this_step_zhihu_score) valid_summary_op = tf.summary.merge([ valid_loss_summary, valid_acc_summary, valid_zhihu_score_summary ]) for epoch in range(0, FLAGS.num_epochs): print('epoch:' + str(epoch)) batch_step = 0 batch_loss = 0. batch_acc = 0. for start, end in zip( range(0, train_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, train_sample_size, FLAGS.batch_size)): if total_step > 20000: feed_dict = { text_rcnn.input_x: train_x[start:end], text_rcnn.input_y: train_y[start:end], text_rcnn.seq_len: train_len[start:end], text_rcnn.first_stage: False, text_rcnn.dropout_keep_prob: 0.5 } else: feed_dict = { text_rcnn.input_x: train_x[start:end], text_rcnn.input_y: train_y[start:end], text_rcnn.seq_len: train_len[start:end], text_rcnn.first_stage: True, text_rcnn.dropout_keep_prob: 0.5 } loss, acc, step, summaries, _ = sess.run([ text_rcnn.loss_val, text_rcnn.accuracy, text_rcnn.global_step, train_summary_op, text_rcnn.train_op ], feed_dict) train_summary_writer.add_summary(summaries, step) total_loss += loss total_acc += acc batch_loss += loss batch_acc += acc batch_step += 1 total_step += 1. if batch_step % FLAGS.print_stats_every == 0: time_str = datetime.datetime.now().isoformat() print( '[%s]Epoch:%d\tBatch_Step:%d\tTrain_Loss:%.4f/%.4f/%.4f\tTrain_Accuracy:%.4f/%.4f/%.4f' % (time_str, epoch, batch_step, loss, batch_loss / batch_step, total_loss / total_step, acc, batch_acc / batch_step, total_acc / total_step)) this_step_valid_acc = 0. if total_step % FLAGS.evaluate_every == 0 and total_step > 40000: eval_loss = 0. eval_acc = 0. eval_step = 0 for start, end in zip( range(0, dev_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, dev_sample_size, FLAGS.batch_size)): feed_dict = { text_rcnn.input_x: valid_x[start:end], text_rcnn.input_y: valid_y[start:end], text_rcnn.seq_len: valid_len[start:end], text_rcnn.first_stage: False, text_rcnn.dropout_keep_prob: 1. } step, summaries, logits, loss, acc = sess.run([ text_rcnn.global_step, dev_summary_op, text_rcnn.logits, text_rcnn.loss_val, text_rcnn.accuracy ], feed_dict) dev_summary_writer.add_summary(summaries, step) if all_marked_label_num == 0: labelNumStats(valid_y) zhihuStats(logits, valid_y[start:end]) eval_loss += loss eval_acc += acc eval_step += 1 this_step_zhihu_score = calZhihuScore() time_str = datetime.datetime.now().isoformat() this_step_valid_acc = eval_acc / eval_step this_step_valid_loss = eval_loss / eval_step print( '[%s]Eval_Loss:%.4f\tEval_Accuracy:%.4f\tZhihu_Score:%.4f' % (time_str, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) if total_step % FLAGS.checkpoint_every == 0 and total_step > 40000: if not FLAGS.save_best_model: path = saver.save(sess, checkpoint_prefix, global_step=step) print('Saved model checkpoint to %s' % path) elif this_step_zhihu_score > best_valid_zhihu_score: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best zhihu_score model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_zhihu_score = this_step_zhihu_score elif this_step_valid_acc > best_valid_acc: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best acc model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_acc = this_step_valid_acc elif this_step_valid_loss < best_valid_loss: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best loss model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_loss = this_step_valid_loss
elif FLAGS.using_nn_type == 'textbirnn': nn = TextBiRNN( model_type=FLAGS.model_type, sequence_length=x_train.shape[1], num_classes=y_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=embedding_dimension, rnn_size=FLAGS.rnn_size, num_layers=FLAGS.num_rnn_layers, # batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) elif FLAGS.using_nn_type == 'textrcnn': nn = TextRCNN(model_type=FLAGS.model_type, sequence_length=x_train.shape[1], num_classes=y_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=embedding_dimension, batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) elif FLAGS.using_nn_type == 'texthan': nn = TextHAN(model_type=FLAGS.model_type, sequence_length=x_train.shape[1], num_sentences=3, num_classes=y_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=embedding_dimension, hidden_size=FLAGS.rnn_size, batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) # Define Training procedure
def train(): with tf.Graph().as_default(): tf.set_random_seed(FLAGS.random_seed) session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) sess = tf.Session(config=session_conf) with sess.as_default(): if FLAGS.model_class == 'siamese': model = SiameseNets( model_type=FLAGS.model_type, sequence_length=FLAGS.max_document_length, vocab_size=len(vocab), embedding_size=FLAGS.embedding_dim, word_embedding_type=FLAGS.word_embedding_type, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, rnn_cell=FLAGS.rnn_cell, hidden_units=FLAGS.hidden_units, num_layers=FLAGS.num_layers, use_attention=FLAGS.use_attention, dense_layer=FLAGS.dense_layer, pred_threshold=FLAGS.pred_threshold, l2_reg_lambda=FLAGS.l2_reg_lambda, energy_func=FLAGS.energy_function, loss_func=FLAGS.loss_function, margin=FLAGS.margin, contrasive_loss_pos_weight=FLAGS.scale_pos_weight, weight_sharing=FLAGS.weight_sharing ) print("Initialized SiameseNets model.") else: model = TextRCNN( model_type=FLAGS.model_type, sequence_length=FLAGS.max_document_length, embedding_size=FLAGS.embedding_dim, vocab_size=len(vocab), filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, rnn_cell=FLAGS.rnn_cell, hidden_units=FLAGS.hidden_units, num_layers=FLAGS.num_layers, pos_weight=FLAGS.scale_pos_weight, l2_reg_lambda=FLAGS.l2_reg_lambda, weight_sharing=FLAGS.weight_sharing, interaction="multiply", word_embedding_type=FLAGS.word_embedding_type ) print("Initialized TextRCNN model.") # Define Training procedure global_step = tf.Variable(0, name="global_step", trainable=False) learning_rate = tf.train.exponential_decay(FLAGS.lr, global_step, decay_steps=int(40000/FLAGS.batch_size), decay_rate=FLAGS.weight_decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) # optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9) # optimizer = tf.train.GradientDescentOptimizer(learning_rate) # optimizer = tf.train.RMSPropOptimizer(learning_rate) # optimizer = tf.train.AdadeltaOptimizer(learning_rate, epsilon=1e-6) # for i, (g, v) in enumerate(grads_and_vars): # if g is not None: # grads_and_vars[i] = (tf.clip_by_global_norm(g, 5), v) # clip gradients # train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) if FLAGS.clip_norm: # improve loss, but small weight cause small score, need to turn threshold for better f1. variables = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(model.loss, variables), FLAGS.clip_norm) train_op = optimizer.apply_gradients(zip(grads, variables), global_step=global_step) grads_and_vars = zip(grads, variables) else: grads_and_vars = optimizer.compute_gradients(model.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in grads_and_vars: if g is not None: grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) print("Defined gradient summaries.") # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", model.loss) f1_summary = tf.summary.scalar("F1-score", model.f1) # Train Summaries train_summary_op = tf.summary.merge([loss_summary, f1_summary, grad_summaries_merged]) train_summary_dir = os.path.join(FLAGS.model_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Dev summaries dev_summary_op = tf.summary.merge([loss_summary, f1_summary]) dev_summary_dir = os.path.join(FLAGS.model_dir, "summaries", "dev") dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it checkpoint_dir = os.path.abspath(os.path.join(FLAGS.model_dir, "checkpoints")) checkpoint_prefix = os.path.join(checkpoint_dir, "model") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) graph_def = tf.get_default_graph().as_graph_def() with open(os.path.join(checkpoint_dir, "graphpb.txt"), 'w') as f: f.write(str(graph_def)) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) # Initialize all variables sess.run(tf.global_variables_initializer()) if FLAGS.word_embedding_type != 'rand': # initial matrix with random uniform # embedding_init = np.random.uniform(-0.25, 0.25, (len(vocab), FLAGS.embedding_dim)) embedding_init = np.zeros(shape=(len(vocab), FLAGS.embedding_dim)) # load vectors from the word2vec print("Initializing word embedding with pre-trained word2vec.") words, vectors = dataset.load_word2vec() for idx, w in enumerate(vocab): vec = vectors[words.index(w)] embedding_init[idx] = np.asarray(vec).astype(np.float32) print("Initialized word embedding") sess.run(model.W.assign(embedding_init)) # Generate batches data = dataset.process_data(data_file=FLAGS.data_file, sequence_length=FLAGS.max_document_length) # (x1, x2, y) train_data, eval_data = dataset.train_test_split(data, test_size=FLAGS.val_percentage, random_seed=FLAGS.random_seed) train_batches = dataset.batch_iter(train_data, FLAGS.batch_size, FLAGS.num_epochs, shuffle=True) print("Starting training...") F1_best = 0. last_improved_step = 0 for batch in train_batches: x1_batch, x2_batch, y_batch, seqlen1, seqlen2 = zip(*batch) # print(x1_batch[:3]) # print(y_batch[:3]) # if random.random() > 0.5: # x1_batch, x2_batch = x2_batch, x1_batch feed_dict = { model.seqlen1: seqlen1, model.seqlen2: seqlen2, model.input_x1: x1_batch, model.input_x2: x2_batch, model.input_y: y_batch, model.dropout_keep_prob: FLAGS.dropout_keep_prob, } _, step, loss, acc, precision, recall, F1, summaries = sess.run( [train_op, global_step, model.loss, model.acc, model.precision, model.recall, model.f1, train_summary_op], feed_dict) time_str = datetime.datetime.now().isoformat() if step % FLAGS.log_every_steps == 0: train_summary_writer.add_summary(summaries, step) print("{} step {} TRAIN loss={:g} acc={:.3f} P={:.3f} R={:.3f} F1={:.6f}".format( time_str, step, loss, acc, precision, recall, F1)) if step % FLAGS.evaluate_every_steps == 0: # eval x1_batch, x2_batch, y_batch, seqlen1, seqlen2 = zip(*eval_data) feed_dict = { model.seqlen1: seqlen1, model.seqlen2: seqlen2, model.input_x1: x1_batch, model.input_x2: x2_batch, model.input_y: y_batch, model.dropout_keep_prob: 1.0, } # x1, out1, out2, sim_euc, sim_cos, sim_ma, e = sess.run( # [model.embedded_1, model.out1, model.out2, model.sim_euc, model.sim_cos, model.sim_ma, model.e], feed_dict) # # print(x1) # sim_euc = [round(s, 2) for s in sim_euc[:30]] # sim_cos = [round(s, 2) for s in sim_cos[:30]] # sim_ma = [round(s, 2) for s in sim_ma[:30]] # e = [round(s, 2) for s in e[:30]] # # print(out1) # out1 = [round(s, 3) for s in out1[0]] # out2 = [round(s, 3) for s in out2[0]] # print(zip(out1, out2)) # for w in zip(y_batch[:30], e, sim_euc, sim_cos, sim_ma): # print(w) loss, acc, cm, precision, recall, F1, summaries = sess.run( [model.loss, model.acc, model.cm, model.precision, model.recall, model.f1, dev_summary_op], feed_dict) dev_summary_writer.add_summary(summaries, step) if F1 > F1_best: F1_best = F1 last_improved_step = step if F1_best > 0.5: path = saver.save(sess, checkpoint_prefix, global_step=step) print("Saved model with F1={} checkpoint to {}\n".format(F1_best, path)) improved_token = '*' else: improved_token = '' print("{} step {} DEV loss={:g} acc={:.3f} cm{} P={:.3f} R={:.3f} F1={:.6f} {}".format( time_str, step, loss, acc, cm, precision, recall, F1, improved_token)) # if step % FLAGS.checkpoint_every_steps == 0: # if F1 >= F1_best: # F1_best = F1 # path = saver.save(sess, checkpoint_prefix, global_step=step) # print("Saved model with F1={} checkpoint to {}\n".format(F1_best, path)) if step - last_improved_step > 4000: # 2000 steps print("No improvement for a long time, early-stopping at best F1={}".format(F1_best)) break
def train_rcnn(): """Training RCNN model.""" # Load sentences, labels, and training parameters logger.info("✔︎ Loading data...") logger.info("✔︎ Training data processing...") train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.embedding_dim) logger.info("✔︎ Validation data processing...") validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.embedding_dim) logger.info("Recommended padding Sequence length is: {0}".format( FLAGS.pad_seq_len)) logger.info("✔︎ Training data padding...") x_train_front, x_train_behind, y_train = dh.pad_data( train_data, FLAGS.pad_seq_len) logger.info("✔︎ Validation data padding...") x_validation_front, x_validation_behind, y_validation = dh.pad_data( validation_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = dh.load_word2vec_matrix( VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and rcnn object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): rcnn = TextRCNN(sequence_length=FLAGS.pad_seq_len, num_classes=y_train.shape[1], vocab_size=VOCAB_SIZE, lstm_hidden_size=FLAGS.lstm_hidden_size, fc_hidden_size=FLAGS.fc_hidden_size, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, filter_sizes=list( map(int, FLAGS.filter_sizes.split(','))), num_filters=FLAGS.num_filters, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay( learning_rate=FLAGS.learning_rate, global_step=rcnn.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(rcnn.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients( zip(grads, vars), global_step=rcnn.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram( "{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar( "{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input( "☛ Please input the checkpoints model you want to restore, " "it should be like(1490175368): " ) # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input( "✘ The format of your input is illegal, please re-input: " ) logger.info( "✔︎ The format of your input is legal, now loading to next step..." ) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) checkpoint_dir = os.path.abspath( os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath( os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", rcnn.loss) acc_summary = tf.summary.scalar("accuracy", rcnn.accuracy) # Train summaries train_summary_op = tf.summary.merge( [loss_summary, acc_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge( [loss_summary, acc_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter( validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if FLAGS.train_or_restore == 'R': # Load rcnn model logger.info("✔︎ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save( sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(rcnn.global_step) def train_step(x_batch_front, x_batch_behind, y_batch): """A single training step""" feed_dict = { rcnn.input_x_front: x_batch_front, rcnn.input_x_behind: x_batch_behind, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob, rcnn.is_training: True } _, step, summaries, loss, accuracy = sess.run([ train_op, rcnn.global_step, train_summary_op, rcnn.loss, rcnn.accuracy ], feed_dict) logger.info("step {0}: loss {1:g}, acc {2:g}".format( step, loss, accuracy)) train_summary_writer.add_summary(summaries, step) def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None): """Evaluates model on a validation set""" feed_dict = { rcnn.input_x_front: x_batch_front, rcnn.input_x_behind: x_batch_behind, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: 1.0, rcnn.is_training: False } step, summaries, loss, accuracy, recall, precision, f1, auc = sess.run( [ rcnn.global_step, validation_summary_op, rcnn.loss, rcnn.accuracy, rcnn.recall, rcnn.precision, rcnn.F1, rcnn.AUC ], feed_dict) logger.info( "step {0}: loss {1:g}, acc {2:g}, recall {3:g}, precision {4:g}, f1 {5:g}, AUC {6}" .format(step, loss, accuracy, recall, precision, f1, auc)) if writer: writer.add_summary(summaries, step) return accuracy # Generate batches batches = dh.batch_iter( list(zip(x_train_front, x_train_behind, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int( (len(x_train_front) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch in batches: x_batch_front, x_batch_behind, y_batch = zip(*batch) train_step(x_batch_front, x_batch_behind, y_batch) current_step = tf.train.global_step(sess, rcnn.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") accuracy = validation_step( x_validation_front, x_validation_behind, y_validation, writer=validation_summary_writer) best_saver.handle(accuracy, sess, current_step) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info( "✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info( "✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
if __name__ == "__main__": # 参数配置 epochs = 50 batch_size = 512 seq_length = 20 embeddings_size = 300 hidden_size = 256 num_layers = 2 num_classes = 9 learning_rate = 0.003 dropout = 0.3 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置随机种子 random.seed(2020) torch.manual_seed(2020) # 加载文本分类模型 TextRCNN model = TextRCNN(embeddings_size, num_classes, hidden_size, num_layers, True, dropout) model.to(device) # 定义损失函数和优化器 criterian = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=learning_rate) scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(1 + 0.05 * epoch)) print('-' * 100) train_and_test(model, optimizer, criterian, scheduler, batch_size, embeddings_size, seq_length, './saved_model/text_rcnn.pth', epochs, device)
def train_rcnn(): """Training RCNN model.""" # Load sentences, labels, and training parameters logger.info('✔︎ Loading data...') logger.info('✔︎ Training data processing...') train_data = data_helpers.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('✔︎ Validation data processing...') validation_data = \ data_helpers.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('Recommand padding Sequence length is: {}'.format(FLAGS.pad_seq_len)) logger.info('✔︎ Training data padding...') x_train, y_train = data_helpers.pad_data(train_data, FLAGS.pad_seq_len) logger.info('✔︎ Validation data padding...') x_validation, y_validation = data_helpers.pad_data(validation_data, FLAGS.pad_seq_len) y_validation_bind = validation_data.labels_bind # Build vocabulary VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and rcnn object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): rcnn = TextRCNN( sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, batch_size=FLAGS.batch_size, vocab_size=VOCAB_SIZE, hidden_size=FLAGS.embedding_dim, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define Training procedure optimizer = tf.train.AdamOptimizer(1e-3) grads_and_vars = optimizer.compute_gradients(rcnn.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=rcnn.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in grads_and_vars: if g is not None: grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input("☛ Please input the checkpoints model you want to restore: ") # 需要恢复的网络模型 while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input('✘ The format of your input is illegal, please re-input: ') logger.info('✔︎ The format of your input is legal, now loading to next step...') checkpoint_dir = 'runs/' + MODEL + '/checkpoints/' out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {}\n".format(out_dir)) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", rcnn.loss) # acc_summary = tf.summary.scalar("accuracy", rcnn.accuracy) # Train Summaries train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) if FLAGS.train_or_restore == 'R': # Load rcnn model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) current_step = sess.run(rcnn.global_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { rcnn.input_x: x_batch, rcnn.input_y: y_batch, rcnn.dropout_keep_prob: FLAGS.dropout_keep_prob } _, step, summaries, loss = sess.run( [train_op, rcnn.global_step, train_summary_op, rcnn.loss], feed_dict) time_str = datetime.datetime.now().isoformat() logger.info("{}: step {}, loss {:g}" .format(time_str, step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_validation, y_validation, y_validation_bind, writer=None): """Evaluates model on a validation set""" batches_validation = data_helpers.batch_iter( list(zip(x_validation, y_validation, y_validation_bind)), 8 * FLAGS.batch_size, FLAGS.num_epochs) eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0 for batch_validation in batches_validation: x_batch_validation, y_batch_validation, y_batch_validation_bind = zip(*batch_validation) feed_dict = { rcnn.input_x: x_batch_validation, rcnn.input_y: y_batch_validation, rcnn.dropout_keep_prob: 1.0 } step, summaries, logits, cur_loss = sess.run( [rcnn.global_step, validation_summary_op, rcnn.logits, rcnn.loss], feed_dict) predicted_labels = data_helpers.get_label_using_logits(logits, y_batch_validation_bind, top_number=FLAGS.top_num) cur_rec, cur_acc = 0.0, 0.0 for index, predicted_label in enumerate(predicted_labels): rec_inc, acc_inc = data_helpers.cal_rec_and_acc(predicted_label, y_batch_validation[index]) cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc cur_rec = cur_rec / len(y_batch_validation) cur_acc = cur_acc / len(y_batch_validation) eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \ eval_acc + cur_acc, eval_counter + 1 logger.info("✔︎ validation batch {} finished.".format(eval_counter)) if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec = float(eval_rec / eval_counter) eval_acc = float(eval_acc / eval_counter) return eval_loss, eval_rec, eval_acc # Generate batches batches_train = data_helpers.batch_iter( list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, rcnn.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_rec, eval_acc = validation_step(x_validation, y_validation, y_validation_bind, writer=validation_summary_writer) time_str = datetime.datetime.now().isoformat() logger.info("{}: step {}, loss {:g}, rec {:g}, acc {:g}" .format(time_str, current_step, eval_loss, eval_rec, eval_acc)) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("✔︎ Saved model checkpoint to {}\n".format(path)) logger.info("✔︎ Done.")
def main(_): # FLAGS._parse_flags() print("\nParameters:") for attr, value in sorted(FLAGS.__flags.items()): print("{}={}".format(attr.upper(), value)) print("") # Data Preparation # ================================================== # Load data print("Loading data...") x_text, y = data_helpers.load_data_and_labels(FLAGS.train_file, FLAGS.num_class) # Build vocabulary max_document_length = max([len(x.split(" ")) for x in x_text]) vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length) x = np.array(list(vocab_processor.fit_transform(x_text))) # Randomly shuffle data np.random.seed(10) shuffle_indices = np.random.permutation(np.arange(len(y))) x_shuffled = x[shuffle_indices] y_shuffled = y[shuffle_indices] # Split train/test set # TODO: This is very crude, should use cross-validation dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y))) x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] del x, y, x_shuffled, y_shuffled print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_))) print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev))) # Training # ================================================== with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) sess = tf.Session(config=session_conf) with sess.as_default(): cnn = TextRCNN( num_classes=y_train.shape[1], batch_size=FLAGS.batch_size, sequence_length=x_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=FLAGS.embedding_dim, hidden_size=FLAGS.hidden_size, is_training=True, learning_rate=1e-3, ) # Define Training procedure global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.AdamOptimizer(1e-3) grads_and_vars = optimizer.compute_gradients(cnn.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in grads_and_vars: if g is not None: grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries timestamp = str(int(time.time())) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) print("Writing to {}\n".format(out_dir)) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", cnn.loss) acc_summary = tf.summary.scalar("accuracy", cnn.accuracy) # Train Summaries train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Dev summaries dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) dev_summary_dir = os.path.join(out_dir, "summaries", "dev") dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) checkpoint_prefix = os.path.join(checkpoint_dir, "model") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) # Write vocabulary vocab_processor.save(os.path.join(out_dir, "vocab")) # Initialize all variables sess.run(tf.global_variables_initializer()) def train_step(x_batch, y_batch): """ A single training step """ feed_dict = { cnn.input_x: x_batch, cnn.input_y: y_batch, cnn.dropout_keep_prob: FLAGS.dropout_keep_prob } _, step, summaries, loss, accuracy = sess.run( [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], feed_dict) time_str = datetime.datetime.now().isoformat() print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) train_summary_writer.add_summary(summaries, step) def dev_step(x_batch, y_batch, writer=None): """ Evaluates model on a dev set """ feed_dict = { cnn.input_x: x_batch, cnn.input_y: y_batch, cnn.dropout_keep_prob: 1.0 } step, summaries, loss, accuracy = sess.run( [global_step, dev_summary_op, cnn.loss, cnn.accuracy], feed_dict) time_str = datetime.datetime.now().isoformat() print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy)) if writer: writer.add_summary(summaries, step) # Generate batches batches = data_helpers.batch_iter( list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) # Training loop. For each batch... for batch in batches: x_batch, y_batch = zip(*batch) train_step(x_batch, y_batch) current_step = tf.train.global_step(sess, global_step) if current_step % FLAGS.evaluate_every == 0: print("\nEvaluation:") dev_step(x_dev, y_dev, writer=dev_summary_writer) print("") if current_step % FLAGS.checkpoint_every == 0: path = saver.save(sess, checkpoint_prefix, global_step=current_step) print("Saved model checkpoint to {}\n".format(path))