コード例 #1
0
def train_sann():
    """Training RNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args, args.train_file, word2idx)
    val_data = dh.load_data_and_labels(args, args.validation_file, word2idx)

    # Build a graph and sann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            sann = TextSANN(sequence_length=args.pad_seq_len,
                            vocab_size=len(word2idx),
                            embedding_type=args.embedding_type,
                            embedding_size=args.embedding_dim,
                            lstm_hidden_size=args.lstm_dim,
                            attention_unit_size=args.attention_dim,
                            attention_hops_size=args.attention_hops_dim,
                            fc_hidden_size=args.fc_dim,
                            num_classes=args.num_classes,
                            l2_reg_lambda=args.l2_lambda,
                            pretrained_embedding=embedding_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=args.learning_rate,
                    global_step=sann.global_step,
                    decay_steps=args.decay_steps,
                    decay_rate=args.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(sann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=sann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", sann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if OPTION == 'R':
                # Load sann model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(sann.global_step)

            def train_step(batch_data):
                """A single training step."""
                x_f, x_b, y_onehot = zip(*batch_data)

                feed_dict = {
                    sann.input_x_front: x_f,
                    sann.input_x_behind: x_b,
                    sann.input_y: y_onehot,
                    sann.dropout_keep_prob: args.dropout_rate,
                    sann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, sann.global_step, train_summary_op, sann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(val_loader, writer=None):
                """Evaluates model on a validation set."""
                batches_validation = dh.batch_iter(
                    list(create_input_data(val_loader)), args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []
                predicted_labels = []

                for batch_validation in batches_validation:
                    x_f, x_b, y_onehot = zip(*batch_validation)
                    feed_dict = {
                        sann.input_x_front: x_f,
                        sann.input_x_behind: x_b,
                        sann.input_y: y_onehot,
                        sann.dropout_keep_prob: 1.0,
                        sann.is_training: False
                    }
                    step, summaries, predictions, cur_loss = sess.run([
                        sann.global_step, validation_summary_op,
                        sann.topKPreds, sann.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_onehot:
                        true_labels.append(np.argmax(i))
                    for j in predictions[0]:
                        predicted_scores.append(j[0])
                    for k in predictions[1]:
                        predicted_labels.append(k[0])

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_acc = accuracy_score(y_true=np.array(true_labels),
                                          y_pred=np.array(predicted_labels))
                eval_pre = precision_score(y_true=np.array(true_labels),
                                           y_pred=np.array(predicted_labels),
                                           average='micro')
                eval_rec = recall_score(y_true=np.array(true_labels),
                                        y_pred=np.array(predicted_labels),
                                        average='micro')
                eval_F1 = f1_score(y_true=np.array(true_labels),
                                   y_pred=np.array(predicted_labels),
                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_labels),
                                         y_score=np.array(predicted_scores),
                                         average='micro')

                return eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc

            # Generate batches
            batches_train = dh.batch_iter(list(create_input_data(train_data)),
                                          args.batch_size, args.epochs)
            num_batches_per_epoch = int(
                (len(train_data['f_pad_seqs']) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                train_step(batch_train)
                current_step = tf.train.global_step(sess, sann.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc = \
                        validation_step(val_data, writer=validation_summary_writer)
                    logger.info(
                        "All Validation set: Loss {0:g} | Acc {1:g} | Precision {2:g} | "
                        "Recall {3:g} | F1 {4:g} | AUC {5:g}".format(
                            eval_loss, eval_acc, eval_pre, eval_rec, eval_F1,
                            eval_auc))
                    best_saver.handle(eval_acc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
コード例 #2
0
def train_han():
    """Training HAN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                lstm_hidden_size=args.lstm_dim,
                fc_hidden_size=args.fc_dim,
                num_classes=args.num_classes,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=han.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if OPTION == 'R':
                # Load han model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: args.dropout_rate,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * args.topK
                eval_rec_tk = [0.0] * args.topK
                eval_F1_tk = [0.0] * args.topK

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_val,
                        han.input_y: y_batch_val,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(args.topK):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                      y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(args.topK):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                   y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
                       eval_pre_tk, eval_rec_tk, eval_F1_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F1_ts))

                    # Predict by topK
                    logger.info("Predict by topK:")
                    for top_num in range(args.topK):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
コード例 #3
0
def train_tarnn():
    """Training TARNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train_content, x_train_question, x_train_option, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val_content, x_val_question, x_val_option, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and tarnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            tarnn = TextTARNN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                rnn_hidden_size=args.rnn_dim,
                rnn_type=args.rnn_type,
                rnn_layers=args.rnn_layers,
                attention_type=args.attention_type,
                fc_hidden_size=args.fc_dim,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=tarnn.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(tarnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=tarnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", tarnn.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=False)

            if OPTION == 'R':
                # Load tarnn model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(tarnn.global_step)

            def train_step(x_batch_content, x_batch_question, x_batch_option, y_batch):
                """A single training step"""
                feed_dict = {
                    tarnn.input_x_content: x_batch_content,
                    tarnn.input_x_question: x_batch_question,
                    tarnn.input_x_option: x_batch_option,
                    tarnn.input_y: y_batch,
                    tarnn.dropout_keep_prob: args.dropout_rate,
                    tarnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, tarnn.global_step, train_summary_op, tarnn.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val_content, x_val_question, x_val_option, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val_content, x_val_question, x_val_option, y_val)),
                                                   args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []

                for batch_validation in batches_validation:
                    x_batch_content, x_batch_question, x_batch_option, y_batch = zip(*batch_validation)
                    feed_dict = {
                        tarnn.input_x_content: x_batch_content,
                        tarnn.input_x_question: x_batch_question,
                        tarnn.input_x_option: x_batch_option,
                        tarnn.input_y: y_batch,
                        tarnn.dropout_keep_prob: 1.0,
                        tarnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [tarnn.global_step, validation_summary_op, tarnn.scores, tarnn.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch:
                        true_labels.append(i)
                    for j in scores:
                        predicted_scores.append(j)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate PCC & DOA
                pcc, doa = dh.evaluation(true_labels, predicted_scores)
                # Calculate RMSE
                rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
                r2 = r2_score(true_labels, predicted_scores)

                return eval_loss, pcc, doa, rmse, r2

            # Generate batches
            batches_train = dh.batch_iter(list(zip(x_train_content, x_train_question, x_train_option, y_train)),
                                          args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(y_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train = zip(*batch_train)
                train_step(x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train)
                current_step = tf.train.global_step(sess, tarnn.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, pcc, doa, rmse, r2 = validation_step(x_val_content, x_val_question, x_val_option, y_val,
                                                                    writer=validation_summary_writer)
                    logger.info("All Validation set: Loss {0:g} | PCC {1:g} | DOA {2:g} | RMSE {3:g} | R2 {4:g}"
                                .format(eval_loss, pcc, doa, rmse, r2))
                    best_saver.handle(rmse, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")