Esempio n. 1
0
    def reset(self, random_seed=None):
        """
        init a user goal and return init state
        """
        self.time_step = -1
        self.topic = 'NONE'
        while True:
            self.goal = self.goal_gen.get_user_goal(random_seed)
            self._mask_user_goal(self.goal)
            if self.goal['domain_ordering']:
                break
            if random_seed:
                random_seed += 1 << 10

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, self.goal, self.cfg)
        dummy_state['user_goal'] = dummy_goal
        dummy_state['last_user_action'] = dict()
        self.sys_da_stack = []  # to save sys da history

        goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        goal_len_input = torch.LongTensor([len(goal_input)]).squeeze()
        usr_a, terminal = self.user.select_action(goal_input, goal_len_input,
                                                  torch.LongTensor([[0]]),
                                                  torch.LongTensor(
                                                      [1]))  # dummy sys da
        usr_a = self._dict_to_vec(
            self.manager.usrseq2da(self.manager.id2sentence(usr_a), self.goal))
        init_state = self.update_belief_usr(dummy_state, usr_a, terminal)

        return init_state
Esempio n. 2
0
 def reset(self, random_seed=None, saved_goal=None):
     """ Build new Goal and Agenda for next session """
     self.time_step = -1
     self.topic = 'NONE'
     self.goal = Goal(self.goal_generator, self._mask_user_goal, seed=random_seed, saved_goal=saved_goal)
     self.agenda = Agenda(self.goal)
     
     dummy_state, dummy_goal = init_session(-1, self.cfg)
     init_goal(dummy_goal, self.goal.domain_goals, self.cfg)
     dummy_state['user_goal'] = dummy_goal
     dummy_state['last_user_action'] = dict()
     
     usr_a, terminal = self.predict(None, {})
     usr_a = self._dict_to_vec(usr_a)
     init_state = self.update_belief_usr(dummy_state, usr_a, terminal)
     return init_state
Esempio n. 3
0
    def reset(self, random_seed=None):
        self.last_state = init_belief_state()
        self.time_step = 0
        self.topic = ''
        self.goal = self.goal_gen.get_user_goal(random_seed)

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg)

        domain_ordering = self.goal['domain_ordering']
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]

        dummy_state['user_goal'] = dummy_goal
        self.evaluator.add_goal(dummy_goal)

        return dummy_state
Esempio n. 4
0
    def reset(self, random_seed=None):
        """
        为下一段对话创建用户目标 goal 和代理器 agenda
        """
        self.time_step = 0
        # 当前对话的主题,这里就是指 domain
        self.topic = ''

        # 创建用户目标,需要使用到目标生成器
        self.goal = Goal(self.goal_generator, seed=random_seed)
        # 创建用户代理器
        self.agenda = Agenda(self.goal)

        # 初始化 状态和目标
        # 状态采用 dict 形式,区别于 向量形式状态
        # 状态中包含 下一个domain, 下一个之后的所有domain, 用户目标,目标状态等
        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'],
                  self.goal.domain_goals, self.cfg)

        print("-domain_goals ", self.goal.domain_goals)
        print("-user_goal ", dummy_goal)

        domain_ordering = self.goal.domains
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]
        dummy_state['user_goal'] = dummy_goal

        # 将初始的用户目标加入到 evaluator 中,用于评价该用户目标是否完成
        self.evaluator.add_goal(dummy_goal)

        # 默认为用户先说话,因此先生成用户的动作
        usr_a, terminal = self.predict(None, {})
        usr_a = self._dict_to_vec(usr_a)
        usr_a[-1] = 1 if terminal else 0
        # 并通过用户的动作 更新初始的状态
        init_state = self.update_belief_usr(dummy_state, usr_a)
        return init_state
Esempio n. 5
0
    def reset(self, random_seed=None):
        """ Build new Goal and Agenda for next session """
        self.time_step = 0
        self.topic = ''
        self.goal = Goal(self.goal_generator, seed=random_seed)
        self.agenda = Agenda(self.goal)

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'],
                  self.goal.domain_goals, self.cfg)

        domain_ordering = self.goal.domains
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]

        dummy_state['user_goal'] = dummy_goal
        self.evaluator.add_goal(dummy_goal)

        usr_a, terminal = self.predict(None, {})
        usr_a = self._dict_to_vec(usr_a)
        usr_a[-1] = 1 if terminal else 0
        init_state = self.update_belief_usr(dummy_state, usr_a)
        return init_state
Esempio n. 6
0
    def _build_data(self, data_dir, data_dir_new, cfg, db):
        data_filename = data_dir + '/' + cfg.data_file
        with open(data_filename, 'r') as f:
            origin_data = json.load(f)
        
        for part in ['train','valid','test']:
            self.data[part] = []
            self.goal[part] = {}
            
        valList = []
        with open(data_dir + '/' + cfg.val_file) as f:
            for line in f:
                valList.append(line.split('.')[0])
        testList = []
        with open(data_dir + '/' + cfg.test_file) as f:
            for line in f:
                testList.append(line.split('.')[0])
            
        for k_sess in origin_data:
            sess = origin_data[k_sess]
            if k_sess in valList:
                part = 'valid'
            elif k_sess in testList:
                part = 'test'
            else:
                part = 'train'
            turn_data, session_data = init_session(k_sess, cfg)
            belief_state = turn_data['belief_state']
            goal_state = turn_data['goal_state']
            init_goal(session_data, goal_state, sess['goal'], cfg)
            self.goal[part][k_sess] = deepcopy(session_data)
            current_domain = ''
            book_domain = ''
            turn_data['trg_user_action'] = {}
            turn_data['trg_sys_action'] = {}
            
            for i, turn in enumerate(sess['log']):
                turn_data['others']['turn'] = i
                turn_data['others']['terminal'] = i + 2 >= len(sess['log'])
                da_origin = turn['dialog_act']
                expand_da(da_origin)
                turn_data['belief_state'] = deepcopy(belief_state) # from previous turn
                turn_data['goal_state'] = deepcopy(goal_state)

                if i % 2 == 0: # user
                    turn_data['sys_action'] = deepcopy(turn_data['trg_sys_action'])
                    del(turn_data['trg_sys_action'])
                    turn_data['trg_user_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        if _domain in cfg.belief_domains:
                            current_domain = _domain
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot))
                            if _da in cfg.da_usr:
                                turn_data['trg_user_action'][_da] = _value
                                if _intent == 'inform':
                                    inform_da = _domain+'-'+_slot
                                    if inform_da in cfg.inform_da:
                                        belief_state[_domain][_slot] = _value
                                    if inform_da in cfg.inform_da_usr and _slot in session_data[_domain] \
                                        and session_data[_domain][_slot] != '?':
                                        discard(goal_state[_domain], _slot)
                                elif _intent == 'request':
                                    request_da = _domain+'-'+_slot
                                    if request_da in cfg.request_da:
                                        belief_state[_domain][_slot] = '?'
                        
                else: # sys
                    book_status = turn['metadata']
                    for domain in cfg.belief_domains:
                        if book_status[domain]['book']['booked']:
                            entity = book_status[domain]['book']['booked'][0]
                            if 'booked' in belief_state[domain]:
                                continue
                            book_domain = domain
                            if domain in ['taxi', 'hospital', 'police']:
                                belief_state[domain]['booked'] = f'{domain}-booked'
                            elif domain == 'train':
                                found = db.query(domain, [('trainID', entity['trainID'])])
                                belief_state[domain]['booked'] = found[0]['ref']
                            else:
                                found = db.query(domain, [('name', entity['name'])])
                                belief_state[domain]['booked'] = found[0]['ref']
                    
                    turn_data['user_action'] = deepcopy(turn_data['trg_user_action'])
                    del(turn_data['trg_user_action'])
                    turn_data['others']['change'] = False
                    turn_data['trg_sys_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot, p))
                            if _da in cfg.da and current_domain:
                                if _slot == 'ref':
                                    turn_data['trg_sys_action'][_da] = belief_state[book_domain]['booked']
                                else:
                                    turn_data['trg_sys_action'][_da] = _value
                                if _intent in ['inform', 'recommend', 'offerbook', 'offerbooked', 'book']:
                                    inform_da = current_domain+'-'+_slot
                                    if inform_da in cfg.request_da:
                                        discard(belief_state[current_domain], _slot, '?')
                                    if inform_da in cfg.request_da_usr and _slot in session_data[current_domain] \
                                        and session_data[current_domain][_slot] == '?':
                                        goal_state[current_domain][_slot] = _value
                                elif _intent in ['nooffer', 'nobook']:
                                    # TODO: better transition
                                    for da in turn_data['user_action']:
                                        __domain, __intent, __slot = da.split('-')
                                        if __intent == 'inform' and __domain == current_domain:
                                            discard(belief_state[current_domain], __slot)
                                    turn_data['others']['change'] = True
                                    reload(goal_state, session_data, current_domain)
                
                if i + 1 == len(sess['log']):
                    turn_data['final_belief_state'] = belief_state
                    turn_data['final_goal_state'] = goal_state
                
                self.data[part].append(deepcopy(turn_data))

        add_domain_mask(self.data)
                                
        def _set_default(obj):
            if isinstance(obj, set):
                return list(obj)
            raise TypeError
        os.makedirs(data_dir_new)
        for part in ['train','valid','test']:
            with open(data_dir_new + '/' + part + '.json', 'w') as f:
                self.data[part] = json.dumps(self.data[part], default=_set_default)
                f.write(self.data[part])
                self.data[part] = json.loads(self.data[part])
            with open(data_dir_new + '/' + part + '_goal.json', 'w') as f:
                self.goal[part] = json.dumps(self.goal[part], default=_set_default)
                f.write(self.goal[part])
                self.goal[part] = json.loads(self.goal[part])
Esempio n. 7
0
    def _build_data(self, data_dir, data_dir_new, cfg, db):
        data_filename = data_dir + '/' + cfg.data_file
        with open(data_filename, 'r') as f:
            origin_data = json.load(f)

        for part in ['train', 'valid', 'test']:
            self.data[part] = []
            self.goal[part] = {}

        valList = []
        with open(data_dir + '/' + cfg.val_file) as f:
            for line in f:
                valList.append(line.split('.')[0])
        testList = []
        with open(data_dir + '/' + cfg.test_file) as f:
            for line in f:
                testList.append(line.split('.')[0])

        for k_sess in origin_data:
            sess = origin_data[k_sess]
            if k_sess in valList:
                part = 'valid'
            elif k_sess in testList:
                part = 'test'
            else:
                part = 'train'
            turn_data, session_data = init_session(k_sess, cfg)
            init_goal(session_data, sess['goal'], cfg)
            self.goal[part][k_sess] = session_data
            belief_state = turn_data['belief_state']

            for i, turn in enumerate(sess['log']):
                turn_data['others']['turn'] = i
                turn_data['others']['terminal'] = i + 2 >= len(sess['log'])
                da_origin = turn['dialog_act']
                expand_da(da_origin)
                turn_data['belief_state'] = deepcopy(
                    belief_state)  # from previous turn

                if i % 2 == 0:  # user
                    if 'last_sys_action' in turn_data:
                        turn_data['history']['sys'] = dict(
                            turn_data['history']['sys'],
                            **turn_data['last_sys_action'])
                        del (turn_data['last_sys_action'])
                    turn_data['last_user_action'] = deepcopy(
                        turn_data['user_action'])
                    turn_data['user_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        if _intent == 'thank':
                            _intent = 'welcome'
                            _domint = _domain + '-' + _intent
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot, p))
                            if _da in cfg.da_usr:
                                turn_data['user_action'][_da] = _value
                                if _intent == 'inform':
                                    inform_da = _domain + '-' + _slot + '-1'
                                    if inform_da in cfg.inform_da:
                                        belief_state['inform'][_domain][
                                            _slot] = _value
                                elif _intent == 'request':
                                    request_da = _domain + '-' + _slot
                                    if request_da in cfg.request_da:
                                        belief_state['request'][_domain].add(
                                            _slot)

                else:  # sys
                    if 'last_user_action' in turn_data:
                        turn_data['history']['user'] = dict(
                            turn_data['history']['user'],
                            **turn_data['last_user_action'])
                        del (turn_data['last_user_action'])
                    turn_data['last_sys_action'] = deepcopy(
                        turn_data['sys_action'])
                    turn_data['sys_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot, p))
                            if _da in cfg.da:
                                turn_data['sys_action'][_da] = _value
                                if _intent == 'inform' and _domain in belief_state[
                                        'request']:
                                    belief_state['request'][_domain].discard(
                                        _slot)
                                elif _intent == 'book' and _slot == 'ref':
                                    for domain in belief_state['request']:
                                        if _slot in belief_state['request'][
                                                domain]:
                                            belief_state['request'][
                                                domain].remove(_slot)
                                            break

                    book_status = turn['metadata']
                    for domain in cfg.belief_domains:
                        if book_status[domain]['book']['booked']:
                            entity = book_status[domain]['book']['booked'][0]
                            if domain == 'taxi':
                                belief_state['booked'][domain] = 'booked'
                            elif domain == 'train':
                                found = db.query(
                                    domain, [('trainID', entity['trainID'])])
                                belief_state['booked'][domain] = found[0][
                                    'ref']
                            else:
                                found = db.query(domain,
                                                 [('name', entity['name'])])
                                belief_state['booked'][domain] = found[0][
                                    'ref']

                if i + 1 == len(sess['log']):
                    turn_data['next_belief_state'] = belief_state

                self.data[part].append(deepcopy(turn_data))

        def _set_default(obj):
            if isinstance(obj, set):
                return list(obj)
            raise TypeError

        os.makedirs(data_dir_new)
        for part in ['train', 'valid', 'test']:
            with open(data_dir_new + '/' + part + '.json', 'w') as f:
                self.data[part] = json.dumps(self.data[part],
                                             default=_set_default)
                f.write(self.data[part])
                self.data[part] = json.loads(self.data[part])
            with open(data_dir_new + '/' + part + '_goal.json', 'w') as f:
                self.goal[part] = json.dumps(self.goal[part],
                                             default=_set_default)
                f.write(self.goal[part])
                self.goal[part] = json.loads(self.goal[part])
Esempio n. 8
0
    def _build_data(self, data_dir, data_dir_new, cfg, db):
        """
        按照train,dev,test构建session数据
        belief_state, goal_state,others,sys_action,trg_sys_action,usr_action,trg_usr_action
        以及附加的状态包括
        final_belief_state, final_goal_state, next_avaliable_domain, invisible_domains
        """
        data_filename = data_dir + '/' + cfg.data_file
        with open(data_filename, 'r') as f:
            origin_data = json.load(f)

        for part in ['train', 'valid', 'test']:
            self.data[part] = []
            self.goal[part] = {}

        valList = []
        with open(data_dir + '/' + cfg.val_file) as f:
            for line in f:
                valList.append(line.split('.')[0])
        testList = []
        with open(data_dir + '/' + cfg.test_file) as f:
            for line in f:
                testList.append(line.split('.')[0])

        num_sess = 0
        for k_sess in origin_data:
            sess = origin_data[k_sess]
            if k_sess in valList:
                part = 'valid'
            elif k_sess in testList:
                part = 'test'
            else:
                part = 'train'
            turn_data, session_data = init_session(k_sess, cfg)
            # belief_state
            belief_state = turn_data['belief_state']
            # goal_state
            goal_state = turn_data['goal_state']
            init_goal(session_data, goal_state, sess['goal'], cfg)
            # 直接跳过多domain场景
            if "SNG" not in k_sess:
                continue

            # 判断如果数据中没有指定domain场景,则不使用该数据w
            contain_domain = False
            for domain in session_data:
                content = session_data[domain]
                if domain in cfg.belief_domains and len(content) > 0:
                    contain_domain = True
                    break
            if contain_domain is False:
                continue

            num_sess += 1

            # goal
            # 完整的用户目标, goal 和 goal_state 的差异在于, goal为任务最终的目标, goal_state 表示用户目标在某一时刻的状态
            self.goal[part][k_sess] = deepcopy(session_data)
            current_domain = ''
            book_domain = ''
            turn_data['trg_user_action'] = {}
            turn_data['trg_sys_action'] = {}

            for i, turn in enumerate(sess['log']):
                # 注意区分 turn_data为session数据,turn为原始数据
                turn_data['others']['turn'] = i
                turn_data['others']['terminal'] = i + 2 >= len(sess['log'])
                # 记录本轮的动作
                da_origin = turn['dialog_act']
                expand_da(da_origin)
                turn_data['belief_state'] = deepcopy(
                    belief_state)  # from previous turn
                turn_data['goal_state'] = deepcopy(goal_state)

                if i % 2 == 0:  # user
                    # sys_action为上一轮系统动作,trg_sys_action为本轮系统动作,
                    # 因此需要将trg_sys_action赋值给sys_action,之后删除trg_sys_action
                    turn_data['sys_action'] = deepcopy(
                        turn_data['trg_sys_action'])
                    del (turn_data['trg_sys_action'])

                    turn_data['trg_user_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        # current_domain 防止出现booking domain的情况
                        if _domain in cfg.belief_domains:
                            current_domain = _domain
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot))
                            # 符合要求的用户动作,如果需要限制用户动作的范围,可以通过修改cfg.da_user即可
                            if _da in cfg.da_usr:
                                turn_data['trg_user_action'][_da] = _value
                                if _intent == 'inform':
                                    inform_da = _domain + '-' + _slot
                                    # 根据动作,调整belief_state和goal_state
                                    if inform_da in cfg.inform_da:
                                        belief_state[_domain][_slot] = _value
                                    if inform_da in cfg.inform_da_usr and _slot in session_data[_domain] \
                                            and session_data[_domain][_slot] != '?':
                                        discard(goal_state[_domain], _slot)
                                elif _intent == 'request':
                                    request_da = _domain + '-' + _slot
                                    if request_da in cfg.request_da:
                                        belief_state[_domain][_slot] = '?'

                else:  # sys
                    # metadata 记录book的状态
                    book_status = turn['metadata']
                    for domain in cfg.belief_domains:
                        if book_status[domain]['book']['booked']:
                            entity = book_status[domain]['book']['booked'][0]
                            # 表示已经完成booked,可以不需要再更新订购状体
                            if 'booked' in belief_state[domain]:
                                continue

                            book_domain = domain
                            if domain in ['taxi', 'hospital', 'police']:
                                belief_state[domain][
                                    'booked'] = f'{domain}-booked'
                            elif domain == 'train':
                                found = db.query(
                                    domain, [('trainID', entity['trainID'])])
                                belief_state[domain]['booked'] = found[0][
                                    'ref']
                            else:
                                found = db.query(domain,
                                                 [('name', entity['name'])])
                                belief_state[domain]['booked'] = found[0][
                                    'ref']

                    # 对用户动作做相同的处理
                    turn_data['user_action'] = deepcopy(
                        turn_data['trg_user_action'])
                    del (turn_data['trg_user_action'])

                    # 保证只会影响到下一个用户动作
                    turn_data['others']['change'] = False
                    turn_data['trg_sys_action'] = dict()
                    for domint in da_origin:
                        domain_intent = da_origin[domint]
                        _domint = domint.lower()
                        _domain, _intent = _domint.split('-')
                        for slot, p, value in domain_intent:
                            _slot = slot.lower()
                            _value = value.strip()
                            _da = '-'.join((_domint, _slot, p))
                            if _da in cfg.da and current_domain:
                                if _slot == 'ref':
                                    turn_data['trg_sys_action'][
                                        _da] = belief_state[book_domain][
                                            'booked']
                                else:
                                    turn_data['trg_sys_action'][_da] = _value
                                if _intent in [
                                        'inform', 'recommend', 'offerbook',
                                        'offerbooked', 'book'
                                ]:
                                    inform_da = current_domain + '-' + _slot
                                    if inform_da in cfg.request_da:
                                        discard(belief_state[current_domain],
                                                _slot, '?')
                                    if inform_da in cfg.request_da_usr and _slot in session_data[current_domain] \
                                            and session_data[current_domain][_slot] == '?':
                                        goal_state[current_domain][
                                            _slot] = _value
                                elif _intent in ['nooffer', 'nobook']:
                                    # TODO: better transition
                                    for da in turn_data['user_action']:
                                        __domain, __intent, __slot = da.split(
                                            '-')
                                        # 如果系统无法提供服务,则删除上一次用户针对本domain提供的inform
                                        if __intent == 'inform' and __domain == current_domain:
                                            discard(
                                                belief_state[current_domain],
                                                __slot)
                                    # 改变change标志位,表示用户目标发生转变
                                    turn_data['others']['change'] = True
                                    reload(goal_state, session_data,
                                           current_domain)
                # 最后一轮,登入final状态
                if i + 1 == len(sess['log']):
                    turn_data['final_belief_state'] = belief_state
                    turn_data['final_goal_state'] = goal_state

                self.data[part].append(deepcopy(turn_data))
        print("session number: ", num_sess)
        add_domain_mask(self.data)

        def _set_default(obj):
            if isinstance(obj, set):
                return list(obj)
            raise TypeError

        # 实验一
        # 按照单个domain划分

        os.makedirs(data_dir_new)
        for part in ['train', 'valid', 'test']:
            with open(data_dir_new + '/' + part + '.json', 'w') as f:
                self.data[part] = json.dumps(self.data[part],
                                             default=_set_default)
                f.write(self.data[part])
                self.data[part] = json.loads(self.data[part])
            with open(data_dir_new + '/' + part + '_goal.json', 'w') as f:
                self.goal[part] = json.dumps(self.goal[part],
                                             default=_set_default)
                f.write(self.goal[part])
                self.goal[part] = json.loads(self.goal[part])