Beispiel #1
0
def predict(task,
            encoding_scheme,
            embedding_type,
            tf_session,
            batch_size,
            sentence_file,
            mention_idx_file,
            feature_file,
            feature_meta_file,
            scores_file=None,
            log=None):
    global CLASSES_CARD, CLASSES_VISUAL

    classes = None
    if task == 'nonvis':
        classes = CLASSES_VISUAL
    elif task == 'card':
        classes = CLASSES_CARD
    n_classes = len(classes)

    # Load the data
    log.info("Loading data from " + sentence_file + " and " + mention_idx_file)
    data_dict = nn_data.load_sentences(sentence_file, embedding_type)
    data_dict.update(
        nn_data.load_mentions(mention_idx_file, task, feature_file,
                              feature_meta_file, n_classes))

    # Get the predicted scores, given our arguments
    mentions = data_dict['mention_indices'].keys()
    pred_scores, gold_label_dict = \
        nn_util.get_pred_scores_mcc(task, encoding_scheme, tf_session,
                                    batch_size, mentions, data_dict,
                                    n_classes, log)

    # If we do an argmax on the scores, we get the predicted labels
    pred_labels = list()
    gold_labels = list()
    for m in mentions:
        pred_labels.append(np.argmax(pred_scores[m]))
        gold_labels.append(np.argmax(gold_label_dict[m]))
    #endfor

    # Evaluate the predictions
    nn_eval.evaluate_multiclass(gold_labels, pred_labels, classes, log)

    # If a scores file was specified, write the scores
    if scores_file is not None:
        log.info("Writing scores file to " + scores_file)
        with open(scores_file, 'w') as f:
            for pair_id in pred_scores.keys():
                score_line = list()
                score_line.append(pair_id)
                for score in pred_scores[pair_id]:
                    if score == 0:
                        score = np.nextafter(0, 1)
                    score_line.append(str(np.log(score)))
                f.write(",".join(score_line) + "\n")
            f.close()
def predict(encoding_scheme,
            embedding_type,
            tf_session,
            batch_size,
            sentence_file,
            mention_idx_file,
            feature_file,
            feature_meta_file,
            box_dir,
            mention_box_label_file,
            box_category_file=None,
            scores_file=None,
            log=None):
    """

    :param encoding_scheme:
    :param embedding_type:
    :param tf_session:
    :param batch_size:
    :param sentence_file:
    :param mention_idx_file:
    :param feature_file:
    :param feature_meta_file:
    :param box_dir:
    :param mention_box_label_file:
    :param scores_file:
    :param log:
    :return:
    """
    global CLASSES, task
    n_classes = len(CLASSES)

    # Load the data
    log.info("Loading data")
    data_dict = nn_data.load_sentences(sentence_file, embedding_type)
    data_dict.update(
        nn_data.load_mentions(mention_idx_file, task, feature_file,
                              feature_meta_file, n_classes))
    data_dict.update(
        nn_data.load_boxes(mention_box_label_file, box_dir, box_category_file))

    # Get the predicted scores, given our arguments
    mention_box_pairs = get_valid_mention_box_pairs(data_dict)
    pred_scores, gold_label_dict = \
        nn_util.get_pred_scores_mcc(task, encoding_scheme, tf_session,
                                    batch_size, mention_box_pairs,
                                    data_dict, n_classes, log)

    # If we do an argmax on the scores, we get the predicted labels
    mentions = list(pred_scores.keys())
    pred_labels = list()
    gold_labels = list()
    for m in mentions:
        pred_labels.append(np.argmax(pred_scores[m]))
        gold_labels.append(np.argmax(gold_label_dict[m]))
    #endfor

    # Evaluate the predictions
    nn_eval.evaluate_multiclass(gold_labels, pred_labels, CLASSES, log)

    # If a scores file was specified, write the scores
    log.info("Writing scores file")
    if scores_file is not None:
        with open(scores_file, 'w') as f:
            for pair_id in pred_scores.keys():
                score_line = list()
                score_line.append(pair_id)
                for score in pred_scores[pair_id]:
                    if score == 0:
                        score = np.nextafter(0, 1)
                    score_line.append(str(np.log(score)))
                f.write(",".join(score_line) + "\n")
            f.close()
Beispiel #3
0
def train(task,
          encoding_scheme,
          embedding_type,
          sentence_file,
          mention_idx_file,
          feature_file,
          feature_meta_file,
          epochs,
          batch_size,
          lstm_hidden_width,
          start_hidden_width,
          hidden_depth,
          weighted_classes,
          lstm_input_dropout,
          dropout,
          lrn_rate,
          adam_epsilon,
          clip_norm,
          data_norm,
          activation,
          model_file=None,
          eval_sentence_file=None,
          eval_mention_idx_file=None,
          eval_feature_file=None,
          eval_feature_meta_file=None,
          early_stopping=False,
          log=None):
    """
    Trains a nonvis or cardinality model

    :param task: {nonvis, card}
    :param encoding_scheme: {first_last_sentence, first_last_mention}
    :param embedding_type: {w2v, glove}
    :param sentence_file: File with captions
    :param mention_idx_file: File with mention pair word indices
    :param feature_file: File with sparse mention pair features
    :param feature_meta_file: File associating sparse indices with feature names
    :param epochs: Number of times to run over the data
    :param batch_size: Number of mention pairs to run each batch
    :param lstm_hidden_width: Number of hidden units in the lstm cells
    :param start_hidden_width: Number of hidden units to which the mention pairs'
                               representation is passed
    :param hidden_depth: Number of hidden layers after the lstm
    :param weighted_classes: Whether to weight the examples by their
                             class inversely with the frequency of
                             that class
    :param lstm_input_dropout: Probability to keep for lstm inputs
    :param dropout: Probability to keep for all other nodes
    :param lrn_rate: Learning rate of the optimizer
    :param clip_norm: Global gradient clipping norm
    :param adam_epsilon: Adam optimizer epsilon value
    :param activation: Nonlinear activation function (sigmoid,tanh,relu)
    :param model_file: File to which the model is periodically saved
    :param eval_sentence_file: Sentence file against which the model
                               should be evaluated
    :param eval_mention_idx_file: Mention index file against which
                                  the model should be evaluated
    :return:
    """
    global CLASSES_CARD, CLASSES_VISUAL

    # Retrieve the correct set of classes
    classes = None
    if task == 'nonvis':
        classes = CLASSES_VISUAL
    elif task == 'card':
        classes = CLASSES_CARD
    n_classes = len(classes)

    log.info("Loading data from " + sentence_file + " and " + mention_idx_file)
    data_dict = nn_data.load_sentences(sentence_file, embedding_type)
    data_dict.update(
        nn_data.load_mentions(mention_idx_file, task, feature_file,
                              feature_meta_file, n_classes))

    log.info("Loading data from " + eval_sentence_file + " and " +
             eval_mention_idx_file)
    eval_data_dict = nn_data.load_sentences(eval_sentence_file, embedding_type)
    eval_data_dict.update(
        nn_data.load_mentions(eval_mention_idx_file, task, eval_feature_file,
                              eval_feature_meta_file, n_classes))

    mentions = list(data_dict['mention_indices'].keys())
    n_pairs = len(mentions)

    log.info("Setting up network architecture")
    with tf.variable_scope('bidirectional_lstm'):
        nn_util.setup_bidirectional_lstm(lstm_hidden_width, data_norm)
    nn_util.setup_core_architecture(task, encoding_scheme, batch_size,
                                    start_hidden_width, hidden_depth,
                                    weighted_classes, activation, n_classes,
                                    data_dict['n_mention_feats'])
    loss = tf.get_collection('loss')[0]
    accuracy = tf.get_collection('accuracy')[0]
    nn_util.add_train_op(loss, lrn_rate, adam_epsilon, clip_norm)
    train_op = tf.get_collection('train_op')[0]
    nn_util.dump_tf_vars()

    # We want to keep track of the best scores with
    # the epoch that they originated from
    best_avg_score = -1
    best_epoch = -1

    log.info("Training")
    saver = tf.train.Saver(max_to_keep=100)
    with tf.Session() as sess:
        # Initialize all our variables
        sess.run(tf.global_variables_initializer())

        # Iterate through the data [epochs] number of times
        for i in range(0, epochs):
            log.info(None, "--- Epoch %d ----", i + 1)
            losses = list()
            accuracies = list()

            # Shuffle the data once for this epoch
            np.random.shuffle(mentions)

            # Iterate through the entirety of the data
            start_idx = 0
            end_idx = start_idx + batch_size
            n_iter = n_pairs / batch_size
            for j in range(0, n_iter):
                log.log_status('info', None,
                               'Training; %d (%.2f%%) batches complete', j,
                               100.0 * j / n_iter)

                # Retrieve this batch
                batch_mentions = mentions[start_idx:end_idx]
                batch_tensors = nn_data.load_batch(batch_mentions, data_dict,
                                                   task, n_classes)

                # Train
                nn_util.run_op(sess, train_op, [batch_tensors],
                               lstm_input_dropout, dropout, encoding_scheme,
                               [task], [""], True)

                # Store the losses and accuracies every 100 batches
                if (j + 1) % 100 == 0:
                    losses.append(
                        nn_util.run_op(sess, loss, [batch_tensors],
                                       lstm_input_dropout, dropout,
                                       encoding_scheme, [task], [""], True))
                    accuracies.append(
                        nn_util.run_op(sess, accuracy, [batch_tensors],
                                       lstm_input_dropout, dropout,
                                       encoding_scheme, [task], [""], True))
                #endif
                start_idx = end_idx
                end_idx = start_idx + batch_size
            #endfor

            # Every epoch, evaluate and save the model
            log.info(None, "Saving model; Average Loss: %.2f; Acc: %.2f%%",
                     sum(losses) / float(len(losses)),
                     100.0 * sum(accuracies) / float(len(accuracies)))
            saver.save(sess, model_file)
            if (
                    i + 1
            ) % 10 == 0 and eval_sentence_file is not None and eval_mention_idx_file is not None:
                eval_mentions = eval_data_dict['mention_indices'].keys()
                pred_scores, gold_label_dict = \
                    nn_util.get_pred_scores_mcc(task, encoding_scheme,
                                                sess, batch_size, eval_mentions,
                                                eval_data_dict, n_classes, log)

                # If we do an argmax on the scores, we get the predicted labels
                eval_mentions = list(pred_scores.keys())
                pred_labels = list()
                gold_labels = list()
                for m in eval_mentions:
                    pred_labels.append(np.argmax(pred_scores[m]))
                    gold_labels.append(np.argmax(gold_label_dict[m]))
                #endfor

                # Evaluate the predictions
                score_dict = nn_eval.evaluate_multiclass(
                    gold_labels, pred_labels, classes, log)

                # Get the current scores and see if their average beats our best
                # by half a point (if we're stopping early)
                avg = score_dict.get_score(0).f1 + score_dict.get_score(1).f1
                avg /= 2.0
                if avg >= best_avg_score - 0.005:
                    log.info(
                        None,
                        "Previous best score average F1 of %.2f%% after %d epochs",
                        100.0 * best_avg_score, best_epoch)
                    best_avg_score = avg
                    best_epoch = i
                    log.info(None, "New best at current epoch (%.2f%%)",
                             100.0 * best_avg_score)
                #endif

                # Implement early stopping; if it's been 10 epochs since our best, stop
                if early_stopping and i >= (best_epoch + 10):
                    log.info(None, "Stopping early; best scores at %d epochs",
                             best_epoch)
                    break
                #endif
            #endif
        #endfor
        log.info("Saving final model")
        saver.save(sess, model_file)
def predict(rel_type,
            encoding_scheme,
            embedding_type,
            tf_session,
            batch_size,
            sentence_file,
            mention_idx_file,
            feature_file,
            feature_meta_file,
            label_file,
            scores_file=None,
            ordered_pairs=False,
            log=None):
    """
    Wrapper for making predictions on a pre-trained model, already loaded into
    the session
    :param rel_type:
    :param encoding_scheme:
    :param embedding_type:
    :param tf_session:
    :param batch_size:
    :param sentence_file:
    :param mention_idx_file:
    :param feature_file:
    :param feature_meta_file:
    :param label_file:
    :param scores_file:
    :return:
    """
    global CLASSES
    n_classes = len(CLASSES)
    task = "rel_" + rel_type

    # Load the data
    log.info("Loading data from " + sentence_file + " and " + mention_idx_file)
    data_dict = nn_data.load_sentences(sentence_file, embedding_type)
    data_dict.update(
        nn_data.load_mentions(mention_idx_file, task, feature_file,
                              feature_meta_file, n_classes))

    # Get the predicted scores, given our arguments
    log.info("Predictiong scores")
    mention_pairs = data_dict['mention_indices'].keys()
    pred_scores, _ = nn_util.get_pred_scores_mcc(task, encoding_scheme,
                                                 tf_session, batch_size,
                                                 mention_pairs, data_dict,
                                                 n_classes, log)
    if ordered_pairs:
        pred_scores = induce_ji_predictions(pred_scores)

    log.info("Loading data from " + label_file)
    gold_label_dict = nn_data.load_relation_labels(label_file)

    # If we do an argmax on the scores, we get the predicted labels
    log.info("Getting labels from scores")
    pred_labels = list()
    for pair in pred_scores.keys():
        pred_labels.append(np.argmax(pred_scores[pair]))

    # Evaluate the predictions
    log.info("Evaluating against the gold")
    nn_eval.evaluate_relations(pred_scores.keys(), pred_labels,
                               gold_label_dict, log)

    # If a scores file was specified, write the scores
    log.info("Writing scores file " + scores_file)
    if scores_file is not None:
        with open(scores_file, 'w') as f:
            for pair_id in pred_scores.keys():
                score_line = list()
                score_line.append(pair_id)
                for score in pred_scores[pair_id]:
                    if score == 0:
                        score = np.nextafter(0, 1)
                    score_line.append(str(np.log(score)))
                f.write(",".join(score_line) + "\n")
            f.close()
def train(rel_type,
          encoding_scheme,
          embedding_type,
          sentence_file,
          mention_idx_file,
          feature_file,
          feature_meta_file,
          epochs,
          batch_size,
          lstm_hidden_width,
          start_hidden_width,
          hidden_depth,
          weighted_classes,
          lstm_input_dropout,
          dropout,
          lrn_rate,
          adam_epsilon,
          clip_norm,
          data_norm,
          activation,
          model_file=None,
          eval_sentence_file=None,
          eval_mention_idx_file=None,
          eval_feature_file=None,
          eval_feature_meta_file=None,
          eval_label_file=None,
          early_stopping=False,
          ordered_pairs=False,
          log=None):
    """
    Trains a relation model

    :param sentence_file: File with captions
    :param mention_idx_file: File with mention pair word indices
    :param feature_file: File with sparse mention pair features
    :param feature_meta_file: File associating sparse indices with feature names
    :param epochs: Number of times to run over the data
    :param batch_size: Number of mention pairs to run each batch
    :param lstm_hidden_width: Number of hidden units in the lstm cells
    :param start_hidden_width: Number of hidden units to which the mention pairs'
                               representation is passed
    :param hidden_depth: Number of hidden layers after the lstm
    :param weighted_classes: Whether to weight the examples by their
                             class inversely with the frequency of
                             that class
    :param lstm_input_dropout: Probability to keep for lstm inputs
    :param dropout: Probability to keep for all other nodes
    :param lrn_rate: Learning rate of the optimizer
    :param clip_norm: Global gradient clipping norm
    :param adam_epsilon: Adam optimizer epsilon value
    :param activation: Nonlinear activation function (sigmoid,tanh,relu)
    :param model_file: File to which the model is periodically saved
    :param eval_sentence_file: Sentence file against which the model
                               should be evaluated
    :param eval_mention_idx_file: Mention index file against which
                                  the model should be evaluated
    :param eval_label_file: Relation label file for eval data
    :return:
    """
    global CLASSES

    task = "rel_" + rel_type
    n_classes = len(CLASSES)

    log.info("Loading data from " + sentence_file + " and " + mention_idx_file)
    data_dict = nn_data.load_sentences(sentence_file, embedding_type)
    data_dict.update(
        nn_data.load_mentions(mention_idx_file, task, feature_file,
                              feature_meta_file, n_classes))
    log.info("Loading data from " + eval_sentence_file + " and " +
             eval_mention_idx_file)
    eval_data_dict = nn_data.load_sentences(eval_sentence_file, embedding_type)
    eval_data_dict.update(
        nn_data.load_mentions(eval_mention_idx_file, task, eval_feature_file,
                              eval_feature_meta_file, n_classes))
    mentions = list(data_dict['mention_indices'].keys())
    n_pairs = len(mentions)

    # Load the gold labels from the label file once, and we can just reuse them every epoch
    gold_label_dict = nn_data.load_relation_labels(eval_label_file)

    # We want to keep track of the best coref and subset scores, along
    # with the epoch that they originated from
    best_coref_subset_avg = -1
    best_coref_subset_epoch = -1

    log.info("Setting up network architecture")

    # Set up the bidirectional LSTM
    with tf.variable_scope('bidirectional_lstm'):
        nn_util.setup_bidirectional_lstm(lstm_hidden_width, data_norm)
    nn_util.setup_core_architecture(task, encoding_scheme, batch_size,
                                    start_hidden_width, hidden_depth,
                                    weighted_classes, activation, n_classes,
                                    data_dict['n_mention_feats'])
    loss = tf.get_collection('loss')[0]
    accuracy = tf.get_collection('accuracy')[0]
    nn_util.add_train_op(loss, lrn_rate, adam_epsilon, clip_norm)
    train_op = tf.get_collection('train_op')[0]
    nn_util.dump_tf_vars()

    log.info("Training")
    saver = tf.train.Saver(max_to_keep=100)
    with tf.Session() as sess:
        # Initialize all our variables
        sess.run(tf.global_variables_initializer())

        # Iterate through the data [epochs] number of times
        for i in range(0, epochs):
            log.info(None, "--- Epoch %d ----", i + 1)
            losses = list()
            accuracies = list()

            # Shuffle the data once for this epoch
            np.random.shuffle(mentions)

            # Iterate through the entirety of the data
            start_idx = 0
            end_idx = start_idx + batch_size
            n_iter = n_pairs / batch_size
            for j in range(0, n_iter):
                log.log_status('info', None,
                               'Training; %d (%.2f%%) batches complete', j,
                               100.0 * j / n_iter)

                # Retrieve this batch
                batch_mentions = mentions[start_idx:end_idx]
                batch_tensors = nn_data.load_batch(batch_mentions, data_dict,
                                                   task, n_classes)

                # Train
                nn_util.run_op(sess, train_op, [batch_tensors],
                               lstm_input_dropout, dropout, encoding_scheme,
                               [task], [""], True)

                # Store the losses and accuracies every 100 batches
                if (j + 1) % 100 == 0:
                    losses.append(
                        nn_util.run_op(sess, loss, [batch_tensors],
                                       lstm_input_dropout, dropout,
                                       encoding_scheme, [task], [""], True))
                    accuracies.append(
                        nn_util.run_op(sess, accuracy, [batch_tensors],
                                       lstm_input_dropout, dropout,
                                       encoding_scheme, [task], [""], True))
                #endif
                start_idx = end_idx
                end_idx = start_idx + batch_size
            #endfor

            # Every epoch, evaluate and save the model
            log.info(None, "Saving model; Average Loss: %.2f; Acc: %.2f%%",
                     sum(losses) / float(len(losses)),
                     100.0 * sum(accuracies) / float(len(accuracies)))
            saver.save(sess, model_file)
            if (
                    i + 1
            ) % 10 == 0 and eval_sentence_file is not None and eval_mention_idx_file is not None:
                # We want to predict over all mentions unless this is our weird
                # ij intra caption case, in which case we want predictions
                # only for the ij pairs
                eval_mention_pairs = eval_data_dict['mention_indices'].keys()
                if ordered_pairs:
                    eval_mention_pairs = get_ij_pairs(eval_mention_pairs)

                # Predict scores
                pred_scores, _ = nn_util.get_pred_scores_mcc(
                    task, encoding_scheme, sess, batch_size,
                    eval_mention_pairs, eval_data_dict, n_classes, log)

                # If this is our ij intra case, we need to induce scores
                # for ji pairs and reset what we consider as the complete set
                # of mention pairs
                if ordered_pairs:
                    pred_scores = induce_ji_predictions(pred_scores)
                    eval_mention_pairs = eval_data_dict[
                        'mention_indices'].keys()

                pred_labels = list()
                for pair in eval_mention_pairs:
                    pred_labels.append(np.argmax(pred_scores[pair]))

                # Evaluate the predictions
                score_dict = \
                    nn_eval.evaluate_relations(eval_mention_pairs, pred_labels,
                                               gold_label_dict)

                # Get the current coref / subset and see if their average beats our best
                coref_subset_avg = score_dict.get_score('coref').f1 + \
                                   score_dict.get_score('subset').f1
                coref_subset_avg /= 2.0
                if coref_subset_avg >= best_coref_subset_avg - 0.005:
                    log.info(
                        None,
                        "Previous best coref/subset average F1 of %.2f%% after %d epochs",
                        100.0 * best_coref_subset_avg, best_coref_subset_epoch)
                    best_coref_subset_avg = coref_subset_avg
                    best_coref_subset_epoch = i
                    log.info(None, "New best at current epoch (%.2f%%)",
                             100.0 * best_coref_subset_avg)
                #endif

                # Implement early stopping; if it's been 10 epochs since our best, stop
                if early_stopping and i >= (best_coref_subset_epoch + 10):
                    log.info(None, "Stopping early; best scores at %d epochs",
                             best_coref_subset_epoch)
                    break
                #endif
            #endif
        #endfor
        log.info("Saving final model")
        saver.save(sess, model_file)
def predict(sess,
            multitask_scheme,
            encoding_scheme,
            task_data_dicts,
            task_ids,
            batch_size=None,
            task_batch_sizes=None,
            log=None):
    """

    :param sess:
    :param multitask_scheme:
    :param encoding_scheme:
    :param task_data_dicts:
    :param task_ids:
    :param batch_size:
    :param task_batch_sizes:
    :param log:
    :return:
    """
    global TASKS, TASK_CLASS_DICT

    for task in TASKS:
        eval_ids = task_ids[task]
        task_batch_size = batch_size
        if multitask_scheme == 'alternate':
            task_batch_size = task_batch_sizes[task]
        with tf.variable_scope(task):
            pred_scores, gold_label_dict = \
                nn_util.get_pred_scores_mcc(task, encoding_scheme, sess,
                                            task_batch_size, eval_ids,
                                            task_data_dicts[task],
                                            len(TASK_CLASS_DICT[task]), log)

            # If we do an argmax on the scores, we get the predicted labels
            mentions = list(pred_scores.keys())
            pred_labels = list()
            gold_labels = list()
            for m in mentions:
                pred_labels.append(np.argmax(pred_scores[m]))
                if 'rel' not in task:
                    gold_labels.append(np.argmax(gold_label_dict[m]))
            #endfor

            # Evaluate the predictions
            if 'rel' in task:
                nn_eval.evaluate_relations(
                    eval_ids, pred_labels,
                    task_data_dicts[task]['gold_label_dict'], log)
            else:
                nn_eval.evaluate_multiclass(gold_labels, pred_labels,
                                            TASK_CLASS_DICT[task], log)
            #endif

            # Write the scores file for this task
            log.info('Writing scores file for ' + task)
            with open(task_data_dicts[task]['scores_file'], 'w') as f:
                for id in pred_scores.keys():
                    score_line = list()
                    score_line.append(id)
                    for score in pred_scores[id]:
                        if score == 0:
                            score = np.nextafter(0, 1)
                        score_line.append(str(np.log(score)))
                    f.write(",".join(score_line) + "\n")
                f.close()
def train_alternately(epochs,
                      task_batch_sizes,
                      lstm_input_dropout,
                      dropout,
                      lrn_rate,
                      adam_epsilon,
                      clip_norm,
                      encoding_scheme,
                      task_vars,
                      task_data_dicts,
                      eval_task_data_dicts,
                      task_ids,
                      eval_task_ids,
                      model_file,
                      log=None):
    """

    :param epochs:
    :param task_batch_sizes:
    :param lstm_input_dropout:
    :param dropout:
    :param lrn_rate:
    :param adam_epsilon:
    :param clip_norm:
    :param encoding_scheme:
    :param task_vars:
    :param task_data_dicts:
    :param eval_task_data_dicts:
    :param task_ids:
    :param eval_task_ids:
    :param model_file:
    :param log:
    :return:
    """
    global TASKS

    # We're going to retrieve and optimate each loss individually
    train_ops = dict()
    for task in TASKS:
        loss = task_vars[task]['loss']
        with tf.variable_scope(task):
            nn_util.add_train_op(loss, lrn_rate, adam_epsilon, clip_norm)
        train_ops[task] = tf.get_collection(task + '/train_op')[0]
    #endfor

    # TODO: Implement early stopping under this framework
    log.info("Training")
    saver = tf.train.Saver(max_to_keep=100)
    with tf.Session() as sess:
        # Initialize all our variables
        sess.run(tf.global_variables_initializer())

        # Iterate through the data [epochs] number of times
        for i in range(0, epochs):
            log.info(None, "--- Epoch %d ----", i + 1)

            # Create a list of (<task>, <batch_id_arr>)
            # tuples for this epoch
            batch_ids = list()
            for task in TASKS:
                ids = task_ids[task]
                if task == 'affinity':
                    ids = shuffle_mention_box_pairs(ids)
                pad_length = task_batch_sizes[task] * \
                             (len(ids) / task_batch_sizes[task] + 1) - len(ids)
                id_arr = np.pad(ids, (0, pad_length), 'wrap')
                id_matrix = np.reshape(id_arr, [-1, task_batch_sizes[task]])
                for row_idx in range(0, id_matrix.shape[0]):
                    batch_ids.append((task, id_matrix[row_idx]))
            #endfor

            # Shuffle that list, feeding examples for whatever
            # task and whatever ids we have
            np.random.shuffle(batch_ids)
            for j in range(0, len(batch_ids)):
                log.log_status('info', None,
                               "Completed %d (%.2f%%) iterations", j,
                               100.0 * j / len(batch_ids))
                task, ids = batch_ids[j]
                batch_tensor = \
                    nn_data.load_batch(ids, task_data_dicts[task],
                                       task, len(TASK_CLASS_DICT[task]))

                # Run the operation for this task
                nn_util.run_op(sess, train_ops[task], [batch_tensor],
                               lstm_input_dropout, dropout, encoding_scheme,
                               [task], [task], True)
            #endfor

            # Every epoch, evaluate and save the model
            log.info(None, "Saving model")
            saver.save(sess, model_file)
            for task in TASKS:
                eval_ids = eval_task_ids[task]
                with tf.variable_scope(task):
                    pred_scores, gold_label_dict = \
                        nn_util.get_pred_scores_mcc(task, encoding_scheme, sess,
                                                    task_batch_sizes[task], eval_ids,
                                                    eval_task_data_dicts[task],
                                                    len(TASK_CLASS_DICT[task]), log)

                # If we do an argmax on the scores, we get the predicted labels
                pred_labels = list()
                gold_labels = list()
                for m in eval_ids:
                    pred_labels.append(np.argmax(pred_scores[m]))
                    if 'rel' not in task:
                        gold_labels.append(np.argmax(gold_label_dict[m]))
                #endfor

                # Evaluate the predictions
                if 'rel' in task:
                    nn_eval.evaluate_relations(
                        eval_ids, pred_labels,
                        eval_task_data_dicts[task]['gold_label_dict'], log)
                else:
                    nn_eval.evaluate_multiclass(gold_labels, pred_labels,
                                                TASK_CLASS_DICT[task], log)
                #endif
            #endfor
        #endfor

        log.info("Saving final model")
        saver.save(sess, model_file)
def train_jointly(multitask_scheme,
                  epochs,
                  batch_size,
                  lstm_input_dropout,
                  dropout,
                  lrn_rate,
                  adam_epsilon,
                  clip_norm,
                  encoding_scheme,
                  task_vars,
                  task_data_dicts,
                  eval_task_data_dicts,
                  task_ids,
                  eval_task_ids,
                  model_file,
                  log=None):
    """
    Trains a joint model, either using a simple sum of losses
    or with weights over losses

    :param multitask_scheme:
    :param epochs:
    :param batch_size:
    :param lstm_input_dropout:
    :param dropout:
    :param lrn_rate:
    :param adam_epsilon:
    :param clip_norm:
    :param encoding_scheme:
    :param task_vars:
    :param task_data_dicts:
    :param eval_task_data_dicts:
    :param task_ids:
    :param eval_task_ids:
    :param model_file:
    :param log:
    :return:
    """
    global TASKS

    # We either have a simple sum-of-losses model or we're learning
    # weights over those losses
    # tf.stack allows us to vector-ize scalars, and since
    # the ffw function assumes a kind of [batch_size, units]
    # shape, we expand the first dimension to 1;
    # We just do this up front because reduce sum doesn't care, so
    # we can use it in both branches
    losses = list()
    for task in TASKS:
        losses.append(task_vars[task]['loss'])
    tf_losses = tf.expand_dims(tf.stack(losses), 0)
    if multitask_scheme == "simple_joint":
        joint_loss = tf.reduce_sum(tf_losses)
        nn_util.add_train_op(joint_loss, lrn_rate, adam_epsilon, clip_norm)
    elif multitask_scheme == "weighted_joint":
        joint_loss = tf.reduce_sum(nn_util.setup_ffw(tf_losses, [5]))
        nn_util.add_train_op(joint_loss, lrn_rate, adam_epsilon, clip_norm)
    #endif
    train_op = tf.get_collection('train_op')[0]

    # TODO: Implement early stopping under this framework
    log.info("Training")
    saver = tf.train.Saver(max_to_keep=100)
    with tf.Session() as sess:
        # Initialize all our variables
        sess.run(tf.global_variables_initializer())

        # Iterate through the data [epochs] number of times
        for i in range(0, epochs):
            log.info(None, "--- Epoch %d ----", i + 1)

            # Shuffle everyone's IDs for this epoch
            max_samples = 0
            sample_indices = dict()
            for task in TASKS:
                if task == 'affinity':
                    task_ids[task] = shuffle_mention_box_pairs(task_ids[task])
                else:
                    np.random.shuffle(task_ids[task])
                max_samples = max(max_samples, len(task_ids[task]))
                sample_indices[task] = 0
            #endfor

            # We iterate until we've seen _every_ tasks' samples
            # at least once
            n_iter = max_samples / batch_size
            for j in range(0, n_iter):
                log.log_status('info', None,
                               "Completed %d (%.2f%%) iterations", j,
                               100.0 * j / n_iter)
                # For each task, get the next [batch_size] samples,
                # wrapping around to the beginning of the list if
                # necessary
                batch_tensor_dicts = list()
                for task in TASKS:
                    ids = task_ids[task]
                    n_ids = len(ids)
                    start_idx = sample_indices[task]
                    if start_idx + batch_size < n_ids:
                        batch_ids = ids[start_idx:start_idx + batch_size]
                        sample_indices[task] += batch_size
                    else:
                        remainder = start_idx + batch_size - n_ids + 1
                        batch_ids = ids[start_idx:n_ids - 1]
                        batch_ids.extend(ids[0:remainder])
                        sample_indices[task] = remainder
                    #endif
                    batch_tensor_dicts.append(
                        nn_data.load_batch(batch_ids, task_data_dicts[task],
                                           task, len(TASK_CLASS_DICT[task])))
                #endfor

                # It so happens that I'm using task names as variable
                # namespaces, which is why I'm passing them twice in the
                # operations, below
                nn_util.run_op(sess, train_op, batch_tensor_dicts,
                               lstm_input_dropout, dropout, encoding_scheme,
                               TASKS, TASKS, True)
            #endfor

            # Every epoch, evaluate and save the model
            log.info(None, "Saving model")
            saver.save(sess, model_file)

            for task in TASKS:
                eval_ids = eval_task_ids[task]
                with tf.variable_scope(task):
                    pred_scores, gold_label_dict = \
                        nn_util.get_pred_scores_mcc(task, encoding_scheme, sess,
                                                    batch_size, eval_ids,
                                                    eval_task_data_dicts[task],
                                                    len(TASK_CLASS_DICT[task]), log)

                # If we do an argmax on the scores, we get the predicted labels
                pred_labels = list()
                gold_labels = list()
                for m in eval_ids:
                    pred_labels.append(np.argmax(pred_scores[m]))
                    if 'rel' not in task:
                        gold_labels.append(np.argmax(gold_label_dict[m]))
                #endfor

                # Evaluate the predictions
                if 'rel' in task:
                    nn_eval.evaluate_relations(
                        eval_ids, pred_labels,
                        eval_task_data_dicts[task]['gold_label_dict'], log)
                else:
                    nn_eval.evaluate_multiclass(gold_labels, pred_labels,
                                                TASK_CLASS_DICT[task], log)
                #endif

            #endfor
        #endfor

        log.info("Saving final model")
        saver.save(sess, model_file)