Пример #1
0
def main(unused_argv):
    """Calls train and test routines for the dictionary model.

  If restore FLAG is true, loads an existing model and runs test
  routine. If restore FLAG is false, builds a model and trains it.
  """

    logger.info("train starts, params:" + str(sys.argv))

    if FLAGS.vocab_file is None:
        vocab_file = os.path.join(FLAGS.data_dir,
                                  "definitions_%s.vocab" % FLAGS.vocab_size)
    else:
        vocab_file = FLAGS.vocab_file

    # Build and train a dictionary model.
    if not FLAGS.restore:
        emb_size = FLAGS.embedding_size
        # Load any pre-trained word embeddings.
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # embs_dict is a dictionary from words to vectors.
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            if FLAGS.pretrained_input:
                emb_size = pre_emb_dim
        else:
            pre_embs, embs_dict = None, None

        # Create vocab file, process definitions (if necessary).
        data_utils.prepare_dict_data(FLAGS.data_dir,
                                     FLAGS.train_file,
                                     FLAGS.dev_file,
                                     vocabulary_size=FLAGS.vocab_size,
                                     max_seq_len=FLAGS.max_seq_len)
        # vocab is a dictionary from strings to integers.
        vocab, _ = data_utils.initialize_vocabulary(vocab_file)
        pre_embs = None
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # pre_embs is a numpy array with row vectors for words in vocab.
            # for vocab words not in embs_dict, vector is all zeros.
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

        # Build the TF graph for the dictionary model.
        model = build_model(max_seq_len=FLAGS.max_seq_len,
                            vocab_size=FLAGS.vocab_size,
                            emb_size=emb_size,
                            learning_rate=FLAGS.learning_rate,
                            encoder_type=FLAGS.encoder_type,
                            pretrained_target=FLAGS.pretrained_target,
                            pretrained_input=FLAGS.pretrained_input,
                            pre_embs=pre_embs)

        # Run the training for specified number of epochs.
        save_path, saver = train_network(model,
                                         FLAGS.num_epochs,
                                         FLAGS.batch_size,
                                         FLAGS.data_dir,
                                         FLAGS.save_dir,
                                         FLAGS.vocab_size,
                                         pre_embs,
                                         name=FLAGS.model_name)

    # Load an existing model.
    else:
        # Note cosine loss output form is hard coded here. For softmax output
        # change "cosine" to "softmax"
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            vocab, _ = data_utils.initialize_vocabulary(vocab_file)
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

        with tf.device("/cpu:0"):
            with tf.Session() as sess:
                (input_node, target_node, predictions, loss, vocab, rev_vocab,
                 dropout_keep_prob) = restore_model(sess,
                                                    FLAGS.save_dir,
                                                    vocab_file,
                                                    out_form="cosine")

                if FLAGS.evaluate:
                    evaluate_model(sess,
                                   FLAGS.data_dir,
                                   input_node,
                                   target_node,
                                   predictions,
                                   loss,
                                   pre_embs,
                                   dropout_keep_prob,
                                   out_form="cosine")

                # Load the final saved model and run querying routine.
                if FLAGS.query:
                    query_model(sess,
                                input_node,
                                predictions,
                                vocab,
                                rev_vocab,
                                FLAGS.max_seq_len,
                                dropout_keep_prob,
                                embs=pre_embs,
                                out_form="cosine")

    writer = tf.summary.FileWriter('./graph', tf.get_default_graph())
    writer.close()
Пример #2
0
def main(unused_argv):
    """Calls train and test routines for the dictionary model.
  If restore FLAG is true, loads an existing model and runs test
  routine. If restore FLAG is false, builds a model and trains it.
  """
    if FLAGS.vocab_file is None:
        vocab_file = os.path.join(FLAGS.data_dir,
                                  "definitions_%s.vocab" % FLAGS.vocab_size)
    else:
        vocab_file = FLAGS.vocab_file

    # Build and train a dictionary model.
    if not FLAGS.restore:
        emb_size = FLAGS.embedding_size
        # Load any pre-trained word embeddings.
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # embs_dict is a dictionary from words to vectors.
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            if FLAGS.pretrained_input:
                emb_size = pre_emb_dim
        else:
            pre_embs, embs_dict = None, None

        # Create vocab file, process definitions (if necessary).
        data_utils.prepare_dict_data(FLAGS.data_dir,
                                     FLAGS.train_file,
                                     FLAGS.dev_file,
                                     vocabulary_size=FLAGS.vocab_size,
                                     max_seq_len=FLAGS.max_seq_len)
        # vocab is a dictionary from strings to integers.
        vocab, _ = data_utils.initialize_vocabulary(vocab_file)
        pre_embs = None
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # pre_embs is a numpy array with row vectors for words in vocab.
            # for vocab words not in embs_dict, vector is all zeros.
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

        # Build the TF graph for the dictionary model.
        model = build_model(max_seq_len=FLAGS.max_seq_len,
                            vocab_size=FLAGS.vocab_size,
                            emb_size=emb_size,
                            learning_rate=FLAGS.learning_rate,
                            encoder_type=FLAGS.encoder_type,
                            pretrained_target=FLAGS.pretrained_target,
                            pretrained_input=FLAGS.pretrained_input,
                            pre_embs=pre_embs)

        # Run the training for specified number of epochs.
        save_path, saver = train_network(model,
                                         FLAGS.num_epochs,
                                         FLAGS.batch_size,
                                         FLAGS.data_dir,
                                         FLAGS.save_dir,
                                         FLAGS.vocab_size,
                                         name=FLAGS.model_name)

    # Load an existing model.
    else:
        if FLAGS.restore_and_query:
            # Note cosine loss output form is hard coded here. For softmax output
            # change "cosine" to "softmax"
            if FLAGS.pretrained_input or FLAGS.pretrained_target:
                embs_dict, pre_emb_dim = load_pretrained_embeddings(
                    FLAGS.embeddings_path)
                vocab, _ = data_utils.initialize_vocabulary(vocab_file)
                pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

            with tf.device("/cpu:0"):
                with tf.Session() as sess:
                    (input_node, target_node, predictions, loss, vocab,
                     rev_vocab) = restore_model(sess,
                                                FLAGS.save_dir,
                                                vocab_file,
                                                out_form="cosine")

                    if FLAGS.evaluate:
                        evaluate_model(sess,
                                       FLAGS.data_dir,
                                       input_node,
                                       target_node,
                                       predictions,
                                       loss,
                                       embs=pre_embs,
                                       out_form="cosine")

                    # Load the final saved model and run querying routine.
                    query_model(sess,
                                input_node,
                                predictions,
                                vocab,
                                rev_vocab,
                                FLAGS.max_seq_len,
                                embs=pre_embs,
                                out_form="cosine")
        else:
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            vocab, _ = data_utils.initialize_vocabulary(vocab_file)
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

            sess = tf.Session()
            model_path = tf.train.latest_checkpoint(FLAGS.save_dir)
            # restore the model from the meta graph
            saver = tf.train.import_meta_graph(model_path + ".meta")
            saver.restore(sess, model_path)
            graph = tf.get_default_graph()

            gloss_in = graph.get_tensor_by_name("input_placeholder:0")
            head_in = graph.get_tensor_by_name("labels_placeholder:0")
            total_loss = graph.get_tensor_by_name('total_loss:0')
            train_step = graph.get_operation_by_name('Adam')

            num_training = 0

            with sess:
                training_losses = []
                # epoch is a generator of batches which passes over the data once.
                for idx, epoch in enumerate(
                        gen_epochs(FLAGS.data_dir,
                                   FLAGS.num_epochs,
                                   FLAGS.batch_size,
                                   FLAGS.vocab_size,
                                   phase="train")):
                    # Running total for training loss reset every 500 steps.
                    training_loss = 0
                    print("\nEPOCH", idx)
                    for step, (gloss, head) in enumerate(epoch):
                        num_training += len(gloss)
                        training_loss_, _ = sess.run([total_loss, train_step],
                                                     feed_dict={
                                                         gloss_in: gloss,
                                                         head_in: head
                                                     })
                        training_loss += training_loss_
                        if step % 500 == 0 and step > 0:
                            loss_ = training_loss / 500
                            print(
                                "Average loss step %s, for last 500 steps: %s"
                                % (step, loss_))
                            training_losses.append(training_loss / 500)
                            training_loss = 0
                    # Save current model after another epoch.
                    if idx % 50 == 0:
                        save_path = os.path.join(
                            FLAGS.save_dir,
                            "%s_%s.ckpt" % (FLAGS.model_name, idx))
                        save_path = saver.save(sess, save_path)
                        print("Model saved in file: %s after epoch: %s" %
                              (save_path, idx))
                print("Total data points seen during training: %s" %
                      num_training)