class FMModule(): def __init__(self, config, load_dataset): self.FM_model_path = config.FM_model_path self.data_dir = config.data_dir self.user_num = config.user_num self.business_num = config.business_num self.BeliefTrackerModule = BeliefTrackerModule(BeliefTrackerConfig(), False) self.BeliefTrackerModule.load_all_model() self.tracker_idx_list = config.tracker_idx_list self.FM = TorchFM(config) self.batch_size = config.batch_size self.use_gpu = config.use_gpu self.epoch_num = config.epoch_num self.save_step = config.save_step self.lr = config.lr self.weight_decay = config.weight_decay self.optimizer = optim.Adam(self.FM.parameters(), lr=self.lr,) self.train_log_path = config.train_log_path self.criterion = nn.MSELoss(reduce=True, size_average=True) self.FM_all_data = None self.train_data_list = None self.eva_data_list = None self.test_data_list = None if load_dataset: with open(config.data_dir, 'r') as f: self.FM_all_data = json.load(f) def save_model(self): torch.save(self.FM.state_dict(), "/".join([self.FM_model_path, "FM_model"])) def load_model(self): self.FM.load_state_dict(torch.load("/".join([self.FM_model_path, "FM_model"]))) def prepare_data(self): random.shuffle(self.FM_all_data) data_num = len(self.FM_all_data) // 10 eva_list = self.FM_all_data[:data_num] test_list = self.FM_all_data[data_num:2*data_num] train_list = self.FM_all_data[2*data_num:] # def make_batch_data(data_list, batch_size): # batch_data_list = [] # for index in range(len(data_list)//batch_size): # left_index = index * batch_size # right_index = (index+1) * batch_size # batch_data = BatchData(data_list[left_index:right_index], \ # self.user_num, self.business_num, self.use_gpu) # batch_data_list.append(batch_data) # return batch_data_list # self.train_data_list = make_batch_data(train_list, self.batch_size) # self.eva_data_list = make_batch_data(eva_list, self.batch_size) # self.test_data_list = make_batch_data(test_list, self.batch_size) self.train_data_list = train_list self.eva_data_list = eva_list self.test_data_list = test_list def batch_data_input(self, batchdata, is_train): if is_train: self.FM.train() else: self.FM.eval() tracker_output = self.BeliefTrackerModule.use_tracker(batchdata.sen_list, self.tracker_idx_list) input_feature = torch.cat([batchdata.user_business_feature, tracker_output], -1) output = self.FM(input_feature) return output def train_model(self): print("prepare data...") self.prepare_data() time_str = datetime.datetime.now().isoformat() print("{} start training FM ...".format(time_str)) print("lr: {:g}, batch_size: {}, weight_decay: {:g}"\ .format(self.lr, self.batch_size, self.weight_decay)) step = 0 min_loss = 1e10 train_log_name = "FM_" + time_str + ".txt" with open("/".join([self.train_log_path, train_log_name]), "a") as f: f.write("{} start training FM ...\n".format(time_str)) f.write("lr: {:g}, batch_size: {}, weight_decay: {:g}\n"\ .format(self.lr, self.batch_size, self.weight_decay)) for _ in range(self.epoch_num): print("epoch: ", _) f.write("epoch: "+str(_)+'\n') # train_data_index_list = [_ for _ in range(len(self.train_data_list))] # random.shuffle(train_data_index_list) random.shuffle(self.train_data_list) # for train_data_index in train_data_index_list: for index in range(len(self.train_data_list)//self.batch_size): left_index = index * self.batch_size right_index = (index+1) * self.batch_size t_batch_data = BatchData(self.train_data_list[left_index:right_index], \ self.user_num, self.business_num, self.use_gpu) output = self.batch_data_input(t_batch_data, True) loss = self.criterion(output, t_batch_data.score_list) self.optimizer.zero_grad() loss.backward() self.optimizer.step() time_str = datetime.datetime.now().isoformat() print("{}: step {}, loss {:g}".format(time_str, step, loss)) f.write("{}: step {}, loss {:g}\n".format(time_str, step, loss)) step += 1 if step % self.save_step == 0: print("start evaluation") # for e_batch_data in self.eva_data_list: cur_loss = 0. for index_ in range(len(self.eva_data_list)//self.batch_size): left_index = index_ * self.batch_size right_index = (index_+1) * self.batch_size e_batch_data = BatchData(self.eva_data_list[left_index:right_index], \ self.user_num, self.business_num, self.use_gpu) output = self.batch_data_input(e_batch_data, False) pre_score_list = output.view(-1).tolist() gt_score_list = e_batch_data.score_list.tolist() for pre, gt in zip(pre_score_list, gt_score_list): cur_loss += (pre - gt) ** 2 print("{}: step {}, score loss {:g}".format(time_str, step, cur_loss)) f.write("----------\n") f.write("Eva {}: step {}, score loss {:g}\n".format(time_str, step, cur_loss)) f.write("----------\n") if cur_loss < min_loss: min_loss = cur_loss self.save_model() f.write("----------\n") f.write("train over\n") f.write("min loss: {:g}\n".format(min_loss)) def eva_model(self): print("prepare data...") self.prepare_data() self.load_model() time_str = datetime.datetime.now().isoformat() print("{} start test FM ...".format(time_str)) test_loss = 0. # for t_batch_data in self.test_data_list: for index in range(len(self.test_data_list)//self.batch_size): left_index = index * self.batch_size right_index = (index+1) * self.batch_size t_batch_data = BatchData(self.test_data_list[left_index:right_index], \ self.user_num, self.business_num, self.use_gpu) output = self.batch_data_input(t_batch_data, False) pre_score_list = output.view(-1).tolist() gt_score_list = t_batch_data.score_list.tolist() for pre, gt in zip(pre_score_list, gt_score_list): test_loss += (pre - gt) ** 2 test_log_name = "FM_test_" + time_str + ".txt" with open("/".join([self.train_log_path, test_log_name]), "a") as f: f.write("test_loss: {:g}\n".format(test_loss)) def use_FM(self, user_id, business_list, current_state): user_list = torch.tensor([user_id] * len(business_list)) business_list = torch.tensor(business_list) business_list_ = torch.tensor(business_list) user_list = nn.functional.one_hot(user_list, self.user_num).float() business_list = nn.functional.one_hot(business_list, self.business_num).float() current_state_list = torch.stack([current_state for _ in range(len(business_list))], 0) input_list = torch.cat([user_list, business_list, current_state_list], -1) score_list = self.FM(input_list) _, sorted_indices = score_list.sort(descending=True) return business_list_[sorted_indices].tolist()
from agents.AgentRule import AgentRule from agents.AgentRuleConfig import AgentRuleConfig from agents.AgentRL import AgentRL from agents.AgentRLConfig import AgentRLConfig from usersim.UserSim import UserSim from usersim.UserSimConfig import UserSimConfig from belieftracker.BeliefTrackerModule import BeliefTrackerModule from belieftracker.BeliefTrackerConfig import BeliefTrackerConfig rec = RecModule(RecModuleConfig()) #agent = AgentRule(AgentRuleConfig()) agent = AgentRL(AgentRLConfig()) agent.load_model(False) user = UserSim(UserSimConfig()) bftracker = BeliefTrackerModule(BeliefTrackerConfig(), True) bftracker.load_all_model() DM = DialogueManager(DialogueManagerConfig(), rec, agent, user, bftracker) with open('../data/RL_data/RL_data.pkl', 'rb') as f: train_data, dev_data, test_data = pickle.load(f) data_list = test_data # TestNum = 10 # for Test in range(TestNum): # print("---------------------------") # print("Simulated Conversation {}:".format(Test)) # data = random.choice(data_list) # print("UserName: {}, BussinessName: {}".format(data[0], data[1])) # DM.initialize_episode(data[0], data[1]) # IsOver = False # while not IsOver:
class RLEnv: def __init__(self, config): self.K = config.K self.C = config.C self.rc = config.rc self.rq = config.rq self.turn_limit = config.turn_limit self.tracker_idx_list = config.tracker_idx_list self.rec_action_facet = config.rec_action_facet self.rec = RecModule(RecModuleConfig()) self.user = UserSim(UserSimConfig()) self.bftracker = BeliefTrackerModule(BeliefTrackerConfig(), True) self.bftracker.load_all_model() self.turn_count = None self.user_name = None self.business_name = None self.user_utt_list = None self.dialogue_state = None self.silence = True def initialize_episode(self, user_name, business_name, silence): if not silence: print("---------------------------") print("Simulated Conversation Start") print("UserName: {}, BussinessName: {}".format(user_name, business_name)) self.user_name = user_name self.business_name = business_name self.turn_count = 0 self.user.init_episode(user_name, business_name) self.user_utt_list = [] self.dialogue_state = None self.silence = silence return self.dialogue_state def step(self, request_facet, unknown_facet): self.turn_count += 1 if not self.silence: print("Turn %d agent: request %s" % (self.turn_count, request_facet)) if request_facet == self.rec_action_facet: rec_reward, rec_rank, rec_list = self.recommend(unknown_facet) if not self.silence: print("Simulated Conversation Over: Success, Target {}/{}".format(rec_rank, len(rec_list))) return True, None, rec_reward if self.turn_count == self.turn_limit: if not self.silence: print("Simulated Conversation Over: Failed") return True, None, self.rq user_nl = self.user_turn(request_facet) self.get_dialogue_state() return False, self.dialogue_state, self.rc def user_turn(self, request_facet): user_nl = self.user.next_turn(request_facet) self.user_utt_list.append(user_nl) if not self.silence: print("Turn %d user: %s" % (self.turn_count, user_nl)) return user_nl def get_dialogue_state(self): self.dialogue_state = self.bftracker.use_tracker_from_nl(self.user_utt_list, self.tracker_idx_list) def recommend(self, unknown_facet): business_list = self.rec.recommend_bussiness(self.user_name, self.dialogue_state, unknown_facet) business_list = business_list[:self.K] for rank_index, business_name in enumerate(business_list): rank_id = rank_index + 1 if business_name == self.business_name: rec_reward = self.C * (self.K - rank_id + 1) / self.K return rec_reward, rank_id, business_list rec_reward = self.rq return rec_reward, -1, business_list