def initialize_document_embedding(int_flag=True, w2v=300, file_path=''): device = 'cuda' # Dict: word token -> ID. if not int_flag: emb_dict = data.load_dict(DIC_PATH=DIC_PATH) else: emb_dict = data.load_dict(DIC_PATH=DIC_PATH_INT) ordered_docID_doc_list = data.get_ordered_docID_document( ORDERED_QID_QUESTION_DICT) docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} net = retriever_module.RetrieverModel(emb_size=w2v, dict_size=len(docID_dict), EMBED_FLAG=False, device='cuda').to('cuda') net.cuda() net.zero_grad() # temp_param_dict = get_net_parameter(net) # Get trained wording embeddings. path = file_path net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1) # Add padding vector. doc_embedding_list.append([0.0] * model.EMBEDDING_DIM) doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda() net.document_emb.weight.data = doc_embedding_tensor.clone().detach() # temp_param_dict1 = get_net_parameter(net) MAP_for_queries = 1.0 epoch = 0 isExists = os.path.exists(SAVES_DIR) if not isExists: os.makedirs(SAVES_DIR) # os.makedirs(SAVES_DIR, exist_ok=True) # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries))) torch.save( net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
def initialize_document_embedding(): device = 'cuda' # Dict: word token -> ID. emb_dict = data.load_dict(DIC_PATH=DIC_PATH) ordered_docID_doc_list = data.get_ordered_docID_document( ORDERED_QID_QUESTION_DICT) docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} net = retriever_module.RetrieverModel(emb_size=50, dict_size=len(docID_dict), EMBED_FLAG=False, device='cuda').to('cuda') net.cuda() net.zero_grad() # temp_param_dict = get_net_parameter(net) # Get trained wording embeddings. path = '../data/saves/maml_batch8_att=0_newdata2k_1storder_1task/epoch_002_0.394_0.796.dat' net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1) # Add padding vector. doc_embedding_list.append([0.0] * model.EMBEDDING_DIM) doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda() net.document_emb.weight.data = doc_embedding_tensor.clone().detach() # temp_param_dict1 = get_net_parameter(net) MAP_for_queries = 1.0 epoch = 0 os.makedirs(SAVES_DIR, exist_ok=True) # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries))) torch.save( net.state_dict(), os.path.join(SAVES_DIR, "epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
action='store_true', help="Enable sampling mode") prog_args = parser.parse_args() conf = configparser.ConfigParser() if not conf.read(os.path.expanduser(prog_args.config)): log.error("Configuration file %s not found", prog_args.config) sys.exit() emb_dict = data.load_emb_dict(os.path.dirname(prog_args.model)) log.info("Loaded embedded dict with %d entries", len(emb_dict)) rev_emb_dict = {idx: word for word, idx in emb_dict.items()} end_token = emb_dict[data.END_TOKEN] net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE) net.load_state_dict(torch.load(prog_args.model)) def bot_func(bot, update, args): text = " ".join(args) words = utils.tokenize(text) seq_1 = data.encode_words(words, emb_dict) input_seq = model.pack_input(seq_1, net.emb) enc = net.encode(input_seq) if prog_args.sample: _, tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS, stop_at_token=end_token) else:
) if (args.query_embed): log.info( "Using the sum of word embedding to represent the questions during the training..." ) else: log.info( "Using the document_emb which is stored in the retriever model to represent the questions..." ) # Index -> word. rev_emb_dict = {idx: word for word, idx in emb_dict.items()} # PhraseModel.__init__() to establish a LSTM model. net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=args.lstm, ATT_FLAG=args.att, EMBED_FLAG=args.embed_grad).to(device) # Using cuda. net.cuda() log.info("Model: %s", net) # Load the pre-trained seq2seq model. net.load_state_dict(torch.load(args.load)) # print("Pre-trained network params") # for name, param in net.named_parameters(): # print(name, param.shape) log.info("Model loaded from %s, continue training in MAML-Reptile mode...", args.load) if (args.adaptive): log.info("Using adaptive reward to train the REINFORCE model...")
log.info("Train the SEQ2SEQ model without attention mechanism...") if args.lstm: log.info("Using LSTM mechanism to train the SEQ2SEQ model...") else: log.info("Using RNN mechanism to train the SEQ2SEQ model...") if args.MonteCarlo: log.info("Using Monte Carlo algorithm for Policy Gradient...") if args.NSM: log.info("Using Neural Symbolic Machine algorithm for RL...") # Index -> word. rev_emb_dict = {idx: word for word, idx in emb_dict.items()} # PhraseModel.__init__() to establish a LSTM model. net = model.PhraseModel(emb_size=args.word_dimension, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=args.lstm, ATT_FLAG=args.att).to(device) # Using cuda. net.cuda() log.info("Model: %s", net) writer = SummaryWriter(comment="-" + args.name) # Load the pre-trained seq2seq model. net.load_state_dict(torch.load(args.load)) log.info("Model loaded from %s, continue training in RL mode...", args.load) if (args.adaptive): log.info("Using adaptive reward to train the REINFORCE model...") else: log.info("Using 0-1 sparse reward to train the REINFORCE model...")
def retriever_training(epochs, RETRIEVER_EMBED_FLAG=True, query_embedding=True): ''' One instance of the retriever training samples: query_index = [800000, 0, 2, 100000, 400000, 600000] document_range = [(700000, 944000), (1, 10), (10, 300000), (10, 300000), (300000, 500000), (500000, 700000)] positive_document_list = [[700001-700000, 700002-700000, 900000-700000, 910000-700000, 944000-2-700000], [2, 3], [13009-10, 34555-10, 234-10, 6789-10, 300000-1-10], [11-10, 16-10, 111111-10, 222222-10, 222223-10], [320000-300000, 330000-300000, 340000-300000, 350000-300000, 360000-300000], [600007-500000, 610007-500000, 620007-500000, 630007-500000, 690007-500000]]''' retriever_path = '../data/saves/retriever/initial_epoch_000_1.000.dat' device = 'cuda' learning_rate = 0.01 docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} training_samples = data.load_json(TRAINING_SAMPLE_DICT) net = retriever_module.RetrieverModel(emb_size=50, dict_size=len(docID_dict), EMBED_FLAG=RETRIEVER_EMBED_FLAG, device=device).to(device) net.load_state_dict(torch.load(retriever_path)) net.zero_grad() # temp_param_dict = get_net_parameter(net) # retriever_optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate) # retriever_optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate, eps=1e-3) retriever_optimizer = adabound.AdaBound(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-3, final_lr=0.1) # temp_param_dict = get_net_parameter(net) emb_dict = None net1 = None qid_question_pair = {} if query_embedding: emb_dict = data.load_dict(DIC_PATH=DIC_PATH) # Get trained wording embeddings. path = RETRIEVER_PATH net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) qid_question_pair = data.get_qid_question_pairs( ORDERED_QID_QUESTION_DICT) max_value = MAX_MAP MAP_for_queries = MAX_MAP for i in range(epochs): print('Epoch %d is training......' % (i)) # count= 0 for key, value in training_samples.items(): retriever_optimizer.zero_grad() net.zero_grad() if query_embedding: if key in qid_question_pair: question_tokens = qid_question_pair[key] else: print("ERROR! NO SUCH QUESTION: %s!" % (str(key))) continue query_tensor = data.get_question_embedding( question_tokens, emb_dict, net1) else: query_tensor = torch.tensor(net.pack_input( value['query_index']).tolist(), requires_grad=False).cuda() document_range = (value['document_range'][0], value['document_range'][1]) logsoftmax_output = net(query_tensor, document_range)[0] logsoftmax_output = logsoftmax_output.cuda() positive_document_list = [ k - value['document_range'][0] for k in value['positive_document_list'] ] possitive_logsoftmax_output = torch.stack( [logsoftmax_output[k] for k in positive_document_list]) loss_policy_v = -possitive_logsoftmax_output.mean() loss_policy_v = loss_policy_v.cuda() loss_policy_v.backward() retriever_optimizer.step() # temp_param_dict = get_net_parameter(net) # if count%100==0: # print(' Epoch %d, %d samples have been trained.' %(i, count)) # count+=1 # Record trained parameters. if i % 1 == 0: MAP_list = [] for j in range(int(len(training_samples) / 40)): random.seed(datetime.now()) key, value = random.choice(list(training_samples.items())) if query_embedding: question_tokens = qid_question_pair[key] query_tensor = data.get_question_embedding( question_tokens, emb_dict, net1) else: query_tensor = torch.tensor(net.pack_input( value['query_index']).tolist(), requires_grad=False).cuda() document_range = (value['document_range'][0], value['document_range'][1]) logsoftmax_output = net(query_tensor, document_range)[0] order = net.calculate_rank(logsoftmax_output.tolist()) positive_document_list = [ k - value['document_range'][0] for k in value['positive_document_list'] ] orders = [order[k] for k in positive_document_list] MAP = mean(orders) MAP_list.append(MAP) MAP_for_queries = mean(MAP_list) print('------------------------------------------------------') print('Epoch %d, MAP_for_queries: %f' % (i, MAP_for_queries)) print('------------------------------------------------------') if MAP_for_queries < max_value: max_value = MAP_for_queries if MAP_for_queries < 500: output_str = "AdaBound" if RETRIEVER_EMBED_FLAG: output_str += "_DocEmbed" if query_embedding: output_str += "_QueryEmbed" torch.save( net.state_dict(), os.path.join( SAVES_DIR, output_str + "_epoch_%03d_%.3f.dat" % (i, MAP_for_queries))) print('Save the state_dict: %s' % (str(i) + ' ' + str(MAP_for_queries))) if MAP_for_queries < 10: break