def __init__(self, config, load_dataset):
        self.FM_model_path = config.FM_model_path
        self.data_dir = config.data_dir
        self.user_num = config.user_num
        self.business_num = config.business_num
        self.BeliefTrackerModule = BeliefTrackerModule(BeliefTrackerConfig(), False)
        self.BeliefTrackerModule.load_all_model()
        self.tracker_idx_list = config.tracker_idx_list
        self.FM = TorchFM(config)
        self.batch_size = config.batch_size
        self.use_gpu = config.use_gpu
        self.epoch_num = config.epoch_num
        self.save_step = config.save_step
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.optimizer = optim.Adam(self.FM.parameters(), lr=self.lr,)
        self.train_log_path = config.train_log_path
        self.criterion = nn.MSELoss(reduce=True, size_average=True)

        self.FM_all_data = None
        self.train_data_list = None
        self.eva_data_list = None
        self.test_data_list = None
        if load_dataset:
            with open(config.data_dir, 'r') as f:
                self.FM_all_data = json.load(f)
Exemple #2
0
 def __init__(self, config):
     self.K = config.K
     self.C = config.C
     self.rc = config.rc
     self.rq = config.rq
     self.turn_limit = config.turn_limit
     self.tracker_idx_list = config.tracker_idx_list
     self.rec_action_facet = config.rec_action_facet
     self.rec = RecModule(RecModuleConfig())
     self.user = UserSim(UserSimConfig())
     self.bftracker = BeliefTrackerModule(BeliefTrackerConfig(), True)
     self.bftracker.load_all_model()
     self.turn_count = None
     self.user_name = None
     self.business_name = None
     self.user_utt_list = None
     self.dialogue_state = None
     self.silence = True
Exemple #3
0
from recommendersystem.RecModuleConfig import RecModuleConfig
from agents.AgentRule import AgentRule
from agents.AgentRuleConfig import AgentRuleConfig
from agents.AgentRL import AgentRL
from agents.AgentRLConfig import AgentRLConfig
from usersim.UserSim import UserSim
from usersim.UserSimConfig import UserSimConfig
from belieftracker.BeliefTrackerModule import BeliefTrackerModule
from belieftracker.BeliefTrackerConfig import BeliefTrackerConfig

rec = RecModule(RecModuleConfig())
#agent = AgentRule(AgentRuleConfig())
agent = AgentRL(AgentRLConfig())
agent.load_model(False)
user = UserSim(UserSimConfig())
bftracker = BeliefTrackerModule(BeliefTrackerConfig(), True)
bftracker.load_all_model()
DM = DialogueManager(DialogueManagerConfig(), rec, agent, user, bftracker)

with open('../data/RL_data/RL_data.pkl', 'rb') as f:
    train_data, dev_data, test_data = pickle.load(f)
data_list = test_data

# TestNum = 10
# for Test in range(TestNum):
#     print("---------------------------")
#     print("Simulated Conversation {}:".format(Test))
#     data = random.choice(data_list)
#     print("UserName: {}, BussinessName: {}".format(data[0], data[1]))
#     DM.initialize_episode(data[0], data[1])
#     IsOver = False
class FMModule():
    def __init__(self, config, load_dataset):
        self.FM_model_path = config.FM_model_path
        self.data_dir = config.data_dir
        self.user_num = config.user_num
        self.business_num = config.business_num
        self.BeliefTrackerModule = BeliefTrackerModule(BeliefTrackerConfig(), False)
        self.BeliefTrackerModule.load_all_model()
        self.tracker_idx_list = config.tracker_idx_list
        self.FM = TorchFM(config)
        self.batch_size = config.batch_size
        self.use_gpu = config.use_gpu
        self.epoch_num = config.epoch_num
        self.save_step = config.save_step
        self.lr = config.lr
        self.weight_decay = config.weight_decay
        self.optimizer = optim.Adam(self.FM.parameters(), lr=self.lr,)
        self.train_log_path = config.train_log_path
        self.criterion = nn.MSELoss(reduce=True, size_average=True)

        self.FM_all_data = None
        self.train_data_list = None
        self.eva_data_list = None
        self.test_data_list = None
        if load_dataset:
            with open(config.data_dir, 'r') as f:
                self.FM_all_data = json.load(f)

    def save_model(self):
        torch.save(self.FM.state_dict(), "/".join([self.FM_model_path, "FM_model"]))

    def load_model(self):
        self.FM.load_state_dict(torch.load("/".join([self.FM_model_path, "FM_model"])))

    def prepare_data(self):
        random.shuffle(self.FM_all_data)
        data_num = len(self.FM_all_data) // 10
        eva_list = self.FM_all_data[:data_num]
        test_list = self.FM_all_data[data_num:2*data_num]
        train_list = self.FM_all_data[2*data_num:]

        # def make_batch_data(data_list, batch_size):
        #     batch_data_list = []
        #     for index in range(len(data_list)//batch_size):
        #         left_index = index * batch_size
        #         right_index = (index+1) * batch_size
        #         batch_data = BatchData(data_list[left_index:right_index], \
        #                                 self.user_num, self.business_num, self.use_gpu)
        #         batch_data_list.append(batch_data)
        #     return batch_data_list

        # self.train_data_list = make_batch_data(train_list, self.batch_size)
        # self.eva_data_list = make_batch_data(eva_list, self.batch_size)
        # self.test_data_list = make_batch_data(test_list, self.batch_size)      
        self.train_data_list = train_list
        self.eva_data_list = eva_list
        self.test_data_list = test_list

    def batch_data_input(self, batchdata, is_train):
        if is_train:
            self.FM.train()
        else:
            self.FM.eval()

        tracker_output = self.BeliefTrackerModule.use_tracker(batchdata.sen_list, self.tracker_idx_list)
        input_feature = torch.cat([batchdata.user_business_feature, tracker_output], -1)
        output = self.FM(input_feature)
        return output

    def train_model(self):
        print("prepare data...")
        self.prepare_data()
        time_str = datetime.datetime.now().isoformat()
        print("{} start training FM ...".format(time_str))
        print("lr: {:g}, batch_size: {}, weight_decay: {:g}"\
                .format(self.lr, self.batch_size, self.weight_decay))
        step = 0
        min_loss = 1e10
        train_log_name = "FM_" + time_str + ".txt"
        with open("/".join([self.train_log_path, train_log_name]), "a") as f:
            f.write("{} start training FM ...\n".format(time_str))
            f.write("lr: {:g}, batch_size: {}, weight_decay: {:g}\n"\
                .format(self.lr, self.batch_size, self.weight_decay))
            for _ in range(self.epoch_num):
                print("epoch: ", _)
                f.write("epoch: "+str(_)+'\n')
                # train_data_index_list = [_ for _ in range(len(self.train_data_list))]
                # random.shuffle(train_data_index_list)
                random.shuffle(self.train_data_list)
                # for train_data_index in train_data_index_list:
                for index in range(len(self.train_data_list)//self.batch_size):
                    left_index = index * self.batch_size
                    right_index = (index+1) * self.batch_size
                    t_batch_data = BatchData(self.train_data_list[left_index:right_index], \
                                            self.user_num, self.business_num, self.use_gpu)

                    output = self.batch_data_input(t_batch_data, True)
                    loss = self.criterion(output, t_batch_data.score_list)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step() 

                    time_str = datetime.datetime.now().isoformat()
                    print("{}: step {}, loss {:g}".format(time_str, step, loss))
                    f.write("{}: step {}, loss {:g}\n".format(time_str, step, loss))

                    step += 1
                    if step % self.save_step == 0:
                        print("start evaluation")
                        # for e_batch_data in self.eva_data_list:
                        cur_loss = 0.
                        for index_ in range(len(self.eva_data_list)//self.batch_size):
                            left_index = index_ * self.batch_size
                            right_index = (index_+1) * self.batch_size
                            e_batch_data = BatchData(self.eva_data_list[left_index:right_index], \
                                                    self.user_num, self.business_num, self.use_gpu)
                            output = self.batch_data_input(e_batch_data, False)
                            pre_score_list = output.view(-1).tolist()
                            gt_score_list = e_batch_data.score_list.tolist()
                            for pre, gt in zip(pre_score_list, gt_score_list):
                                cur_loss += (pre - gt) ** 2
                        print("{}: step {}, score loss {:g}".format(time_str, step, cur_loss))
                        f.write("----------\n")
                        f.write("Eva {}: step {}, score loss {:g}\n".format(time_str, step, cur_loss))
                        f.write("----------\n")
                        if cur_loss < min_loss:
                            min_loss = cur_loss
                            self.save_model()
            f.write("----------\n")
            f.write("train over\n")
            f.write("min loss: {:g}\n".format(min_loss))

    def eva_model(self):
        print("prepare data...")
        self.prepare_data()
        self.load_model()
        time_str = datetime.datetime.now().isoformat()
        print("{} start test FM ...".format(time_str))
        test_loss = 0.
        # for t_batch_data in self.test_data_list:
        for index in range(len(self.test_data_list)//self.batch_size):
            left_index = index * self.batch_size
            right_index = (index+1) * self.batch_size
            t_batch_data = BatchData(self.test_data_list[left_index:right_index], \
                                    self.user_num, self.business_num, self.use_gpu)
            output = self.batch_data_input(t_batch_data, False)
            pre_score_list = output.view(-1).tolist()
            gt_score_list = t_batch_data.score_list.tolist()
            for pre, gt in zip(pre_score_list, gt_score_list):
                test_loss += (pre - gt) ** 2        
        test_log_name = "FM_test_" + time_str + ".txt"
        with open("/".join([self.train_log_path, test_log_name]), "a") as f:
            f.write("test_loss: {:g}\n".format(test_loss))

    def use_FM(self, user_id, business_list, current_state):
        user_list = torch.tensor([user_id] * len(business_list))
        business_list = torch.tensor(business_list)
        business_list_ = torch.tensor(business_list)
        user_list = nn.functional.one_hot(user_list, self.user_num).float()
        business_list = nn.functional.one_hot(business_list, self.business_num).float()
        current_state_list = torch.stack([current_state for _ in range(len(business_list))], 0)
        input_list = torch.cat([user_list, business_list, current_state_list], -1)
        score_list = self.FM(input_list)
        _, sorted_indices = score_list.sort(descending=True)
        return business_list_[sorted_indices].tolist()
Exemple #5
0
class RLEnv:
    def __init__(self, config):
        self.K = config.K
        self.C = config.C
        self.rc = config.rc
        self.rq = config.rq
        self.turn_limit = config.turn_limit
        self.tracker_idx_list = config.tracker_idx_list
        self.rec_action_facet = config.rec_action_facet
        self.rec = RecModule(RecModuleConfig())
        self.user = UserSim(UserSimConfig())
        self.bftracker = BeliefTrackerModule(BeliefTrackerConfig(), True)
        self.bftracker.load_all_model()
        self.turn_count = None
        self.user_name = None
        self.business_name = None
        self.user_utt_list = None
        self.dialogue_state = None
        self.silence = True

    def initialize_episode(self, user_name, business_name, silence):
        if not silence:
            print("---------------------------")
            print("Simulated Conversation Start")
            print("UserName: {}, BussinessName: {}".format(user_name, business_name))
        self.user_name = user_name
        self.business_name = business_name
        self.turn_count = 0
        self.user.init_episode(user_name, business_name)
        self.user_utt_list = []
        self.dialogue_state = None
        self.silence = silence
        return self.dialogue_state

    def step(self, request_facet, unknown_facet):
        self.turn_count += 1
        if not self.silence:
            print("Turn %d agent: request %s" % (self.turn_count, request_facet))
        if request_facet == self.rec_action_facet:
            rec_reward, rec_rank, rec_list = self.recommend(unknown_facet)
            if not self.silence:
                print("Simulated Conversation Over: Success, Target {}/{}".format(rec_rank, len(rec_list)))
            return True, None, rec_reward
        if self.turn_count == self.turn_limit:
            if not self.silence:
                print("Simulated Conversation Over: Failed")
            return True, None, self.rq
        user_nl = self.user_turn(request_facet)
        self.get_dialogue_state()
        return False, self.dialogue_state, self.rc

    def user_turn(self, request_facet):
        user_nl = self.user.next_turn(request_facet)
        self.user_utt_list.append(user_nl)
        if not self.silence:
            print("Turn %d user: %s" % (self.turn_count, user_nl))
        return user_nl

    def get_dialogue_state(self):
        self.dialogue_state = self.bftracker.use_tracker_from_nl(self.user_utt_list, self.tracker_idx_list)

    def recommend(self, unknown_facet):
        business_list = self.rec.recommend_bussiness(self.user_name, self.dialogue_state, unknown_facet)
        business_list = business_list[:self.K]
        for rank_index, business_name in enumerate(business_list):
            rank_id = rank_index + 1
            if business_name == self.business_name:
                rec_reward = self.C * (self.K - rank_id + 1) / self.K
                return rec_reward, rank_id, business_list
        rec_reward = self.rq
        return rec_reward, -1, business_list