コード例 #1
0
        id2token = index["id2token"]
        label2id = index["label2id"]
        id2label = index["id2label"]
        num_tokens = len(token2id)
        num_labels = len(label2id)

    print("-- Loading test set")
    test_labels, test_padded_premises, test_padded_hypotheses, test_img_names, test_original_premises, test_original_hypotheses = \
        load_vte_dataset(
            args.test_filename,
            token2id,
            label2id
        )

    print("-- Loading images")
    image_reader = ImageReader(args.img_names_filename, args.img_features_filename)

    print("-- Restoring model")
    premise_input = tf.placeholder(tf.int32, (None, None), name="premise_input")
    hypothesis_input = tf.placeholder(tf.int32, (None, None), name="hypothesis_input")
    img_features_input = tf.placeholder(tf.float32, (None, params["num_img_features"], params["img_features_size"]),
                                        name="img_features_input")
    label_input = tf.placeholder(tf.int32, (None,), name="label_input")
    dropout_input = tf.placeholder(tf.float32, name="dropout_input")
    logits = build_bottom_up_top_down_vte_model_hi(
        premise_input,
        hypothesis_input,
        img_features_input,
        dropout_input,
        num_tokens,
        num_labels,
コード例 #2
0
ファイル: train_mt_model.py プロジェクト: hoavt-54/nli-images
        load_vte_dataset(
            args.vte_train_filename,
            token2id,
            vte_label2id
        )

    print("-- Loading development set")
    vte_dev_labels, vte_dev_premises, vte_dev_hypotheses, vte_dev_img_names, _, _ =\
        load_vte_dataset(
            args.vte_dev_filename,
            token2id,
            vte_label2id
        )

    print("-- Loading images")
    ic_image_reader = ImageReader(args.ic_img_names_filename,
                                  args.ic_img_features_filename)

    print("-- Loading images")
    vte_image_reader = ImageReader(args.vte_img_names_filename,
                                   args.vte_img_features_filename)

    sentence_input = tf.placeholder(tf.int32, (None, None),
                                    name="sentence_input")
    premise_input = tf.placeholder(tf.int32, (None, None),
                                   name="premise_input")
    hypothesis_input = tf.placeholder(tf.int32, (None, None),
                                      name="hypothesis_input")
    img_features_input = tf.placeholder(
        tf.float32, (None, args.num_img_features, args.img_features_size),
        name="img_features_input")
    ic_label_input = tf.placeholder(tf.int32, (None, ), name="label_input")
コード例 #3
0
def main(_):

    random_seed = 12345
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.set_random_seed(random_seed)

    start_logger(FLAGS.model_save_filename + ".train_log")
    atexit.register(stop_logger)

    print("-- Building vocabulary")
    #embeddings, token2id, id2token = load_glove(args.vectors_filename, args.max_vocab, args.embeddings_size)

    label2id = {"neutral": 0, "entailment": 1, "contradiction": 2}
    id2label = {v: k for k, v in label2id.items()}

    #num_tokens = len(token2id)

    num_labels = len(label2id)

    #print("Number of tokens: {}".format(num_tokens))
    print("Number of labels: {}".format(num_labels))

    # Load e_vsnli
    # Explanations are encoded/padded, we ignore original explanations
    print("-- Loading training set")

    train_labels, train_explanations, train_premises, train_hypotheses, train_img_names, _, _, _, train_max_length, embeddings, token2id, id2token, _ = \
        load_e_vsnli_dataset_and_glove(
            FLAGS.train_filename,
            label2id,
            FLAGS.vectors_filename,
            FLAGS.max_vocab,
            model_config.embedding_size,
            buffer_size=FLAGS.buffer_size,
            min_threshold = FLAGS.min_threshold,
        )

    num_tokens = len(token2id)
    print("Number of tokens after filtering: ", num_tokens)

    print("-- Loading development set")
    dev_labels, dev_explanations, dev_premises, dev_hypotheses, dev_img_names, dev_original_explanations, _, _, dev_max_length, _ = \
        load_e_vsnli_dataset(
            FLAGS.dev_filename,
            token2id,
            label2id,
            buffer_size=FLAGS.buffer_size,
            padding_length=train_max_length,
        )

    if FLAGS.imbalance == True:
        dev_num_examples = dev_labels.shape[0]
        class_freqs = np.bincount(dev_labels) / dev_num_examples
        class_weights = 1 / (class_freqs * num_labels)
        print("Class frequencies: ", class_freqs)
        print("Weights: ", class_weights)
        np.save(FLAGS.model_save_filename + '_class_freqs.npy', class_freqs)
    print("-- Loading images")
    image_reader = ImageReader(FLAGS.img_names_filename,
                               FLAGS.img_features_filename)

    print("-- Saving parameters")
    with open(FLAGS.model_save_filename + ".params", mode="w") as out_file:
        json.dump(vars(FLAGS), out_file)
        print("Params saved to: {}".format(FLAGS.model_save_filename +
                                           ".params"))

        with open(FLAGS.model_save_filename + ".index", mode="wb") as out_file:
            pickle.dump(
                {
                    "token2id": token2id,
                    "id2token": id2token,
                    "label2id": label2id,
                    "id2label": id2label
                }, out_file)
            print("Index saved to: {}".format(FLAGS.model_save_filename +
                                              ".index"))

    model_config.set_vocab_size(num_tokens)
    print("Vocab size, set to %d" % model_config.vocab_size)
    model_config.set_alpha(FLAGS.alpha)
    print("alpha = %f, set!" % model_config.alpha)

    ilabel2itoken = {}
    for i in id2label:
        label = id2label[i]
        if label in token2id:
            j = token2id[label]
        else:
            j = token2id["#unk#"]
        ilabel2itoken[i] = j

    print("label_id --> token_id: constructed!")

    num_examples = train_labels.shape[0]
    num_batches = num_examples // FLAGS.batch_size

    dev_num_examples = dev_labels.shape[0]
    dev_batches_indexes = np.arange(dev_num_examples)
    num_batches_dev = dev_num_examples // FLAGS.dev_batch_size

    tf.reset_default_graph()

    # Build the TensorFlow graph and train it
    g = tf.Graph()
    with g.as_default():

        model = build_model(model_config,
                            embeddings,
                            ilabel2itoken=ilabel2itoken,
                            mode=mode)

        # Set up the learning rate.
        learning_rate_decay_fn = None
        learning_rate = tf.constant(training_config.initial_learning_rate)
        if training_config.learning_rate_decay_factor > 0:
            num_batches_per_epoch = (num_examples / FLAGS.batch_size)
            decay_steps = int(num_batches_per_epoch *
                              training_config.num_epochs_per_decay)

            def _learning_rate_decay_fn(learning_rate, global_step):
                return tf.train.exponential_decay(
                    learning_rate,
                    global_step,
                    decay_steps=decay_steps,
                    decay_rate=training_config.learning_rate_decay_factor,
                    staircase=True)

            learning_rate_decay_fn = _learning_rate_decay_fn

        # Set up the training ops.
        train_op = tf.contrib.layers.optimize_loss(
            loss=model['total_loss'],
            global_step=model['global_step'],
            learning_rate=learning_rate,
            optimizer=training_config.optimizer,
            clip_gradients=training_config.clip_gradients,
            learning_rate_decay_fn=learning_rate_decay_fn)

        dev_best_accuracy = -1
        stopping_step = 0
        best_epoch = None
        should_stop = False

        # initialize all variables
        init = tf.global_variables_initializer()

        with tf.Session() as session:
            session.run(init)
            #session.run(tf.initializers.tables_initializer(name='init_all_tables'))

            t = 0  # counting iterations

            time_now = datetime.now()

            for epoch in range(training_config.total_num_epochs):
                if should_stop:
                    break

                print("\n==> Online epoch # {0}".format(epoch + 1))
                progress = Progbar(num_batches)
                batches_indexes = np.arange(num_examples)
                np.random.shuffle(batches_indexes)

                np.random.shuffle(batches_indexes)
                batch_index = 1
                loss_history = []
                epoch_loss = 0

                for indexes in batch(batches_indexes, FLAGS.batch_size):

                    t += 1
                    batch_hypotheses = train_hypotheses[indexes]
                    batch_labels = train_labels[indexes]

                    # explanations have been encoded / padded when loaded
                    batch_explanations = train_explanations[indexes]
                    batch_explanation_lengths = [
                        len(expl) for expl in batch_explanations
                    ]

                    batch_img_names = [train_img_names[i] for i in indexes]
                    batch_img_features = image_reader.get_features(
                        batch_img_names)

                    total_loss_value = _step(
                        session, batch_hypotheses, batch_labels,
                        batch_explanations, batch_img_features, train_op,
                        model, model_config.lstm_dropout_keep_prob
                    )  # run each training step

                    progress.update(batch_index, [("Loss", total_loss_value)])
                    loss_history.append(total_loss_value)
                    epoch_loss += total_loss_value
                    batch_index += 1

                    if FLAGS.print_every > 0 and t % FLAGS.print_every == 0:
                        print(
                            '(Iteration %d) loss: %f, and time elapsed: %.2f minutes'
                            % (t + 1, float(loss_history[-1]),
                               (datetime.now() - time_now).seconds / 60.0))

                print("Current mean training loss: {}\n".format(epoch_loss /
                                                                num_batches))

                print("-- Validating model")

                progress = Progbar(num_batches_dev)

                dev_num_correct = 0
                dev_batch_index = 0

                for indexes in batch(dev_batches_indexes,
                                     FLAGS.dev_batch_size):

                    t += 1

                    dev_batch_num_correct = 0

                    dev_batch_index += 1
                    dev_batch_hypotheses = dev_hypotheses[indexes]
                    dev_batch_labels = dev_labels[indexes]

                    # explanations have been encoded / padded when loaded
                    dev_batch_explanations = dev_explanations[indexes]
                    dev_batch_img_names = [dev_img_names[i] for i in indexes]
                    dev_batch_img_features = image_reader.get_features(
                        dev_batch_img_names)

                    pred_explanations, pred_labels = _run_validation(
                        session, dev_batch_hypotheses, dev_batch_labels,
                        dev_batch_explanations, dev_batch_img_features,
                        len(indexes), ilabel2itoken, model, 1.0)

                    if FLAGS.imbalance == True:
                        dev_batch_num_correct += np.dot(
                            pred_labels == dev_batch_labels,
                            class_weights[dev_batch_labels])
                    else:
                        dev_batch_num_correct += (
                            pred_labels == dev_batch_labels).sum()
                    dev_num_correct += dev_batch_num_correct

                    progress.update(
                        dev_batch_index,
                        [("Proportion of correct labels",
                          float(dev_batch_num_correct) / len(indexes))])
                    if FLAGS.sample_every > 0 and (
                            t + 1) % FLAGS.sample_every == 0:
                        pred_explanations = [
                            unpack.reshape(-1, 1)
                            for unpack in pred_explanations
                        ]
                        pred_explanations = np.concatenate(
                            pred_explanations, 1)
                        pred_explanations_decoded = [
                            decode(pred_explanations[i], id2token)
                            for i in range(len(indexes))
                        ]
                        print("\nExample generated explanation: ",
                              pred_explanations_decoded[0])  #TODO: decode it
                        #print("Original explanation: ", dev_original_explanations[indexes][0])

                dev_accuracy = float(dev_num_correct) / dev_num_examples
                print("Current mean validation accuracy: {}".format(
                    dev_accuracy))

                #if True:
                if dev_accuracy > dev_best_accuracy:
                    stopping_step = 0
                    best_epoch = epoch + 1
                    dev_best_accuracy = dev_accuracy
                    model['saver'].save(session,
                                        FLAGS.model_save_filename + ".ckpt")
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
                    print("Best model saved to: {}".format(
                        FLAGS.model_save_filename))
                else:
                    stopping_step += 1
                    print("Current stopping step: {}".format(stopping_step))
                if stopping_step >= FLAGS.patience:
                    print("Early stopping at epoch {}!".format(epoch + 1))
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
                    should_stop = True
                if epoch + 1 >= training_config.total_num_epochs:
                    print("Stopping at epoch {}!".format(epoch + 1))
                    print(
                        "Best mean validation accuracy: {} (reached at epoch {})"
                        .format(dev_best_accuracy, best_epoch))
コード例 #4
0
def main(_):

    BATCH_SIZE_INFERENCE = 1

    random_seed = 12345
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.set_random_seed(random_seed)

    start_logger(FLAGS.result_filename + ".log")
    atexit.register(stop_logger)

    print("-- Loading params")
    with open(FLAGS.model_filename + ".params", mode="r") as in_file:
        params = json.load(in_file)

    print("-- Loading index")
    with open(FLAGS.model_filename + ".index", mode="rb") as in_file:
        index = pickle.load(in_file)
        token2id = index["token2id"]
        id2token = index["id2token"]
        label2id = index["label2id"]
        id2label = index["id2label"]
        num_tokens = len(token2id)
        num_labels = len(label2id)

    print("Number of tokens: {}".format(num_tokens))
    print("Number of labels: {}".format(num_labels))

    model_config.set_vocab_size(num_tokens)
    print("Vocab size set!")

    print("-- Loading test set")
    test_labels, test_padded_explanations, test_padded_premises, test_padded_hypotheses, test_img_names, test_original_explanations, test_original_premises, test_original_hypotheses, test_max_length, test_pairIDs = \
        load_e_vsnli_dataset(
            FLAGS.test_filename,
            token2id,
            label2id,
            buffer_size=FLAGS.buffer_size,
        )

    if FLAGS.imbalance == True:
        #class_freqs = np.load(FLAGS.model_filename + '_class_freqs.npy')

        test_num_examples = test_labels.shape[0]
        class_freqs = np.bincount(test_labels) / test_num_examples
        class_weights = 1 / (class_freqs * num_labels)
        print("Class frequencies: ", class_freqs)
        print("Weights: ", class_weights)

    test_original_premises = np.array(test_original_premises)
    test_original_hypotheses = np.array(test_original_hypotheses)
    test_original_explanations = np.array(test_original_explanations)

    print("-- Loading images")
    image_reader = ImageReader(FLAGS.img_names_filename,
                               FLAGS.img_features_filename)

    ilabel2itoken = {}
    for i in id2label:
        label = id2label[i]
        if label in token2id:
            j = token2id[label]
        else:
            j = token2id["#unk#"]
        ilabel2itoken[i] = j
    print("label_id --> token_id: constructed!")

    model_config.set_vocab_size(num_tokens)
    model_config.set_alpha(params['alpha'])

    # Build the TensorFlow graph and train it
    g = tf.Graph()
    with g.as_default():
        # Build the model.

        model = build_model(model_config,
                            embeddings=None,
                            mode=mode,
                            inference_batch=BATCH_SIZE_INFERENCE)

        generator = LabelExplanationGenerator(
            model,
            vocab=token2id,
            ilabel2itoken=ilabel2itoken,
            max_explanation_length=model_config.padded_length - 1)

        # run training
        init = tf.global_variables_initializer()
        with tf.Session() as session:

            session.run(init)

            model['saver'].restore(session, FLAGS.model_filename + ".ckpt")

            print("Model restored! Last step run: ",
                  session.run(model['global_step']))

            print("-- Evaluating model")
            test_num_examples = test_labels.shape[0]
            test_batches_indexes = np.arange(test_num_examples)
            test_num_correct = 0
            y_true = []
            y_pred = []

            with open(FLAGS.result_filename + ".predictions",
                      mode="w") as out_file:
                writer = csv.writer(out_file, delimiter="\t")
                for indexes in batch(test_batches_indexes, FLAGS.batch_size):

                    test_batch_pairIDs = test_pairIDs[indexes]

                    test_batch_premises = test_padded_premises[indexes]
                    test_batch_hypotheses = test_padded_hypotheses[indexes]
                    test_batch_labels = test_labels[indexes]
                    test_batch_explanations = test_padded_explanations[indexes]
                    batch_img_names = [test_img_names[i] for i in indexes]
                    batch_img_features = image_reader.get_features(
                        batch_img_names)

                    test_batch_original_premises = test_original_premises[
                        indexes]
                    test_batch_original_hypotheses = test_original_hypotheses[
                        indexes]
                    test_batch_original_explanations = test_original_explanations[
                        indexes]

                    #                     pred_explanations, pred_labels = _step_test(session, test_batch_hypotheses, test_batch_labels, test_batch_explanations, test_batch_img_features, BATCH_SIZE_INFERENCE, model, 1.0) # the output is size (32, 16)
                    #pred_explanations = [unpack.reshape(-1, 1) for unpack in pred_explanations]
                    #pred_explanations = np.concatenate(pred_explanations, 1)

                    pred_labels, pred_explanations = run_inference(
                        session, test_batch_hypotheses, batch_img_features,
                        generator, 1.0)

                    # don't decode the first token which corresponds to the prepended label
                    # nor the last because it is <end>
                    pred_explanations_decoded = [
                        decode(pred_explanations[i][1:-1], id2token)
                        for i in range(len(indexes))
                    ]

                    #batch_bleu = corpus_bleu(test_batch_original_explanations, pred_explanations_decoded)
                    #print("Current BLEU score: ", batch_bleu)

                    if FLAGS.imbalance == True:
                        test_num_correct += np.dot(
                            (pred_labels == test_batch_labels),
                            class_weights[pred_labels])
                    else:
                        test_num_correct += (
                            pred_labels == test_batch_labels).sum()

                    # add explanations in result file
                    for i in range(len(indexes)):

                        writer.writerow([
                            id2label[test_batch_labels[i]],
                            id2label[pred_labels[i]],
                            " ".join([
                                id2token[id] for id in test_batch_premises[i]
                                if id != token2id["#pad#"]
                            ]),
                            " ".join([
                                id2token[id] for id in test_batch_hypotheses[i]
                                if id != token2id["#pad#"]
                            ]),
                            batch_img_names[i],
                            test_batch_original_premises[i],
                            test_batch_original_hypotheses[i],
                            #test_batch_original_explanations[i],
                            " ".join([
                                id2token[id]
                                for id in test_batch_explanations[i]
                                if id != token2id["#pad#"]
                            ]),
                            pred_explanations_decoded[i],
                            [],
                            test_batch_pairIDs[i]
                            #list(np.where(pred_atts[i]>0.1)[0])
                        ])
                        y_true.append(id2label[test_batch_labels[i]])

                        y_pred.append(id2label[pred_labels[i]])
            test_accuracy = float(test_num_correct) / test_num_examples
            print("Mean test accuracy: {}".format(test_accuracy))
            y_true = pd.Series(y_true, name="Actual")
            y_pred = pd.Series(y_pred, name="Predicted")
            confusion_matrix = pd.crosstab(y_true, y_pred, margins=True)
            confusion_matrix.to_csv(FLAGS.result_filename +
                                    ".confusion_matrix")

            # TODO: evaluation for explanations

            data = pd.read_csv(FLAGS.result_filename + ".predictions",
                               sep="\t",
                               header=None,
                               names=[
                                   "gold_label", "predicted_label",
                                   "premise_toks", "hypothesis_toks", "jpg",
                                   "premise", "hypothesis",
                                   "original_explanation",
                                   "generated_explanation", "top_rois",
                                   "pairID"
                               ])

            print("Overall accuracy: {}".format(
                accuracy_score(data["gold_label"], data["predicted_label"])))

            data_entailment = data.loc[data["gold_label"] == "entailment"]
            print("Accuracy for 'entailment': {}".format(
                accuracy_score(data_entailment["gold_label"],
                               data_entailment["predicted_label"])))

            data_contradiction = data.loc[data["gold_label"] ==
                                          "contradiction"]
            print("Accuracy for 'contradiction': {}".format(
                accuracy_score(data_contradiction["gold_label"],
                               data_contradiction["predicted_label"])))

            data_neutral = data.loc[data["gold_label"] == "neutral"]
            print("Accuracy for 'neutral': {}".format(
                accuracy_score(data_neutral["gold_label"],
                               data_neutral["predicted_label"])))
コード例 #5
0
def main(_):

    BATCH_SIZE_INFERENCE = 1

    random_seed = 12345
    os.environ["PYTHONHASHSEED"] = str(random_seed)
    random.seed(random_seed)
    np.random.seed(random_seed)
    tf.set_random_seed(random_seed)

    start_logger(FLAGS.result_filename + ".log")
    atexit.register(stop_logger)

    print("-- Loading params")
    with open(FLAGS.model_filename + ".params", mode="r") as in_file:
        params = json.load(in_file)

    print("-- Loading index")
    with open(FLAGS.model_filename + ".index", mode="rb") as in_file:
        index = pickle.load(in_file)
        token2id = index["token2id"]
        id2token = index["id2token"]
        label2id = index["label2id"]
        id2label = index["id2label"]
        num_tokens = len(token2id)
        num_labels = len(label2id)

    print("Number of tokens: {}".format(num_tokens))
    print("Number of labels: {}".format(num_labels))

    model_config.set_vocab_size(num_tokens)
    print("Vocab size set!")

    print("-- Loading test set")
    test_labels, test_padded_explanations, test_padded_premises, test_padded_hypotheses, test_img_names, test_original_explanations, test_original_premises, test_original_hypotheses, test_max_length, test_pairIDs = \
        load_e_vsnli_dataset(
            FLAGS.test_filename,
            token2id,
            label2id,
            buffer_size=FLAGS.buffer_size,
        )

    if FLAGS.imbalance == True:
        #class_freqs = np.load(FLAGS.model_filename + '_class_freqs.npy')

        test_num_examples = test_labels.shape[0]
        class_freqs = np.bincount(test_labels) / test_num_examples
        class_weights = 1 / (class_freqs * num_labels)
        print("Class frequencies: ", class_freqs)
        print("Weights: ", class_weights)

    test_original_premises = np.array(test_original_premises)
    test_original_hypotheses = np.array(test_original_hypotheses)
    test_original_explanations = np.array(test_original_explanations)

    print("-- Loading images")
    image_reader = ImageReader(FLAGS.img_names_filename,
                               FLAGS.img_features_filename)

    model_config.set_vocab_size(num_tokens)
    model_config.set_alpha(params['alpha'])

    # Build the TensorFlow graph and train it
    g = tf.Graph()
    with g.as_default():
        # Build the model.

        model = build_model(model_config,
                            embeddings=None,
                            mode=mode,
                            inference_batch=BATCH_SIZE_INFERENCE)

        generator = AttentionExplanationGenerator(
            model,
            vocab=token2id,
            max_explanation_length=model_config.padded_length - 1)

        # run training
        init = tf.global_variables_initializer()
        with tf.Session() as session:

            session.run(init)

            model['saver'].restore(session, FLAGS.model_filename + ".ckpt")

            print("Model restored! Last step run: ",
                  session.run(model['global_step']))

            print("-- Evaluating model")
            test_num_examples = test_labels.shape[0]
            test_batches_indexes = np.arange(test_num_examples)
            test_num_correct = 0
            y_true = []
            y_pred = []

            with open(FLAGS.result_filename + ".predictions",
                      mode="w") as out_file:
                writer = csv.writer(out_file, delimiter="\t")
                for indexes in batch(test_batches_indexes, FLAGS.batch_size):

                    test_batch_pairIDs = test_pairIDs[indexes]

                    test_batch_premises = test_padded_premises[indexes]
                    test_batch_hypotheses = test_padded_hypotheses[indexes]
                    test_batch_labels = test_labels[indexes]
                    test_batch_explanations = test_padded_explanations[indexes]
                    batch_img_names = [test_img_names[i] for i in indexes]
                    batch_img_features = image_reader.get_features(
                        batch_img_names)

                    test_batch_original_premises = test_original_premises[
                        indexes]
                    test_batch_original_hypotheses = test_original_hypotheses[
                        indexes]
                    test_batch_original_explanations = test_original_explanations[
                        indexes]

                    pred_attns, pred_explanations = run_inference_attn(
                        session, test_batch_hypotheses, batch_img_features,
                        generator, 1.0)

                    # don't decode the first token which corresponds to the prepended label
                    # nor the last because it is <end>
                    pred_explanations_decoded = [
                        decode(pred_explanations[i][1:-1], id2token)
                        for i in range(len(indexes))
                    ]

                    #batch_bleu = corpus_bleu(test_batch_original_explanations, pred_explanations_decoded)
                    #print("Current BLEU score: ", batch_bleu)

                    # add explanations in result file
                    for i in range(len(indexes)):

                        writer.writerow([
                            id2label[test_batch_labels[i]],
                            " ".join([
                                id2token[id] for id in test_batch_premises[i]
                                if id != token2id["#pad#"]
                            ]),
                            " ".join([
                                id2token[id] for id in test_batch_hypotheses[i]
                                if id != token2id["#pad#"]
                            ]),
                            batch_img_names[i],
                            test_batch_original_premises[i],
                            test_batch_original_hypotheses[i],
                            " ".join([
                                id2token[id]
                                for id in test_batch_explanations[i]
                                if id != token2id["#pad#"]
                            ]),
                            pred_explanations_decoded[i],
                            list(np.where(pred_attns[i] > 0.05)[0]),
                            #pred_attns[i][0],
                            test_batch_pairIDs[i]
                        ])

            data = pd.read_csv(FLAGS.result_filename + ".predictions",
                               sep="\t",
                               header=None,
                               names=[
                                   "gold_label", "premise_toks",
                                   "hypothesis_toks", "jpg", "premise",
                                   "hypothesis", "original_explanation",
                                   "generated_explanation", "top_rois"
                               ])