def restore_model(sess, save_dir, vocab_file, out_form):
    model_path = tf.train.latest_checkpoint(save_dir)
    # restore the model from the meta graph
    saver = tf.train.import_meta_graph(model_path + ".meta")
    saver.restore(sess, model_path)
    graph = tf.get_default_graph()
    # get the names of input and output tensors
    input_node_fw = graph.get_tensor_by_name("fw_input_placeholder:0")
    input_node_bw = graph.get_tensor_by_name("bw_input_placeholder:0")
    target_node = graph.get_tensor_by_name("labels_placeholder:0")
    if out_form == "softmax":
        predictions = graph.get_tensor_by_name("predictions:0")
    else:
        predictions = graph.get_tensor_by_name("fully_connected/Tanh:0")
    loss = graph.get_tensor_by_name("total_loss:0")  # check this is OK
    # vocab is mapping from words to ids, rev_vocab is the reverse.
    vocab, rev_vocab = data_utils_BPE.initialize_vocabulary(vocab_file)
    return input_node_fw, input_node_bw, target_node, predictions, loss, vocab, rev_vocab
def main(unused_argv):
    """Calls train and test routines for the dictionary model.

  If restore FLAG is true, loads an existing model and runs test
  routine. If restore FLAG is false, builds a model and trains it.
  """
    if FLAGS.vocab_file is None:
        vocab_file = os.path.join(FLAGS.data_dir,
                                  "definitions_%s.vocab" % FLAGS.vocab_size)
    else:
        vocab_file = FLAGS.vocab_file

        # Build and train a dictionary model and test every epoch.
    if not FLAGS.restore:
        emb_size = FLAGS.embedding_size
        # Load any pre-trained word embeddings.
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # embs_dict is a dictionary from words to vectors.
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            if FLAGS.pretrained_input:
                emb_size = pre_emb_dim
        else:
            pre_embs, embs_dict = None, None

        # Create vocab file, process definitions (if necessary).
        data_utils_BPE.prepare_dict_data(FLAGS.data_dir,
                                         FLAGS.train_file,
                                         FLAGS.dev_file,
                                         FLAGS.test_file,
                                         vocabulary_size=FLAGS.vocab_size,
                                         max_seq_len=FLAGS.max_seq_len)
        # vocab is a dictionary from strings to integers.
        vocab, _ = data_utils_BPE.initialize_vocabulary(vocab_file)
        pre_embs = None
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            # pre_embs is a numpy array with row vectors for words in vocab.
            # for vocab words not in embs_dict, vector is all zeros.
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)

        # Build the TF graph for the dictionary model.
        model = build_model(max_seq_len=FLAGS.max_seq_len,
                            vocab_size=FLAGS.vocab_size,
                            emb_size=emb_size,
                            learning_rate=FLAGS.learning_rate,
                            encoder_type=FLAGS.encoder_type,
                            pretrained_target=FLAGS.pretrained_target,
                            pretrained_input=FLAGS.pretrained_input,
                            pre_embs=pre_embs)

        # Run the training for specified number of epochs.
        save_path, saver = train_network(model,
                                         FLAGS.num_epochs,
                                         FLAGS.batch_size,
                                         FLAGS.data_dir,
                                         FLAGS.save_dir,
                                         FLAGS.vocab_size,
                                         name=FLAGS.model_name)

    # Load an existing model.
    else:
        # Note cosine loss output form is hard coded here. For softmax output
        # change "cosine" to "softmax"
        if FLAGS.pretrained_input or FLAGS.pretrained_target:
            embs_dict, pre_emb_dim = load_pretrained_embeddings(
                FLAGS.embeddings_path)
            vocab, _ = data_utils_BPE.initialize_vocabulary(vocab_file)
            pre_embs = get_embedding_matrix(embs_dict, vocab, pre_emb_dim)
            out_form = "cosine"
            print('out form is cosine')
        else:
            out_form = "softmax"
        with tf.device("/cpu:0"):
            with tf.Session() as sess:
                (input_node_fw, input_node_bw, target_node, predictions, loss,
                 vocab, rev_vocab) = restore_model(sess,
                                                   FLAGS.save_dir,
                                                   vocab_file,
                                                   out_form=out_form)

                if FLAGS.evaluate:
                    evaluate_model(sess,
                                   FLAGS.data_dir,
                                   input_node_fw,
                                   input_node_bw,
                                   target_node,
                                   predictions,
                                   loss,
                                   embs=pre_embs,
                                   out_form=out_form)
def evaluate_model(sess,
                   data_dir,
                   input_node_fw,
                   input_node_bw,
                   target_node,
                   prediction,
                   loss,
                   embs,
                   out_form="cosine"):
    '''
  Runs the evaluation routine. Added by JP.
  Inputs: 
    data_dir: directory for retrieving test examples
    input_node: input placeholder on graph
    target_node: target placeholder on graph
    prediction: prediction placeholder on graph (output embedding, after tanh layer)
    loss: (present in original implementation but not used here)
    embs: embeddings file 
  
  Uses batch size of 1 to get a prediction (embedding) for each training example, this is 
  compared with all embeddings in the vocabulary using cosine similarity. 
  Note that for crossword questions, only words of correct length are considered. Greatly
  reducing output vocabulary.
  The rank of the correct word is calculated using np.where, and median rank 
  across the test set is reported.
  
  Returns:
    None
  Results:
    If restore = True:
      Saves the results on the various test sets into separate CSV files alongside the correct word.
      Allows for further analysis of results
    If restore = False:
      Evaluating during training, simply outputs the median rank on each test set to the outfile.
      Allows to evaluate model performance during training (to check for overfitting)
  '''
    num_epochs = 1
    batch_size = 1
    check = False
    check_words = []
    print('evaluating model on dev set')
    predictions = np.empty((0, 300), dtype=float)
    correct_word = np.empty((0), dtype=int)
    ranks = np.empty((0), dtype=int)

    # read the test data using gen_epochs
    for epoch in gen_epochs(data_dir,
                            num_epochs,
                            batch_size,
                            FLAGS.vocab_size,
                            phase="test"):
        for (gloss_fw, gloss_bw, head, _) in epoch:
            gloss_batch_fw = np.array([array for array in gloss_fw],
                                      dtype=np.int32)
            gloss_batch_bw = np.array([array for array in gloss_bw],
                                      dtype=np.int32)
            # use sess.run and feed_dict to get a prediction
            if config.LSTM_type == 'average':
                prediction_ = sess.run(prediction,
                                       feed_dict={
                                           input_node_fw: gloss_batch_fw,
                                           target_node: head
                                       })
            elif config.LSTM_type == 'bidirectional':
                prediction_ = sess.run(prediction,
                                       feed_dict={
                                           input_node_fw: gloss_batch_fw,
                                           input_node_bw: gloss_batch_bw,
                                           target_node: head
                                       })

            correct_word = np.append(correct_word, head, axis=0)
            predictions = np.append(predictions, prediction_, axis=0)

    sims = 1 - np.squeeze(dist.cdist(predictions, embs, metric="cosine"))
    sims = np.nan_to_num(sims)
    vocab, rev_vocab = data_utils_BPE.initialize_vocabulary(FLAGS.vocab_file)
    # create a list of all the real correct words (not IDs)
    real_word = [rev_vocab[idx] for idx in correct_word]
    # find lengths of these words (for crossword clues)
    real_word_len = [len(word) for word in real_word]

    vocab_list = np.empty((0), dtype=int)
    # pred_array is a list of cosine similarity values for all words in vocab,
    for idx, pred_array in enumerate(sims[:400]):
        # find IDs for all words of correct length:
        for word in vocab:
            if len(word) == real_word_len[idx]:
                vocab_list = np.append(vocab_list, [vocab[word]], axis=0)
        correct_length_ids = vocab_list
        ranked_wids = pred_array.argsort()[::-1]
        words = [word for word in ranked_wids if word in correct_length_ids]

        xword_rank = np.where(words == correct_word[idx])
        ranks = np.append(ranks, xword_rank)
        vocab_list = np.empty((0), dtype=int)

    # find rank for definitions (non-crossword clues)
    counter = 400  # cant loop through the idx because it will start at 0 again
    for idx, pred_array in enumerate(sims[400:]):
        rank = np.where(pred_array.argsort()[::-1] == correct_word[counter])
        ranks = np.append(ranks, rank)
        counter += 1
        # find rank for crossword clues

        if check:  # check undeperforming words to see top candidates
            if idx > 250:
                if rank[0] > 90000:
                    check_words.append('HEAD WORD: {}'.format(
                        rev_vocab[correct_word[idx]]))
                    for candidate in pred_array.argsort()[::-1][:10]:
                        check_words.append(rev_vocab[candidate])

    if check:
        print(check_words)

# Test set composition:
# guardian_long[:100],
# guardian_shor[100:200],
# NYT_long[200:300],
# NYT_short[300:400],
# eval_set[400]

    if FLAGS.restore:
        guardian_long = pd.DataFrame({
            'head WID': correct_word[:100],
            'rank': ranks[:100]
        })
        guardian_short = pd.DataFrame({
            'head WID': correct_word[100:200],
            'rank': ranks[100:200]
        })
        NYT_long = pd.DataFrame({
            'head WID': correct_word[200:300],
            'rank': ranks[200:300]
        })
        NYT_short = pd.DataFrame({
            'head WID': correct_word[300:400],
            'rank': ranks[300:400]
        })
        definitions_frame = pd.DataFrame({
            'head WID': correct_word[400:],
            'rank': ranks[400:]
        })
        xword_frame = pd.concat(
            [guardian_long, guardian_short, NYT_long, NYT_short],
            axis=0,
            join='inner')
        xword_frame.to_csv('final_csvs' + FLAGS.save_dir.split('/')[2] +
                           'x_word.csv',
                           sep=',')
        definitions_frame.to_csv('final_csvs' + FLAGS.save_dir.split('/')[2] +
                                 'definitions.csv',
                                 sep=',')
        print(
            'guardian_long median: {}\nguardian_short median: {}\nNYT_long median: {}\nNYT_short median: {}\nDefinitions median: {}'
            .format(np.median(ranks[:100]), np.median(ranks[100:200]),
                    np.median(ranks[200:300]), np.median(ranks[300:400]),
                    np.median(ranks[400:])))
    else:
        with open(outfile, 'a') as f:
            print(
                'guardian_long median: {}\nguardian_short median: {}\nNYT_long median: {}\nNYT_short median: {}\nDefinitions median: {}'
                .format(np.median(ranks[:100]), np.median(ranks[100:200]),
                        np.median(ranks[200:300]), np.median(ranks[300:400]),
                        np.median(ranks[400:])),
                file=f)
def train_network(model,
                  num_epochs,
                  batch_size,
                  data_dir,
                  save_dir,
                  vocab_size,
                  name="model",
                  verbose=True):
    '''
  JP: Training function extended to allow evaluation for each epoch
  '''
    # Running count of the number of training instances.
    num_training = 0
    # saver object for saving the model after each epoch.
    saver = tf.train.Saver()
    with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
        gloss_in_fw, gloss_in_bw, head_in, incorrect_in, total_loss, train_step, _ = model
        # Initialize the model parameters.
        sess.run(tf.global_variables_initializer())
        # Record all training losses for potential reporting.
        training_losses = []
        # epoch is a generator of batches which passes over the data once.
        for idx, epoch in enumerate(
                gen_epochs(data_dir,
                           num_epochs,
                           batch_size,
                           vocab_size,
                           phase="train")):
            # Running total for training loss reset every 500 steps.
            training_loss = 0
            if verbose:
                with open(outfile, 'a') as f:
                    print("\nEPOCH", idx, file=f)

            for step, (gloss_fw, gloss_bw, head,
                       incorrect) in enumerate(epoch):
                # Glosses come out as a list because not all equal lengths, convert to array
                gloss_batch_fw = np.array([array for array in gloss_fw],
                                          dtype=np.int32)
                gloss_batch_bw = np.array([array for array in gloss_bw],
                                          dtype=np.int32)
                num_training += len(gloss_fw)
                if config.LSTM_type == 'average':
                    training_loss_, _ = sess.run(
                        [total_loss, train_step],
                        feed_dict={
                            gloss_in_fw: gloss_batch_fw,
                            head_in: head,
                            incorrect_in: incorrect
                        })
                elif config.LSTM_type == 'bidirectional':
                    training_loss_, _ = sess.run(
                        [total_loss, train_step],
                        feed_dict={
                            gloss_in_fw: gloss_batch_fw,
                            gloss_in_bw: gloss_batch_bw,
                            head_in: head,
                            incorrect_in: incorrect
                        })
                training_loss += training_loss_
                if step % 500 == 0 and step > 0:
                    if verbose:
                        loss_ = training_loss / 500
                        with open(outfile, 'a') as f:
                            print(
                                "Average loss step %s, for last 500 steps: %s"
                                % (step, loss_),
                                file=f)
                    training_losses.append(training_loss / 500)
                    training_loss = 0

            # Save current model after another epoch.
            save_path = os.path.join(save_dir, "%s_%s.ckpt" % (name, idx))
            save_path = saver.save(sess, save_path)
            print("Model saved in file: %s after epoch: %s" % (save_path, idx))

            # JP: run the evaluation routine (for each epoch)
            if FLAGS.evaluate:
                with tf.device("/cpu:0"):
                    with open(outfile, 'a') as f:
                        print('evaluating while training', file=f)
                    if FLAGS.vocab_file is None:
                        vocab_file = os.path.join(
                            FLAGS.data_dir,
                            "definitions_%s.vocab" % FLAGS.vocab_size)
                    else:
                        vocab_file = FLAGS.vocab_file

                    if FLAGS.pretrained_input or FLAGS.pretrained_target:
                        embs_dict, pre_emb_dim = load_pretrained_embeddings(
                            FLAGS.embeddings_path)
                        vocab, _ = data_utils_BPE.initialize_vocabulary(
                            vocab_file)
                        pre_embs = get_embedding_matrix(
                            embs_dict, vocab, pre_emb_dim)

                    out_form = "cosine"
                    graph = tf.get_default_graph()
                    # get the names of input and output tensors
                    input_node_fw = graph.get_tensor_by_name(
                        "fw_input_placeholder:0")
                    input_node_bw = graph.get_tensor_by_name(
                        "bw_input_placeholder:0")

                    target_node = graph.get_tensor_by_name(
                        "labels_placeholder:0")
                    if out_form == "softmax":
                        predictions = graph.get_tensor_by_name("predictions:0")
                    else:
                        predictions = graph.get_tensor_by_name(
                            "fully_connected/Tanh:0")
                    loss = graph.get_tensor_by_name(
                        "total_loss:0")  # not used in evaluation

                    evaluate_model(sess,
                                   FLAGS.data_dir,
                                   input_node_fw,
                                   input_node_bw,
                                   target_node,
                                   predictions,
                                   loss,
                                   embs=pre_embs,
                                   out_form="cosine")

            # Remove older model versions from previous epochs to minimize HDD usage
            if idx > 0:
                os.remove(
                    os.path.join(
                        save_dir,
                        "%s_%s.ckpt.data-00000-of-00001" % (name, idx - 1)))
                os.remove(
                    os.path.join(save_dir,
                                 "%s_%s.ckpt.index" % (name, idx - 1)))
                os.remove(
                    os.path.join(save_dir,
                                 "%s_%s.ckpt.meta" % (name, idx - 1)))
                with open(outfile, 'a') as f:
                    print('deleting old files ',
                          "%s_%s.ckpt" % (name, idx - 1),
                          file=f)

        print("Total data points seen during training: %s" % num_training)
        return save_dir, saver