Пример #1
0
    def init(self, obj_schema):
        self.schema = obj_schema
        self.goal_generator = GoalGenerator(self.schema)
        self.goal_schema = self.goal_generator.get_goal()
        self.generate_plan()

        print(self.goal_schema)
Пример #2
0
    def __init__(self, args, manager, config, pretrain=False):

        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.optim = optim.Adam(self.user.parameters(), lr=args.lr_simu)
        self.goal_gen = GoalGenerator(
            args.data_dir,
            goal_model_path='processed_data/goal_model.pkl',
            corpus_path=config.data_file)
        self.cfg = config
        self.manager = manager
        self.user.eval()

        if pretrain:
            self.print_per_batch = args.print_per_batch
            self.save_dir = args.save_dir
            self.save_per_epoch = args.save_per_epoch
            seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
            train_goals, train_usrdas, train_sysdas, \
            test_goals, test_usrdas, test_sysdas, \
            val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
                seq_goals, seq_usr_dass, seq_sys_dass)
            self.data_train = (train_goals, train_usrdas, train_sysdas,
                               args.batchsz)
            self.data_valid = (val_goals, val_usrdas, val_sysdas, args.batchsz)
            self.data_test = (test_goals, test_usrdas, test_sysdas,
                              args.batchsz)
            self.nll_loss = nn.NLLLoss(ignore_index=0)  # PAD=0
            self.bce_loss = nn.BCEWithLogitsLoss()
        else:
            from dbquery import DBQuery
            self.db = DBQuery(args.data_dir)
Пример #3
0
 def __init__(self, data_dir, config):
     super(Controller, self).__init__(data_dir, config)
     self.goal_gen = GoalGenerator(
         data_dir,
         config,
         goal_model_path='processed_data/goal_model.pkl',
         corpus_path=config.data_file
     )  # data_file needs to have train, test and dev parts
Пример #4
0
 def __init__(self, data_dir, cfg):
     super(SystemRule, self).__init__(data_dir, cfg)
     self.last_state = {}
     self.goal_gen = GoalGenerator(
         data_dir,
         cfg,
         goal_model_path='processed_data/goal_model.pkl',
         corpus_path=cfg.data_file)
Пример #5
0
 def __init__(self, data_dir, config):
     """
     负责构建目标生成器
     """
     super(Controller, self).__init__(data_dir, config)
     self.goal_gen = GoalGenerator(
         data_dir,
         config,
         goal_model_path='processed_data_' + config.d +
         '/goal_model.pkl',  # 重新生成数据分布
         corpus_path=config.data_file)
Пример #6
0
    def __init__(self,
                 goal_generator: GoalGenerator,
                 mask_user_goal,
                 seed=None):
        """
        create new Goal by random
        Args:
            goal_generator (GoalGenerator): Goal Gernerator.
            mask_user_goal: mask invalid domains in the goal
        """
        while True:
            self.domain_goals = goal_generator.get_user_goal(seed)
            mask_user_goal(self.domain_goals)
            if self.domain_goals['domain_ordering']:
                break
            if seed:
                seed += 1 << 10

        self.domains = list(self.domain_goals['domain_ordering'])
        del self.domain_goals['domain_ordering']

        for domain in self.domains:
            if 'reqt' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['reqt'] = {
                    slot: DEF_VAL_UNK
                    for slot in self.domain_goals[domain]['reqt']
                }

            if 'book' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
Пример #7
0
class Controller(StateTracker):
    def __init__(self, data_dir, config):
        """
        负责构建目标生成器
        """
        super(Controller, self).__init__(data_dir, config)
        self.goal_gen = GoalGenerator(
            data_dir,
            config,
            goal_model_path='processed_data_' + config.d +
            '/goal_model.pkl',  # 重新生成数据分布
            corpus_path=config.data_file)

    def reset(self, random_seed=None):
        """
        随机生成用户目标,初始化状态,同时更新evaluator的判定目标
        """
        self.time_step = 0
        self.topic = ''
        self.goal = self.goal_gen.get_user_goal(random_seed)

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg)

        domain_ordering = self.goal['domain_ordering']
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]

        dummy_state['user_goal'] = dummy_goal
        self.evaluator.add_goal(dummy_goal)

        return dummy_state

    def step_sys(self, s, sys_a):
        """
        根据系统执行动作,更新系统状态
        """
        # update state with sys_act
        current_s = self.update_belief_sys(s, sys_a)

        return current_s

    def step_usr(self, s, usr_a):
        """
        根据用户执行动作,更新用户状态
        """
        current_s = self.update_belief_usr(s, usr_a)
        terminal = current_s['others']['terminal']
        return current_s, terminal
Пример #8
0
    def __init__(self, data_dir, cfg):
        super(UserAgenda, self).__init__(data_dir, cfg)
        self.max_turn = 40
        self.max_initiative = 4

        # load stand value
        with open(data_dir + '/' + cfg.ontology_file) as f:
            self.stand_value_dict = json.load(f)

        self.goal_generator = GoalGenerator(
            data_dir,
            cfg,
            goal_model_path='processed_data/goal_model.pkl',
            corpus_path=cfg.data_file)

        self.goal = None
        self.agenda = None
Пример #9
0
    def __init__(self, data_dir, cfg):
        self.max_turn = 40
        self.max_initiative = 4
        self.cfg = cfg
        self.db = DBQuery(data_dir)
        
        # load stand value
        with open(data_dir + '/' + cfg.ontology_file) as f:
            self.stand_value_dict = json.load(f)

        self.goal_generator = GoalGenerator(data_dir,
                                            goal_model_path='processed_data/goal_model.pkl',
                                            corpus_path=cfg.data_file)

        self.time_step = 0
        self.goal = None
        self.agenda = None
Пример #10
0
class Controller(StateTracker):
    def __init__(self, data_dir, config):
        super(Controller, self).__init__(data_dir, config)
        self.goal_gen = GoalGenerator(
            data_dir,
            config,
            goal_model_path='processed_data/goal_model.pkl',
            corpus_path=config.data_file
        )  # data_file needs to have train, test and dev parts

    def reset(self, random_seed=None):
        """
        init a user goal and return init state
        """
        self.time_step = 0
        self.topic = ''
        self.goal = self.goal_gen.get_user_goal(random_seed)

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg)

        domain_ordering = self.goal['domain_ordering']
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]

        dummy_state['user_goal'] = dummy_goal
        self.evaluator.add_goal(dummy_goal)

        return dummy_state

    def step_sys(self, s, sys_a):
        """
        interact with simulator for one sys-user turn
        """
        # update state with sys_act
        current_s = self.update_belief_sys(s, sys_a)

        return current_s

    def step_usr(self, s, usr_a):
        current_s = self.update_belief_usr(s, usr_a)
        terminal = current_s['others']['terminal']
        return current_s, terminal
Пример #11
0
    def __init__(self, goal_generator: GoalGenerator, seed=None):
        """
        create new Goal by random
        Args:
            goal_generator (GoalGenerator): Goal Gernerator.
        """
        self.domain_goals = goal_generator.get_user_goal(seed)
        self.domains = list(self.domain_goals['domain_ordering'])
        del self.domain_goals['domain_ordering']

        for domain in self.domains:
            if 'reqt' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['reqt'] = {
                    slot: DEF_VAL_UNK
                    for slot in self.domain_goals[domain]['reqt']
                }

            if 'book' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
Пример #12
0
class Planner:
    def __init__(self, world):
        self.world = world
        self.plan = []

    def init(self, obj_schema):
        self.schema = obj_schema
        self.goal_generator = GoalGenerator(self.schema)
        self.goal_schema = self.goal_generator.get_goal()
        self.generate_plan()

        print(self.goal_schema)

    def get_goal_schema(self):
        return self.goal_schema

    def generate_plan(self):
        move = [
            self.world.find_entity_by_name('Toyota'), "on.p",
            self.world.find_entity_by_name('Table')
        ]
        self.plan = [utils.rel_to_ulf(move)]

    def next(self):
        return self.plan[0]

    def execute(self):
        self.plan.pop(0)

    def update(self):
        pass


# planner = Planner('($ obj-schema \
#  :header (?x BW-row.n) \
#  :types \
#    !t0 (?x row-of.n \'BW-block.n) \
#  :skeletal-prototype \
#    bw-row1.obj \
#    bw-row2.obj \
#    bw-row3.obj)')
Пример #13
0
    def __init__(self, goal_generator: GoalGenerator, seed=None):
        """
        随机创建用户目标, 包括domain信息
        用户目标记录了请求完成情况和预定情况
        """
        # 随机生成用户目标
        self.domain_goals = goal_generator.get_user_goal(seed)
        # 单独拎出来domains
        self.domains = list(self.domain_goals['domain_ordering'])
        del self.domain_goals['domain_ordering']

        for domain in self.domains:
            # 对目标中的reqt,由list转化为dict形式,value为DEF_VAL_UNK符号
            if 'reqt' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['reqt'] = {
                    slot: DEF_VAL_UNK
                    for slot in self.domain_goals[domain]['reqt']
                }
            # 如果目标中存在book,则新增booked属性为DEF_VAL_UNK符号
            if 'book' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
Пример #14
0
    def __init__(self, data_dir, cfg):
        super(UserAgenda, self).__init__(data_dir, cfg)
        # 最大对话轮数, 貌似没有用到
        # 结束会话的情况只有三种,nooffer, nobook, task_complete
        self.max_turn = 40
        # 最多连续执行动作数
        self.max_initiative = 4

        # ontology_file = value_set.json 只在这里使用
        # value_set 记录了所有domain的所有slot对应的值
        with open(data_dir + '/' + cfg.ontology_file) as f:
            self.stand_value_dict = json.load(f)

        # 根据 goal_model.pkl 构建用户目标生成器,
        # goal_model.pkl 中记录了真实数据中的 domain,book,slot, slot-value的出现概率
        self.goal_generator = GoalGenerator(data_dir,
                                            cfg,
                                            goal_model_path='processed_data_' +
                                            cfg.d + '/goal_model.pkl',
                                            corpus_path=cfg.data_file)

        self.goal = None
        self.agenda = None
Пример #15
0
class SystemRule(StateTracker):
    ''' Rule-based bot. Implemented for Multiwoz dataset.
        模拟系统agent的实现
    '''

    recommend_flag = -1
    choice = ""

    def __init__(self, data_dir, cfg):
        super(SystemRule, self).__init__(data_dir, cfg)
        self.last_state = {}
        self.goal_gen = GoalGenerator(data_dir, cfg,
                                      goal_model_path='processed_data'+cfg.d+'/goal_model.pkl',
                                      corpus_path=cfg.data_file)

    def reset(self, random_seed=None):
        self.last_state = init_belief_state()
        self.time_step = 0
        self.topic = ''
        # todo 模拟系统为什么还需要goal?
        self.goal = self.goal_gen.get_user_goal(random_seed)

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, dummy_state['goal_state'], self.goal, self.cfg)

        domain_ordering = self.goal['domain_ordering']
        dummy_state['next_available_domain'] = domain_ordering[0]
        dummy_state['invisible_domains'] = domain_ordering[1:]

        dummy_state['user_goal'] = dummy_goal
        self.evaluator.add_goal(dummy_goal)

        return dummy_state

    def _action_to_dict(self, das):
        da_dict = {}
        for da, value in das.items():
            domain, intent, slot = da.split('-')
            if domain != 'general':
                domain = domain.capitalize()
            if intent in ['inform', 'request']:
                intent = intent.capitalize()
            domint = '-'.join((domain, intent))
            if domint not in da_dict:
                da_dict[domint] = []
            da_dict[domint].append([slot.capitalize(), value])
        return da_dict

    def _dict_to_vec(self, das):
        da_vector = torch.zeros(self.cfg.a_dim, dtype=torch.int32)
        expand_da(das)
        for domint in das:
            pairs = das[domint]
            for slot, p, value in pairs:
                da = '-'.join((domint, slot, p)).lower()
                if da in self.cfg.da2idx:
                    idx = self.cfg.da2idx[da]
                    da_vector[idx] = 1
        return da_vector

    def step(self, s, usr_a):
        """
        interact with simulator for one user-sys turn
        """
        # update state with user_act
        # 主要更新user_action, goal_state
        current_s = self.update_belief_usr(s, usr_a)

        # 系统操作部分
        da_dict = self._action_to_dict(current_s['user_action'])
        state = self._update_state(da_dict)
        sys_a = self.predict(state)
        sys_a = self._dict_to_vec(sys_a)

        # update state with sys_act
        next_s = self.update_belief_sys(current_s, sys_a)
        return next_s

    def predict(self, state):
        """
        Args:
            State, please refer to util/state.py
        Output:
            DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...}
        """

        if self.recommend_flag != -1:
            self.recommend_flag += 1

        self.kb_result = {}

        DA = {}

        if 'user_action' in state and (len(state['user_action']) > 0):
            user_action = state['user_action']
        else:
            user_action = check_diff(self.last_state, state)

        # Debug info for check_diff function

        self.last_state = state

        for user_act in user_action:
            # 根据每条用户动作做出相应的回复
            domain, intent_type = user_act.split('-')

            # Respond to general greetings
            if domain == 'general':
                # 通用答复
                self._update_greeting(user_act, state, DA)

            # Book taxi for user
            elif domain == 'Taxi':
                # taxi 订购答复
                self._book_taxi(user_act, state, DA)

            elif domain == 'Booking':
                self._update_booking(user_act, state, DA)

            # User's talking about other domain
            elif domain != "Train":
                self._update_DA(user_act, user_action, state, DA)

            # Info about train
            else:
                self._update_train(user_act, user_action, state, DA)

            # Judge if user want to book
            self._judge_booking(user_act, user_action, DA)

            if 'Booking-Book' in DA:
                # 完成booking之后的其他操作没有意义
                if random.random() < 0.5:
                    DA['general-reqmore'] = []
                user_acts = []
                for user_act in DA:
                    if user_act != 'Booking-Book':
                        user_acts.append(user_act)
                for user_act in user_acts:
                    del DA[user_act]

        if DA == {}:
            return {'general-greet': [['none', 'none']]}
        return DA

    def _update_state(self, user_act=None):
        """
        通过读取user_act, 更新 belief_state, request_state, user_action
        """
        if not isinstance(user_act, dict):
            raise Exception('Expect user_act to be <class \'dict\'> type but get {}.'.format(type(user_act)))
        previous_state = self.last_state
        new_belief_state = copy.deepcopy(previous_state['belief_state'])
        new_request_state = copy.deepcopy(previous_state['request_state'])
        for domain_type in user_act.keys():
            domain, tpe = domain_type.lower().split('-')
            if domain in ['unk', 'general', 'booking']:
                # 跳过一些不会更新状态的domain
                continue
            if tpe == 'inform':
                # inform 状态更新belief_State
                for k, v in user_act[domain_type]:
                    k = REF_SYS_DA[domain.capitalize()].get(k, k)
                    if k is None:
                        continue
                    try:
                        assert domain in new_belief_state
                    except:
                        raise Exception('Error: domain <{}> not in new belief state'.format(domain))
                    domain_dic = new_belief_state[domain]
                    assert 'semi' in domain_dic
                    assert 'book' in domain_dic

                    if k in domain_dic['semi']:
                        nvalue = v
                        new_belief_state[domain]['semi'][k] = nvalue
                    elif k in domain_dic['book']:
                        new_belief_state[domain]['book'][k] = v
                    elif k.lower() in domain_dic['book']:
                        new_belief_state[domain]['book'][k.lower()] = v
                    elif k == 'trainID' and domain == 'train':
                        new_belief_state[domain]['book'][k] = v
                    else:
                        # raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain))
                        with open('unknown_slot.log', 'a+') as f:
                            f.write('unknown slot name <{}> of domain <{}>\n'.format(k, domain))
            elif tpe == 'request':
                # request 更新request_State
                for k, v in user_act[domain_type]:
                    k = REF_SYS_DA[domain.capitalize()].get(k, k)
                    if domain not in new_request_state:
                        new_request_state[domain] = {}
                    if k not in new_request_state[domain]:
                        new_request_state[domain][k] = 0

        new_state = copy.deepcopy(previous_state)
        new_state['belief_state'] = new_belief_state
        new_state['request_state'] = new_request_state
        new_state['user_action'] = user_act

        return new_state

    def _update_greeting(self, user_act, state, DA):
        """ General request / inform. """
        _, intent_type = user_act.split('-')

        # Respond to goodbye
        if intent_type == 'bye':
            if 'general-bye' not in DA:
                DA['general-bye'] = []
            if random.random() < 0.3:
                if 'general-welcome' not in DA:
                    DA['general-welcome'] = []
        elif intent_type == 'thank':
            DA['general-welcome'] = []

    def _book_taxi(self, user_act, state, DA):
        """ Book a taxi for user.
            taxi不需要查询数据库,因此只要必须要的信息收集到,就一定能订购到
        """

        blank_info = []
        for info in ['departure', 'destination']:
            if state['belief_state']['taxi']['semi'] == "":
                info = REF_USR_DA['Taxi'].get(info, info)
                blank_info.append(info)
        if state['belief_state']['taxi']['semi']['leaveAt'] == "" and state['belief_state']['taxi']['semi'][
            'arriveBy'] == "":
            blank_info += ['Leave', 'Arrive']

        # Finish booking, tell user car type and phone number
        if len(blank_info) == 0:
            # 收集到完整的departure,destination,leaveAt,arrive信息,表示可以进行订购
            if 'Taxi-Inform' not in DA:
                DA['Taxi-Inform'] = []
            car = generate_car()
            phone_num = generate_phone_num(11)
            DA['Taxi-Inform'].append(['Car', car])
            DA['Taxi-Inform'].append(['Phone', phone_num])
            return

        # Need essential info to finish booking
        # 否者缺乏订购的必要信息,返回Request动作
        request_num = random.randint(0, 999999) % len(blank_info) + 1
        if 'Taxi-Request' not in DA:
            DA['Taxi-Request'] = []
        for i in range(request_num):
            slot = REF_USR_DA.get(blank_info[i], blank_info[i])
            DA['Taxi-Request'].append([slot, '?'])

    def _update_booking(self, user_act, state, DA):
        pass

    def _update_DA(self, user_act, user_action, state, DA):
        """ Answer user's utterance about any domain other than taxi or train.
            回应用户动作,根据数据库中是否满足用户要求,提供不同的回复
        """

        domain, intent_type = user_act.split('-')

        constraints = []
        for slot in state['belief_state'][domain.lower()]['semi']:
            if state['belief_state'][domain.lower()]['semi'][slot] != "":
                constraints.append([slot, state['belief_state'][domain.lower()]['semi'][slot]])

        kb_result = self.db.query(domain.lower(), constraints)
        self.kb_result[domain] = deepcopy(kb_result)

        # Respond to user's request
        if intent_type == 'Request':
            # 优先反馈用户咨询的问题,根据数据库进行回复
            if self.recommend_flag > 1:
                self.recommend_flag = -1
                self.choice = ""
            elif self.recommend_flag == 1:
                self.recommend_flag == 0
            if (domain + "-Inform") not in DA:
                DA[domain + "-Inform"] = []
            for slot in user_action[user_act]:
                if len(kb_result) > 0:
                    kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0])
                    if kb_slot_name in kb_result[0]:
                        DA[domain + "-Inform"].append([slot[0], kb_result[0][kb_slot_name]])
                    else:
                        DA[domain + "-Inform"].append([slot[0], "unknown"])

        else:
            # There's no result matching user's constraint
            # 如果没有用户限制,返回Nooffer意图,同时一定概率返回Request,询问是否改变限制
            if len(kb_result) == 0:
                if (domain + "-NoOffer") not in DA:
                    DA[domain + "-NoOffer"] = []

                for slot in state['belief_state'][domain.lower()]['semi']:
                    if state['belief_state'][domain.lower()]['semi'][slot] != "" and \
                            state['belief_state'][domain.lower()]['semi'][slot] != "do n't care":
                        slot_name = REF_USR_DA[domain].get(slot, slot)
                        DA[domain + "-NoOffer"].append([slot_name, state['belief_state'][domain.lower()]['semi'][slot]])

                p = random.random()

                # Ask user if he wants to change constraint
                # 写死的概率,选择前三个slot告知用户是否需要切换
                if p < 0.3:
                    req_num = min(random.randint(0, 999999) % len(DA[domain + "-NoOffer"]) + 1, 3)
                    if domain + "-Request" not in DA:
                        DA[domain + "-Request"] = []
                    for i in range(req_num):
                        slot_name = REF_USR_DA[domain].get(DA[domain + "-NoOffer"][i][0], DA[domain + "-NoOffer"][i][0])
                        DA[domain + "-Request"].append([slot_name, "?"])

            # There's exactly one result matching user's constraint
            elif len(kb_result) == 1:
                # 如果只有一个答案符合限制要求,则直接Inform答案
                # Inform user about this result
                if (domain + "-Inform") not in DA:
                    DA[domain + "-Inform"] = []
                props = []
                for prop in state['belief_state'][domain.lower()]['semi']:
                    props.append(prop)
                property_num = len(props)
                if property_num > 0:
                    info_num = random.randint(0, 999999) % property_num + 1
                    random.shuffle(props)
                    for i in range(info_num):
                        slot_name = REF_USR_DA[domain].get(props[i], props[i])
                        DA[domain + "-Inform"].append([slot_name, kb_result[0][props[i]]])

            # There are multiple resultes matching user's constraint
            else:
                p = random.random()
                # 如果有多个答案,Inform答案数量,同时Recommend其中一个答案,随机展示几个实体结果
                # Recommend a choice from kb_list
                if True:  # p < 0.3:
                    if (domain + "-Inform") not in DA:
                        DA[domain + "-Inform"] = []
                    if (domain + "-Recommend") not in DA:
                        DA[domain + "-Recommend"] = []
                    DA[domain + "-Inform"].append(["Choice", str(len(kb_result))])
                    idx = random.randint(0, 999999) % len(kb_result)
                    choice = kb_result[idx]
                    if domain in ["Hotel", "Attraction", "Police", "Restaurant"]:
                        DA[domain + "-Recommend"].append(['Name', choice['name']])
                    self.recommend_flag = 0
                    self.candidate = choice
                    props = []
                    for prop in choice:
                        props.append([prop, choice[prop]])
                    prop_num = min(random.randint(0, 999999) % 3, len(props))
                    random.shuffle(props)
                    for i in range(prop_num):
                        slot = props[i][0]
                        string = REF_USR_DA[domain].get(slot, slot)
                        if string in INFORMABLE_SLOTS:
                            DA[domain + "-Recommend"].append([string, str(props[i][1])])

                # Ask user to choose a candidate.
                elif p < 0.5:
                    prop_values = []
                    props = []
                    for prop in kb_result[0]:
                        for candidate in kb_result:
                            if prop not in candidate:
                                continue
                            if candidate[prop] not in prop_values:
                                prop_values.append(candidate[prop])
                        if len(prop_values) > 1:
                            props.append([prop, prop_values])
                        prop_values = []
                    random.shuffle(props)
                    idx = 0
                    while idx < len(props):
                        if props[idx][0] not in SELECTABLE_SLOTS[domain]:
                            props.pop(idx)
                            idx -= 1
                        idx += 1
                    if domain + "-Select" not in DA:
                        DA[domain + "-Select"] = []
                    for i in range(min(len(props[0][1]), 5)):
                        prop_value = REF_USR_DA[domain].get(props[0][0], props[0][0])
                        DA[domain + "-Select"].append([prop_value, props[0][1][i]])

                # Ask user for more constraint
                else:
                    reqs = []
                    for prop in state['belief_state'][domain.lower()]['semi']:
                        if state['belief_state'][domain.lower()]['semi'][prop] == "":
                            prop_value = REF_USR_DA[domain].get(prop, prop)
                            reqs.append([prop_value, "?"])
                    i = 0
                    while i < len(reqs):
                        if reqs[i][0] not in REQUESTABLE_SLOTS:
                            reqs.pop(i)
                            i -= 1
                        i += 1
                    random.shuffle(reqs)
                    if len(reqs) == 0:
                        return
                    req_num = min(random.randint(0, 999999) % len(reqs) + 1, 2)
                    if (domain + "-Request") not in DA:
                        DA[domain + "-Request"] = []
                    for i in range(req_num):
                        req = reqs[i]
                        req[0] = REF_USR_DA[domain].get(req[0], req[0])
                        DA[domain + "-Request"].append(req)

    def _update_train(self, user_act, user_action, state, DA):
        """
        相比于taxi,train需要多提供day的信息
        """
        constraints = []
        for time in ['leaveAt', 'arriveBy']:
            if state['belief_state']['train']['semi'][time] != "":
                constraints.append([time, state['belief_state']['train']['semi'][time]])

        if len(constraints) == 0:
            p = random.random()
            if 'Train-Request' not in DA:
                DA['Train-Request'] = []
            if p < 0.33:
                DA['Train-Request'].append(['Leave', '?'])
            elif p < 0.66:
                DA['Train-Request'].append(['Arrive', '?'])
            else:
                DA['Train-Request'].append(['Leave', '?'])
                DA['Train-Request'].append(['Arrive', '?'])

        if 'Train-Request' not in DA:
            DA['Train-Request'] = []
        for prop in ['day', 'destination', 'departure']:
            if state['belief_state']['train']['semi'][prop] == "":
                slot = REF_USR_DA['Train'].get(prop, prop)
                DA["Train-Request"].append([slot, '?'])
            else:
                constraints.append([prop, state['belief_state']['train']['semi'][prop]])

        kb_result = self.db.query('train', constraints)
        self.kb_result['Train'] = deepcopy(kb_result)

        if user_act == 'Train-Request':
            del (DA['Train-Request'])
            if 'Train-Inform' not in DA:
                DA['Train-Inform'] = []
            for slot in user_action[user_act]:
                slot_name = REF_SYS_DA['Train'].get(slot[0], slot[0])
                try:
                    DA['Train-Inform'].append([slot[0], kb_result[0][slot_name]])
                except:
                    pass
            return
        if len(kb_result) == 0:
            if 'Train-NoOffer' not in DA:
                DA['Train-NoOffer'] = []
            for prop in constraints:
                DA['Train-NoOffer'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]])
            if 'Train-Request' in DA:
                del DA['Train-Request']
        elif len(kb_result) >= 1:
            if len(constraints) < 4:
                # 条件没有完全达成
                return
            # 条件完全达成,则去掉Request的请求,直接提供订购信息
            if 'Train-Request' in DA:
                del DA['Train-Request']
            if 'Train-OfferBook' not in DA:
                DA['Train-OfferBook'] = []
            for prop in constraints:
                DA['Train-OfferBook'].append([REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]])

    def _judge_booking(self, user_act, user_action, DA):
        """ If user want to book, return a ref number. """
        if self.recommend_flag > 1:
            self.recommend_flag = -1
            self.choice = ""
        elif self.recommend_flag == 1:
            self.recommend_flag == 0
        domain, _ = user_act.split('-')
        for slot in user_action[user_act]:
            if domain in booking_info and slot[0] in booking_info[domain]:
                if 'Booking-Book' not in DA:
                    if domain in self.kb_result and len(self.kb_result[domain]) > 0:
                        if 'Ref' in self.kb_result[domain][0]:
                            DA['Booking-Book'] = [["Ref", self.kb_result[domain][0]['Ref']]]
                        else:
                            DA['Booking-Book'] = [["Ref", "N/A"]]
Пример #16
0
class UserNeural(StateTracker):
    def __init__(self, args, manager, config, pretrain=False):

        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size,
                         voc_sys_size).to(device=DEVICE)
        self.optim = optim.Adam(self.user.parameters(), lr=args.lr_simu)
        self.goal_gen = GoalGenerator(
            args.data_dir,
            goal_model_path='processed_data/goal_model.pkl',
            corpus_path=config.data_file)
        self.cfg = config
        self.manager = manager
        self.user.eval()

        if pretrain:
            self.print_per_batch = args.print_per_batch
            self.save_dir = args.save_dir
            self.save_per_epoch = args.save_per_epoch
            seq_goals, seq_usr_dass, seq_sys_dass = manager.data_loader_seg()
            train_goals, train_usrdas, train_sysdas, \
            test_goals, test_usrdas, test_sysdas, \
            val_goals, val_usrdas, val_sysdas = manager.train_test_val_split_seg(
                seq_goals, seq_usr_dass, seq_sys_dass)
            self.data_train = (train_goals, train_usrdas, train_sysdas,
                               args.batchsz)
            self.data_valid = (val_goals, val_usrdas, val_sysdas, args.batchsz)
            self.data_test = (test_goals, test_usrdas, test_sysdas,
                              args.batchsz)
            self.nll_loss = nn.NLLLoss(ignore_index=0)  # PAD=0
            self.bce_loss = nn.BCEWithLogitsLoss()
        else:
            from dbquery import DBQuery
            self.db = DBQuery(args.data_dir)

    def user_loop(self, data):
        batch_input = to_device(padding_data(data))
        a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])

        loss_a, targets_a = 0, batch_input[
            'origin_responses'][:, 1:]  # remove sos_id
        for i, a_weight in enumerate(a_weights):
            loss_a += self.nll_loss(a_weight, targets_a[:, i])
        loss_a /= i
        loss_t = self.bce_loss(t_weights, batch_input['terminal'])
        loss_a += self.cfg.alpha * kl_gaussian(argu)
        return loss_a, loss_t

    def imitating(self, epoch):
        """
        train the user simulator by simple imitation learning (behavioral cloning)
        """
        self.user.train()
        a_loss, t_loss = 0., 0.
        data_train_iter = batch_iter(self.data_train[0], self.data_train[1],
                                     self.data_train[2], self.data_train[3])
        for i, data in enumerate(data_train_iter):
            self.optim.zero_grad()
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()
            loss = loss_a + loss_t
            loss.backward()
            self.optim.step()

            if (i + 1) % self.print_per_batch == 0:
                a_loss /= self.print_per_batch
                t_loss /= self.print_per_batch
                logging.debug(
                    '<<user simulator>> epoch {}, iter {}, loss_a:{}, loss_t:{}'
                    .format(epoch, i, a_loss, t_loss))
                a_loss, t_loss = 0., 0.

        if (epoch + 1) % self.save_per_epoch == 0:
            self.save(self.save_dir, epoch)
        self.user.eval()

    def imit_test(self, epoch, best):
        """
        provide an unbiased evaluation of the user simulator fit on the training dataset
        """
        a_loss, t_loss = 0., 0.
        data_valid_iter = batch_iter(self.data_valid[0], self.data_valid[1],
                                     self.data_valid[2], self.data_valid[3])
        for i, data in enumerate(data_valid_iter):
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()

        a_loss /= i
        t_loss /= i
        logging.debug(
            '<<user simulator>> validation, epoch {}, loss_a:{}, loss_t:{}'.
            format(epoch, a_loss, t_loss))
        loss = a_loss + t_loss
        if loss < best:
            logging.info('<<user simulator>> best model saved')
            best = loss
            self.save(self.save_dir, 'best')

        a_loss, t_loss = 0., 0.
        data_test_iter = batch_iter(self.data_test[0], self.data_test[1],
                                    self.data_test[2], self.data_test[3])
        for i, data in enumerate(data_test_iter):
            loss_a, loss_t = self.user_loop(data)
            a_loss += loss_a.item()
            t_loss += loss_t.item()

        a_loss /= i
        t_loss /= i
        logging.debug(
            '<<user simulator>> test, epoch {}, loss_a:{}, loss_t:{}'.format(
                epoch, a_loss, t_loss))
        return best

    def save(self, directory, epoch):
        if not os.path.exists(directory):
            os.makedirs(directory)

        torch.save(self.user.state_dict(),
                   directory + '/' + str(epoch) + '_simulator.mdl')
        logging.info(
            '<<user simulator>> epoch {}: saved network to mdl'.format(epoch))

    def load(self, filename):
        user_mdl = filename + '_simulator.mdl'
        if os.path.exists(user_mdl):
            self.user.load_state_dict(torch.load(user_mdl))
            logging.info(
                '<<user simulator>> loaded checkpoint from file: {}'.format(
                    user_mdl))

    def reset(self, random_seed=None):
        """
        init a user goal and return init state
        """
        self.time_step = -1
        self.topic = 'NONE'
        while True:
            self.goal = self.goal_gen.get_user_goal(random_seed)
            self._mask_user_goal(self.goal)
            if self.goal['domain_ordering']:
                break
            if random_seed:
                random_seed += 1 << 10

        dummy_state, dummy_goal = init_session(-1, self.cfg)
        init_goal(dummy_goal, self.goal, self.cfg)
        dummy_state['user_goal'] = dummy_goal
        dummy_state['last_user_action'] = dict()
        self.sys_da_stack = []  # to save sys da history

        goal_input = torch.LongTensor(
            self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
        goal_len_input = torch.LongTensor([len(goal_input)]).squeeze()
        usr_a, terminal = self.user.select_action(goal_input, goal_len_input,
                                                  torch.LongTensor([[0]]),
                                                  torch.LongTensor(
                                                      [1]))  # dummy sys da
        usr_a = self._dict_to_vec(
            self.manager.usrseq2da(self.manager.id2sentence(usr_a), self.goal))
        init_state = self.update_belief_usr(dummy_state, usr_a, terminal)

        return init_state

    def step(self, s, sys_a):
        """
        interact with simulator for one sys-user turn
        """
        # update state with sys_act
        current_s = self.update_belief_sys(s, sys_a)
        if current_s['others']['terminal']:
            # user has terminated the session at last turn
            usr_a, terminal = torch.zeros(self.cfg.a_dim_usr,
                                          dtype=torch.int32), True
        else:
            goal_input = torch.LongTensor(
                self.manager.get_goal_id(self.manager.usrgoal2seq(self.goal)))
            goal_len_input = torch.LongTensor([len(goal_input)]).squeeze()
            sys_seq_turn = self.manager.sysda2seq(
                self.manager.ref_data2stand(
                    self._action_to_dict(current_s['sys_action'])), self.goal)
            self.sys_da_stack.append(sys_seq_turn)
            sys_seq = self.manager.get_sysda_id(self.sys_da_stack)
            sys_seq_len = torch.LongTensor(
                [max(len(sen), 1) for sen in sys_seq])
            max_sen_len = sys_seq_len.max().item()
            sys_seq = torch.LongTensor(padding(sys_seq, max_sen_len))
            usr_a, terminal = self.user.select_action(goal_input,
                                                      goal_len_input, sys_seq,
                                                      sys_seq_len)
            usr_a = self._dict_to_vec(
                self.manager.usrseq2da(self.manager.id2sentence(usr_a),
                                       self.goal))

        # update state with user_act
        next_s = self.update_belief_usr(current_s, usr_a, terminal)
        return next_s, terminal