def train_han(): """Training HAN model.""" # Print parameters used for the model dh.tab_printer(args, logger) # Load sentences, labels, and training parameters logger.info("Loading data...") logger.info("Data processing...") train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False) val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False) logger.info("Data padding...") x_train, y_train = dh.pad_data(train_data, args.pad_seq_len) x_val, y_val = dh.pad_data(val_data, args.pad_seq_len) # Build vocabulary VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file) # Build a graph and han object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=args.allow_soft_placement, log_device_placement=args.log_device_placement) session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): han = TextHAN( sequence_length=args.pad_seq_len, vocab_size=VOCAB_SIZE, embedding_type=args.embedding_type, embedding_size=EMBEDDING_SIZE, lstm_hidden_size=args.lstm_dim, fc_hidden_size=args.fc_dim, num_classes=args.num_classes, l2_reg_lambda=args.l2_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate, global_step=han.global_step, decay_steps=args.decay_steps, decay_rate=args.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(han.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio) train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries out_dir = dh.get_out_dir(OPTION, logger) checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss loss_summary = tf.summary.scalar("loss", han.loss) # Train summaries train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints) best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if OPTION == 'R': # Load han model logger.info("Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) if OPTION == 'T': if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = args.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(han.global_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { han.input_x: x_batch, han.input_y: y_batch, han.dropout_keep_prob: args.dropout_rate, han.is_training: True } _, step, summaries, loss = sess.run( [train_op, han.global_step, train_summary_op, han.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_val, y_val, writer=None): """Evaluates model on a validation set""" batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1) # Predict classes by threshold or topk ('ts': threshold; 'tk': topk) eval_counter, eval_loss = 0, 0.0 eval_pre_tk = [0.0] * args.topK eval_rec_tk = [0.0] * args.topK eval_F1_tk = [0.0] * args.topK true_onehot_labels = [] predicted_onehot_scores = [] predicted_onehot_labels_ts = [] predicted_onehot_labels_tk = [[] for _ in range(args.topK)] for batch_validation in batches_validation: x_batch_val, y_batch_val = zip(*batch_validation) feed_dict = { han.input_x: x_batch_val, han.input_y: y_batch_val, han.dropout_keep_prob: 1.0, han.is_training: False } step, summaries, scores, cur_loss = sess.run( [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict) # Prepare for calculating metrics for i in y_batch_val: true_onehot_labels.append(i) for j in scores: predicted_onehot_scores.append(j) # Predict by threshold batch_predicted_onehot_labels_ts = \ dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold) for k in batch_predicted_onehot_labels_ts: predicted_onehot_labels_ts.append(k) # Predict by topK for top_num in range(args.topK): batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1) for i in batch_predicted_onehot_labels_tk: predicted_onehot_labels_tk[top_num].append(i) eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) # Calculate Precision & Recall & F1 eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_ts), average='micro') for top_num in range(args.topK): eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels), y_pred=np.array(predicted_onehot_labels_tk[top_num]), average='micro') # Calculate the average AUC eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') # Calculate the average PR eval_prc = average_precision_score(y_true=np.array(true_onehot_labels), y_score=np.array(predicted_onehot_scores), average='micro') return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \ eval_pre_tk, eval_rec_tk, eval_F1_tk # Generate batches batches_train = dh.batch_iter( list(zip(x_train, y_train)), args.batch_size, args.epochs) num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, han.global_step) if current_step % args.evaluate_steps == 0: logger.info("\nEvaluation:") eval_loss, eval_auc, eval_prc, \ eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \ validation_step(x_val, y_val, writer=validation_summary_writer) logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}" .format(eval_loss, eval_auc, eval_prc)) # Predict by threshold logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}" .format(eval_pre_ts, eval_rec_ts, eval_F1_ts)) # Predict by topK logger.info("Predict by topK:") for top_num in range(args.topK): logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}" .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num])) best_saver.handle(eval_prc, sess, current_step) if current_step % args.checkpoint_steps == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info("Epoch {0} has finished!".format(current_epoch)) logger.info("All Done.")
def train_han(): """Training HAN model.""" # Load sentences, labels, and training parameters logger.info('✔︎ Loading data...') logger.info('✔︎ Training data processing...') train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False) logger.info('✔︎ Validation data processing...') validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim, data_aug_flag=False) logger.info('Recommended padding Sequence length is: {0}'.format(FLAGS.pad_seq_len)) logger.info('✔︎ Training data padding...') x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len) logger.info('✔︎ Validation data padding...') x_validation, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and han object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): han = TextHAN( sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, vocab_size=VOCAB_SIZE, lstm_hidden_size=FLAGS.lstm_hidden_size, fc_hidden_size=FLAGS.fc_hidden_size, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, global_step=han.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(han.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input("☛ Please input the checkpoints model you want to restore, " "it should be like(1490175368): ") # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input('✘ The format of your input is illegal, please re-input: ') logger.info('✔︎ The format of your input is legal, now loading to next step...') checkpoint_dir = 'runs/' + MODEL + '/checkpoints/' out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) # Summaries for loss loss_summary = tf.summary.scalar("loss", han.loss) # Train summaries train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) if FLAGS.train_or_restore == 'R': # Load han model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = 'embedding' embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save(sess, os.path.join(out_dir, 'embedding', 'embedding.ckpt')) current_step = sess.run(han.global_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { han.input_x: x_batch, han.input_y: y_batch, han.dropout_keep_prob: FLAGS.dropout_keep_prob, han.is_training: True } _, step, summaries, loss = sess.run( [train_op, han.global_step, train_summary_op, han.loss], feed_dict) logger.info("step {0}: loss {1:g}".format(step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_validation, y_validation, writer=None): """Evaluates model on a validation set""" batches_validation = dh.batch_iter( list(zip(x_validation, y_validation)), FLAGS.batch_size, 1) # Predict classes by threshold or topk ('ts': threshold; 'tk': topk) eval_counter, eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0 eval_rec_tk = [0.0] * FLAGS.top_num eval_pre_tk = [0.0] * FLAGS.top_num eval_F_tk = [0.0] * FLAGS.top_num for batch_validation in batches_validation: x_batch_validation, y_batch_validation = zip(*batch_validation) feed_dict = { han.input_x: x_batch_validation, han.input_y: y_batch_validation, han.dropout_keep_prob: 1.0, han.is_training: False } step, summaries, scores, cur_loss = sess.run( [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict) # Predict by threshold predicted_labels_threshold, predicted_values_threshold = \ dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold) cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0 for index, predicted_label_threshold in enumerate(predicted_labels_threshold): rec_inc_ts, pre_inc_ts = dh.cal_metric(predicted_label_threshold, y_batch_validation[index]) cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts cur_rec_ts = cur_rec_ts / len(y_batch_validation) cur_pre_ts = cur_pre_ts / len(y_batch_validation) cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts) eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts # Predict by topK topK_predicted_labels = [] for top_num in range(FLAGS.top_num): predicted_labels_topk, predicted_values_topk = \ dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1) topK_predicted_labels.append(predicted_labels_topk) cur_rec_tk = [0.0] * FLAGS.top_num cur_pre_tk = [0.0] * FLAGS.top_num cur_F_tk = [0.0] * FLAGS.top_num for top_num, predicted_labels_topK in enumerate(topK_predicted_labels): for index, predicted_label_topK in enumerate(predicted_labels_topK): rec_inc_tk, pre_inc_tk = dh.cal_metric(predicted_label_topK, y_batch_validation[index]) cur_rec_tk[top_num], cur_pre_tk[top_num] = \ cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_validation) cur_pre_tk[top_num] = cur_pre_tk[top_num] / len(y_batch_validation) cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num], cur_pre_tk[top_num]) eval_rec_tk[top_num], eval_pre_tk[top_num] = \ eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num] eval_loss = eval_loss + cur_loss eval_counter = eval_counter + 1 logger.info("✔︎ validation batch {0}: loss {1:g}".format(eval_counter, cur_loss)) logger.info("︎☛ Predict by threshold: recall {0:g}, precision {1:g}, F {2:g}" .format(cur_rec_ts, cur_pre_ts, cur_F_ts)) logger.info("︎☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info("Top{0}: recall {1:g}, precision {2:g}, F {3:g}" .format(top_num + 1, cur_rec_tk[top_num], cur_pre_tk[top_num], cur_F_tk[top_num])) if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec_ts = float(eval_rec_ts / eval_counter) eval_pre_ts = float(eval_pre_ts / eval_counter) eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts) for top_num in range(FLAGS.top_num): eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter) eval_pre_tk[top_num] = float(eval_pre_tk[top_num] / eval_counter) eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num], eval_pre_tk[top_num]) return eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk # Generate batches batches_train = dh.batch_iter( list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, han.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \ validation_step(x_validation, y_validation, writer=validation_summary_writer) logger.info("All Validation set: Loss {0:g}".format(eval_loss)) # Predict by threshold logger.info("︎☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}" .format(eval_rec_ts, eval_pre_ts, eval_F_ts)) # Predict by topK logger.info("︎☛ Predict by topK:") for top_num in range(FLAGS.top_num): logger.info("Top{0}: Recall {1:g}, Precision {2:g}, F {3:g}" .format(top_num+1, eval_rec_tk[top_num], eval_pre_tk[top_num], eval_F_tk[top_num])) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info("✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
# batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) elif FLAGS.using_nn_type == 'textrcnn': nn = TextRCNN(model_type=FLAGS.model_type, sequence_length=x_train.shape[1], num_classes=y_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=embedding_dimension, batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) elif FLAGS.using_nn_type == 'texthan': nn = TextHAN(model_type=FLAGS.model_type, sequence_length=x_train.shape[1], num_sentences=3, num_classes=y_train.shape[1], vocab_size=len(vocab_processor.vocabulary_), embedding_size=embedding_dimension, hidden_size=FLAGS.rnn_size, batch_size=FLAGS.batch_size, l2_reg_lambda=FLAGS.l2_reg_lambda) # Define Training procedure global_step = tf.Variable(0, name="global_step", trainable=False) optimizer = tf.train.AdamOptimizer(nn.learning_rate) # Clip the gradient to avoid larger ones tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(nn.loss, tvars), FLAGS.grad_clip) # grads_and_vars = optimizer.compute_gradients(nn.loss) grads_and_vars = tuple(zip(grads, tvars)) train_op = optimizer.apply_gradients(grads_and_vars,
def train_han(): """Training FASTTEXT model.""" # Load sentences, labels, and training parameters logger.info('✔︎ Loading data...') logger.info('✔︎ Training data processing...') train_data = data_helpers.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('✔︎ Validation data processing...') validation_data = \ data_helpers.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim) logger.info('Recommand padding Sequence length is: {}'.format( FLAGS.pad_seq_len)) logger.info('✔︎ Training data padding...') x_train, y_train = data_helpers.pad_data(train_data, FLAGS.pad_seq_len) logger.info('✔︎ Validation data padding...') x_validation, y_validation = data_helpers.pad_data(validation_data, FLAGS.pad_seq_len) y_validation_bind = validation_data.labels_bind # Build vocabulary VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim) pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix( VOCAB_SIZE, FLAGS.embedding_dim) # Build a graph and han object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): han = TextHAN(sequence_length=FLAGS.pad_seq_len, num_classes=FLAGS.num_classes, batch_size=FLAGS.batch_size, vocab_size=VOCAB_SIZE, hidden_size=FLAGS.embedding_dim, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define Training procedure optimizer = tf.train.AdamOptimizer(1e-3) grads_and_vars = optimizer.compute_gradients(han.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=han.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in grads_and_vars: if g is not None: grad_hist_summary = tf.summary.histogram( "{}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar( "{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input( "☛ Please input the checkpoints model you want to restore: " ) # 需要恢复的网络模型 while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input( '✘ The format of your input is illegal, please re-input: ' ) logger.info( '✔︎ The format of your input is legal, now loading to next step...' ) checkpoint_dir = 'runs/' + MODEL + '/checkpoints/' out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath( os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {}\n".format(out_dir)) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", han.loss) # acc_summary = tf.summary.scalar("accuracy", han.accuracy) # Train Summaries train_summary_op = tf.summary.merge( [loss_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter( validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) if FLAGS.train_or_restore == 'R': # Load han model logger.info("✔ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph( "{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: checkpoint_dir = os.path.abspath( os.path.join(out_dir, "checkpoints")) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) current_step = sess.run(han.global_step) def train_step(x_batch, y_batch): """A single training step""" feed_dict = { han.input_x: x_batch, han.input_y: y_batch, han.dropout_keep_prob: FLAGS.dropout_keep_prob } _, step, summaries, loss = sess.run( [train_op, han.global_step, train_summary_op, han.loss], feed_dict) time_str = datetime.datetime.now().isoformat() logger.info("{}: step {}, loss {:g}".format( time_str, step, loss)) train_summary_writer.add_summary(summaries, step) def validation_step(x_validation, y_validation, y_validation_bind, writer=None): """Evaluates model on a validation set""" batches_validation = data_helpers.batch_iter( list(zip(x_validation, y_validation, y_validation_bind)), 8 * FLAGS.batch_size, FLAGS.num_epochs) eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0 for batch_validation in batches_validation: x_batch_validation, y_batch_validation, y_batch_validation_bind = zip( *batch_validation) feed_dict = { han.input_x: x_batch_validation, han.input_y: y_batch_validation, han.dropout_keep_prob: 1.0 } step, summaries, logits, cur_loss = sess.run([ han.global_step, validation_summary_op, han.logits, han.loss ], feed_dict) predicted_labels = data_helpers.get_label_using_logits( logits, y_batch_validation_bind, top_number=FLAGS.top_num) cur_rec, cur_acc = 0.0, 0.0 for index, predicted_label in enumerate(predicted_labels): rec_inc, acc_inc = data_helpers.cal_rec_and_acc( predicted_label, y_batch_validation[index]) cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc cur_rec = cur_rec / len(y_batch_validation) cur_acc = cur_acc / len(y_batch_validation) eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \ eval_acc + cur_acc, eval_counter + 1 logger.info("✔︎ validation batch {} finished.".format( eval_counter)) if writer: writer.add_summary(summaries, step) eval_loss = float(eval_loss / eval_counter) eval_rec = float(eval_rec / eval_counter) eval_acc = float(eval_acc / eval_counter) return eval_loss, eval_rec, eval_acc # Generate batches batches_train = data_helpers.batch_iter( list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs) # Training loop. For each batch... for batch_train in batches_train: x_batch_train, y_batch_train = zip(*batch_train) train_step(x_batch_train, y_batch_train) current_step = tf.train.global_step(sess, han.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") eval_loss, eval_rec, eval_acc = validation_step( x_validation, y_validation, y_validation_bind, writer=validation_summary_writer) time_str = datetime.datetime.now().isoformat() logger.info( "{}: step {}, loss {:g}, rec {:g}, acc {:g}".format( time_str, current_step, eval_loss, eval_rec, eval_acc)) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info( "✔︎ Saved model checkpoint to {}\n".format(path)) logger.info("✔︎ Done.")
def train_han(): """Training HAN model.""" # Load sentences, labels, and training parameters logger.info("✔︎ Loading data...") logger.info("✔︎ Training data processing...") train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.embedding_dim) logger.info("✔︎ Validation data processing...") validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.embedding_dim) logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len)) logger.info("✔︎ Training data padding...") x_train_front, x_train_behind, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len) logger.info("✔︎ Validation data padding...") x_validation_front, x_validation_behind, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len) # Build vocabulary VOCAB_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(FLAGS.embedding_dim) # Build a graph and han object with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth sess = tf.Session(config=session_conf) with sess.as_default(): han = TextHAN( sequence_length=FLAGS.pad_seq_len, num_classes=y_train.shape[1], vocab_size=VOCAB_SIZE, lstm_hidden_size=FLAGS.lstm_hidden_size, fc_hidden_size=FLAGS.fc_hidden_size, embedding_size=FLAGS.embedding_dim, embedding_type=FLAGS.embedding_type, l2_reg_lambda=FLAGS.l2_reg_lambda, pretrained_embedding=pretrained_word2vec_matrix) # Define training procedure with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate, global_step=han.global_step, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate) grads, vars = zip(*optimizer.compute_gradients(han.loss)) grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio) train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op") # Keep track of gradient values and sparsity (optional) grad_summaries = [] for g, v in zip(grads, vars): if g is not None: grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g) sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g)) grad_summaries.append(grad_hist_summary) grad_summaries.append(sparsity_summary) grad_summaries_merged = tf.summary.merge(grad_summaries) # Output directory for models and summaries if FLAGS.train_or_restore == 'R': MODEL = input("☛ Please input the checkpoints model you want to restore, " "it should be like(1490175368): ") # The model you want to restore while not (MODEL.isdigit() and len(MODEL) == 10): MODEL = input("✘ The format of your input is illegal, please re-input: ") logger.info("✔︎ The format of your input is legal, now loading to next step...") out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) else: timestamp = str(int(time.time())) out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp)) logger.info("✔︎ Writing to {0}\n".format(out_dir)) checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints")) best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints")) # Summaries for loss and accuracy loss_summary = tf.summary.scalar("loss", han.loss) acc_summary = tf.summary.scalar("accuracy", han.accuracy) # Train summaries train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged]) train_summary_dir = os.path.join(out_dir, "summaries", "train") train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) # Validation summaries validation_summary_op = tf.summary.merge([loss_summary, acc_summary]) validation_summary_dir = os.path.join(out_dir, "summaries", "validation") validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True) if FLAGS.train_or_restore == 'R': # Load han model logger.info("✔︎ Loading model...") checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) logger.info(checkpoint_file) # Load the saved meta graph and restore variables saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) else: if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Embedding visualization config config = projector.ProjectorConfig() embedding_conf = config.embeddings.add() embedding_conf.tensor_name = "embedding" embedding_conf.metadata_path = FLAGS.metadata_file projector.visualize_embeddings(train_summary_writer, config) projector.visualize_embeddings(validation_summary_writer, config) # Save the embedding visualization saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt")) current_step = sess.run(han.global_step) def train_step(x_batch_front, x_batch_behind, y_batch): """A single training step""" feed_dict = { han.input_x_front: x_batch_front, han.input_x_behind: x_batch_behind, han.input_y: y_batch, han.dropout_keep_prob: FLAGS.dropout_keep_prob, han.is_training: True } _, step, summaries, loss, accuracy = sess.run( [train_op, han.global_step, train_summary_op, han.loss, han.accuracy], feed_dict) logger.info("step {0}: loss {1:g}, acc {2:g}".format(step, loss, accuracy)) train_summary_writer.add_summary(summaries, step) def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None): """Evaluates model on a validation set""" feed_dict = { han.input_x_front: x_batch_front, han.input_x_behind: x_batch_behind, han.input_y: y_batch, han.dropout_keep_prob: 1.0, han.is_training: False } step, summaries, loss, accuracy, recall, precision, f1, auc = sess.run( [han.global_step, validation_summary_op, han.loss, han.accuracy, han.recall, han.precision, han.F1, han.AUC], feed_dict) logger.info("step {0}: loss {1:g}, acc {2:g}, recall {3:g}, precision {4:g}, f1 {5:g}, AUC {6}" .format(step, loss, accuracy, recall, precision, f1, auc)) if writer: writer.add_summary(summaries, step) return accuracy # Generate batches batches = dh.batch_iter( list(zip(x_train_front, x_train_behind, y_train)), FLAGS.batch_size, FLAGS.num_epochs) num_batches_per_epoch = int((len(x_train_front) - 1) / FLAGS.batch_size) + 1 # Training loop. For each batch... for batch in batches: x_batch_front, x_batch_behind, y_batch = zip(*batch) train_step(x_batch_front, x_batch_behind, y_batch) current_step = tf.train.global_step(sess, han.global_step) if current_step % FLAGS.evaluate_every == 0: logger.info("\nEvaluation:") accuracy = validation_step(x_validation_front, x_validation_behind, y_validation, writer=validation_summary_writer) best_saver.handle(accuracy, sess, current_step) if current_step % FLAGS.checkpoint_every == 0: checkpoint_prefix = os.path.join(checkpoint_dir, "model") path = saver.save(sess, checkpoint_prefix, global_step=current_step) logger.info("✔︎ Saved model checkpoint to {0}\n".format(path)) if current_step % num_batches_per_epoch == 0: current_epoch = current_step // num_batches_per_epoch logger.info("✔︎ Epoch {0} has finished!".format(current_epoch)) logger.info("✔︎ Done.")
def main(_): print('Loading word2vec model finished:%s' % (FLAGS.word_embedding_file)) #w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, FLAGS.embedding_size) w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 256) print('Load word2vec model finished') print('Loading train/valid samples:%s' % (FLAGS.training_data)) train_x, train_y, valid_x, valid_y = loadSamples( FLAGS.training_data, FLAGS.label_file, FLAGS.label_map, FLAGS.eval_data_file, word2id, FLAGS.valid_rate, FLAGS.num_classes, FLAGS.sent_len, FLAGS.doc_len) print('Load train/valid samples finished') labelNumStats(valid_y) train_sample_size = len(train_x) dev_sample_size = len(valid_x) print('Training sample size:%d' % (train_sample_size)) print('Valid sample size:%d' % (dev_sample_size)) timestamp = str(int(time.time())) runs_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs')) if not os.path.exists(runs_dir): os.makedirs(runs_dir) out_dir = os.path.abspath(os.path.join(runs_dir, timestamp)) if not os.path.exists(out_dir): os.makedirs(out_dir) checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints')) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) checkpoint_prefix = os.path.join(checkpoint_dir, 'model') with tf.Graph().as_default(): session_conf = tf.ConfigProto( allow_soft_placement=FLAGS.allow_soft_placement, log_device_placement=FLAGS.log_device_placement) session_conf.gpu_options.allow_growth = True sess = tf.Session(config=session_conf) #sess = tf.Session() with sess.as_default(), tf.device('/gpu:0'): text_han = TextHAN(num_classes=FLAGS.num_classes, learning_rate=FLAGS.learning_rate, decay_steps=FLAGS.decay_steps, decay_rate=FLAGS.decay_rate, l2_reg_lambda=FLAGS.l2_reg_lambda, embedding_size=FLAGS.embedding_size, doc_len=FLAGS.doc_len, sent_len=FLAGS.sent_len, w2v_model=w2v_model, rnn_hidden_size=FLAGS.rnn_hidden_size, fc_layer_size=FLAGS.fc_layer_size) print('delete word2id') word2id = {} print('delete w2v_model') w2v_model = [] saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints) train_summary_dir = os.path.join(out_dir, 'summaries', 'train') dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev') loss_summary = tf.summary.scalar('loss', text_han.loss_val) acc_summary = tf.summary.scalar('accuracy', text_han.accuracy) train_summary_op = tf.summary.merge([loss_summary, acc_summary]) train_summary_writer = tf.summary.FileWriter( train_summary_dir, sess.graph) dev_summary_op = tf.summary.merge([loss_summary, acc_summary]) dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph) sess.run(tf.global_variables_initializer()) total_loss = 0. total_acc = 0. total_step = 0. best_valid_acc = 0. best_valid_loss = 1000. best_valid_zhihu_score = 0. this_step_valid_acc = 0. this_step_valid_loss = 0. this_step_zhihu_score = 0. valid_loss_summary = tf.summary.scalar('loss', this_step_valid_loss) valid_acc_summary = tf.summary.scalar('accuracy', this_step_valid_acc) valid_zhihu_score_summary = tf.summary.scalar( 'zhihu_score', this_step_zhihu_score) valid_summary_op = tf.summary.merge([ valid_loss_summary, valid_acc_summary, valid_zhihu_score_summary ]) for epoch in range(0, FLAGS.num_epochs): print('epoch:' + str(epoch)) if FLAGS.shuffle: shuffle_indices = np.random.permutation( np.arange(train_sample_size)) train_x = train_x[shuffle_indices] train_y = train_y[shuffle_indices] batch_step = 0 batch_loss = 0. batch_acc = 0. for start, end in zip( range(0, train_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, train_sample_size, FLAGS.batch_size)): batch_input_x = train_x[start:end] batch_input_y = train_y[start:end] batch_input_x, mask = paddingX(batch_input_x, FLAGS.sent_len, FLAGS.doc_len) batch_input_y = paddingY(batch_input_y, FLAGS.num_classes) feed_dict = { text_han.input_x: batch_input_x, text_han.input_y: batch_input_y, text_han.mask: mask, text_han.l1_dropout_keep_prob: FLAGS.l1_dropout_keep_prob, text_han.l2_dropout_keep_prob: FLAGS.l2_dropout_keep_prob } loss, acc, step, summaries, _ = sess.run([ text_han.loss_val, text_han.accuracy, text_han.global_step, train_summary_op, text_han.train_op ], feed_dict) train_summary_writer.add_summary(summaries, step) total_loss += loss total_acc += acc batch_loss += loss batch_acc += acc batch_step += 1 total_step += 1. if batch_step % FLAGS.print_stats_every == 0: time_str = datetime.datetime.now().isoformat() print( '[%s]Epoch:%d\tBatch_Step:%d\tTrain_Loss:%.4f/%.4f/%.4f\tTrain_Accuracy:%.4f/%.4f/%.4f' % (time_str, epoch, batch_step, loss, batch_loss / batch_step, total_loss / total_step, acc, batch_acc / batch_step, total_acc / total_step)) if batch_step % FLAGS.evaluate_every == 0 and total_step > 0: eval_loss = 0. eval_acc = 0. eval_step = 0 for start, end in zip( range(0, dev_sample_size, FLAGS.batch_size), range(FLAGS.batch_size, dev_sample_size, FLAGS.batch_size)): batch_input_x = valid_x[start:end] batch_input_x, mask = paddingX( batch_input_x, FLAGS.sent_len, FLAGS.doc_len) batch_input_y = valid_y[start:end] batch_input_y = paddingY(batch_input_y, FLAGS.num_classes) feed_dict = { text_han.input_x: batch_input_x, text_han.input_y: batch_input_y, text_han.mask: mask, text_han.l1_dropout_keep_prob: FLAGS.l1_dropout_keep_prob, text_han.l2_dropout_keep_prob: FLAGS.l2_dropout_keep_prob } step, summaries, loss, acc, logits = sess.run([ text_han.global_step, dev_summary_op, text_han.loss_val, text_han.accuracy, text_han.logits ], feed_dict) dev_summary_writer.add_summary(summaries, step) zhihuStats(logits, batch_input_y) #valid_y[start:end]) eval_loss += loss eval_acc += acc eval_step += 1 this_step_zhihu_score = calZhihuScore() time_str = datetime.datetime.now().isoformat() print( '[%s]Eval_Loss:%.4f\tEval_Accuracy:%.4f\tZhihu_Score:%.4f' % (time_str, eval_loss / eval_step, eval_acc / eval_step, this_step_zhihu_score)) this_step_valid_acc = eval_acc / eval_step this_step_valid_loss = eval_loss / eval_step #dev_summary_writer.add_summary(summaries, step) if batch_step % FLAGS.checkpoint_every == 0 and total_step > 0: if not FLAGS.save_best_model: path = saver.save(sess, checkpoint_prefix, global_step=step) print('Saved model checkpoint to %s' % path) elif this_step_zhihu_score > best_valid_zhihu_score: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best zhihu_score model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_zhihu_score = this_step_zhihu_score elif this_step_valid_acc > best_valid_acc: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best acc model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_acc = this_step_valid_acc elif this_step_valid_loss < best_valid_loss: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved best loss model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score)) best_valid_loss = this_step_valid_loss elif total_step % 22000 == 0: path = saver.save(sess, checkpoint_prefix, global_step=step) print( 'Saved model checkpoint to %s[%.4f,%.4f,%.4f]' % (path, this_step_valid_loss, this_step_valid_acc, this_step_zhihu_score))