예제 #1
0
def test_cnn():
    """Test CNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load data
    logger.info("Loading data...")
    logger.info("Data processing...")
    test_data = dh.load_data_and_labels(args, args.test_file, word2idx)

    # Load cnn model
    OPTION = dh._option(pattern=1)
    if OPTION == 'B':
        logger.info("Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(BEST_CPT_DIR, select_maximum_value=True)
    else:
        logger.info("Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(CPT_DIR)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x_front = graph.get_operation_by_name("input_x_front").outputs[0]
            input_x_behind = graph.get_operation_by_name("input_x_behind").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/topKPreds").outputs[0]
            predictions = graph.get_operation_by_name("output/topKPreds").outputs[1]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "output/topKPreds"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                            output_node_names.split("|"))
            tf.train.write_graph(output_graph_def, "graph", "graph-cnn-{0}.pb".format(MODEL), as_text=False)

            # Generate batches for one epoch
            batches_test = dh.batch_iter(list(create_input_data(test_data)), args.batch_size, 1, shuffle=False)

            # Collect the predictions here
            test_counter, test_loss = 0, 0.0
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            for batch_test in batches_test:
                x_f, x_b, y_onehot = zip(*batch_test)
                feed_dict = {
                    input_x_front: x_f,
                    input_x_behind: x_b,
                    input_y: y_onehot,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }

                batch_predicted_scores, batch_predicted_labels, batch_loss \
                    = sess.run([scores, predictions, loss], feed_dict)

                for i in y_onehot:
                    true_labels.append(np.argmax(i))
                for j in batch_predicted_scores:
                    predicted_scores.append(j[0])
                for k in batch_predicted_labels:
                    predicted_labels.append(k[0])

                test_loss = test_loss + batch_loss
                test_counter = test_counter + 1

            test_loss = float(test_loss / test_counter)

            # Calculate Precision & Recall & F1
            test_acc = accuracy_score(y_true=np.array(true_labels), y_pred=np.array(predicted_labels))
            test_pre = precision_score(y_true=np.array(true_labels),
                                       y_pred=np.array(predicted_labels), average='micro')
            test_rec = recall_score(y_true=np.array(true_labels),
                                    y_pred=np.array(predicted_labels), average='micro')
            test_F1 = f1_score(y_true=np.array(true_labels),
                               y_pred=np.array(predicted_labels), average='micro')

            # Calculate the average AUC
            test_auc = roc_auc_score(y_true=np.array(true_labels),
                                     y_score=np.array(predicted_scores), average='micro')

            logger.info("All Test Dataset: Loss {0:g} | Acc {1:g} | Precision {2:g} | "
                        "Recall {3:g} | F1 {4:g} | AUC {5:g}"
                        .format(test_loss, test_acc, test_pre, test_rec, test_F1, test_auc))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", front_data_id=test_data['f_id'],
                                      behind_data_id=test_data['b_id'], true_labels=true_labels,
                                      predict_labels=predicted_labels, predict_scores=predicted_scores)

    logger.info("All Done.")
def test_fasttext():
    """Test FASTTEXT model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load data
    logger.info("Loading data...")
    logger.info("Data processing...")
    test_data = dh.load_data_and_labels(args.test_file,
                                        args.num_classes,
                                        args.word2vec_file,
                                        data_aug_flag=False)

    logger.info("Data padding...")
    x_test, y_test = dh.pad_data(test_data, args.pad_seq_len)
    y_test_labels = test_data.labels

    # Load fasttext model
    OPTION = dh._option(pattern=1)
    if OPTION == 'B':
        logger.info("Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(BEST_CPT_DIR,
                                                 select_maximum_value=True)
    else:
        logger.info("Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(CPT_DIR)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 "graph",
                                 "graph-fasttext-{0}.pb".format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(zip(x_test, y_test, y_test_labels)),
                                    args.batch_size,
                                    1,
                                    shuffle=False)

            test_counter, test_loss = 0, 0.0

            test_pre_tk = [0.0] * args.topK
            test_rec_tk = [0.0] * args.topK
            test_F1_tk = [0.0] * args.topK

            # Collect the predictions here
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            # Collect for calculating metrics
            true_onehot_labels = []
            predicted_onehot_scores = []
            predicted_onehot_labels_ts = []
            predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_labels = zip(
                    *batch_test)
                feed_dict = {
                    input_x: x_batch_test,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Prepare for calculating metrics
                for i in y_batch_test:
                    true_onehot_labels.append(i)
                for j in batch_scores:
                    predicted_onehot_scores.append(j)

                # Get the predicted labels by threshold
                batch_predicted_labels_ts, batch_predicted_scores_ts = \
                    dh.get_label_threshold(scores=batch_scores, threshold=args.threshold)

                # Add results to collection
                for i in y_batch_test_labels:
                    true_labels.append(i)
                for j in batch_predicted_labels_ts:
                    predicted_labels.append(j)
                for k in batch_predicted_scores_ts:
                    predicted_scores.append(k)

                # Get onehot predictions by threshold
                batch_predicted_onehot_labels_ts = \
                    dh.get_onehot_label_threshold(scores=batch_scores, threshold=args.threshold)
                for i in batch_predicted_onehot_labels_ts:
                    predicted_onehot_labels_ts.append(i)

                # Get onehot predictions by topK
                for top_num in range(args.topK):
                    batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(
                        scores=batch_scores, top_num=top_num + 1)

                    for i in batch_predicted_onehot_labels_tk:
                        predicted_onehot_labels_tk[top_num].append(i)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            # Calculate Precision & Recall & F1
            test_pre_ts = precision_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_rec_ts = recall_score(
                y_true=np.array(true_onehot_labels),
                y_pred=np.array(predicted_onehot_labels_ts),
                average='micro')
            test_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                  y_pred=np.array(predicted_onehot_labels_ts),
                                  average='micro')

            for top_num in range(args.topK):
                test_pre_tk[top_num] = precision_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_rec_tk[top_num] = recall_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')
                test_F1_tk[top_num] = f1_score(
                    y_true=np.array(true_onehot_labels),
                    y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                    average='micro')

            # Calculate the average AUC
            test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                     y_score=np.array(predicted_onehot_scores),
                                     average='micro')

            # Calculate the average PR
            test_prc = average_precision_score(
                y_true=np.array(true_onehot_labels),
                y_score=np.array(predicted_onehot_scores),
                average="micro")
            test_loss = float(test_loss / test_counter)

            logger.info(
                "All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}".
                format(test_loss, test_auc, test_prc))

            # Predict by threshold
            logger.info(
                "Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                .format(test_pre_ts, test_rec_ts, test_F1_ts))

            # Predict by topK
            logger.info("Predict by topK:")
            for top_num in range(args.topK):
                logger.info(
                    "Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}".format(
                        top_num + 1, test_pre_tk[top_num],
                        test_rec_tk[top_num], test_F1_tk[top_num]))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR +
                                      "/predictions.json",
                                      data_id=test_data.testid,
                                      all_labels=true_labels,
                                      all_predict_labels=predicted_labels,
                                      all_predict_scores=predicted_scores)

    logger.info("All Done.")
def train_han():
    """Training HAN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.num_classes, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.num_classes, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and han object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            han = TextHAN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                lstm_hidden_size=args.lstm_dim,
                fc_hidden_size=args.fc_dim,
                num_classes=args.num_classes,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=han.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(han.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=han.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", han.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=True)

            if OPTION == 'R':
                # Load han model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(han.global_step)

            def train_step(x_batch, y_batch):
                """A single training step"""
                feed_dict = {
                    han.input_x: x_batch,
                    han.input_y: y_batch,
                    han.dropout_keep_prob: args.dropout_rate,
                    han.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, han.global_step, train_summary_op, han.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val, y_val)), args.batch_size, 1)

                # Predict classes by threshold or topk ('ts': threshold; 'tk': topk)
                eval_counter, eval_loss = 0, 0.0

                eval_pre_tk = [0.0] * args.topK
                eval_rec_tk = [0.0] * args.topK
                eval_F1_tk = [0.0] * args.topK

                true_onehot_labels = []
                predicted_onehot_scores = []
                predicted_onehot_labels_ts = []
                predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

                for batch_validation in batches_validation:
                    x_batch_val, y_batch_val = zip(*batch_validation)
                    feed_dict = {
                        han.input_x: x_batch_val,
                        han.input_y: y_batch_val,
                        han.dropout_keep_prob: 1.0,
                        han.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [han.global_step, validation_summary_op, han.scores, han.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch_val:
                        true_onehot_labels.append(i)
                    for j in scores:
                        predicted_onehot_scores.append(j)

                    # Predict by threshold
                    batch_predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=scores, threshold=args.threshold)

                    for k in batch_predicted_onehot_labels_ts:
                        predicted_onehot_labels_ts.append(k)

                    # Predict by topK
                    for top_num in range(args.topK):
                        batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=scores, top_num=top_num+1)

                        for i in batch_predicted_onehot_labels_tk:
                            predicted_onehot_labels_tk[top_num].append(i)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                              y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                           y_pred=np.array(predicted_onehot_labels_ts), average='micro')
                eval_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                      y_pred=np.array(predicted_onehot_labels_ts), average='micro')

                for top_num in range(args.topK):
                    eval_pre_tk[top_num] = precision_score(y_true=np.array(true_onehot_labels),
                                                           y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                           average='micro')
                    eval_rec_tk[top_num] = recall_score(y_true=np.array(true_onehot_labels),
                                                        y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                        average='micro')
                    eval_F1_tk[top_num] = f1_score(y_true=np.array(true_onehot_labels),
                                                   y_pred=np.array(predicted_onehot_labels_tk[top_num]),
                                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                         y_score=np.array(predicted_onehot_scores), average='micro')
                # Calculate the average PR
                eval_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                                   y_score=np.array(predicted_onehot_scores), average='micro')

                return eval_loss, eval_auc, eval_prc, eval_pre_ts, eval_rec_ts, eval_F1_ts, \
                       eval_pre_tk, eval_rec_tk, eval_F1_tk

            # Generate batches
            batches_train = dh.batch_iter(
                list(zip(x_train, y_train)), args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(x_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train, y_batch_train = zip(*batch_train)
                train_step(x_batch_train, y_batch_train)
                current_step = tf.train.global_step(sess, han.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_auc, eval_prc, \
                    eval_pre_ts, eval_rec_ts, eval_F1_ts, eval_pre_tk, eval_rec_tk, eval_F1_tk = \
                        validation_step(x_val, y_val, writer=validation_summary_writer)

                    logger.info("All Validation set: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                                .format(eval_loss, eval_auc, eval_prc))

                    # Predict by threshold
                    logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                                .format(eval_pre_ts, eval_rec_ts, eval_F1_ts))

                    # Predict by topK
                    logger.info("Predict by topK:")
                    for top_num in range(args.topK):
                        logger.info("Top{0}: Precision {1:g}, Recall {2:g}, F1 {3:g}"
                                    .format(top_num+1, eval_pre_tk[top_num], eval_rec_tk[top_num], eval_F1_tk[top_num]))
                    best_saver.handle(eval_prc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
def test_rmidp():
    """Test RMIDP model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load data
    logger.info("Loading data...")
    logger.info("Data processing...")
    test_data = dh.load_data_and_labels(args.test_file, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_test_content, x_test_question, x_test_option, y_test = dh.pad_data(test_data, args.pad_seq_len)

    # Load rmidp model
    OPTION = dh.option(pattern=1)
    if OPTION == 'B':
        logger.info("Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(BEST_CPT_DIR, select_maximum_value=True)
    else:
        logger.info("Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(CPT_DIR)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x_content = graph.get_operation_by_name("input_x_content").outputs[0]
            input_x_question = graph.get_operation_by_name("input_x_question").outputs[0]
            input_x_option = graph.get_operation_by_name("input_x_option").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                            output_node_names.split("|"))
            tf.train.write_graph(output_graph_def, "graph", "graph-rmidp-{0}.pb".format(MODEL), as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(zip(x_test_content, x_test_question, x_test_option, y_test)),
                                    args.batch_size, 1, shuffle=False)

            test_counter, test_loss = 0, 0.0

            # Collect the predictions here
            true_labels = []
            predicted_scores = []

            for batch_test in batches:
                x_batch_content, x_batch_question, x_batch_option, y_batch = zip(*batch_test)
                feed_dict = {
                    input_x_content: x_batch_content,
                    input_x_question: x_batch_question,
                    input_x_option: x_batch_option,
                    input_y: y_batch,
                    dropout_keep_prob: 1.0,
                    is_training: False
                }
                batch_scores, cur_loss = sess.run([scores, loss], feed_dict)

                # Prepare for calculating metrics
                for i in y_batch:
                    true_labels.append(i)
                for j in batch_scores:
                    predicted_scores.append(j)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            # Calculate PCC & DOA
            pcc, doa = dh.evaluation(true_labels, predicted_scores)
            # Calculate RMSE
            rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
            r2 = r2_score(true_labels, predicted_scores)

            test_loss = float(test_loss / test_counter)

            logger.info("All Test Dataset: Loss {0:g} | PCC {1:g} | DOA {2:g} | RMSE {3:g} | R2 {4:g}"
                        .format(test_loss, pcc, doa, rmse, r2))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", all_id=test_data.id,
                                      all_labels=true_labels, all_predict_scores=predicted_scores)

    logger.info("All Done.")
예제 #5
0
def train_sann():
    """Training RNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args, args.train_file, word2idx)
    val_data = dh.load_data_and_labels(args, args.validation_file, word2idx)

    # Build a graph and sann object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            sann = TextSANN(sequence_length=args.pad_seq_len,
                            vocab_size=len(word2idx),
                            embedding_type=args.embedding_type,
                            embedding_size=args.embedding_dim,
                            lstm_hidden_size=args.lstm_dim,
                            attention_unit_size=args.attention_dim,
                            attention_hops_size=args.attention_hops_dim,
                            fc_hidden_size=args.fc_dim,
                            num_classes=args.num_classes,
                            l2_reg_lambda=args.l2_lambda,
                            pretrained_embedding=embedding_matrix)

            # Define training procedure
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(
                    learning_rate=args.learning_rate,
                    global_step=sann.global_step,
                    decay_steps=args.decay_steps,
                    decay_rate=args.decay_rate,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(sann.loss))
                grads, _ = tf.clip_by_global_norm(grads,
                                                  clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(
                    zip(grads, vars),
                    global_step=sann.global_step,
                    name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram(
                        "{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar(
                        "{0}/grad/sparsity".format(v.name),
                        tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(
                os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", sann.loss)

            # Train summaries
            train_summary_op = tf.summary.merge(
                [loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries",
                                                  "validation")
            validation_summary_writer = tf.summary.FileWriter(
                validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(),
                                   max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir,
                                                num_to_keep=3,
                                                maximize=True)

            if OPTION == 'R':
                # Load sann model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph(
                    "{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer,
                                               config)

                # Save the embedding visualization
                saver.save(
                    sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(sann.global_step)

            def train_step(batch_data):
                """A single training step."""
                x_f, x_b, y_onehot = zip(*batch_data)

                feed_dict = {
                    sann.input_x_front: x_f,
                    sann.input_x_behind: x_b,
                    sann.input_y: y_onehot,
                    sann.dropout_keep_prob: args.dropout_rate,
                    sann.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, sann.global_step, train_summary_op, sann.loss],
                    feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(val_loader, writer=None):
                """Evaluates model on a validation set."""
                batches_validation = dh.batch_iter(
                    list(create_input_data(val_loader)), args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []
                predicted_labels = []

                for batch_validation in batches_validation:
                    x_f, x_b, y_onehot = zip(*batch_validation)
                    feed_dict = {
                        sann.input_x_front: x_f,
                        sann.input_x_behind: x_b,
                        sann.input_y: y_onehot,
                        sann.dropout_keep_prob: 1.0,
                        sann.is_training: False
                    }
                    step, summaries, predictions, cur_loss = sess.run([
                        sann.global_step, validation_summary_op,
                        sann.topKPreds, sann.loss
                    ], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_onehot:
                        true_labels.append(np.argmax(i))
                    for j in predictions[0]:
                        predicted_scores.append(j[0])
                    for k in predictions[1]:
                        predicted_labels.append(k[0])

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate Precision & Recall & F1
                eval_acc = accuracy_score(y_true=np.array(true_labels),
                                          y_pred=np.array(predicted_labels))
                eval_pre = precision_score(y_true=np.array(true_labels),
                                           y_pred=np.array(predicted_labels),
                                           average='micro')
                eval_rec = recall_score(y_true=np.array(true_labels),
                                        y_pred=np.array(predicted_labels),
                                        average='micro')
                eval_F1 = f1_score(y_true=np.array(true_labels),
                                   y_pred=np.array(predicted_labels),
                                   average='micro')

                # Calculate the average AUC
                eval_auc = roc_auc_score(y_true=np.array(true_labels),
                                         y_score=np.array(predicted_scores),
                                         average='micro')

                return eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc

            # Generate batches
            batches_train = dh.batch_iter(list(create_input_data(train_data)),
                                          args.batch_size, args.epochs)
            num_batches_per_epoch = int(
                (len(train_data['f_pad_seqs']) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                train_step(batch_train)
                current_step = tf.train.global_step(sess, sann.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc = \
                        validation_step(val_data, writer=validation_summary_writer)
                    logger.info(
                        "All Validation set: Loss {0:g} | Acc {1:g} | Precision {2:g} | "
                        "Recall {3:g} | F1 {4:g} | AUC {5:g}".format(
                            eval_loss, eval_acc, eval_pre, eval_rec, eval_F1,
                            eval_auc))
                    best_saver.handle(eval_acc, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess,
                                      checkpoint_prefix,
                                      global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info(
                        "Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
def test_harnn():
    """Test HARNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load data
    logger.info("Loading data...")
    logger.info("Data processing...")
    test_data = dh.load_data_and_labels(args.test_file, args.num_classes_list, args.total_classes,
                                        args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_test, y_test, y_test_tuple = dh.pad_data(test_data, args.pad_seq_len)
    y_test_labels = test_data.labels

    # Load harnn model
    OPTION = dh._option(pattern=1)
    if OPTION == 'B':
        logger.info("Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(BEST_CPT_DIR, select_maximum_value=True)
    else:
        logger.info("Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(CPT_DIR)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y_first = graph.get_operation_by_name("input_y_first").outputs[0]
            input_y_second = graph.get_operation_by_name("input_y_second").outputs[0]
            input_y_third = graph.get_operation_by_name("input_y_third").outputs[0]
            input_y_fourth = graph.get_operation_by_name("input_y_fourth").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
            beta = graph.get_operation_by_name("beta").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            first_scores = graph.get_operation_by_name("first-output/scores").outputs[0]
            second_scores = graph.get_operation_by_name("second-output/scores").outputs[0]
            third_scores = graph.get_operation_by_name("third-output/scores").outputs[0]
            fourth_scores = graph.get_operation_by_name("fourth-output/scores").outputs[0]
            scores = graph.get_operation_by_name("output/scores").outputs[0]
            loss = graph.get_operation_by_name("loss/loss").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "first-output/scores|second-output/scores|third-output/scores|fourth-output/scores|output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                            output_node_names.split("|"))
            tf.train.write_graph(output_graph_def, "graph", "graph-harnn-{0}.pb".format(MODEL), as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(zip(x_test, y_test, y_test_tuple, y_test_labels)),
                                    args.batch_size, 1, shuffle=False)

            test_counter, test_loss = 0, 0.0

            # Collect the predictions here
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            # Collect for calculating metrics
            true_onehot_labels = []
            predicted_onehot_scores = []
            predicted_onehot_labels_ts = []
            predicted_onehot_labels_tk = [[] for _ in range(args.topK)]

            true_onehot_first_labels = []
            true_onehot_second_labels = []
            true_onehot_third_labels = []
            true_onehot_fourth_labels = []
            predicted_onehot_scores_first = []
            predicted_onehot_scores_second = []
            predicted_onehot_scores_third = []
            predicted_onehot_scores_fourth = []
            predicted_onehot_labels_first = []
            predicted_onehot_labels_second = []
            predicted_onehot_labels_third = []
            predicted_onehot_labels_fourth = []

            for batch_test in batches:
                x_batch_test, y_batch_test, y_batch_test_tuple, y_batch_test_labels = zip(*batch_test)

                y_batch_test_first = [i[0] for i in y_batch_test_tuple]
                y_batch_test_second = [j[1] for j in y_batch_test_tuple]
                y_batch_test_third = [k[2] for k in y_batch_test_tuple]
                y_batch_test_fourth = [t[3] for t in y_batch_test_tuple]

                feed_dict = {
                    input_x: x_batch_test,
                    input_y_first: y_batch_test_first,
                    input_y_second: y_batch_test_second,
                    input_y_third: y_batch_test_third,
                    input_y_fourth: y_batch_test_fourth,
                    input_y: y_batch_test,
                    dropout_keep_prob: 1.0,
                    beta: args.beta,
                    is_training: False
                }
                batch_first_scores, batch_second_scores, batch_third_scores, batch_fourth_scores, batch_scores, cur_loss = \
                    sess.run([first_scores, second_scores, third_scores, fourth_scores, scores, loss], feed_dict)

                # Prepare for calculating metrics
                for onehot_labels in y_batch_test:
                    true_onehot_labels.append(onehot_labels)
                for onehot_labels in y_batch_test_first:
                    true_onehot_first_labels.append(onehot_labels)
                for onehot_labels in y_batch_test_second:
                    true_onehot_second_labels.append(onehot_labels)
                for onehot_labels in y_batch_test_third:
                    true_onehot_third_labels.append(onehot_labels)
                for onehot_labels in y_batch_test_fourth:
                    true_onehot_fourth_labels.append(onehot_labels)

                for onehot_scores in batch_scores:
                    predicted_onehot_scores.append(onehot_scores)
                for onehot_scores in batch_first_scores:
                    predicted_onehot_scores_first.append(onehot_scores)
                for onehot_scores in batch_second_scores:
                    predicted_onehot_scores_second.append(onehot_scores)
                for onehot_scores in batch_third_scores:
                    predicted_onehot_scores_third.append(onehot_scores)
                for onehot_scores in batch_fourth_scores:
                    predicted_onehot_scores_fourth.append(onehot_scores)

                # Get the predicted labels by threshold
                batch_predicted_labels_ts, batch_predicted_scores_ts = \
                    dh.get_label_threshold(scores=batch_scores, threshold=args.threshold)

                # Add results to collection
                for labels in y_batch_test_labels:
                    true_labels.append(labels)
                for labels in batch_predicted_labels_ts:
                    predicted_labels.append(labels)
                for values in batch_predicted_scores_ts:
                    predicted_scores.append(values)

                # Get one-hot prediction by threshold
                batch_predicted_onehot_labels_ts = \
                    dh.get_onehot_label_threshold(scores=batch_scores, threshold=args.threshold)
                batch_predicted_onehot_labels_first = \
                    dh.get_onehot_label_threshold(scores=batch_first_scores, threshold=args.threshold)
                batch_predicted_onehot_labels_second = \
                    dh.get_onehot_label_threshold(scores=batch_second_scores, threshold=args.threshold)
                batch_predicted_onehot_labels_third = \
                    dh.get_onehot_label_threshold(scores=batch_third_scores, threshold=args.threshold)
                batch_predicted_onehot_labels_fourth = \
                    dh.get_onehot_label_threshold(scores=batch_fourth_scores, threshold=args.threshold)

                for onehot_labels in batch_predicted_onehot_labels_ts:
                    predicted_onehot_labels_ts.append(onehot_labels)
                for onehot_labels in batch_predicted_onehot_labels_first:
                    predicted_onehot_labels_first.append(onehot_labels)
                for onehot_labels in batch_predicted_onehot_labels_second:
                    predicted_onehot_labels_second.append(onehot_labels)
                for onehot_labels in batch_predicted_onehot_labels_third:
                    predicted_onehot_labels_third.append(onehot_labels)
                for onehot_labels in batch_predicted_onehot_labels_fourth:
                    predicted_onehot_labels_fourth.append(onehot_labels)

                # Get one-hot prediction by topK
                for i in range(args.topK):
                    batch_predicted_onehot_labels_tk = dh.get_onehot_label_topk(scores=batch_scores, top_num=i + 1)

                    for onehot_labels in batch_predicted_onehot_labels_tk:
                        predicted_onehot_labels_tk[i].append(onehot_labels)

                test_loss = test_loss + cur_loss
                test_counter = test_counter + 1

            # Calculate Precision & Recall & F1
            test_pre_ts = precision_score(y_true=np.array(true_onehot_labels),
                                          y_pred=np.array(predicted_onehot_labels_ts), average='micro')

            test_pre_first = precision_score(y_true=np.array(true_onehot_first_labels),
                                             y_pred=np.array(predicted_onehot_labels_first), average='micro')
            test_pre_second = precision_score(y_true=np.array(true_onehot_second_labels),
                                              y_pred=np.array(predicted_onehot_labels_second), average='micro')
            test_pre_third = precision_score(y_true=np.array(true_onehot_third_labels),
                                             y_pred=np.array(predicted_onehot_labels_third), average='micro')
            test_pre_fourth = precision_score(y_true=np.array(true_onehot_fourth_labels),
                                              y_pred=np.array(predicted_onehot_labels_fourth), average='micro')

            test_rec_ts = recall_score(y_true=np.array(true_onehot_labels),
                                       y_pred=np.array(predicted_onehot_labels_ts), average='micro')

            test_rec_first = recall_score(y_true=np.array(true_onehot_first_labels),
                                          y_pred=np.array(predicted_onehot_labels_first), average='micro')
            test_rec_second = recall_score(y_true=np.array(true_onehot_second_labels),
                                           y_pred=np.array(predicted_onehot_labels_second), average='micro')
            test_rec_third = recall_score(y_true=np.array(true_onehot_third_labels),
                                          y_pred=np.array(predicted_onehot_labels_third), average='micro')
            test_rec_fourth = recall_score(y_true=np.array(true_onehot_fourth_labels),
                                           y_pred=np.array(predicted_onehot_labels_fourth), average='micro')

            test_F1_ts = f1_score(y_true=np.array(true_onehot_labels),
                                  y_pred=np.array(predicted_onehot_labels_ts), average='micro')

            test_F1_first = f1_score(y_true=np.array(true_onehot_first_labels),
                                     y_pred=np.array(predicted_onehot_labels_first), average='micro')
            test_F1_second = f1_score(y_true=np.array(true_onehot_second_labels),
                                      y_pred=np.array(predicted_onehot_labels_second), average='micro')
            test_F1_third = f1_score(y_true=np.array(true_onehot_third_labels),
                                     y_pred=np.array(predicted_onehot_labels_third), average='micro')
            test_F1_fourth = f1_score(y_true=np.array(true_onehot_fourth_labels),
                                      y_pred=np.array(predicted_onehot_labels_fourth), average='micro')

            # Calculate the average AUC
            test_auc = roc_auc_score(y_true=np.array(true_onehot_labels),
                                     y_score=np.array(predicted_onehot_scores), average='micro')

            # Calculate the average PR
            test_prc = average_precision_score(y_true=np.array(true_onehot_labels),
                                               y_score=np.array(predicted_onehot_scores), average="micro")
            test_prc_first = average_precision_score(y_true=np.array(true_onehot_first_labels),
                                                     y_score=np.array(predicted_onehot_scores_first), average="micro")
            test_prc_second = average_precision_score(y_true=np.array(true_onehot_second_labels),
                                                      y_score=np.array(predicted_onehot_scores_second), average="micro")
            test_prc_third = average_precision_score(y_true=np.array(true_onehot_third_labels),
                                                     y_score=np.array(predicted_onehot_scores_third), average="micro")
            test_prc_fourth = average_precision_score(y_true=np.array(true_onehot_fourth_labels),
                                                      y_score=np.array(predicted_onehot_scores_fourth), average="micro")

            test_loss = float(test_loss / test_counter)

            logger.info("All Test Dataset: Loss {0:g} | AUC {1:g} | AUPRC {2:g}"
                        .format(test_loss, test_auc, test_prc))
            # Predict by threshold
            logger.info("Predict by threshold: Precision {0:g}, Recall {1:g}, F1 {2:g}"
                        .format(test_pre_ts, test_rec_ts, test_F1_ts))

            logger.info("Predict by threshold in Level-1: Precision {0:g}, Recall {1:g}, F1 {2:g}, AUPRC {3:g}"
                        .format(test_pre_first, test_rec_first, test_F1_first, test_prc_first))
            logger.info("Predict by threshold in Level-2: Precision {0:g}, Recall {1:g}, F1 {2:g}, AUPRC {3:g}"
                        .format(test_pre_second, test_rec_second, test_F1_second, test_prc_second))
            logger.info("Predict by threshold in Level-3: Precision {0:g}, Recall {1:g}, F1 {2:g}, AUPRC {3:g}"
                        .format(test_pre_third, test_rec_third, test_F1_third, test_prc_third))
            logger.info("Predict by threshold in Level-4: Precision {0:g}, Recall {1:g}, F1 {2:g}, AUPRC {3:g}"
                        .format(test_pre_fourth, test_rec_fourth, test_F1_fourth, test_prc_fourth))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR + "/predictions.json", data_id=test_data.patent_id,
                                      all_labels=true_labels, all_predict_labels=predicted_labels,
                                      all_predict_scores=predicted_scores)

    logger.info("All Done.")
예제 #7
0
def test_harnn():
    """Test HARNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load word2vec model
    word2idx, embedding_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Load data
    logger.info("Loading data...")
    logger.info("Data processing...")
    test_data = dh.load_data_and_labels(args, args.test_file, word2idx)

    # Load harnn model
    OPTION = dh._option(pattern=1)
    if OPTION == 'B':
        logger.info("Loading best model...")
        checkpoint_file = cm.get_best_checkpoint(BEST_CPT_DIR,
                                                 select_maximum_value=True)
    else:
        logger.info("Loading latest model...")
        checkpoint_file = tf.train.latest_checkpoint(CPT_DIR)
    logger.info(checkpoint_file)

    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            # Load the saved meta graph and restore variables
            saver = tf.train.import_meta_graph(
                "{0}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)

            # Get the placeholders from the graph by name
            input_x = graph.get_operation_by_name("input_x").outputs[0]
            input_y_first = graph.get_operation_by_name(
                "input_y_first").outputs[0]
            input_y_second = graph.get_operation_by_name(
                "input_y_second").outputs[0]
            input_y_third = graph.get_operation_by_name(
                "input_y_third").outputs[0]
            input_y_fourth = graph.get_operation_by_name(
                "input_y_fourth").outputs[0]
            input_y = graph.get_operation_by_name("input_y").outputs[0]
            dropout_keep_prob = graph.get_operation_by_name(
                "dropout_keep_prob").outputs[0]
            alpha = graph.get_operation_by_name("alpha").outputs[0]
            is_training = graph.get_operation_by_name("is_training").outputs[0]

            # Tensors we want to evaluate
            first_scores = graph.get_operation_by_name(
                "first-output/scores").outputs[0]
            second_scores = graph.get_operation_by_name(
                "second-output/scores").outputs[0]
            third_scores = graph.get_operation_by_name(
                "third-output/scores").outputs[0]
            fourth_scores = graph.get_operation_by_name(
                "fourth-output/scores").outputs[0]
            scores = graph.get_operation_by_name("output/scores").outputs[0]

            # Split the output nodes name by '|' if you have several output nodes
            output_node_names = "first-output/scores|second-output/scores|third-output/scores|fourth-output/scores|output/scores"

            # Save the .pb model file
            output_graph_def = tf.graph_util.convert_variables_to_constants(
                sess, sess.graph_def, output_node_names.split("|"))
            tf.train.write_graph(output_graph_def,
                                 "graph",
                                 "graph-harnn-{0}.pb".format(MODEL),
                                 as_text=False)

            # Generate batches for one epoch
            batches = dh.batch_iter(list(create_input_data(test_data)),
                                    args.batch_size,
                                    1,
                                    shuffle=False)

            # Collect the predictions here
            true_labels = []
            predicted_labels = []
            predicted_scores = []

            # Collect for calculating metrics
            true_onehot_labels = [[], [], [], [], []]
            predicted_onehot_scores = [[], [], [], [], []]
            predicted_onehot_labels = [[], [], [], [], []]

            for batch_test in batches:
                x, sec, subsec, group, subgroup, y_onehot, y = zip(*batch_test)

                y_batch_test_list = [y_onehot, sec, subsec, group, subgroup]

                feed_dict = {
                    input_x: x,
                    input_y_first: sec,
                    input_y_second: subsec,
                    input_y_third: group,
                    input_y_fourth: subgroup,
                    input_y: y_onehot,
                    dropout_keep_prob: 1.0,
                    alpha: args.alpha,
                    is_training: False
                }
                batch_global_scores, batch_first_scores, batch_second_scores, batch_third_scores, batch_fourth_scores = \
                    sess.run([scores, first_scores, second_scores, third_scores, fourth_scores], feed_dict)

                batch_scores = [
                    batch_global_scores, batch_first_scores,
                    batch_second_scores, batch_third_scores,
                    batch_fourth_scores
                ]

                # Get the predicted labels by threshold
                batch_predicted_labels_ts, batch_predicted_scores_ts = \
                    dh.get_label_threshold(scores=batch_scores[0], threshold=args.threshold)

                # Add results to collection
                for labels in y:
                    true_labels.append(labels)
                for labels in batch_predicted_labels_ts:
                    predicted_labels.append(labels)
                for values in batch_predicted_scores_ts:
                    predicted_scores.append(values)

                for index in range(len(predicted_onehot_scores)):
                    for onehot_labels in y_batch_test_list[index]:
                        true_onehot_labels[index].append(onehot_labels)
                    for onehot_scores in batch_scores[index]:
                        predicted_onehot_scores[index].append(onehot_scores)
                    # Get one-hot prediction by threshold
                    predicted_onehot_labels_ts = \
                        dh.get_onehot_label_threshold(scores=batch_scores[index], threshold=args.threshold)
                    for onehot_labels in predicted_onehot_labels_ts:
                        predicted_onehot_labels[index].append(onehot_labels)

            # Calculate Precision & Recall & F1
            for index in range(len(predicted_onehot_scores)):
                test_pre = precision_score(
                    y_true=np.array(true_onehot_labels[index]),
                    y_pred=np.array(predicted_onehot_labels[index]),
                    average='micro')
                test_rec = recall_score(
                    y_true=np.array(true_onehot_labels[index]),
                    y_pred=np.array(predicted_onehot_labels[index]),
                    average='micro')
                test_F1 = f1_score(y_true=np.array(true_onehot_labels[index]),
                                   y_pred=np.array(
                                       predicted_onehot_labels[index]),
                                   average='micro')
                test_auc = roc_auc_score(
                    y_true=np.array(true_onehot_labels[index]),
                    y_score=np.array(predicted_onehot_scores[index]),
                    average='micro')
                test_prc = average_precision_score(
                    y_true=np.array(true_onehot_labels[index]),
                    y_score=np.array(predicted_onehot_scores[index]),
                    average="micro")
                if index == 0:
                    logger.info(
                        "[Global] Predict by threshold: Precision {0:g}, Recall {1:g}, "
                        "F1 {2:g}, AUC {3:g}, AUPRC {4:g}".format(
                            test_pre, test_rec, test_F1, test_auc, test_prc))
                else:
                    logger.info(
                        "[Local] Predict by threshold in Level-{0}: Precision {1:g}, Recall {2:g}, "
                        "F1 {3:g}, AUPRC {4:g}".format(index, test_pre,
                                                       test_rec, test_F1,
                                                       test_prc))

            # Save the prediction result
            if not os.path.exists(SAVE_DIR):
                os.makedirs(SAVE_DIR)
            dh.create_prediction_file(output_file=SAVE_DIR +
                                      "/predictions.json",
                                      data_id=test_data['uniq_id'],
                                      true_labels=true_labels,
                                      predict_labels=predicted_labels,
                                      predict_scores=predicted_scores)
    logger.info("All Done.")
예제 #8
0
def train_tarnn():
    """Training TARNN model."""
    # Print parameters used for the model
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file, data_aug_flag=False)
    val_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file, data_aug_flag=False)

    logger.info("Data padding...")
    x_train_content, x_train_question, x_train_option, y_train = dh.pad_data(train_data, args.pad_seq_len)
    x_val_content, x_val_question, x_val_option, y_val = dh.pad_data(val_data, args.pad_seq_len)

    # Build vocabulary
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Build a graph and tarnn object
    with tf.Graph().as_default():
        session_conf = tf.ConfigProto(
            allow_soft_placement=args.allow_soft_placement,
            log_device_placement=args.log_device_placement)
        session_conf.gpu_options.allow_growth = args.gpu_options_allow_growth
        sess = tf.Session(config=session_conf)
        with sess.as_default():
            tarnn = TextTARNN(
                sequence_length=args.pad_seq_len,
                vocab_size=VOCAB_SIZE,
                embedding_type=args.embedding_type,
                embedding_size=EMBEDDING_SIZE,
                rnn_hidden_size=args.rnn_dim,
                rnn_type=args.rnn_type,
                rnn_layers=args.rnn_layers,
                attention_type=args.attention_type,
                fc_hidden_size=args.fc_dim,
                l2_reg_lambda=args.l2_lambda,
                pretrained_embedding=pretrained_word2vec_matrix)

            # Define training procedure
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate,
                                                           global_step=tarnn.global_step, decay_steps=args.decay_steps,
                                                           decay_rate=args.decay_rate, staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads, vars = zip(*optimizer.compute_gradients(tarnn.loss))
                grads, _ = tf.clip_by_global_norm(grads, clip_norm=args.norm_ratio)
                train_op = optimizer.apply_gradients(zip(grads, vars), global_step=tarnn.global_step, name="train_op")

            # Keep track of gradient values and sparsity (optional)
            grad_summaries = []
            for g, v in zip(grads, vars):
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{0}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{0}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            # Output directory for models and summaries
            out_dir = dh.get_out_dir(OPTION, logger)
            checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
            best_checkpoint_dir = os.path.abspath(os.path.join(out_dir, "bestcheckpoints"))

            # Summaries for loss
            loss_summary = tf.summary.scalar("loss", tarnn.loss)

            # Train summaries
            train_summary_op = tf.summary.merge([loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # Validation summaries
            validation_summary_op = tf.summary.merge([loss_summary])
            validation_summary_dir = os.path.join(out_dir, "summaries", "validation")
            validation_summary_writer = tf.summary.FileWriter(validation_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_checkpoints)
            best_saver = cm.BestCheckpointSaver(save_dir=best_checkpoint_dir, num_to_keep=3, maximize=False)

            if OPTION == 'R':
                # Load tarnn model
                logger.info("Loading model...")
                checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
                logger.info(checkpoint_file)

                # Load the saved meta graph and restore variables
                saver = tf.train.import_meta_graph("{0}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
            if OPTION == 'T':
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                sess.run(tf.global_variables_initializer())
                sess.run(tf.local_variables_initializer())

                # Embedding visualization config
                config = projector.ProjectorConfig()
                embedding_conf = config.embeddings.add()
                embedding_conf.tensor_name = "embedding"
                embedding_conf.metadata_path = args.metadata_file

                projector.visualize_embeddings(train_summary_writer, config)
                projector.visualize_embeddings(validation_summary_writer, config)

                # Save the embedding visualization
                saver.save(sess, os.path.join(out_dir, "embedding", "embedding.ckpt"))

            current_step = sess.run(tarnn.global_step)

            def train_step(x_batch_content, x_batch_question, x_batch_option, y_batch):
                """A single training step"""
                feed_dict = {
                    tarnn.input_x_content: x_batch_content,
                    tarnn.input_x_question: x_batch_question,
                    tarnn.input_x_option: x_batch_option,
                    tarnn.input_y: y_batch,
                    tarnn.dropout_keep_prob: args.dropout_rate,
                    tarnn.is_training: True
                }
                _, step, summaries, loss = sess.run(
                    [train_op, tarnn.global_step, train_summary_op, tarnn.loss], feed_dict)
                logger.info("step {0}: loss {1:g}".format(step, loss))
                train_summary_writer.add_summary(summaries, step)

            def validation_step(x_val_content, x_val_question, x_val_option, y_val, writer=None):
                """Evaluates model on a validation set"""
                batches_validation = dh.batch_iter(list(zip(x_val_content, x_val_question, x_val_option, y_val)),
                                                   args.batch_size, 1)

                eval_counter, eval_loss = 0, 0.0
                true_labels = []
                predicted_scores = []

                for batch_validation in batches_validation:
                    x_batch_content, x_batch_question, x_batch_option, y_batch = zip(*batch_validation)
                    feed_dict = {
                        tarnn.input_x_content: x_batch_content,
                        tarnn.input_x_question: x_batch_question,
                        tarnn.input_x_option: x_batch_option,
                        tarnn.input_y: y_batch,
                        tarnn.dropout_keep_prob: 1.0,
                        tarnn.is_training: False
                    }
                    step, summaries, scores, cur_loss = sess.run(
                        [tarnn.global_step, validation_summary_op, tarnn.scores, tarnn.loss], feed_dict)

                    # Prepare for calculating metrics
                    for i in y_batch:
                        true_labels.append(i)
                    for j in scores:
                        predicted_scores.append(j)

                    eval_loss = eval_loss + cur_loss
                    eval_counter = eval_counter + 1

                    if writer:
                        writer.add_summary(summaries, step)

                eval_loss = float(eval_loss / eval_counter)

                # Calculate PCC & DOA
                pcc, doa = dh.evaluation(true_labels, predicted_scores)
                # Calculate RMSE
                rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
                r2 = r2_score(true_labels, predicted_scores)

                return eval_loss, pcc, doa, rmse, r2

            # Generate batches
            batches_train = dh.batch_iter(list(zip(x_train_content, x_train_question, x_train_option, y_train)),
                                          args.batch_size, args.epochs)

            num_batches_per_epoch = int((len(y_train) - 1) / args.batch_size) + 1

            # Training loop. For each batch...
            for batch_train in batches_train:
                x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train = zip(*batch_train)
                train_step(x_batch_train_content, x_batch_train_question, x_batch_train_option, y_batch_train)
                current_step = tf.train.global_step(sess, tarnn.global_step)

                if current_step % args.evaluate_steps == 0:
                    logger.info("\nEvaluation:")
                    eval_loss, pcc, doa, rmse, r2 = validation_step(x_val_content, x_val_question, x_val_option, y_val,
                                                                    writer=validation_summary_writer)
                    logger.info("All Validation set: Loss {0:g} | PCC {1:g} | DOA {2:g} | RMSE {3:g} | R2 {4:g}"
                                .format(eval_loss, pcc, doa, rmse, r2))
                    best_saver.handle(rmse, sess, current_step)
                if current_step % args.checkpoint_steps == 0:
                    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                    path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                    logger.info("Saved model checkpoint to {0}\n".format(path))
                if current_step % num_batches_per_epoch == 0:
                    current_epoch = current_step // num_batches_per_epoch
                    logger.info("Epoch {0} has finished!".format(current_epoch))

    logger.info("All Done.")
def train():
    """Training RMIDP model."""
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file, args.word2vec_file)
    val_data = dh.load_data_and_labels(args.validation_file, args.word2vec_file)

    logger.info("Data padding...")
    train_dataset = dh.MyData(train_data, args.pad_seq_len, device)
    val_dataset = dh.MyData(val_data, args.pad_seq_len, device)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # Load word2vec model
    VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix = dh.load_word2vec_matrix(args.word2vec_file)

    # Init network
    logger.info("Init nn...")
    net = RMIDP(args, VOCAB_SIZE, EMBEDDING_SIZE, pretrained_word2vec_matrix).to(device)

    print("Model's state_dict:")
    for param_tensor in net.state_dict():
        print(param_tensor, "\t", net.state_dict()[param_tensor].size())

    criterion = Loss()
    optimizer = torch.optim.Adam(net.parameters(), lr=args.learning_rate, weight_decay=args.l2_lambda)

    if OPTION == 'T':
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir, num_to_keep=args.num_checkpoints, maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
    elif OPTION == 'R':
        timestamp = input("[Input] Please input the checkpoints model you want to restore: ")
        while not (timestamp.isdigit() and len(timestamp) == 10):
            timestamp = input("[Warning] The format of your input is illegal, please re-input: ")
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir, num_to_keep=args.num_checkpoints, maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
        checkpoint = torch.load(out_dir)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logger.info("Training...")
    writer = SummaryWriter('summary')

    def eval_model(val_loader, epoch):
        """
        Evaluate on the validation set.
        """
        net.eval()
        eval_loss = 0.0
        true_labels, predicted_scores = [], []
        for batch in val_loader:
            x_val_fb_content, x_val_fb_question, x_val_fb_option, \
            x_val_fb_clens, x_val_fb_qlens, x_val_fb_olens, y_val_fb = batch

            logits, scores = net(x_val_fb_content, x_val_fb_question, x_val_fb_option)
            avg_batch_loss = criterion(scores, y_val_fb)
            eval_loss = eval_loss + avg_batch_loss.item()
            for i in y_val_fb[0].tolist():
                true_labels.append(i)
            for j in scores[0].tolist():
                predicted_scores.append(j)

        # Calculate the Metrics
        eval_rmse = mean_squared_error(true_labels, predicted_scores) ** 0.5
        eval_r2 = r2_score(true_labels, predicted_scores)
        eval_pcc, eval_doa = dh.evaluation(true_labels, predicted_scores)
        eval_loss = eval_loss / len(val_loader)
        cur_value = eval_rmse
        logger.info("All Validation set: Loss {0:g} | PCC {1:.4f} | DOA {2:.4f} | RMSE {3:.4f} | R2 {4:.4f}"
                    .format(eval_loss, eval_pcc, eval_doa, eval_rmse, eval_r2))
        writer.add_scalar('validation loss', eval_loss, epoch)
        writer.add_scalar('validation PCC', eval_pcc, epoch)
        writer.add_scalar('validation DOA', eval_doa, epoch)
        writer.add_scalar('validation RMSE', eval_rmse, epoch)
        writer.add_scalar('validation R2', eval_r2, epoch)
        return cur_value

    for epoch in tqdm(range(args.epochs), desc="Epochs:", leave=True):
        # Training step
        batches = trange(len(train_loader), desc="Batches", leave=True)
        for batch_cnt, batch in zip(batches, train_loader):
            net.train()
            x_train_fb_content, x_train_fb_question, x_train_fb_option, \
            x_train_fb_clens, x_train_fb_qlens, x_train_fb_olens, y_train_fb = batch

            optimizer.zero_grad()   # 如果不置零,Variable 的梯度在每次 backward 的时候都会累加
            logits, scores = net(x_train_fb_content, x_train_fb_question, x_train_fb_option)
            avg_batch_loss = criterion(scores, y_train_fb)
            avg_batch_loss.backward()
            optimizer.step()    # Parameter updating
            batches.set_description("Batches (Loss={:.4f})".format(avg_batch_loss.item()))
            logger.info('[epoch {0}, batch {1}] loss: {2:.4f}'.format(epoch + 1, batch_cnt, avg_batch_loss.item()))
            writer.add_scalar('training loss', avg_batch_loss, batch_cnt)
        # Evaluation step
        cur_value = eval_model(val_loader, epoch)
        saver.handle(cur_value, net, optimizer, epoch)
    writer.close()

    logger.info('Training Finished.')
예제 #10
0
def train():
    """Training QuesNet model."""
    dh.tab_printer(args, logger)

    # Load sentences, labels, and training parameters
    logger.info("Loading data...")
    logger.info("Data processing...")
    train_data = dh.load_data_and_labels(args.train_file)
    val_data = dh.load_data_and_labels(args.validation_file)

    logger.info("Data padding...")
    train_dataset = dh.MyData(train_data.activity, train_data.timestep,
                              train_data.labels)
    val_dataset = dh.MyData(val_data.activity, val_data.timestep,
                            val_data.labels)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=dh.collate_fn)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            collate_fn=dh.collate_fn)

    # Load word2vec model
    COURSE_SIZE = dh.course2vec(args.course2vec_file)

    # Init network
    logger.info("Init nn...")
    net = MOOCNet(args, COURSE_SIZE).to(device)

    # weights_init(model=net)
    # print_weight(model=net)

    print("Model's state_dict:")
    for param_tensor in net.state_dict():
        print(param_tensor, "\t", net.state_dict()[param_tensor].size())

    criterion = Loss()
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.l2_lambda)

    if OPTION == 'T':
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir,
                                       num_to_keep=args.num_checkpoints,
                                       maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
    elif OPTION == 'R':
        timestamp = input(
            "[Input] Please input the checkpoints model you want to restore: ")
        while not (timestamp.isdigit() and len(timestamp) == 10):
            timestamp = input(
                "[Warning] The format of your input is illegal, please re-input: "
            )
        out_dir = os.path.abspath(
            os.path.join(os.path.curdir, "runs", timestamp))
        saver = cm.BestCheckpointSaver(save_dir=out_dir,
                                       num_to_keep=args.num_checkpoints,
                                       maximize=False)
        logger.info("Writing to {0}\n".format(out_dir))
        checkpoint = torch.load(out_dir)
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    logger.info("Training...")
    writer = SummaryWriter('summary')

    def eval_model(val_loader, epoch):
        """
        Evaluate on the validation set.
        """
        net.eval()
        eval_loss = 0.0
        true_labels, predicted_scores, predicted_labels = [], [], []
        for batch in val_loader:
            x_val, tsp_val, y_val = create_input_data(batch)
            logits, scores = net(x_val, tsp_val)
            avg_batch_loss = criterion(scores, y_val)
            eval_loss = eval_loss + avg_batch_loss.item()
            for i in y_val.tolist():
                true_labels.append(i)
            for j in scores.tolist():
                predicted_scores.append(j)
                if j >= args.threshold:
                    predicted_labels.append(1)
                else:
                    predicted_labels.append(0)

        # Calculate the Metrics
        eval_acc = accuracy_score(true_labels, predicted_labels)
        eval_pre = precision_score(true_labels, predicted_labels)
        eval_rec = recall_score(true_labels, predicted_labels)
        eval_F1 = f1_score(true_labels, predicted_labels)
        eval_auc = roc_auc_score(true_labels, predicted_scores)
        eval_prc = average_precision_score(true_labels, predicted_scores)
        eval_loss = eval_loss / len(val_loader)
        cur_value = eval_F1
        logger.info(
            "All Validation set: Loss {0:g} | ACC {1:.4f} | PRE {2:.4f} | REC {3:.4f} | F1 {4:.4f} | AUC {5:.4f} | PRC {6:.4f}"
            .format(eval_loss, eval_acc, eval_pre, eval_rec, eval_F1, eval_auc,
                    eval_prc))
        writer.add_scalar('validation loss', eval_loss, epoch)
        writer.add_scalar('validation ACC', eval_acc, epoch)
        writer.add_scalar('validation PRECISION', eval_pre, epoch)
        writer.add_scalar('validation RECALL', eval_rec, epoch)
        writer.add_scalar('validation F1', eval_F1, epoch)
        writer.add_scalar('validation AUC', eval_auc, epoch)
        writer.add_scalar('validation PRC', eval_prc, epoch)
        return cur_value

    for epoch in tqdm(range(args.epochs), desc="Epochs:", leave=True):
        # Training step
        batches = trange(len(train_loader), desc="Batches", leave=True)
        for batch_cnt, batch in zip(batches, train_loader):
            net.train()
            x_train, tsp_train, y_train = create_input_data(batch)
            optimizer.zero_grad()  # 如果不置零,Variable 的梯度在每次 backward 的时候都会累加
            logits, scores = net(x_train, tsp_train)
            # TODO
            avg_batch_loss = criterion(scores, y_train)
            avg_batch_loss.backward()
            optimizer.step()  # Parameter updating
            batches.set_description("Batches (Loss={:.4f})".format(
                avg_batch_loss.item()))
            logger.info('[epoch {0}, batch {1}] loss: {2:.4f}'.format(
                epoch + 1, batch_cnt, avg_batch_loss.item()))
            writer.add_scalar('training loss', avg_batch_loss, batch_cnt)
        # Evaluation step
        cur_value = eval_model(val_loader, epoch)
        saver.handle(cur_value, net, optimizer, epoch)
    writer.close()

    logger.info('Training Finished.')