Пример #1
0
def generate_training_samples():
    training_sample_dict = {}
    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    positive_q_docs_pair = data.load_json(POSITIVE_Q_DOCS)
    qtype_docs_range = data.load_json(QTYPE_DOC_RANGE)
    phrase_pairs, _ = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH, DIC_PATH,
                                          MAX_TOKENS)
    print("Obtained %d phrase pairs from %s." %
          (len(phrase_pairs), TRAIN_QUESTION_ANSWER_PATH))
    for question in phrase_pairs:
        if len(question) == 2 and 'qid' in question[1]:
            key_weak, _, query_qid = AnalyzeQuestion(question[1])
            query_index = docID_dict[query_qid]
            if key_weak in qtype_docs_range:
                document_range = (qtype_docs_range[key_weak]['start'],
                                  qtype_docs_range[key_weak]['end'])
            else:
                document_range = (0, len(docID_dict))
            positive_document_list = [
                docID_dict[doc] for doc in positive_q_docs_pair[query_qid]
            ]
            training_sample_dict[query_qid] = {
                'query_index': query_index,
                'document_range': document_range,
                'positive_document_list': positive_document_list
            }
    fw = open('../data/auto_QA_data/retriever_training_samples.json',
              'w',
              encoding="UTF-8")
    fw.writelines(
        json.dumps(training_sample_dict, indent=1, ensure_ascii=False))
    fw.close()
    print('Writing retriever_training_samples.json is done!')
Пример #2
0
def initialize_document_embedding(int_flag=True, w2v=300, file_path=''):
    device = 'cuda'
    # Dict: word token -> ID.
    if not int_flag:
        emb_dict = data.load_dict(DIC_PATH=DIC_PATH)
    else:
        emb_dict = data.load_dict(DIC_PATH=DIC_PATH_INT)
    ordered_docID_doc_list = data.get_ordered_docID_document(
        ORDERED_QID_QUESTION_DICT)
    docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list)
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}

    net = retriever_module.RetrieverModel(emb_size=w2v,
                                          dict_size=len(docID_dict),
                                          EMBED_FLAG=False,
                                          device='cuda').to('cuda')
    net.cuda()
    net.zero_grad()
    # temp_param_dict = get_net_parameter(net)

    # Get trained wording embeddings.
    path = file_path
    net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                             dict_size=len(emb_dict),
                             hid_size=model.HIDDEN_STATE_SIZE,
                             LSTM_FLAG=True,
                             ATT_FLAG=False,
                             EMBED_FLAG=False).to(device)
    net1.load_state_dict(torch.load(path))
    doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1)
    # Add padding vector.
    doc_embedding_list.append([0.0] * model.EMBEDDING_DIM)
    doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda()
    net.document_emb.weight.data = doc_embedding_tensor.clone().detach()
    # temp_param_dict1 = get_net_parameter(net)

    MAP_for_queries = 1.0
    epoch = 0
    isExists = os.path.exists(SAVES_DIR)
    if not isExists:
        os.makedirs(SAVES_DIR)

    # os.makedirs(SAVES_DIR, exist_ok=True)
    # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
    torch.save(
        net.state_dict(),
        os.path.join(SAVES_DIR,
                     "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
Пример #3
0
def initialize_document_embedding():
    device = 'cuda'
    # Dict: word token -> ID.
    emb_dict = data.load_dict(DIC_PATH=DIC_PATH)
    ordered_docID_doc_list = data.get_ordered_docID_document(
        ORDERED_QID_QUESTION_DICT)
    docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list)
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}

    net = retriever_module.RetrieverModel(emb_size=50,
                                          dict_size=len(docID_dict),
                                          EMBED_FLAG=False,
                                          device='cuda').to('cuda')
    net.cuda()
    net.zero_grad()
    # temp_param_dict = get_net_parameter(net)

    # Get trained wording embeddings.
    path = '../data/saves/maml_batch8_att=0_newdata2k_1storder_1task/epoch_002_0.394_0.796.dat'
    net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                             dict_size=len(emb_dict),
                             hid_size=model.HIDDEN_STATE_SIZE,
                             LSTM_FLAG=True,
                             ATT_FLAG=False,
                             EMBED_FLAG=False).to(device)
    net1.load_state_dict(torch.load(path))
    doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1)
    # Add padding vector.
    doc_embedding_list.append([0.0] * model.EMBEDDING_DIM)
    doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda()
    net.document_emb.weight.data = doc_embedding_tensor.clone().detach()
    # temp_param_dict1 = get_net_parameter(net)

    MAP_for_queries = 1.0
    epoch = 0
    os.makedirs(SAVES_DIR, exist_ok=True)
    # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
    torch.save(
        net.state_dict(),
        os.path.join(SAVES_DIR,
                     "epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
Пример #4
0
    net.cuda()
    log.info("Model: %s", net)

    # Load the pre-trained seq2seq model.
    net.load_state_dict(torch.load(args.load))
    # print("Pre-trained network params")
    # for name, param in net.named_parameters():
    #     print(name, param.shape)
    log.info("Model loaded from %s, continue training in MAML-Reptile mode...",
             args.load)
    if (args.adaptive):
        log.info("Using adaptive reward to train the REINFORCE model...")
    else:
        log.info("Using 0-1 sparse reward to train the REINFORCE model...")

    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}

    qtype_docs_range = data.load_json(QTYPE_DOC_RANGE)

    retriever_net = retriever_module.RetrieverModel(
        emb_size=50,
        dict_size=len(docID_dict),
        EMBED_FLAG=args.docembed_grad,
        device=device).to(device)
    retriever_net.cuda()
    log.info("Retriever model: %s", retriever_net)
    retriever_net.load_state_dict(torch.load(args.retrieverload))
    log.info("Retriever model loaded from %s, continue training in RL mode...",
             args.retrieverload)
Пример #5
0
def establish_positive_question_documents_pair(MAX_TOKENS):
    # Dict: word token -> ID.
    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}
    # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1.
    phrase_pairs, emb_dict = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH,
                                                 DIC_PATH, MAX_TOKENS)
    print("Obtained %d phrase pairs with %d uniq words from %s." %
          (len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH))
    phrase_pairs_944K = data.load_data_MAML(TRAIN_944K_QUESTION_ANSWER_PATH,
                                            max_tokens=MAX_TOKENS)
    print("Obtained %d phrase pairs from %s." %
          (len(phrase_pairs_944K), TRAIN_944K_QUESTION_ANSWER_PATH))

    # Transform token into index in dictionary.
    train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict)
    # train_data = data.group_train_data(train_data)
    train_data = data.group_train_data_RLTR(train_data)
    train_data_944K = data.encode_phrase_pairs_RLTR(phrase_pairs_944K,
                                                    emb_dict)
    train_data_944K = data.group_train_data_RLTR_for_support(train_data_944K)

    dict944k = data.get944k(DICT_944K)
    print("Reading dict944k from %s is done. %d pairs in dict944k." %
          (DICT_944K, len(dict944k)))
    dict944k_weak = data.get944k(DICT_944K_WEAK)
    print("Reading dict944k_weak from %s is done. %d pairs in dict944k_weak" %
          (DICT_944K_WEAK, len(dict944k_weak)))

    metaLearner = metalearner.MetaLearner(
        samples=5,
        train_data_support_944K=train_data_944K,
        dict=dict944k,
        dict_weak=dict944k_weak,
        steps=5,
        weak_flag=True)

    question_doctments_pair_list = {}
    idx = 0
    for temp_batch in data.iterate_batches(train_data, 1):
        task = temp_batch[0]
        if len(task) == 2 and 'qid' in task[1]:
            # print("Task %s is training..." %(str(task[1]['qid'])))
            # Establish support set.
            support_set = metaLearner.establish_support_set(
                task, metaLearner.steps, metaLearner.weak_flag,
                metaLearner.train_data_support_944K)
            documents = []
            if len(support_set) > 0:
                for support_sample in support_set:
                    if len(support_sample) == 2 and 'qid' in support_sample[1]:
                        documents.append(support_sample[1]['qid'])
            else:
                print('task %s has no support set!' % (str(task[1]['qid'])))
                documents.append(task[1]['qid'])
            question_doctments_pair_list[task[1]['qid']] = documents
            if idx % 100 == 0:
                print(idx)
            idx += 1
        else:
            print('task has no qid or len(task)!=2:')
            print(task)
    fw = open('../data/auto_QA_data/retriever_question_documents_pair.json',
              'w',
              encoding="UTF-8")
    fw.writelines(
        json.dumps(question_doctments_pair_list, indent=1, ensure_ascii=False))
    fw.close()
    print('Writing retriever_question_documents_pair.json is done!')
Пример #6
0
def retriever_training(epochs,
                       RETRIEVER_EMBED_FLAG=True,
                       query_embedding=True):
    ''' One instance of the retriever training samples:
        query_index = [800000, 0, 2, 100000, 400000, 600000]
        document_range = [(700000, 944000),
                          (1, 10),
                          (10, 300000),
                          (10, 300000),
                          (300000, 500000),
                          (500000, 700000)]
        positive_document_list =
        [[700001-700000, 700002-700000, 900000-700000, 910000-700000, 944000-2-700000],
        [2, 3],
        [13009-10, 34555-10, 234-10, 6789-10, 300000-1-10],
        [11-10, 16-10, 111111-10, 222222-10, 222223-10],
        [320000-300000, 330000-300000, 340000-300000, 350000-300000, 360000-300000],
        [600007-500000, 610007-500000, 620007-500000, 630007-500000, 690007-500000]]'''

    retriever_path = '../data/saves/retriever/initial_epoch_000_1.000.dat'
    device = 'cuda'
    learning_rate = 0.01
    docID_dict, _ = data.get_docID_indices(
        data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT))
    # Index -> qid.
    rev_docID_dict = {id: doc for doc, id in docID_dict.items()}
    training_samples = data.load_json(TRAINING_SAMPLE_DICT)

    net = retriever_module.RetrieverModel(emb_size=50,
                                          dict_size=len(docID_dict),
                                          EMBED_FLAG=RETRIEVER_EMBED_FLAG,
                                          device=device).to(device)
    net.load_state_dict(torch.load(retriever_path))
    net.zero_grad()
    # temp_param_dict = get_net_parameter(net)
    # retriever_optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate)
    # retriever_optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate, eps=1e-3)
    retriever_optimizer = adabound.AdaBound(filter(lambda p: p.requires_grad,
                                                   net.parameters()),
                                            lr=1e-3,
                                            final_lr=0.1)
    # temp_param_dict = get_net_parameter(net)

    emb_dict = None
    net1 = None
    qid_question_pair = {}
    if query_embedding:
        emb_dict = data.load_dict(DIC_PATH=DIC_PATH)
        # Get trained wording embeddings.
        path = RETRIEVER_PATH
        net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                                 dict_size=len(emb_dict),
                                 hid_size=model.HIDDEN_STATE_SIZE,
                                 LSTM_FLAG=True,
                                 ATT_FLAG=False,
                                 EMBED_FLAG=False).to(device)
        net1.load_state_dict(torch.load(path))
        qid_question_pair = data.get_qid_question_pairs(
            ORDERED_QID_QUESTION_DICT)

    max_value = MAX_MAP
    MAP_for_queries = MAX_MAP

    for i in range(epochs):
        print('Epoch %d is training......' % (i))
        # count= 0
        for key, value in training_samples.items():
            retriever_optimizer.zero_grad()
            net.zero_grad()
            if query_embedding:
                if key in qid_question_pair:
                    question_tokens = qid_question_pair[key]
                else:
                    print("ERROR! NO SUCH QUESTION: %s!" % (str(key)))
                    continue
                query_tensor = data.get_question_embedding(
                    question_tokens, emb_dict, net1)
            else:
                query_tensor = torch.tensor(net.pack_input(
                    value['query_index']).tolist(),
                                            requires_grad=False).cuda()
            document_range = (value['document_range'][0],
                              value['document_range'][1])
            logsoftmax_output = net(query_tensor, document_range)[0]
            logsoftmax_output = logsoftmax_output.cuda()
            positive_document_list = [
                k - value['document_range'][0]
                for k in value['positive_document_list']
            ]
            possitive_logsoftmax_output = torch.stack(
                [logsoftmax_output[k] for k in positive_document_list])
            loss_policy_v = -possitive_logsoftmax_output.mean()
            loss_policy_v = loss_policy_v.cuda()
            loss_policy_v.backward()
            retriever_optimizer.step()
            # temp_param_dict = get_net_parameter(net)
            # if count%100==0:
            #     print('     Epoch %d, %d samples have been trained.' %(i, count))
            # count+=1

        # Record trained parameters.
        if i % 1 == 0:
            MAP_list = []
            for j in range(int(len(training_samples) / 40)):
                random.seed(datetime.now())
                key, value = random.choice(list(training_samples.items()))
                if query_embedding:
                    question_tokens = qid_question_pair[key]
                    query_tensor = data.get_question_embedding(
                        question_tokens, emb_dict, net1)
                else:
                    query_tensor = torch.tensor(net.pack_input(
                        value['query_index']).tolist(),
                                                requires_grad=False).cuda()
                document_range = (value['document_range'][0],
                                  value['document_range'][1])
                logsoftmax_output = net(query_tensor, document_range)[0]
                order = net.calculate_rank(logsoftmax_output.tolist())
                positive_document_list = [
                    k - value['document_range'][0]
                    for k in value['positive_document_list']
                ]
                orders = [order[k] for k in positive_document_list]
                MAP = mean(orders)
                MAP_list.append(MAP)
            MAP_for_queries = mean(MAP_list)
            print('------------------------------------------------------')
            print('Epoch %d, MAP_for_queries: %f' % (i, MAP_for_queries))
            print('------------------------------------------------------')
            if MAP_for_queries < max_value:
                max_value = MAP_for_queries
                if MAP_for_queries < 500:
                    output_str = "AdaBound"
                    if RETRIEVER_EMBED_FLAG:
                        output_str += "_DocEmbed"
                    if query_embedding:
                        output_str += "_QueryEmbed"
                    torch.save(
                        net.state_dict(),
                        os.path.join(
                            SAVES_DIR, output_str + "_epoch_%03d_%.3f.dat" %
                            (i, MAP_for_queries)))
                    print('Save the state_dict: %s' %
                          (str(i) + ' ' + str(MAP_for_queries)))
        if MAP_for_queries < 10:
            break