Esempio n. 1
0
    def create_dataset_global(self, part, file_dir, data_dir, cfg, db):
        datas = self.data[part]
        goals = self.goal[part]
        s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    current_goal = goals[turn_data['others']['session_id']]
                    evaluator.add_goal(current_goal)
                else:
                    next_s_usr.append(s_usr[-1])
                
                if turn_data['others']['change'] and evaluator.cur_domain:
                    if 'final' in current_goal[evaluator.cur_domain]:
                        for key in current_goal[evaluator.cur_domain]['final']:
                            current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key]
                        del(current_goal[evaluator.cur_domain]['final'])
                
                turn_data['user_goal'] = deepcopy(current_goal)
                s_usr.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain)))
                evaluator.add_usr_da(turn_data['trg_user_action'])
                    
                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = turn_data['trg_user_action']
                    next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action']
                    next_turn_data['trg_user_action'] = {}
                    next_turn_data['goal_state'] = datas[idx+1]['final_goal_state']
                    next_s_usr.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain)))
            
            else:
                if turn_data['others']['turn'] != 1:
                    next_s_sys.append(s_sys[-1])

                s_sys.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
                evaluator.add_sys_da(turn_data['trg_sys_action'])
            
                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = {}
                    next_turn_data['sys_action'] = turn_data['trg_sys_action']
                    next_turn_data['trg_sys_action'] = {}
                    next_turn_data['belief_state'] = turn_data['final_belief_state']
                    next_s_sys.append(torch.Tensor(state_vectorize(next_turn_data, cfg, db, True)))
                    reward_g = 20 if evaluator.task_success() else -5
                    r_g.append(reward_g)
                    t.append(1)
                else:
                    reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(evaluator.cur_domain) else -1
                    r_g.append(reward_g)
                    t.append(0)
                
        torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
Esempio n. 2
0
    def create_dataset_sys(self, part, file_dir, data_dir, cfg, db):
        datas = self.data[part]
        goals = self.goal[part]
        s, a, r, next_s, t = [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    evaluator.add_goal(
                        goals[turn_data['others']['session_id']])
                evaluator.add_usr_da(turn_data['trg_user_action'])
                continue
            if turn_data['others']['turn'] != 1:
                next_s.append(s[-1])

            s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
            a.append(
                torch.Tensor(action_vectorize(turn_data['trg_sys_action'],
                                              cfg)))
            evaluator.add_sys_da(turn_data['trg_sys_action'])
            if turn_data['others']['terminal']:
                next_turn_data = deepcopy(turn_data)
                next_turn_data['others']['turn'] = -1
                next_turn_data['user_action'] = {}
                next_turn_data['sys_action'] = turn_data['trg_sys_action']
                next_turn_data['trg_sys_action'] = {}
                next_turn_data['belief_state'] = turn_data[
                    'final_belief_state']
                next_s.append(
                    torch.Tensor(state_vectorize(next_turn_data, cfg, db,
                                                 True)))
                reward = 20 if evaluator.task_success(False) else -5
                r.append(reward)
                t.append(1)
            else:
                reward = 0
                if evaluator.cur_domain:
                    for slot, value in turn_data['belief_state'][
                            evaluator.cur_domain].items():
                        if value == '?':
                            for da in turn_data['trg_sys_action']:
                                d, i, k, p = da.split('-')
                                if i in [
                                        'inform', 'recommend', 'offerbook',
                                        'offerbooked'
                                ] and k == slot:
                                    break
                            else:
                                # not answer request
                                reward -= 1
                if not turn_data['trg_sys_action']:
                    reward -= 5
                r.append(reward)
                t.append(0)

        torch.save((s, a, r, next_s, t), file_dir)
Esempio n. 3
0
    def create_dataset_global(self, part, file_dir, data_dir, cfg, db):
        """
        创建global数据,这个数据记录了用户侧和系统侧的所有状态以及奖励
        """
        datas = self.data[part]
        goals = self.goal[part]
        s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir, cfg.d)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    current_goal = goals[turn_data['others']['session_id']]
                    evaluator.add_goal(current_goal)
                else:
                    next_s_usr.append(s_usr[-1])

                # 当用户目标无法满足时,切换用户目标
                if turn_data['others']['change'] and evaluator.cur_domain:
                    if 'final' in current_goal[evaluator.cur_domain]:
                        for key in current_goal[evaluator.cur_domain]['final']:
                            current_goal[
                                evaluator.cur_domain][key] = current_goal[
                                    evaluator.cur_domain]['final'][key]
                        del (current_goal[evaluator.cur_domain]['final'])
                turn_data['user_goal'] = deepcopy(current_goal)

                s_usr.append(
                    torch.Tensor(
                        state_vectorize_user(turn_data, cfg,
                                             evaluator.cur_domain)))
                evaluator.add_usr_da(turn_data['trg_user_action'])

                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = turn_data[
                        'trg_user_action']
                    next_turn_data['sys_action'] = datas[idx +
                                                         1]['trg_sys_action']
                    next_turn_data['trg_user_action'] = {}
                    next_turn_data['goal_state'] = datas[idx +
                                                         1]['final_goal_state']
                    next_s_usr.append(
                        torch.Tensor(
                            state_vectorize_user(next_turn_data, cfg,
                                                 evaluator.cur_domain)))

            else:
                if turn_data['others']['turn'] != 1:
                    next_s_sys.append(s_sys[-1])

                s_sys.append(
                    torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
                evaluator.add_sys_da(turn_data['trg_sys_action'])

                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = {}
                    next_turn_data['sys_action'] = turn_data['trg_sys_action']
                    next_turn_data['trg_sys_action'] = {}
                    next_turn_data['belief_state'] = turn_data[
                        'final_belief_state']
                    next_s_sys.append(
                        torch.Tensor(
                            state_vectorize(next_turn_data, cfg, db, True)))
                    # 由于多轮对话系统,默认最终都是系统说结束语,因此通过系统判断任务是否成功作为整体的奖励
                    reward_g = 20 if evaluator.task_success() else -5
                    r_g.append(reward_g)
                    t.append(1)
                else:
                    # 增加domain_success的奖励,其他的则每增加一轮减少一点损失,用于缩短轮数 todo 什么是 domain_success
                    reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(
                        evaluator.cur_domain) else -1
                    r_g.append(reward_g)
                    t.append(0)

        torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
Esempio n. 4
0
    def create_dataset_sys(self, part, file_dir, data_dir, cfg, db):
        """
        创建sys的训练数据
        """
        datas = self.data[part]
        goals = self.goal[part]
        # 系统状态+系统动作+回报+上一轮系统状态+末轮标志位
        s, a, r, next_s, t = [], [], [], [], []
        # evaluator 全称记录数据
        evaluator = MultiWozEvaluator(data_dir, cfg.d)
        for idx, turn_data in enumerate(datas):
            # user
            # 用户侧并没有做数据的更新操作
            if turn_data['others']['turn'] % 2 == 0:
                # 首轮对话加载用户目标
                if turn_data['others']['turn'] == 0:
                    evaluator.add_goal(
                        goals[turn_data['others']['session_id']])
                #
                evaluator.add_usr_da(turn_data['trg_user_action'])
                continue

            # 错位了,确实表示的下一轮状态
            if turn_data['others']['turn'] != 1:
                next_s.append(s[-1])

            # 将当前数据转化为状态向量
            s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
            # 将当前动作转化为动作向量
            a.append(
                torch.Tensor(action_vectorize(turn_data['trg_sys_action'],
                                              cfg)))
            evaluator.add_sys_da(turn_data['trg_sys_action'])
            if turn_data['others']['terminal']:
                # 结束轮
                next_turn_data = deepcopy(turn_data)
                next_turn_data['others']['turn'] = -1
                next_turn_data['user_action'] = {}
                next_turn_data['sys_action'] = turn_data['trg_sys_action']
                next_turn_data['trg_sys_action'] = {}
                next_turn_data['belief_state'] = turn_data[
                    'final_belief_state']
                # 统计next_s
                next_s.append(
                    torch.Tensor(state_vectorize(next_turn_data, cfg, db,
                                                 True)))
                # 统计奖励, 对于系统动作,判决任务是否完成作为最终奖励依据,
                # 系统是否完成了真实用户动作所提出的订阅请求,且系统是否回答了真实用户动作所咨询的所有问题
                reward = 20 if evaluator.task_success(False) else -5
                r.append(reward)
                # 结束标志位
                t.append(1)
            else:
                reward = 0
                if evaluator.cur_domain:
                    for slot, value in turn_data['belief_state'][
                            evaluator.cur_domain].items():
                        if value == '?':
                            for da in turn_data['trg_sys_action']:
                                d, i, k, p = da.split('-')
                                if i in [
                                        'inform', 'recommend', 'offerbook',
                                        'offerbooked'
                                ] and k == slot:
                                    break
                            else:
                                # not answer request
                                # 没有完成对belief_state中的提问,奖励减一
                                reward -= 1
                if not turn_data['trg_sys_action']:
                    # 本轮没有回复奖励减五
                    reward -= 5
                r.append(reward)
                t.append(0)

        torch.save((s, a, r, next_s, t), file_dir)