Пример #1
0
def inference(sess, data_iterator, probs_op, placeholders):
    """Get the predicted class {0, 1} of given sentence pairs."""
    x_source, source_seq_length, \
    x_target, target_seq_length, \
    labels = placeholders

    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(x_source)

    num_iter = int(np.ceil(data_iterator.size / FLAGS.batch_size))
    probs = []
    for step in xrange(num_iter):
        source, target, label = data_iterator.next_batch(FLAGS.batch_size)
        source_len = utils.sequence_length(source)
        target_len = utils.sequence_length(target)

        feed_dict = {
            x_source: source,
            x_target: target,
            labels: label,
            source_seq_length: source_len,
            target_seq_length: target_len
        }

        batch_probs = sess.run(probs_op, feed_dict=feed_dict)
        probs.extend(batch_probs.tolist())
    probs = np.array(probs[:data_iterator.size])
    return probs
Пример #2
0
def eval_epoch(sess, model, data_iterator, summary_writer):
    """Evaluate model for one epoch."""
    sess.run(tf.local_variables_initializer())
    num_iter = int(np.ceil(data_iterator.size / FLAGS.batch_size))
    epoch_loss = 0
    for step in xrange(num_iter):
        source, target, label = data_iterator.next_batch(FLAGS.batch_size)
        source_len = utils.sequence_length(source)
        target_len = utils.sequence_length(target)
        feed_dict = {model.x_source: source,
                     model.x_target: target,
                     model.labels: label,
                     model.source_seq_length: source_len,
                     model.target_seq_length: target_len,
                     model.decision_threshold: FLAGS.decision_threshold}
        loss_value, epoch_accuracy,\
        epoch_precision, epoch_recall = sess.run([model.mean_loss,
                                                  model.accuracy[1],
                                                  model.precision[1],
                                                  model.recall[1]],
                                                  feed_dict=feed_dict)
        epoch_loss += loss_value
        if step % FLAGS.steps_per_checkpoint == 0:
            summary = sess.run(model.summaries, feed_dict=feed_dict)
            summary_writer.add_summary(summary, global_step=data_iterator.global_step)
    epoch_loss /= step
    epoch_f1 = utils.f1_score(epoch_precision, epoch_recall)
    print("  Testing:  Loss = {:.6f}, Accuracy = {:.4f}, "
          "Precision = {:.4f}, Recall = {:.4f}, F1 = {:.4f}"
          .format(epoch_loss, epoch_accuracy,
                  epoch_precision, epoch_recall, epoch_f1))
Пример #3
0
def inference(sess, data_iterator, probs_op, placeholders):
    """Get the predicted class {0, 1} of given sentence pairs."""
    x_source, source_seq_length,\
    x_target, target_seq_length,\
    labels = placeholders

    num_iter = int(np.ceil(data_iterator.size / FLAGS.batch_size))
    probs = []
    labelss = []
    for step in xrange(num_iter):
        source, target, label = data_iterator.next_batch(FLAGS.batch_size)
        source_len = utils.sequence_length(source)
        target_len = utils.sequence_length(target)

        feed_dict = {
            x_source: source,
            x_target: target,
            labels: label,
            source_seq_length: source_len,
            target_seq_length: target_len
        }

        batch_probs = sess.run(probs_op, feed_dict=feed_dict)
        probs.extend(batch_probs.tolist())
        labelss.extend(label.tolist())
        #pdb.set_trace()
    probs = np.array(probs[:data_iterator.size])
    labelss = np.array(labelss[:data_iterator.size])
    with open('y_scores', 'w') as f:
        for s, l in zip(probs, labelss):
            f.write(str(s) + '\t' + str(l) + '\n')

    return probs
Пример #4
0
def main(_):
    assert FLAGS.source_train_path, ("--source_train_path is required.")
    assert FLAGS.target_train_path, ("--target_train_path is required.")

    # Create vocabularies.
    source_vocab_path = os.path.join(os.path.dirname(FLAGS.source_train_path),
                                     "vocabulary.source")
    target_vocab_path = os.path.join(os.path.dirname(FLAGS.source_train_path),
                                     "vocabulary.target")
    utils.create_vocabulary(source_vocab_path, FLAGS.source_train_path, FLAGS.source_vocab_size)
    utils.create_vocabulary(target_vocab_path, FLAGS.target_train_path, FLAGS.target_vocab_size)

    # Read vocabularies.
    source_vocab, rev_source_vocab = utils.initialize_vocabulary(source_vocab_path)
    target_vocab, rev_target_vocab = utils.initialize_vocabulary(target_vocab_path)

    # Read parallel sentences.
    parallel_data = utils.read_data(FLAGS.source_train_path, FLAGS.target_train_path,
                                    source_vocab, target_vocab)

    # Read validation data set.
    if FLAGS.source_valid_path and FLAGS.target_valid_path:
        valid_data = utils.read_data(FLAGS.source_valid_path, FLAGS.target_valid_path,
                                    source_vocab, target_vocab)

    # Initialize BiRNN.
    config = Config(len(source_vocab),
                    len(target_vocab),
                    FLAGS.embedding_size,
                    FLAGS.state_size,
                    FLAGS.hidden_size,
                    FLAGS.num_layers,
                    FLAGS.learning_rate,
                    FLAGS.max_gradient_norm,
                    FLAGS.use_lstm,
                    FLAGS.use_mean_pooling,
                    FLAGS.use_max_pooling,
                    FLAGS.source_embeddings_path,
                    FLAGS.target_embeddings_path,
                    FLAGS.fix_pretrained)

    model = BiRNN(config)

    # Build graph.
    model.build_graph()

    # Train  model.
    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        train_iterator = utils.TrainingIteratorRandom(parallel_data, FLAGS.num_negative)
        train_summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "train"), sess.graph)

        if FLAGS.source_valid_path and FLAGS.target_valid_path:
            valid_iterator = utils.EvalIterator(valid_data)
            valid_summary_writer = tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "valid"), sess.graph)

        epoch_loss = 0
        epoch_completed = 0
        batch_completed = 0

        num_iter = int(np.ceil(train_iterator.size / FLAGS.batch_size * FLAGS.num_epochs))
        start_time = time.time()
        print("Training model on {} sentence pairs per epoch:".
            format(train_iterator.size, valid_iterator.size))

        for step in xrange(num_iter):
            source, target, label = train_iterator.next_batch(FLAGS.batch_size)
            source_len = utils.sequence_length(source)
            target_len = utils.sequence_length(target)
            feed_dict = {model.x_source: source,
                         model.x_target: target,
                         model.labels: label,
                         model.source_seq_length: source_len,
                         model.target_seq_length: target_len,
                         model.input_dropout: FLAGS.keep_prob_input,
                         model.output_dropout: FLAGS.keep_prob_output,
                         model.decision_threshold: FLAGS.decision_threshold}

            _, loss_value, epoch_accuracy,\
            epoch_precision, epoch_recall = sess.run([model.train_op,
                                                      model.mean_loss,
                                                      model.accuracy[1],
                                                      model.precision[1],
                                                      model.recall[1]],
                                                      feed_dict=feed_dict)
            epoch_loss += loss_value
            batch_completed += 1
            # Write the model's training summaries.
            if step % FLAGS.steps_per_checkpoint == 0:
                summary = sess.run(model.summaries, feed_dict=feed_dict)
                train_summary_writer.add_summary(summary, global_step=step)
            # End of current epoch.
            if train_iterator.epoch_completed > epoch_completed:
                epoch_time = time.time() - start_time
                epoch_loss /= batch_completed
                epoch_f1 = utils.f1_score(epoch_precision, epoch_recall)
                epoch_completed += 1
                print("Epoch {} in {:.0f} sec\n"
                      "  Training: Loss = {:.6f}, Accuracy = {:.4f}, "
                      "Precision = {:.4f}, Recall = {:.4f}, F1 = {:.4f}"
                      .format(epoch_completed, epoch_time,
                              epoch_loss, epoch_accuracy,
                              epoch_precision, epoch_recall, epoch_f1))
                # Save a model checkpoint.
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir, "model.ckpt")
                model.saver.save(sess, checkpoint_path, global_step=step)
                # Evaluate model on the validation set.
                if FLAGS.source_valid_path and FLAGS.target_valid_path:
                    eval_epoch(sess, model, valid_iterator, valid_summary_writer)
                # Initialize local variables for new epoch.
                batch_completed = 0
                epoch_loss = 0
                sess.run(tf.local_variables_initializer())
                start_time = time.time()

        print("Training done with {} steps.".format(num_iter))
        train_summary_writer.close()
        valid_summary_writer.close()