示例#1
0
def generate_training_samples():
    training_sample_dict = {}
    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    positive_q_docs_pair = data.load_json(POSITIVE_Q_DOCS)
    qtype_docs_range = data.load_json(QTYPE_DOC_RANGE)
    phrase_pairs, _ = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH, DIC_PATH,
                                          MAX_TOKENS)
    print("Obtained %d phrase pairs from %s." %
          (len(phrase_pairs), TRAIN_QUESTION_ANSWER_PATH))
    for question in phrase_pairs:
        if len(question) == 2 and 'qid' in question[1]:
            key_weak, _, query_qid = AnalyzeQuestion(question[1])
            query_index = docID_dict[query_qid]
            if key_weak in qtype_docs_range:
                document_range = (qtype_docs_range[key_weak]['start'],
                                  qtype_docs_range[key_weak]['end'])
            else:
                document_range = (0, len(docID_dict))
            positive_document_list = [
                docID_dict[doc] for doc in positive_q_docs_pair[query_qid]
            ]
            training_sample_dict[query_qid] = {
                'query_index': query_index,
                'document_range': document_range,
                'positive_document_list': positive_document_list
            }
    fw = open('../data/auto_QA_data/retriever_training_samples.json',
              'w',
              encoding="UTF-8")
    fw.writelines(
        json.dumps(training_sample_dict, indent=1, ensure_ascii=False))
    fw.close()
    print('Writing retriever_training_samples.json is done!')
                        (str(x).lower() in ['true', '1', 'yes']),
                        help="Using weak mode to search for support set")
    parser.add_argument('--retriever-random',
                        action='store_true',
                        help='randomly get support set for the retriever')
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")
    log.info("Device info: %s", str(device))

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    # TODO: In maml, all data points in 944K training dataset will be used. So it is much better to use the dict of 944K training the model from scratch.
    # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1.
    phrase_pairs, emb_dict = data.load_data_MAML(
        QUESTION_PATH=TRAIN_QUESTION_ANSWER_PATH,
        DIC_PATH=DIC_PATH,
        max_tokens=MAX_TOKENS)
    log.info("Obtained %d phrase pairs with %d uniq words from %s.",
             len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)
    phrase_pairs_944K = data.load_data_MAML(
        QUESTION_PATH=TRAIN_944K_QUESTION_ANSWER_PATH, max_tokens=MAX_TOKENS)
    log.info("Obtained %d phrase pairs from %s.", len(phrase_pairs_944K),
             TRAIN_944K_QUESTION_ANSWER_PATH)
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    # Transform token into index in dictionary.
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # # list of (seq1, [seq*]) pairs,把训练对做成1:N的形式;
    # train_data = data.group_train_data(train_data)
    train_data = data.group_train_data_RLTR(train_data)
示例#3
0
def establish_positive_question_documents_pair(MAX_TOKENS):
    # Dict: word token -> ID.
    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}
    # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1.
    phrase_pairs, emb_dict = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH,
                                                 DIC_PATH, MAX_TOKENS)
    print("Obtained %d phrase pairs with %d uniq words from %s." %
          (len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH))
    phrase_pairs_944K = data.load_data_MAML(TRAIN_944K_QUESTION_ANSWER_PATH,
                                            max_tokens=MAX_TOKENS)
    print("Obtained %d phrase pairs from %s." %
          (len(phrase_pairs_944K), TRAIN_944K_QUESTION_ANSWER_PATH))

    # Transform token into index in dictionary.
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # train_data = data.group_train_data(train_data)
    train_data = data.group_train_data_RLTR(train_data)
    train_data_944K = data.encode_phrase_pairs_RLTR(phrase_pairs_944K,
                                                    emb_dict)
    train_data_944K = data.group_train_data_RLTR_for_support(train_data_944K)

    dict944k = data.get944k(DICT_944K)
    print("Reading dict944k from %s is done. %d pairs in dict944k." %
          (DICT_944K, len(dict944k)))
    dict944k_weak = data.get944k(DICT_944K_WEAK)
    print("Reading dict944k_weak from %s is done. %d pairs in dict944k_weak" %
          (DICT_944K_WEAK, len(dict944k_weak)))

    metaLearner = metalearner.MetaLearner(
        samples=5,
        train_data_support_944K=train_data_944K,
        dict=dict944k,
        dict_weak=dict944k_weak,
        steps=5,
        weak_flag=True)

    question_doctments_pair_list = {}
    idx = 0
    for temp_batch in data.iterate_batches(train_data, 1):
        task = temp_batch[0]
        if len(task) == 2 and 'qid' in task[1]:
            # print("Task %s is training..." %(str(task[1]['qid'])))
            # Establish support set.
            support_set = metaLearner.establish_support_set(
                task, metaLearner.steps, metaLearner.weak_flag,
                metaLearner.train_data_support_944K)
            documents = []
            if len(support_set) > 0:
                for support_sample in support_set:
                    if len(support_sample) == 2 and 'qid' in support_sample[1]:
                        documents.append(support_sample[1]['qid'])
            else:
                print('task %s has no support set!' % (str(task[1]['qid'])))
                documents.append(task[1]['qid'])
            question_doctments_pair_list[task[1]['qid']] = documents
            if idx % 100 == 0:
                print(idx)
            idx += 1
        else:
            print('task has no qid or len(task)!=2:')
            print(task)
    fw = open('../data/auto_QA_data/retriever_question_documents_pair.json',
              'w',
              encoding="UTF-8")
    fw.writelines(
        json.dumps(question_doctments_pair_list, indent=1, ensure_ascii=False))
    fw.close()
    print('Writing retriever_question_documents_pair.json is done!')
示例#4
0
    parser.add_argument("--MonteCarlo", action='store_true', default=False,
                        help="using Monte Carlo algorithm for REINFORCE")
    args = parser.parse_args()

    device = torch.device("cuda" if args.cuda else "cpu")
    log.info("Device info: %s", str(device))

    PREDICT_PATH = '../data/saves/' + str(args.name) + '/' + str(args.pred) + '_predict.actions'
    fwPredict = open(PREDICT_PATH, 'w', encoding="UTF-8")

    TEST_QUESTION_PATH = '../data/auto_QA_data/mask_test/' + str(args.pred).upper() + '_test.question'
    log.info("Open: %s", '../data/auto_QA_data/mask_test/' + str(args.pred).upper() + '_test.question')

    phrase_pairs, emb_dict = data.load_data_MAML_TEST(QUESTION_PATH=TEST_QUESTION_PATH, DIC_PATH=DIC_PATH)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    phrase_pairs_944K = data.load_data_MAML(TRAIN_944K_QUESTION_ANSWER_PATH, max_tokens=MAX_TOKENS)
    log.info("Obtained %d phrase pairs from %s.", len(phrase_pairs_944K), TRAIN_944K_QUESTION_ANSWER_PATH)

    if args.retriever_random:
        log.info("Using random support set for test.")

    end_token = emb_dict[data.END_TOKEN]
    # Transform token into index in dictionary.
    test_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # # list of (seq1, [seq*]) pairs,把训练对做成1:N的形式;
    # train_data = data.group_train_data(train_data)
    test_data = data.group_train_data_RLTR(test_data)

    train_data_944K = data.encode_phrase_pairs_RLTR(phrase_pairs_944K, emb_dict)
    train_data_944K = data.group_train_data_RLTR_for_support(train_data_944K)
    dict944k = data.get944k(DICT_944K)
                        help='tasks of a batch in outer loop of MAML')
    # If weak is true, it means when searching for support set, the questions with same number of E/R/T nut different relation will be retrieved if the questions in this pattern is less than the number of steps.
    parser.add_argument("--weak",
                        type=lambda x:
                        (str(x).lower() in ['true', '1', 'yes']),
                        help="Using weak mode to search for support set")
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")
    log.info("Device info: %s", str(device))

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    # TODO: In maml, all data points in 944K training dataset will be used. So it is much better to use the dict of 944K training the model from scratch.
    # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1.
    phrase_pairs, emb_dict = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH,
                                                 DIC_PATH, MAX_TOKENS)
    log.info("Obtained %d phrase pairs with %d uniq words from %s.",
             len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)
    phrase_pairs_webqsp = data.load_data_MAML(
        TRAIN_WEBQSP_QUESTION_ANSWER_PATH, max_tokens=MAX_TOKENS)
    log.info("Obtained %d phrase pairs from %s.", len(phrase_pairs_webqsp),
             TRAIN_WEBQSP_QUESTION_ANSWER_PATH)
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    # Transform token into index in dictionary.
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # # list of (seq1, [seq*]) pairs,把训练对做成1:N的形式;
    # train_data = data.group_train_data(train_data)
    train_data = data.group_train_data_RLTR(train_data)

    train_data_webqsp = data.encode_phrase_pairs_RLTR(phrase_pairs_webqsp,