예제 #1
0
    phrase_pairs, emb_dict = [], list()
    TEST_QUESTION_PATH = '../data/auto_QA_data/mask_test/' + str(
        args.pred).upper() + '_test.question'
    log.info(
        "Open: %s", '../data/auto_QA_data/mask_test/' +
        str(args.pred).upper() + '_test.question')
    TEST_ACTION_PATH = '../data/auto_QA_data/mask_test/' + str(
        args.pred).upper() + '_test.action'
    log.info(
        "Open: %s", '../data/auto_QA_data/mask_test/' +
        str(args.pred).upper() + '_test.action')
    # if args.pred == 'pt' or 'final' in args.pred:
    #     phrase_pairs, emb_dict = data.load_data_from_existing_data(TEST_QUESTION_PATH, TEST_ACTION_PATH, DIC_PATH)
    # elif args.pred == 'rl':
    #     phrase_pairs, emb_dict = data.load_RL_data(TEST_QUESTION_PATH, TEST_ACTION_PATH, DIC_PATH)
    phrase_pairs, emb_dict = data.load_RL_data_TR(TEST_QUESTION_ANSWER_PATH,
                                                  DIC_PATH, MAX_TOKENS)

    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs),
             len(emb_dict))
    # train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    if args.pred == 'rl':
        # train_data = data.group_train_data(train_data)
        train_data = data.group_train_data_RLTR(train_data)
    else:
        train_data = data.group_train_data_one_to_one(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                            dict_size=len(emb_dict),
                            hid_size=model.HIDDEN_STATE_SIZE,
예제 #2
0
    device = torch.device("cuda" if args.cuda else "cpu")
    log.info("Device info: %s", str(device))

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1.
    if args.int:
        phrase_pairs, emb_dict = data.load_RL_data_TR_INT(
            TRAIN_QUESTION_ANSWER_PATH_INT, DIC_PATH_INT, MAX_TOKENS_INT,
            bool(args.NSM))
        log.info(
            "Obtained %d phrase pairs with %d uniq words from %s with INT mask information.",
            len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH_INT)
    else:
        phrase_pairs, emb_dict = data.load_RL_data_TR(
            TRAIN_QUESTION_ANSWER_PATH, DIC_PATH, MAX_TOKENS, bool(args.NSM))
        log.info(
            "Obtained %d phrase pairs with %d uniq words from %s without INT mask information.",
            len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)

    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # # list of (seq1, [seq*]) pairs,把训练对做成1:N的形式;
    # train_data = data.group_train_data(train_data)
    train_data = data.group_train_data_RLTR(train_data)
    rand = np.random.RandomState(data.SHUFFLE_SEED)
    rand.shuffle(train_data)
    train_data, test_data = data.split_train_test(train_data, TRAIN_RATIO)
    log.info("Training data converted, got %d samples", len(train_data))
    log.info("Train set has %d phrases, test %d", len(train_data),
예제 #3
0
        if args.int:
            if args.dataset == "csqa":
                phrase_pairs, emb_dict = data.load_RL_data_TR_INT(
                    TRAIN_QUESTION_ANSWER_PATH_INT, DIC_PATH_INT,
                    MAX_TOKENS_INT)
            else:
                phrase_pairs, emb_dict = data.load_RL_data_TR_INT(
                    TRAIN_QUESTION_ANSWER_PATH_INT_WEBQSP, DIC_PATH_INT_WEBQSP,
                    MAX_TOKENS_INT)
            log.info(
                "Obtained %d phrase pairs with %d uniq words from %s with INT mask information.",
                len(phrase_pairs), len(emb_dict),
                TRAIN_QUESTION_ANSWER_PATH_INT)
        else:
            if args.dataset == "csqa":
                phrase_pairs, emb_dict = data.load_RL_data_TR(
                    TRAIN_QUESTION_ANSWER_PATH, DIC_PATH, MAX_TOKENS)
            else:
                phrase_pairs, emb_dict = data.load_RL_data_TR(
                    TRAIN_QUESTION_ANSWER_PATH_WEBQSP, DIC_PATH_WEBQSP,
                    MAX_TOKENS)
            log.info(
                "Obtained %d phrase pairs with %d uniq words from %s without INT mask information.",
                len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)

        # Index -> word.
        rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
        end_token = emb_dict[data.END_TOKEN]
        # 将tokens转换为emb_dict中的indices;
        test_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)

        net = model.PhraseModel(emb_size=model.EMBEDDING_DIM,