def train_han():
    """Training HAN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                lstm_hidden_size=args.lstm_dim,
                fc_hidden_size=args.fc_dim,
                num_classes=args.num_classes,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=han.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if OPTION == 'R':
                # Load han model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: args.dropout_rate,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * args.topK
                eval_rec_tk = [0.0] * args.topK
                eval_F1_tk = [0.0] * args.topK

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_val,
                        han.input_y: y_batch_val,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(args.topK):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                      y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(args.topK):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                   y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
                       eval_pre_tk, eval_rec_tk, eval_F1_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F1_ts))

                    # Predict by topK
                    logger.info("Predict by topK:")
                    for top_num in range(args.topK):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
def train_han():
    """Training HAN model."""

    # Load sentences, labels, and training parameters
    logger.info('✔︎ Loading data...')

    logger.info('✔︎ Training data processing...')
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.num_classes,
                                         FLAGS.embedding_dim, data_aug_flag=False)

    logger.info('✔︎ Validation data processing...')
    validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes,
                                              FLAGS.embedding_dim, data_aug_flag=False)

    logger.info('Recommended padding Sequence length is: {0}'.format(FLAGS.pad_seq_len))

    logger.info('✔︎ Training data padding...')
    x_train, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info('✔︎ Validation data padding...')
    x_validation, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE = dh.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = dh.load_word2vec_matrix(VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=FLAGS.num_classes,
                vocab_size=VOCAB_SIZE,
                lstm_hidden_size=FLAGS.lstm_hidden_size,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=han.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input('✘ The format of your input is illegal, please re-input: ')
                logger.info('✔︎ The format of your input is legal, now loading to next step...')

                checkpoint_dir = 'runs/' + MODEL + '/checkpoints/'

                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

            if FLAGS.train_or_restore == 'R':
                # Load han model
                logger.info("✔ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = 'embedding'
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, 'embedding', 'embedding.ckpt'))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_validation, y_validation, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(
                    list(zip(x_validation, y_validation)), FLAGS.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts = 0, 0.0, 0.0, 0.0, 0.0
                eval_rec_tk = [0.0] * FLAGS.top_num
                eval_pre_tk = [0.0] * FLAGS.top_num
                eval_F_tk = [0.0] * FLAGS.top_num

                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_validation,
                        han.input_y: y_batch_validation,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Predict by threshold
                    predicted_labels_threshold, predicted_values_threshold = \
                        dh.get_label_using_scores_by_threshold(scores=scores, threshold=FLAGS.threshold)

                    cur_rec_ts, cur_pre_ts, cur_F_ts = 0.0, 0.0, 0.0

                    for index, predicted_label_threshold in enumerate(predicted_labels_threshold):
                        rec_inc_ts, pre_inc_ts = dh.cal_metric(predicted_label_threshold, y_batch_validation[index])
                        cur_rec_ts, cur_pre_ts = cur_rec_ts + rec_inc_ts, cur_pre_ts + pre_inc_ts

                    cur_rec_ts = cur_rec_ts / len(y_batch_validation)
                    cur_pre_ts = cur_pre_ts / len(y_batch_validation)

                    cur_F_ts = dh.cal_F(cur_rec_ts, cur_pre_ts)

                    eval_rec_ts, eval_pre_ts = eval_rec_ts + cur_rec_ts, eval_pre_ts + cur_pre_ts

                    # Predict by topK
                    topK_predicted_labels = []
                    for top_num in range(FLAGS.top_num):
                        predicted_labels_topk, predicted_values_topk = \
                            dh.get_label_using_scores_by_topk(scores=scores, top_num=top_num+1)
                        topK_predicted_labels.append(predicted_labels_topk)

                    cur_rec_tk = [0.0] * FLAGS.top_num
                    cur_pre_tk = [0.0] * FLAGS.top_num
                    cur_F_tk = [0.0] * FLAGS.top_num

                    for top_num, predicted_labels_topK in enumerate(topK_predicted_labels):
                        for index, predicted_label_topK in enumerate(predicted_labels_topK):
                            rec_inc_tk, pre_inc_tk = dh.cal_metric(predicted_label_topK, y_batch_validation[index])
                            cur_rec_tk[top_num], cur_pre_tk[top_num] = \
                                cur_rec_tk[top_num] + rec_inc_tk, cur_pre_tk[top_num] + pre_inc_tk

                        cur_rec_tk[top_num] = cur_rec_tk[top_num] / len(y_batch_validation)
                        cur_pre_tk[top_num] = cur_pre_tk[top_num] / len(y_batch_validation)

                        cur_F_tk[top_num] = dh.cal_F(cur_rec_tk[top_num], cur_pre_tk[top_num])

                        eval_rec_tk[top_num], eval_pre_tk[top_num] = \
                            eval_rec_tk[top_num] + cur_rec_tk[top_num], eval_pre_tk[top_num] + cur_pre_tk[top_num]

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    logger.info("✔︎ validation batch {0}: loss {1:g}".format(eval_counter, cur_loss))
                    logger.info("︎☛ Predict by threshold: recall {0:g}, precision {1:g}, F {2:g}"
                                .format(cur_rec_ts, cur_pre_ts, cur_F_ts))

                    logger.info("︎☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: recall {1:g}, precision {2:g}, F {3:g}"
                                    .format(top_num + 1, cur_rec_tk[top_num], cur_pre_tk[top_num], cur_F_tk[top_num]))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec_ts = float(eval_rec_ts / eval_counter)
                eval_pre_ts = float(eval_pre_ts / eval_counter)
                eval_F_ts = dh.cal_F(eval_rec_ts, eval_pre_ts)

                for top_num in range(FLAGS.top_num):
                    eval_rec_tk[top_num] = float(eval_rec_tk[top_num] / eval_counter)
                    eval_pre_tk[top_num] = float(eval_pre_tk[top_num] / eval_counter)
                    eval_F_tk[top_num] = dh.cal_F(eval_rec_tk[top_num], eval_pre_tk[top_num])

                return eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_rec_ts, eval_pre_ts, eval_F_ts, eval_rec_tk, eval_pre_tk, eval_F_tk = \
                        validation_step(x_validation, y_validation, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g}".format(eval_loss))

                    # Predict by threshold
                    logger.info("︎☛ Predict by threshold: Recall {0:g}, Precision {1:g}, F {2:g}"
                                .format(eval_rec_ts, eval_pre_ts, eval_F_ts))

                    # Predict by topK
                    logger.info("︎☛ Predict by topK:")
                    for top_num in range(FLAGS.top_num):
                        logger.info("Top{0}: Recall {1:g}, Precision {2:g}, F {3:g}"
                                    .format(top_num+1, eval_rec_tk[top_num], eval_pre_tk[top_num], eval_F_tk[top_num]))
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Example #3
0
                # batch_size=FLAGS.batch_size,
                l2_reg_lambda=FLAGS.l2_reg_lambda)
        elif FLAGS.using_nn_type == 'textrcnn':
            nn = TextRCNN(model_type=FLAGS.model_type,
                          sequence_length=x_train.shape[1],
                          num_classes=y_train.shape[1],
                          vocab_size=len(vocab_processor.vocabulary_),
                          embedding_size=embedding_dimension,
                          batch_size=FLAGS.batch_size,
                          l2_reg_lambda=FLAGS.l2_reg_lambda)
        elif FLAGS.using_nn_type == 'texthan':
            nn = TextHAN(model_type=FLAGS.model_type,
                         sequence_length=x_train.shape[1],
                         num_sentences=3,
                         num_classes=y_train.shape[1],
                         vocab_size=len(vocab_processor.vocabulary_),
                         embedding_size=embedding_dimension,
                         hidden_size=FLAGS.rnn_size,
                         batch_size=FLAGS.batch_size,
                         l2_reg_lambda=FLAGS.l2_reg_lambda)

        # Define Training procedure
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.AdamOptimizer(nn.learning_rate)
        # Clip the gradient to avoid larger ones
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(nn.loss, tvars),
                                          FLAGS.grad_clip)
        # grads_and_vars = optimizer.compute_gradients(nn.loss)
        grads_and_vars = tuple(zip(grads, tvars))
        train_op = optimizer.apply_gradients(grads_and_vars,
Example #4
0
def train_han():
    """Training FASTTEXT model."""
    # Load sentences, labels, and training parameters
    logger.info('✔︎ Loading data...')

    logger.info('✔︎ Training data processing...')
    train_data = data_helpers.load_data_and_labels(FLAGS.training_data_file,
                                                   FLAGS.num_classes,
                                                   FLAGS.embedding_dim)

    logger.info('✔︎ Validation data processing...')
    validation_data = \
        data_helpers.load_data_and_labels(FLAGS.validation_data_file, FLAGS.num_classes, FLAGS.embedding_dim)

    logger.info('Recommand padding Sequence length is: {}'.format(
        FLAGS.pad_seq_len))

    logger.info('✔︎ Training data padding...')
    x_train, y_train = data_helpers.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info('✔︎ Validation data padding...')
    x_validation, y_validation = data_helpers.pad_data(validation_data,
                                                       FLAGS.pad_seq_len)

    y_validation_bind = validation_data.labels_bind

    # Build vocabulary
    VOCAB_SIZE = data_helpers.load_vocab_size(FLAGS.embedding_dim)
    pretrained_word2vec_matrix = data_helpers.load_word2vec_matrix(
        VOCAB_SIZE, FLAGS.embedding_dim)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(sequence_length=FLAGS.pad_seq_len,
                          num_classes=FLAGS.num_classes,
                          batch_size=FLAGS.batch_size,
                          vocab_size=VOCAB_SIZE,
                          hidden_size=FLAGS.embedding_dim,
                          embedding_size=FLAGS.embedding_dim,
                          embedding_type=FLAGS.embedding_type,
                          l2_reg_lambda=FLAGS.l2_reg_lambda,
                          pretrained_embedding=pretrained_word2vec_matrix)

            # Define Training procedure
            optimizer = tf.train.AdamOptimizer(1e-3)
            grads_and_vars = optimizer.compute_gradients(han.loss)
            train_op = optimizer.apply_gradients(grads_and_vars,
                                                 global_step=han.global_step,
                                                 name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input(
                    "☛ Please input the checkpoints model you want to restore: "
                )  # 需要恢复的网络模型

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input(
                        '✘ The format of your input is illegal, please re-input: '
                    )
                logger.info(
                    '✔︎ The format of your input is legal, now loading to next step...'
                )

                checkpoint_dir = 'runs/' + MODEL + '/checkpoints/'

                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {}\n".format(out_dir))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", han.loss)
            # acc_summary = tf.summary.scalar("accuracy", han.accuracy)

            # Train Summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)

            if FLAGS.train_or_restore == 'R':
                # Load han model
                logger.info("✔ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                checkpoint_dir = os.path.abspath(
                    os.path.join(out_dir, "checkpoints"))
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: FLAGS.dropout_keep_prob
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss],
                    feed_dict)
                time_str = datetime.datetime.now().isoformat()
                logger.info("{}: step {}, loss {:g}".format(
                    time_str, step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_validation,
                                y_validation,
                                y_validation_bind,
                                writer=None):
                """Evaluates model on a validation set"""
                batches_validation = data_helpers.batch_iter(
                    list(zip(x_validation, y_validation, y_validation_bind)),
                    8 * FLAGS.batch_size, FLAGS.num_epochs)
                eval_loss, eval_rec, eval_acc, eval_counter = 0.0, 0.0, 0.0, 0
                for batch_validation in batches_validation:
                    x_batch_validation, y_batch_validation, y_batch_validation_bind = zip(
                        *batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_validation,
                        han.input_y: y_batch_validation,
                        han.dropout_keep_prob: 1.0
                    }
                    step, summaries, logits, cur_loss = sess.run([
                        han.global_step, validation_summary_op, han.logits,
                        han.loss
                    ], feed_dict)

                    predicted_labels = data_helpers.get_label_using_logits(
                        logits,
                        y_batch_validation_bind,
                        top_number=FLAGS.top_num)
                    cur_rec, cur_acc = 0.0, 0.0
                    for index, predicted_label in enumerate(predicted_labels):
                        rec_inc, acc_inc = data_helpers.cal_rec_and_acc(
                            predicted_label, y_batch_validation[index])
                        cur_rec, cur_acc = cur_rec + rec_inc, cur_acc + acc_inc

                    cur_rec = cur_rec / len(y_batch_validation)
                    cur_acc = cur_acc / len(y_batch_validation)

                    eval_loss, eval_rec, eval_acc, eval_counter = eval_loss + cur_loss, eval_rec + cur_rec, \
                                                                  eval_acc + cur_acc, eval_counter + 1
                    logger.info("✔︎ validation batch {} finished.".format(
                        eval_counter))

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)
                eval_rec = float(eval_rec / eval_counter)
                eval_acc = float(eval_acc / eval_counter)

                return eval_loss, eval_rec, eval_acc

            # Generate batches
            batches_train = data_helpers.batch_iter(
                list(zip(x_train, y_train)), FLAGS.batch_size,
                FLAGS.num_epochs)

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_rec, eval_acc = validation_step(
                        x_validation,
                        y_validation,
                        y_validation_bind,
                        writer=validation_summary_writer)
                    time_str = datetime.datetime.now().isoformat()
                    logger.info(
                        "{}: step {}, loss {:g}, rec {:g}, acc {:g}".format(
                            time_str, current_step, eval_loss, eval_rec,
                            eval_acc))

                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info(
                        "✔︎ Saved model checkpoint to {}\n".format(path))

    logger.info("✔︎ Done.")
def train_han():
    """Training HAN model."""

    # Load sentences, labels, and training parameters
    logger.info("✔︎ Loading data...")

    logger.info("✔︎ Training data processing...")
    train_data = dh.load_data_and_labels(FLAGS.training_data_file, FLAGS.embedding_dim)

    logger.info("✔︎ Validation data processing...")
    validation_data = dh.load_data_and_labels(FLAGS.validation_data_file, FLAGS.embedding_dim)

    logger.info("Recommended padding Sequence length is: {0}".format(FLAGS.pad_seq_len))

    logger.info("✔︎ Training data padding...")
    x_train_front, x_train_behind, y_train = dh.pad_data(train_data, FLAGS.pad_seq_len)

    logger.info("✔︎ Validation data padding...")
    x_validation_front, x_validation_behind, y_validation = dh.pad_data(validation_data, FLAGS.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(FLAGS.embedding_dim)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = FLAGS.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=FLAGS.pad_seq_len,
                num_classes=y_train.shape[1],
                vocab_size=VOCAB_SIZE,
                lstm_hidden_size=FLAGS.lstm_hidden_size,
                fc_hidden_size=FLAGS.fc_hidden_size,
                embedding_size=FLAGS.embedding_dim,
                embedding_type=FLAGS.embedding_type,
                l2_reg_lambda=FLAGS.l2_reg_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=FLAGS.learning_rate,
                                                           global_step=han.global_step, decay_steps=FLAGS.decay_steps,
                                                           decay_rate=FLAGS.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=FLAGS.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            if FLAGS.train_or_restore == 'R':
                MODEL = input("☛ Please input the checkpoints model you want to restore, "
                              "it should be like(1490175368): ")  # The model you want to restore

                while not (MODEL.isdigit() and len(MODEL) == 10):
                    MODEL = input("✘ The format of your input is illegal, please re-input: ")
                logger.info("✔︎ The format of your input is legal, now loading to next step...")
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", MODEL))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))
            else:
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
                logger.info("✔︎ Writing to {0}\n".format(out_dir))

            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss and accuracy
            loss_summary = tf.summary.scalar("loss", han.loss)
            acc_summary = tf.summary.scalar("accuracy", han.accuracy)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary, acc_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if FLAGS.train_or_restore == 'R':
                # Load han model
                logger.info("✔︎ Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            else:
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = FLAGS.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch_front, x_batch_behind, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x_front: x_batch_front,
                    han.input_x_behind: x_batch_behind,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: FLAGS.dropout_keep_prob,
                    han.is_training: True
                }
                _, step, summaries, loss, accuracy = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss, han.accuracy], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}".format(step, loss, accuracy))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_batch_front, x_batch_behind, y_batch, writer=None):
                """Evaluates model on a validation set"""
                feed_dict = {
                    han.input_x_front: x_batch_front,
                    han.input_x_behind: x_batch_behind,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: 1.0,
                    han.is_training: False
                }
                step, summaries, loss, accuracy, recall, precision, f1, auc = sess.run(
                    [han.global_step, validation_summary_op, han.loss, han.accuracy,
                     han.recall, han.precision, han.F1, han.AUC], feed_dict)
                logger.info("step {0}: loss {1:g}, acc {2:g}, recall {3:g}, precision {4:g}, f1 {5:g}, AUC {6}"
                            .format(step, loss, accuracy, recall, precision, f1, auc))
                if writer:
                    writer.add_summary(summaries, step)

                return accuracy

            # Generate batches
            batches = dh.batch_iter(
                list(zip(x_train_front, x_train_behind, y_train)), FLAGS.batch_size, FLAGS.num_epochs)

            num_batches_per_epoch = int((len(x_train_front) - 1) / FLAGS.batch_size) + 1

            # Training loop. For each batch...
            for batch in batches:
                x_batch_front, x_batch_behind, y_batch = zip(*batch)
                train_step(x_batch_front, x_batch_behind, y_batch)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % FLAGS.evaluate_every == 0:
                    logger.info("\nEvaluation:")
                    accuracy = validation_step(x_validation_front, x_validation_behind, y_validation,
                                               writer=validation_summary_writer)
                    best_saver.handle(accuracy, sess, current_step)
                if current_step % FLAGS.checkpoint_every == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("✔︎ Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("✔︎ Epoch {0} has finished!".format(current_epoch))

    logger.info("✔︎ Done.")
Example #6
0
def main(_):
    print('Loading word2vec model finished:%s' % (FLAGS.word_embedding_file))
    #w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, FLAGS.embedding_size)
    w2v_model, word2id = load_w2v(FLAGS.word_embedding_file, 256)
    print('Load word2vec model finished')
    print('Loading train/valid samples:%s' % (FLAGS.training_data))
    train_x, train_y, valid_x, valid_y = loadSamples(
        FLAGS.training_data, FLAGS.label_file, FLAGS.label_map,
        FLAGS.eval_data_file, word2id, FLAGS.valid_rate, FLAGS.num_classes,
        FLAGS.sent_len, FLAGS.doc_len)
    print('Load train/valid samples finished')
    labelNumStats(valid_y)

    train_sample_size = len(train_x)
    dev_sample_size = len(valid_x)
    print('Training sample size:%d' % (train_sample_size))
    print('Valid sample size:%d' % (dev_sample_size))

    timestamp = str(int(time.time()))
    runs_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs'))
    if not os.path.exists(runs_dir):
        os.makedirs(runs_dir)
    out_dir = os.path.abspath(os.path.join(runs_dir, timestamp))
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')

    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement)
        session_conf.gpu_options.allow_growth = True
        sess = tf.Session(config=session_conf)
        #sess = tf.Session()
        with sess.as_default(), tf.device('/gpu:0'):
            text_han = TextHAN(num_classes=FLAGS.num_classes,
                               learning_rate=FLAGS.learning_rate,
                               decay_steps=FLAGS.decay_steps,
                               decay_rate=FLAGS.decay_rate,
                               l2_reg_lambda=FLAGS.l2_reg_lambda,
                               embedding_size=FLAGS.embedding_size,
                               doc_len=FLAGS.doc_len,
                               sent_len=FLAGS.sent_len,
                               w2v_model=w2v_model,
                               rnn_hidden_size=FLAGS.rnn_hidden_size,
                               fc_layer_size=FLAGS.fc_layer_size)

            print('delete word2id')
            word2id = {}
            print('delete w2v_model')
            w2v_model = []

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=FLAGS.num_checkpoints)
            train_summary_dir = os.path.join(out_dir, 'summaries', 'train')
            dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')
            loss_summary = tf.summary.scalar('loss', text_han.loss_val)
            acc_summary = tf.summary.scalar('accuracy', text_han.accuracy)
            train_summary_op = tf.summary.merge([loss_summary, acc_summary])
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)
            dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir,
                                                       sess.graph)

            sess.run(tf.global_variables_initializer())
            total_loss = 0.
            total_acc = 0.
            total_step = 0.
            best_valid_acc = 0.
            best_valid_loss = 1000.
            best_valid_zhihu_score = 0.
            this_step_valid_acc = 0.
            this_step_valid_loss = 0.
            this_step_zhihu_score = 0.
            valid_loss_summary = tf.summary.scalar('loss',
                                                   this_step_valid_loss)
            valid_acc_summary = tf.summary.scalar('accuracy',
                                                  this_step_valid_acc)
            valid_zhihu_score_summary = tf.summary.scalar(
                'zhihu_score', this_step_zhihu_score)
            valid_summary_op = tf.summary.merge([
                valid_loss_summary, valid_acc_summary,
                valid_zhihu_score_summary
            ])
            for epoch in range(0, FLAGS.num_epochs):
                print('epoch:' + str(epoch))
                if FLAGS.shuffle:
                    shuffle_indices = np.random.permutation(
                        np.arange(train_sample_size))
                    train_x = train_x[shuffle_indices]
                    train_y = train_y[shuffle_indices]
                batch_step = 0
                batch_loss = 0.
                batch_acc = 0.
                for start, end in zip(
                        range(0, train_sample_size, FLAGS.batch_size),
                        range(FLAGS.batch_size, train_sample_size,
                              FLAGS.batch_size)):
                    batch_input_x = train_x[start:end]
                    batch_input_y = train_y[start:end]
                    batch_input_x, mask = paddingX(batch_input_x,
                                                   FLAGS.sent_len,
                                                   FLAGS.doc_len)
                    batch_input_y = paddingY(batch_input_y, FLAGS.num_classes)

                    feed_dict = {
                        text_han.input_x: batch_input_x,
                        text_han.input_y: batch_input_y,
                        text_han.mask: mask,
                        text_han.l1_dropout_keep_prob:
                        FLAGS.l1_dropout_keep_prob,
                        text_han.l2_dropout_keep_prob:
                        FLAGS.l2_dropout_keep_prob
                    }
                    loss, acc, step, summaries, _ = sess.run([
                        text_han.loss_val, text_han.accuracy,
                        text_han.global_step, train_summary_op,
                        text_han.train_op
                    ], feed_dict)
                    train_summary_writer.add_summary(summaries, step)
                    total_loss += loss
                    total_acc += acc
                    batch_loss += loss
                    batch_acc += acc
                    batch_step += 1
                    total_step += 1.
                    if batch_step % FLAGS.print_stats_every == 0:
                        time_str = datetime.datetime.now().isoformat()
                        print(
                            '[%s]Epoch:%d\tBatch_Step:%d\tTrain_Loss:%.4f/%.4f/%.4f\tTrain_Accuracy:%.4f/%.4f/%.4f'
                            % (time_str, epoch, batch_step, loss, batch_loss /
                               batch_step, total_loss / total_step, acc,
                               batch_acc / batch_step, total_acc / total_step))
                    if batch_step % FLAGS.evaluate_every == 0 and total_step > 0:
                        eval_loss = 0.
                        eval_acc = 0.
                        eval_step = 0
                        for start, end in zip(
                                range(0, dev_sample_size, FLAGS.batch_size),
                                range(FLAGS.batch_size, dev_sample_size,
                                      FLAGS.batch_size)):
                            batch_input_x = valid_x[start:end]
                            batch_input_x, mask = paddingX(
                                batch_input_x, FLAGS.sent_len, FLAGS.doc_len)
                            batch_input_y = valid_y[start:end]
                            batch_input_y = paddingY(batch_input_y,
                                                     FLAGS.num_classes)
                            feed_dict = {
                                text_han.input_x:
                                batch_input_x,
                                text_han.input_y:
                                batch_input_y,
                                text_han.mask:
                                mask,
                                text_han.l1_dropout_keep_prob:
                                FLAGS.l1_dropout_keep_prob,
                                text_han.l2_dropout_keep_prob:
                                FLAGS.l2_dropout_keep_prob
                            }
                            step, summaries, loss, acc, logits = sess.run([
                                text_han.global_step, dev_summary_op,
                                text_han.loss_val, text_han.accuracy,
                                text_han.logits
                            ], feed_dict)
                            dev_summary_writer.add_summary(summaries, step)
                            zhihuStats(logits,
                                       batch_input_y)  #valid_y[start:end])
                            eval_loss += loss
                            eval_acc += acc
                            eval_step += 1
                        this_step_zhihu_score = calZhihuScore()
                        time_str = datetime.datetime.now().isoformat()
                        print(
                            '[%s]Eval_Loss:%.4f\tEval_Accuracy:%.4f\tZhihu_Score:%.4f'
                            % (time_str, eval_loss / eval_step,
                               eval_acc / eval_step, this_step_zhihu_score))
                        this_step_valid_acc = eval_acc / eval_step
                        this_step_valid_loss = eval_loss / eval_step
                        #dev_summary_writer.add_summary(summaries, step)
                    if batch_step % FLAGS.checkpoint_every == 0 and total_step > 0:
                        if not FLAGS.save_best_model:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print('Saved model checkpoint to %s' % path)
                        elif this_step_zhihu_score > best_valid_zhihu_score:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best zhihu_score model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_zhihu_score = this_step_zhihu_score
                        elif this_step_valid_acc > best_valid_acc:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best acc model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_acc = this_step_valid_acc
                        elif this_step_valid_loss < best_valid_loss:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved best loss model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))
                            best_valid_loss = this_step_valid_loss
                        elif total_step % 22000 == 0:
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=step)
                            print(
                                'Saved model checkpoint to %s[%.4f,%.4f,%.4f]'
                                % (path, this_step_valid_loss,
                                   this_step_valid_acc, this_step_zhihu_score))