Example #1
0
 def __init__(self, data_dir, config):
     self.time_step = 0
     self.cfg = config
     self.db = DBQuery(data_dir, config)
     self.topic = ''
     self.evaluator = MultiWozEvaluator(data_dir)
     self.lock_evalutor = False
Example #2
0
    def create_dataset_global(self, part, file_dir, data_dir, cfg, db):
        datas = self.data[part]
        goals = self.goal[part]
        s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    current_goal = goals[turn_data['others']['session_id']]
                    evaluator.add_goal(current_goal)
                else:
                    next_s_usr.append(s_usr[-1])
                
                if turn_data['others']['change'] and evaluator.cur_domain:
                    if 'final' in current_goal[evaluator.cur_domain]:
                        for key in current_goal[evaluator.cur_domain]['final']:
                            current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key]
                        del(current_goal[evaluator.cur_domain]['final'])
                
                turn_data['user_goal'] = deepcopy(current_goal)
                s_usr.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain)))
                evaluator.add_usr_da(turn_data['trg_user_action'])
                    
                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = turn_data['trg_user_action']
                    next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action']
                    next_turn_data['trg_user_action'] = {}
                    next_turn_data['goal_state'] = datas[idx+1]['final_goal_state']
                    next_s_usr.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain)))
            
            else:
                if turn_data['others']['turn'] != 1:
                    next_s_sys.append(s_sys[-1])

                s_sys.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
                evaluator.add_sys_da(turn_data['trg_sys_action'])
            
                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = {}
                    next_turn_data['sys_action'] = turn_data['trg_sys_action']
                    next_turn_data['trg_sys_action'] = {}
                    next_turn_data['belief_state'] = turn_data['final_belief_state']
                    next_s_sys.append(torch.Tensor(state_vectorize(next_turn_data, cfg, db, True)))
                    reward_g = 20 if evaluator.task_success() else -5
                    r_g.append(reward_g)
                    t.append(1)
                else:
                    reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(evaluator.cur_domain) else -1
                    r_g.append(reward_g)
                    t.append(0)
                
        torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
Example #3
0
    def create_dataset_sys(self, part, file_dir, data_dir, cfg, db):
        datas = self.data[part]
        goals = self.goal[part]
        s, a, r, next_s, t = [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    evaluator.add_goal(
                        goals[turn_data['others']['session_id']])
                evaluator.add_usr_da(turn_data['trg_user_action'])
                continue
            if turn_data['others']['turn'] != 1:
                next_s.append(s[-1])

            s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
            a.append(
                torch.Tensor(action_vectorize(turn_data['trg_sys_action'],
                                              cfg)))
            evaluator.add_sys_da(turn_data['trg_sys_action'])
            if turn_data['others']['terminal']:
                next_turn_data = deepcopy(turn_data)
                next_turn_data['others']['turn'] = -1
                next_turn_data['user_action'] = {}
                next_turn_data['sys_action'] = turn_data['trg_sys_action']
                next_turn_data['trg_sys_action'] = {}
                next_turn_data['belief_state'] = turn_data[
                    'final_belief_state']
                next_s.append(
                    torch.Tensor(state_vectorize(next_turn_data, cfg, db,
                                                 True)))
                reward = 20 if evaluator.task_success(False) else -5
                r.append(reward)
                t.append(1)
            else:
                reward = 0
                if evaluator.cur_domain:
                    for slot, value in turn_data['belief_state'][
                            evaluator.cur_domain].items():
                        if value == '?':
                            for da in turn_data['trg_sys_action']:
                                d, i, k, p = da.split('-')
                                if i in [
                                        'inform', 'recommend', 'offerbook',
                                        'offerbooked'
                                ] and k == slot:
                                    break
                            else:
                                # not answer request
                                reward -= 1
                if not turn_data['trg_sys_action']:
                    reward -= 5
                r.append(reward)
                t.append(0)

        torch.save((s, a, r, next_s, t), file_dir)
Example #4
0
 def create_dataset_usr(self, part, file_dir, data_dir, cfg, db):
     datas = self.data[part]
     goals = self.goal[part]
     s, a, r, next_s, t = [], [], [], [], []
     evaluator = MultiWozEvaluator(data_dir)
     current_goal = None
     for idx, turn_data in enumerate(datas):
         if turn_data['others']['turn'] % 2 == 1:
             evaluator.add_sys_da(turn_data['trg_sys_action'])
             continue
         
         if turn_data['others']['turn'] == 0:
             current_goal = goals[turn_data['others']['session_id']]
             evaluator.add_goal(current_goal)
         else:
             next_s.append(s[-1])
         if turn_data['others']['change'] and evaluator.cur_domain:
             if 'final' in current_goal[evaluator.cur_domain]:
                 for key in current_goal[evaluator.cur_domain]['final']:
                     current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key]
                 del(current_goal[evaluator.cur_domain]['final'])
         turn_data['user_goal'] = deepcopy(current_goal)
         
         s.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain)))
         a.append(torch.Tensor(action_vectorize_user(turn_data['trg_user_action'], turn_data['others']['terminal'], cfg)))
         evaluator.add_usr_da(turn_data['trg_user_action'])
         if turn_data['others']['terminal']:
             next_turn_data = deepcopy(turn_data)
             next_turn_data['others']['turn'] = -1
             next_turn_data['user_action'] = turn_data['trg_user_action']
             next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action']
             next_turn_data['trg_user_action'] = {}
             next_turn_data['goal_state'] = datas[idx+1]['final_goal_state']
             next_s.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain)))
             reward = 20 if evaluator.inform_F1(ansbysys=False)[1] == 1. else -5
             r.append(reward)
             t.append(1)
         else:
             reward = 0
             if evaluator.cur_domain:
                 for da in turn_data['trg_user_action']:
                     d, i, k = da.split('-')
                     if i == 'request':
                         for slot, value in turn_data['goal_state'][d].items():
                             if value != '?' and slot in turn_data['user_goal'][d]\
                                 and turn_data['user_goal'][d][slot] != '?':
                                 # request before express constraint
                                 reward -= 1
             if not turn_data['trg_user_action']:
                 reward -= 5
             r.append(reward)
             t.append(0)
             
     torch.save((s, a, r, next_s, t), file_dir)
Example #5
0
    def __init__(self, env_cls, args, manager, cfg, process_num, character, pre=False, infer=False):
        """
        专门用于更新预训练模型
        :param env_cls: env class or function, not instance, as we need to create several instance in class.
        :param args:
        :param manager:
        :param cfg:
        :param process_num: process number
        :param character: user or system
        :param pre: set to pretrain mode
        :param infer: set to test mode
        """

        self.process_num = process_num
        self.character = character

        # initialize envs for each process
        self.env_list = []
        for _ in range(process_num):
            self.env_list.append(env_cls())

        # construct policy and value network
        self.policy = MultiDiscretePolicy(cfg, character).to(device=DEVICE)

        if pre:
            self.print_per_batch = args.print_per_batch
            from dbquery import DBQuery
            db = DBQuery(args.data_dir, cfg)
            self.data_train = manager.create_dataset_policy('train', args.batchsz, cfg, db, character)
            self.data_valid = manager.create_dataset_policy('valid', args.batchsz, cfg, db, character)
            self.data_test = manager.create_dataset_policy('test', args.batchsz, cfg, db, character)
            if character == 'sys':
                pos_weight = args.policy_weight_sys * torch.ones([cfg.a_dim]).to(device=DEVICE)
            elif character == 'usr':
                pos_weight = args.policy_weight_usr * torch.ones([cfg.a_dim_usr]).to(device=DEVICE)
            else:
                raise Exception('Unknown character')
            self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        else:
            self.evaluator = MultiWozEvaluator(args.data_dir, cfg.d)

        self.save_dir = args.save_dir + '/' + character if pre else args.save_dir
        self.save_per_epoch = args.save_per_epoch
        self.optim_batchsz = args.batchsz
        self.policy.eval()

        self.gamma = args.gamma
        self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=args.lr_policy, weight_decay=args.weight_decay)
        self.writer = SummaryWriter()
Example #6
0
class StateTracker(object):
    def __init__(self, data_dir, config):
        self.time_step = 0
        self.cfg = config
        self.db = DBQuery(data_dir, config)
        self.topic = ''
        self.evaluator = MultiWozEvaluator(data_dir)
        self.lock_evalutor = False

    def set_rollout(self, rollout):
        if rollout:
            self.save_time_step = self.time_step
            self.save_topic = self.topic
            self.lock_evalutor = True
        else:
            self.time_step = self.save_time_step
            self.save_topic = self.topic
            self.lock_evalutor = False

    def get_entities(self, s, domain):
        origin = s['belief_state'][domain].items()
        constraint = []
        for k, v in origin:
            if v != '?' and k in self.cfg.mapping[domain]:
                constraint.append((self.cfg.mapping[domain][k], v))
        entities = self.db.query(domain, constraint)
        random.shuffle(entities)
        return entities

    def update_belief_sys(self, old_s, a):
        """
        update belief/goal state with sys action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step

        # update sys/user dialog act
        s['sys_action'] = dict()

        # update belief part
        das = [self.cfg.idx2da[idx.item()] for idx in a_index]
        das = [da.split('-') for da in das]
        sorted(das, key=lambda x: x[0])  # sort by domain

        entities = [] if self.topic == '' else self.get_entities(s, self.topic)
        return_flag = False
        for domain, intent, slot, p in das:
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain
                entities = self.get_entities(s, domain)

            da = '-'.join((domain, intent, slot, p))
            if intent == 'request':
                s['sys_action'][da] = '?'
            elif intent in ['nooffer', 'nobook'] and self.topic != '':
                return_flag = True
                if slot in s['belief_state'][self.topic] and s['belief_state'][
                        self.topic][slot] != '?':
                    s['sys_action'][da] = s['belief_state'][self.topic][slot]
                else:
                    s['sys_action'][da] = 'none'
            elif slot == 'choice':
                s['sys_action'][da] = str(len(entities))
            elif slot == 'none':
                s['sys_action'][da] = 'none'
            else:
                num = int(p) - 1
                if self.topic and len(
                        entities) > num and slot in self.cfg.mapping[
                            self.topic]:
                    typ = self.cfg.mapping[self.topic][slot]
                    if typ in entities[num]:
                        s['sys_action'][da] = entities[num][typ]
                    else:
                        s['sys_action'][da] = 'none'
                else:
                    s['sys_action'][da] = 'none'

                if not self.topic:
                    continue
                if intent in [
                        'inform', 'recommend', 'offerbook', 'offerbooked',
                        'book'
                ]:
                    discard(s['belief_state'][self.topic], slot, '?')
                    if slot in s['user_goal'][self.topic] and s['user_goal'][
                            self.topic][slot] == '?':
                        s['goal_state'][self.topic][slot] = s['sys_action'][da]

                # booked
                if intent == 'inform' and slot == 'car':  # taxi
                    if 'booked' not in s['belief_state']['taxi']:
                        s['belief_state']['taxi']['booked'] = 'taxi-booked'
                elif intent in ['offerbooked', 'book'
                                ] and slot == 'ref':  # train
                    if self.topic in ['taxi', 'hospital', 'police']:
                        s['belief_state'][
                            self.topic]['booked'] = f'{self.topic}-booked'
                        s['sys_action'][da] = f'{self.topic}-booked'
                    elif entities:
                        book_domain = entities[0]['ref'].split('-')[0]
                        if 'booked' not in s['belief_state'][
                                book_domain] and entities:
                            s['belief_state'][book_domain][
                                'booked'] = entities[0]['ref']
                            s['sys_action'][da] = entities[0]['ref']

        if return_flag:
            for da in s['user_action']:
                d_usr, i_usr, s_usr = da.split('-')
                if i_usr == 'inform' and d_usr == self.topic:
                    discard(s['belief_state'][d_usr], s_usr)
            reload(s['goal_state'], s['user_goal'], self.topic)

        if not self.lock_evalutor:
            self.evaluator.add_sys_da(s['sys_action'])

        return s

    def update_belief_usr(self, old_s, a):
        """
        update belief/goal state with user action
        """
        s = deepcopy(old_s)
        a_index = torch.nonzero(a)  # get multiple da indices

        self.time_step += 1
        s['others']['turn'] = self.time_step
        s['others']['terminal'] = 1 if (self.cfg.a_dim_usr -
                                        1) in a_index else 0

        # update sys/user dialog act
        s['user_action'] = dict()

        # update belief part
        das = [
            self.cfg.idx2da_u[idx.item()] for idx in a_index
            if idx.item() != self.cfg.a_dim_usr - 1
        ]
        das = [da.split('-') for da in das]
        if s['invisible_domains']:
            for da in das:
                if da[0] == s['next_available_domain']:
                    s['next_available_domain'] = s['invisible_domains'][0]
                    s['invisible_domains'].remove(s['next_available_domain'])
                    break
        sorted(das, key=lambda x: x[0])  # sort by domain

        for domain, intent, slot in das:
            if domain in self.cfg.belief_domains and domain != self.topic:
                self.topic = domain

            da = '-'.join((domain, intent, slot))
            if intent == 'request':
                s['user_action'][da] = '?'
                s['belief_state'][self.topic][slot] = '?'
            elif slot == 'none':
                s['user_action'][da] = 'none'
            else:
                if self.topic and slot in s['user_goal'][
                        self.topic] and s['user_goal'][domain][slot] != '?':
                    s['user_action'][da] = s['user_goal'][domain][slot]
                else:
                    s['user_action'][da] = 'dont care'

                if not self.topic:
                    continue
                if intent == 'inform':
                    s['belief_state'][domain][slot] = s['user_action'][da]
                    if slot in s['user_goal'][self.topic] and s['user_goal'][
                            self.topic][slot] != '?':
                        discard(s['goal_state'][self.topic], slot)

        if not self.lock_evalutor:
            self.evaluator.add_usr_da(s['user_action'])

        return s

    def reset(self, random_seed=None):
        """
        Args:
            random_seed (int):
        Returns:
            init_state (dict):
        """
        pass

    def step(self, s, sys_a):
        """
        Args:
            s (dict):
            sys_a (vector):
        Returns:
            next_s (dict):
            terminal (bool):
        """
        pass
Example #7
0
    def create_dataset_global(self, part, file_dir, data_dir, cfg, db):
        """
        创建global数据,这个数据记录了用户侧和系统侧的所有状态以及奖励
        """
        datas = self.data[part]
        goals = self.goal[part]
        s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], []
        evaluator = MultiWozEvaluator(data_dir, cfg.d)
        for idx, turn_data in enumerate(datas):
            if turn_data['others']['turn'] % 2 == 0:
                if turn_data['others']['turn'] == 0:
                    current_goal = goals[turn_data['others']['session_id']]
                    evaluator.add_goal(current_goal)
                else:
                    next_s_usr.append(s_usr[-1])

                # 当用户目标无法满足时,切换用户目标
                if turn_data['others']['change'] and evaluator.cur_domain:
                    if 'final' in current_goal[evaluator.cur_domain]:
                        for key in current_goal[evaluator.cur_domain]['final']:
                            current_goal[
                                evaluator.cur_domain][key] = current_goal[
                                    evaluator.cur_domain]['final'][key]
                        del (current_goal[evaluator.cur_domain]['final'])
                turn_data['user_goal'] = deepcopy(current_goal)

                s_usr.append(
                    torch.Tensor(
                        state_vectorize_user(turn_data, cfg,
                                             evaluator.cur_domain)))
                evaluator.add_usr_da(turn_data['trg_user_action'])

                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = turn_data[
                        'trg_user_action']
                    next_turn_data['sys_action'] = datas[idx +
                                                         1]['trg_sys_action']
                    next_turn_data['trg_user_action'] = {}
                    next_turn_data['goal_state'] = datas[idx +
                                                         1]['final_goal_state']
                    next_s_usr.append(
                        torch.Tensor(
                            state_vectorize_user(next_turn_data, cfg,
                                                 evaluator.cur_domain)))

            else:
                if turn_data['others']['turn'] != 1:
                    next_s_sys.append(s_sys[-1])

                s_sys.append(
                    torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
                evaluator.add_sys_da(turn_data['trg_sys_action'])

                if turn_data['others']['terminal']:
                    next_turn_data = deepcopy(turn_data)
                    next_turn_data['others']['turn'] = -1
                    next_turn_data['user_action'] = {}
                    next_turn_data['sys_action'] = turn_data['trg_sys_action']
                    next_turn_data['trg_sys_action'] = {}
                    next_turn_data['belief_state'] = turn_data[
                        'final_belief_state']
                    next_s_sys.append(
                        torch.Tensor(
                            state_vectorize(next_turn_data, cfg, db, True)))
                    # 由于多轮对话系统,默认最终都是系统说结束语,因此通过系统判断任务是否成功作为整体的奖励
                    reward_g = 20 if evaluator.task_success() else -5
                    r_g.append(reward_g)
                    t.append(1)
                else:
                    # 增加domain_success的奖励,其他的则每增加一轮减少一点损失,用于缩短轮数 todo 什么是 domain_success
                    reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(
                        evaluator.cur_domain) else -1
                    r_g.append(reward_g)
                    t.append(0)

        torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
Example #8
0
    def create_dataset_sys(self, part, file_dir, data_dir, cfg, db):
        """
        创建sys的训练数据
        """
        datas = self.data[part]
        goals = self.goal[part]
        # 系统状态+系统动作+回报+上一轮系统状态+末轮标志位
        s, a, r, next_s, t = [], [], [], [], []
        # evaluator 全称记录数据
        evaluator = MultiWozEvaluator(data_dir, cfg.d)
        for idx, turn_data in enumerate(datas):
            # user
            # 用户侧并没有做数据的更新操作
            if turn_data['others']['turn'] % 2 == 0:
                # 首轮对话加载用户目标
                if turn_data['others']['turn'] == 0:
                    evaluator.add_goal(
                        goals[turn_data['others']['session_id']])
                #
                evaluator.add_usr_da(turn_data['trg_user_action'])
                continue

            # 错位了,确实表示的下一轮状态
            if turn_data['others']['turn'] != 1:
                next_s.append(s[-1])

            # 将当前数据转化为状态向量
            s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True)))
            # 将当前动作转化为动作向量
            a.append(
                torch.Tensor(action_vectorize(turn_data['trg_sys_action'],
                                              cfg)))
            evaluator.add_sys_da(turn_data['trg_sys_action'])
            if turn_data['others']['terminal']:
                # 结束轮
                next_turn_data = deepcopy(turn_data)
                next_turn_data['others']['turn'] = -1
                next_turn_data['user_action'] = {}
                next_turn_data['sys_action'] = turn_data['trg_sys_action']
                next_turn_data['trg_sys_action'] = {}
                next_turn_data['belief_state'] = turn_data[
                    'final_belief_state']
                # 统计next_s
                next_s.append(
                    torch.Tensor(state_vectorize(next_turn_data, cfg, db,
                                                 True)))
                # 统计奖励, 对于系统动作,判决任务是否完成作为最终奖励依据,
                # 系统是否完成了真实用户动作所提出的订阅请求,且系统是否回答了真实用户动作所咨询的所有问题
                reward = 20 if evaluator.task_success(False) else -5
                r.append(reward)
                # 结束标志位
                t.append(1)
            else:
                reward = 0
                if evaluator.cur_domain:
                    for slot, value in turn_data['belief_state'][
                            evaluator.cur_domain].items():
                        if value == '?':
                            for da in turn_data['trg_sys_action']:
                                d, i, k, p = da.split('-')
                                if i in [
                                        'inform', 'recommend', 'offerbook',
                                        'offerbooked'
                                ] and k == slot:
                                    break
                            else:
                                # not answer request
                                # 没有完成对belief_state中的提问,奖励减一
                                reward -= 1
                if not turn_data['trg_sys_action']:
                    # 本轮没有回复奖励减五
                    reward -= 5
                r.append(reward)
                t.append(0)

        torch.save((s, a, r, next_s, t), file_dir)