import torch

log = logging.getLogger("data_test")

if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True,
                        help="Category to use for training. Empty string to train on full dataset")
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    args = parser.parse_args()

    phrase_pairs, emb_dict = data.load_data(args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    train_data = data.group_train_data(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    end_token = emb_dict[data.END_TOKEN]

    seq_count = 0
    sum_bleu = 0.0

    for seq_1, targets in train_data:
        input_seq = model.pack_input(seq_1, net.emb)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
Beispiel #2
0
        "Open: %s", '../data/auto_QA_data/nomask_test/' +
        str(args.pred).upper() + '_test.question')
    TEST_ACTION_PATH = '../data/auto_QA_data/nomask_test/' + str(
        args.pred).upper() + '_test.action'
    log.info(
        "Open: %s", '../data/auto_QA_data/nomask_test/' +
        str(args.pred).upper() + '_test.action')
    if args.pred == 'pt' or 'final' in args.pred:
        phrase_pairs, emb_dict = data.load_data_from_existing_data(
            TEST_QUESTION_PATH, TEST_ACTION_PATH, DIC_PATH)
    elif args.pred == 'rl':
        phrase_pairs, emb_dict = data.load_RL_data(TEST_QUESTION_PATH,
                                                   TEST_ACTION_PATH, DIC_PATH)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs),
             len(emb_dict))
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    if args.pred == 'rl':
        train_data = data.group_train_data(train_data)
    else:
        train_data = data.group_train_data_one_to_one(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                            dict_size=len(emb_dict),
                            hid_size=model.HIDDEN_STATE_SIZE)
    net = net.cuda()
    model_path = '../data/saves/' + str(args.name) + '/' + str(args.model)
    net.load_state_dict((torch.load(model_path)))
    end_token = emb_dict[data.END_TOKEN]

    seq_count = 0