Ejemplo n.º 1
0
def dev_step(sess, model, target_loss_weight):
    results = defaultdict(list)
    num_test = 0
    num_correct = 0.0
    test_batches = data_helpers.batch_iter(valid_dataset,
                                           FLAGS.batch_size,
                                           1,
                                           target_loss_weight,
                                           FLAGS.max_sequence_length,
                                           charVocab,
                                           FLAGS.max_word_length,
                                           shuffle=True)
    for test_batch in test_batches:
        x_question, x_answer, x_question_len, x_answer_len, x_target, x_target_weight, id_pairs, x_q_char, x_q_len, x_a_char, x_a_len = test_batch
        feed_dict = {
            model.question: x_question,
            model.answer: x_answer,
            model.question_len: x_question_len,
            model.answer_len: x_answer_len,
            model.target: x_target,
            model.target_loss_weight: x_target_weight,
            model.dropout_keep_prob: 1.0,
            model.q_charVec: x_q_char,
            model.q_charLen: x_q_len,
            model.a_charVec: x_a_char,
            model.a_charLen: x_a_len
        }
        batch_accuracy, predicted_prob = sess.run(
            [model.accuracy, model.probs], feed_dict)
        num_test += len(predicted_prob)
        if num_test % 1000 == 0:
            print(num_test)

        num_correct += len(predicted_prob) * batch_accuracy
        for i, prob_score in enumerate(predicted_prob):
            question_id, answer_id, label = id_pairs[i]
            results[question_id].append((answer_id, label, prob_score))

    #calculate top-1 precision
    print('num_test_samples: {}  test_accuracy: {}'.format(
        num_test, num_correct / num_test))
    accu, precision, recall, f1, loss = metrics.classification_metrics(results)
    print('Accuracy: {}, Precision: {}  Recall: {}  F1: {} Loss: {}'.format(
        accu, precision, recall, f1, loss))

    mrr = metrics.mean_reciprocal_rank(results)
    top_1_precision = metrics.top_k_precision(results, k=1)
    top_2_precision = metrics.top_k_precision(results, k=2)
    top_5_precision = metrics.top_k_precision(results, k=5)
    top_10_precision = metrics.top_k_precision(results, k=10)
    total_valid_query = metrics.get_num_valid_query(results)

    print(
        'MRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}\n'.
        format(mrr, top_1_precision, total_valid_query))
    print('Top-2 precision: {}'.format(top_2_precision))
    print('Top-5 precision: {}'.format(top_5_precision))
    print('Top-10 precision: {}'.format(top_10_precision))

    return mrr
Ejemplo n.º 2
0
            def dev_step():
                results = defaultdict(list)
                num_test = 0
                num_correct = 0.0
                test_batches = data_helpers.batch_iter(test_dataset,
                                                       FLAGS.batch_size,
                                                       1,
                                                       target_loss_weight,
                                                       idf,
                                                       SEQ_LEN,
                                                       charVocab,
                                                       FLAGS.max_word_length,
                                                       shuffle=True)
                for test_batch in test_batches:
                    x_question, x_answer, x_question_len, x_answer_len, x_target, x_target_weight, id_pairs, extra_feature, q_feature, a_feature, x_q_char, x_q_len, x_a_char, x_a_len = test_batch
                    feed_dict = {
                        esim.question: x_question,
                        esim.answer: x_answer,
                        esim.question_len: x_question_len,
                        esim.answer_len: x_answer_len,
                        esim.target: x_target,
                        esim.target_loss_weight: x_target_weight,
                        esim.dropout_keep_prob: 1.0,
                        esim.extra_feature: extra_feature,
                        esim.q_word_feature: q_feature,
                        esim.a_word_feature: a_feature,
                        esim.q_charVec: x_q_char,
                        esim.q_charLen: x_q_len,
                        esim.a_charVec: x_a_char,
                        esim.a_charLen: x_a_len
                    }
                    batch_accuracy, predicted_prob = sess.run(
                        [esim.accuracy, esim.probs], feed_dict)
                    num_test += len(predicted_prob)
                    if num_test % 1000 == 0:
                        print(num_test)

                    num_correct += len(predicted_prob) * batch_accuracy
                    for i, prob_score in enumerate(predicted_prob):
                        question_id, answer_id, label = id_pairs[i]
                        results[question_id].append(
                            (answer_id, label, prob_score))

                #calculate top-1 precision
                print('num_test_samples: {}  test_accuracy: {}'.format(
                    num_test, num_correct / num_test))
                accu, precision, recall, f1, loss = metrics.classification_metrics(
                    results)
                print(
                    'Accuracy: {}, Precision: {}  Recall: {}  F1: {} Loss: {}'.
                    format(accu, precision, recall, f1, loss))

                mvp = metrics.mean_average_precision(results)
                mrr = metrics.mean_reciprocal_rank(results)
                top_1_precision = metrics.top_1_precision(results)
                total_valid_query = metrics.get_num_valid_query(results)
                print(
                    'MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'
                    .format(mvp, mrr, top_1_precision, total_valid_query))

                return mrr
Ejemplo n.º 3
0
                mvp = metrics.mean_average_precision(results)
                mrr = metrics.mean_reciprocal_rank(results)
                top_1_precision = metrics.top_1_precision(results)
                total_valid_query = metrics.get_num_valid_query(results)
                print(
                    'MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'
                    .format(mvp, mrr, top_1_precision, total_valid_query))

                return mrr

            best_mrr = 0.0
            batches = data_helpers.batch_iter(train_dataset,
                                              FLAGS.batch_size,
                                              FLAGS.num_epochs,
                                              target_loss_weight,
                                              idf,
                                              SEQ_LEN,
                                              charVocab,
                                              FLAGS.max_word_length,
                                              shuffle=True)
            for batch in batches:
                x_question, x_answer, x_question_len, x_answer_len, x_target, x_target_weight, id_pairs, extra_feature, q_feature, a_feature, x_q_char, x_q_len, x_a_char, x_a_len = batch
                train_step(x_question, x_answer, x_question_len, x_answer_len,
                           x_target, x_target_weight, id_pairs, extra_feature,
                           q_feature, a_feature, x_q_char, x_q_len, x_a_char,
                           x_a_len)
                current_step = tf.train.global_step(sess, global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    print("\nEvaluation:")
                    valid_mrr = dev_step()
                    if valid_mrr > best_mrr:
        q_char_len = graph.get_operation_by_name(
            "question_char_len").outputs[0]

        a_char_feature = graph.get_operation_by_name("answer_char").outputs[0]
        a_char_len = graph.get_operation_by_name("answer_char_len").outputs[0]

        # Tensors we want to evaluate
        prob = graph.get_operation_by_name("convolution-1/prob").outputs[0]

        results = defaultdict(list)
        num_test = 0
        test_batches = data_helpers.batch_iter(test_dataset,
                                               FLAGS.batch_size,
                                               1,
                                               target_loss_weight,
                                               idf,
                                               SEQ_LEN,
                                               charVocab,
                                               FLAGS.max_word_length,
                                               shuffle=False)
        for test_batch in test_batches:
            batch_question, batch_answer, batch_question_len, batch_answer_len, batch_target, batch_target_weight, batch_id_pairs, extra_feature, q_feature, a_feature, x_q_char, x_q_len, x_a_char, x_a_len = test_batch
            feed_dict = {
                question_x: batch_question,
                answer_x: batch_answer,
                question_len: batch_question_len,
                answer_len: batch_answer_len,
                dropout_keep_prob: 1.0,
                model_extra_feature: extra_feature,
                question_word_feature: q_feature,
                answer_word_feature: a_feature,
Ejemplo n.º 5
0
def train():
    graph = tf.Graph()
    with graph.as_default():
        with tf.device("/gpu:0"):
            session_conf = tf.ConfigProto(
                allow_soft_placement=FLAGS.allow_soft_placement,
                log_device_placement=FLAGS.log_device_placement)
            sess = tf.Session(config=session_conf)
            with sess.as_default():
                model = DualEncoder(sequence_length=FLAGS.max_sequence_length,
                                    vocab_size=len(vocab),
                                    embedding_size=FLAGS.embedding_dim,
                                    vocab=vocab,
                                    rnn_size=FLAGS.rnn_size,
                                    maxWordLength=FLAGS.max_word_length,
                                    charVocab=charVocab,
                                    l2_reg_lambda=FLAGS.l2_reg_lambda)
                # Define Training procedure
                global_step = tf.Variable(0,
                                          name="global_step",
                                          trainable=False)
                starter_learning_rate = 0.001
                target_loss_weight = [1.0, 1.0]
                learning_rate = tf.train.exponential_decay(
                    starter_learning_rate,
                    global_step,
                    5000,
                    0.96,
                    staircase=True)
                optimizer = tf.train.AdamOptimizer(learning_rate)
                grads_and_vars = optimizer.compute_gradients(model.mean_loss)
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step=global_step)
                # Output directory for models and summaries
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(
                    os.path.join(os.path.curdir, "runs", timestamp))
                print("Writing to {}\n".format(out_dir))

                # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
                checkpoint_dir = os.path.abspath(
                    os.path.join(out_dir, "checkpoints"))
                checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                saver = tf.train.Saver(tf.global_variables())

                # Initialize all variables
                sess.run(tf.global_variables_initializer())

                best_mrr = 0.0
                batches = data_helpers.batch_iter(train_dataset,
                                                  FLAGS.batch_size,
                                                  FLAGS.num_epochs,
                                                  target_loss_weight,
                                                  FLAGS.max_sequence_length,
                                                  charVocab,
                                                  FLAGS.max_word_length,
                                                  shuffle=True)
                for batch in batches:
                    x_question, x_answer, x_question_len, x_answer_len, x_target, x_target_weight, id_pairs, x_q_char, x_q_len, x_a_char, x_a_len = batch
                    feed_dict = {
                        model.question: x_question,
                        model.answer: x_answer,
                        model.question_len: x_question_len,
                        model.answer_len: x_answer_len,
                        model.target: x_target,
                        model.target_loss_weight: x_target_weight,
                        model.dropout_keep_prob: FLAGS.dropout_keep_prob,
                        model.q_charVec: x_q_char,
                        model.q_charLen: x_q_len,
                        model.a_charVec: x_a_char,
                        model.a_charLen: x_a_len
                    }
                    train_step(sess, train_op, global_step, model, feed_dict)
                    current_step = tf.train.global_step(sess, global_step)
                    if current_step > 10 and current_step % FLAGS.evaluate_every == 0:
                        print("\nEvaluation:")
                        valid_mrr = dev_step(sess, model, target_loss_weight)
                        if valid_mrr > best_mrr:
                            best_mrr = valid_mrr
                            path = saver.save(sess,
                                              checkpoint_prefix,
                                              global_step=current_step)
                            print(
                                "Saved model checkpoint to {}\n".format(path))