Example #1
0
    with open(args.model_filename + ".params", mode="r") as in_file:
        params = json.load(in_file)

    print("-- Loading index")
    with open(args.model_filename + ".index", mode="rb") as in_file:
        index = pickle.load(in_file)
        token2id = index["token2id"]
        id2token = index["id2token"]
        label2id = index["label2id"]
        id2label = index["id2label"]
        num_tokens = len(token2id)
        num_labels = len(label2id)

    print("-- Loading test set")
    print("-- Loading test set")
    test_labels, test_padded_sentences, test_img_names, test_original_sentences = load_ic_dataset(
        args.test_filename, token2id, label2id)

    print("-- Loading images")
    image_reader = ImageReader(args.img_names_filename,
                               args.img_features_filename)

    print("-- Restoring model")
    sentence_input = tf.placeholder(tf.int32, (None, None),
                                    name="sentence_input")
    img_features_input = tf.placeholder(
        tf.float32,
        (None, params["num_img_features"], params["img_features_size"]),
        name="img_features_input")
    label_input = tf.placeholder(tf.int32, (None, ), name="label_input")
    dropout_input = tf.placeholder(tf.float32, name="dropout_input")
    logits = build_bottom_up_top_down_ic_model(
        print("Params saved to: {}".format(args.model_save_filename +
                                           ".params"))

        with open(args.model_save_filename + ".index", mode="wb") as out_file:
            pickle.dump(
                {
                    "token2id": token2id,
                    "id2token": id2token,
                    "label2id": label2id,
                    "id2label": id2label
                }, out_file)
            print("Index saved to: {}".format(args.model_save_filename +
                                              ".index"))

    print("-- Loading training set")
    train_labels, train_sentences, _, _ = load_ic_dataset(
        args.train_filename, token2id, label2id)

    print("-- Loading development set")
    dev_labels, dev_sentences, _, _ = load_ic_dataset(args.dev_filename,
                                                      token2id, label2id)

    print("-- Building model")
    sentence_input = tf.placeholder(tf.int32, (None, None),
                                    name="sentence_input")
    label_input = tf.placeholder(tf.int32, (None, ), name="label_input")
    dropout_input = tf.placeholder(tf.float32, name="dropout_input")
    logits = build_simple_blind_ic_model(sentence_input, dropout_input,
                                         num_tokens, num_labels, embeddings,
                                         args.embeddings_size,
                                         args.train_embeddings,
                                         args.rnn_hidden_size,
Example #3
0
        with open(args.model_save_filename + ".index", mode="wb") as out_file:
            pickle.dump(
                {
                    "token2id": token2id,
                    "id2token": id2token,
                    "vte_label2id": vte_label2id,
                    "vte_id2label": vte_id2label,
                    "ic_label2id": ic_label2id,
                    "ic_id2label": ic_id2label
                }, out_file)
            print("Index saved to: {}".format(args.model_save_filename +
                                              ".index"))

    print("-- Loading training set")
    ic_train_labels, ic_train_sentences, ic_train_img_names, _ = load_ic_dataset(
        args.ic_train_filename, token2id, ic_label2id)

    print("-- Loading training set")
    vte_train_labels, vte_train_premises, vte_train_hypotheses, vte_train_img_names, _, _ =\
        load_vte_dataset(
            args.vte_train_filename,
            token2id,
            vte_label2id
        )

    print("-- Loading development set")
    vte_dev_labels, vte_dev_premises, vte_dev_hypotheses, vte_dev_img_names, _, _ =\
        load_vte_dataset(
            args.vte_dev_filename,
            token2id,
            vte_label2id