writer = SummaryWriter(comment="-" + args.name) # BEGIN token beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device) beg_token = beg_token.cuda() metaLearner = metalearner.MetaLearner( net=net, retriever_net=retriever_net, device=device, beg_token=beg_token, end_token=end_token, adaptive=args.adaptive, samples=args.samples, train_data_support_944K=train_data_944K, rev_emb_dict=rev_emb_dict, first_order=args.first_order, fast_lr=args.fast_lr, meta_optimizer_lr=args.meta_lr, dial_shown=False, dict=dict944k, dict_weak=dict944k_weak, steps=args.steps, weak_flag=args.weak, query_embed=args.query_embed) log.info( "Meta-learner: %d inner steps, %f inner learning rate, " "%d outer steps, %f outer learning rate, using weak mode:%s, retriever random model:%s" % (args.steps, args.fast_lr, args.batches, args.meta_lr, str( args.weak), str(args.retriever_random)))
def establish_positive_question_documents_pair(MAX_TOKENS): # Dict: word token -> ID. docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1. phrase_pairs, emb_dict = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH, DIC_PATH, MAX_TOKENS) print("Obtained %d phrase pairs with %d uniq words from %s." % (len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)) phrase_pairs_944K = data.load_data_MAML(TRAIN_944K_QUESTION_ANSWER_PATH, max_tokens=MAX_TOKENS) print("Obtained %d phrase pairs from %s." % (len(phrase_pairs_944K), TRAIN_944K_QUESTION_ANSWER_PATH)) # Transform token into index in dictionary. train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict) # train_data = data.group_train_data(train_data) train_data = data.group_train_data_RLTR(train_data) train_data_944K = data.encode_phrase_pairs_RLTR(phrase_pairs_944K, emb_dict) train_data_944K = data.group_train_data_RLTR_for_support(train_data_944K) dict944k = data.get944k(DICT_944K) print("Reading dict944k from %s is done. %d pairs in dict944k." % (DICT_944K, len(dict944k))) dict944k_weak = data.get944k(DICT_944K_WEAK) print("Reading dict944k_weak from %s is done. %d pairs in dict944k_weak" % (DICT_944K_WEAK, len(dict944k_weak))) metaLearner = metalearner.MetaLearner( samples=5, train_data_support_944K=train_data_944K, dict=dict944k, dict_weak=dict944k_weak, steps=5, weak_flag=True) question_doctments_pair_list = {} idx = 0 for temp_batch in data.iterate_batches(train_data, 1): task = temp_batch[0] if len(task) == 2 and 'qid' in task[1]: # print("Task %s is training..." %(str(task[1]['qid']))) # Establish support set. support_set = metaLearner.establish_support_set( task, metaLearner.steps, metaLearner.weak_flag, metaLearner.train_data_support_944K) documents = [] if len(support_set) > 0: for support_sample in support_set: if len(support_sample) == 2 and 'qid' in support_sample[1]: documents.append(support_sample[1]['qid']) else: print('task %s has no support set!' % (str(task[1]['qid']))) documents.append(task[1]['qid']) question_doctments_pair_list[task[1]['qid']] = documents if idx % 100 == 0: print(idx) idx += 1 else: print('task has no qid or len(task)!=2:') print(task) fw = open('../data/auto_QA_data/retriever_question_documents_pair.json', 'w', encoding="UTF-8") fw.writelines( json.dumps(question_doctments_pair_list, indent=1, ensure_ascii=False)) fw.close() print('Writing retriever_question_documents_pair.json is done!')