Ejemplo n.º 1
0
    def __init__(self, sys_da_voc_json, usr_da_voc_json):
        self.sys_da_voc = json.load(open(sys_da_voc_json, encoding='utf8'))
        self.usr_da_voc = json.load(open(usr_da_voc_json, encoding='utf8'))
        self.database = Database()

        self.sys_da2id = {a: i for i, a in enumerate(self.sys_da_voc)}
        self.id2sys_da = {i: a for i, a in enumerate(self.sys_da_voc)}

        # 155
        self.sys_da_dim = len(self.sys_da_voc)

        self.usr_da2id = {a: i for i, a in enumerate(self.usr_da_voc)}
        self.id2usr_da = {i: a for i, a in enumerate(self.usr_da_voc)}

        # 142
        self.usr_da_dim = len(self.usr_da_voc)

        # 26
        self.belief_state_dim = 0  # belief_state 中所有的 slot-values 的数量
        for domain, svs in default_state()['belief_state'].items():
            self.belief_state_dim += len(svs)

        self.db_res_dim = 4

        self.state_dim = self.sys_da_dim + self.usr_da_dim + self.belief_state_dim + self.db_res_dim + 1  # terminated

        self.cur_domain = None
        self.belief_state = None
        self.db_res = None
Ejemplo n.º 2
0
    def __init__(self):
        super(TradeDST, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          TradeDST.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         TradeDST.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config['data_path'] = os.path.join(
            get_data_path(), 'crosswoz/dst_trade_data')
        self.model_config['n_gpus'] = 0 if self.model_config[
            'device'] == 'cpu' else torch.cuda.device_count()
        self.model_config['device'] = torch.device(self.model_config['device'])
        if model_config['load_embedding']:
            model_config['hidden_size'] = 300

        # download data
        for model_key, url in TradeDST.model_urls.items():
            dst = os.path.join(self.model_config['data_path'], model_key)
            if model_key.endswith('pth'):
                file_name = 'trained_model_path'
            elif model_key.endswith('pkl'):
                file_name = model_key.rsplit('-', maxsplit=1)[0]
            else:
                file_name = model_key.split('.')[0]  # ontology
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config['use_cache']:
                download_from_url(url, dst)

        # load date & model
        ontology = json.load(
            open(self.model_config['ontology'], 'r', encoding='utf8'))
        self.all_slots = get_slot_information(ontology)
        self.gate2id = {'ptr': 0, 'none': 1}
        self.id2gate = {id_: gate for gate, id_ in self.gate2id.items()}
        self.lang = pickle.load(open(self.model_config['lang'], 'rb'))
        self.mem_lang = pickle.load(open(self.model_config['mem-lang'], 'rb'))

        model = Trade(
            lang=self.lang,
            vocab_size=len(self.lang.index2word),
            hidden_size=self.model_config['hidden_size'],
            dropout=self.model_config['dropout'],
            num_encoder_layers=self.model_config['num_encoder_layers'],
            num_decoder_layers=self.model_config['num_decoder_layers'],
            pad_id=self.model_config['pad_id'],
            slots=self.all_slots,
            num_gates=len(self.gate2id),
            unk_mask=self.model_config['unk_mask'])

        model.load_state_dict(
            torch.load(self.model_config['trained_model_path']))

        self.model = model.to(self.model_config['device']).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
        self.state = default_state()
        print('>>> State initialized ...')
Ejemplo n.º 3
0
    def init_session(self, state: Optional[dict] = None) -> None:
        """Initialize ``self.state`` with a default state.

        Args:
            state: see xbot.util.state.default_state
        """
        self.dialogue_state = default_state() if not state else deepcopy(state)
Ejemplo n.º 4
0
    def __init__(self):
        super(FSMDST, self).__init__()
        self.dialogue_state = default_state()
        self.dialogue_state["cur_domain"] = "餐馆"
        self.restaurant_states = list(
            self.dialogue_state["belief_state"]["餐馆"].keys())
        self.restaurant_states.append("Request")
        self.machine = Machine(model=self,
                               states=self.restaurant_states,
                               initial="greet")

        self.dest2trigger = {
            "推荐菜": {
                "trigger": "inform_restaurant_dish",
                "callback": "set_recommend_dish",
            },
            "评分": {
                "trigger": "inform_restaurant_rating",
                "callback": "set_rating"
            },
            "人均消费": {
                "trigger": "inform_restaurant_avg_cost",
                "callback": "set_avg_cost",
            },
            "周边酒店": {
                "trigger": "inform_restaurant_surrounding_hotel",
                "callback": "set_surrounding_hotel",
            },
            "周边景点": {
                "trigger": "inform_restaurant_surrounding_attraction",
                "callback": "set_surrounding_attraction",
            },
            "周边餐馆": {
                "trigger": "inform_restaurant_surrounding_restaurant",
                "callback": "set_surrounding_restaurant",
            },
        }

        for slot, trigger in self.dest2trigger.items():
            self.machine.add_transition(
                trigger=trigger["trigger"],
                source="*",
                dest=slot,
                before=trigger["callback"],
            )
Ejemplo n.º 5
0
    def __init__(self):
        super(FSMDST, self).__init__()
        self.dialogue_state = default_state()
        self.dialogue_state['cur_domain'] = '餐馆'
        self.restaurant_states = list(
            self.dialogue_state['belief_state']['餐馆'].keys())
        self.restaurant_states.append('Request')
        self.machine = Machine(model=self,
                               states=self.restaurant_states,
                               initial='greet')

        self.dest2trigger = {
            '推荐菜': {
                'trigger': 'inform_restaurant_dish',
                'callback': 'set_recommend_dish'
            },
            '评分': {
                'trigger': 'inform_restaurant_rating',
                'callback': 'set_rating'
            },
            '人均消费': {
                'trigger': 'inform_restaurant_avg_cost',
                'callback': 'set_avg_cost'
            },
            '周边酒店': {
                'trigger': 'inform_restaurant_surrounding_hotel',
                'callback': 'set_surrounding_hotel'
            },
            '周边景点': {
                "trigger": 'inform_restaurant_surrounding_attraction',
                'callback': 'set_surrounding_attraction'
            },
            '周边餐馆': {
                'trigger': 'inform_restaurant_surrounding_restaurant',
                'callback': 'set_surrounding_restaurant'
            }
        }

        for slot, trigger in self.dest2trigger.items():
            self.machine.add_transition(trigger=trigger['trigger'],
                                        source='*',
                                        dest=slot,
                                        before=trigger['callback'])
Ejemplo n.º 6
0
    def __init__(self):
        super(BertDST, self).__init__()
        # load config
        infer_config = self.load_config()

        # download data
        self.download_data(infer_config)

        self.ontology = json.load(
            open(infer_config['cleaned_ontology'], 'r', encoding='utf8'))
        self.model = BertForSequenceClassification.from_pretrained(
            infer_config['model_dir'])
        self.model.to(infer_config['device'])
        self.tokenizer = BertTokenizer.from_pretrained(
            infer_config['model_dir'])
        self.config = infer_config

        self.model.eval()
        self.state = default_state()
        self.domains = set(self.state['belief_state'].keys())
Ejemplo n.º 7
0
 def init_session(self):
     self.state = default_state()
Ejemplo n.º 8
0
                        if absence:
                            return False
                        s_arr = [val for val in plan_values.split(" ") if val]
                        # 只要有一个推荐菜没有被包含就视为不满足约束
                        if [val for val in s_arr if contains(db_values, val)]:
                            return False
                else:
                    plan_values = options["params"]
                    if plan_values is not None:
                        if absence:
                            return False
                        if db_values.find(plan_values) < 0:
                            return False
            return True

        return [item for item in db if func3(item)]


if __name__ == "__main__":
    from pprint import pprint
    from collections import Counter

    database = Database()
    state = default_state()
    dishes = {}
    for n, v in database.query(state["belief_state"], "餐馆"):
        for dish in v["推荐菜"]:
            dishes.setdefault(dish, 0)
            dishes[dish] += 1
    pprint(Counter(dishes))
Ejemplo n.º 9
0
    def update(self, usr_da=None):
        """update belief_state, cur_domain, request_slot
        :param usr_da: List[List[intent, domain, slot, value]]
        :return: state
        """
        sys_da = self.state['system_action']

        # 统计各个意图下的 domain
        select_domains = Counter([x[1] for x in usr_da if x[0] == 'Select'])
        request_domains = Counter([x[1] for x in usr_da if x[0] == 'Request'])
        inform_domains = Counter([x[1] for x in usr_da if x[0] == 'Inform'])
        sys_domains = Counter(
            [x[1] for x in sys_da if x[0] in ['Inform', 'Recommend']])

        # 为什么首选 select_domain
        # 观察数据集可以发现,Select 意图出现的时候是主导整句话的,即便出现了
        # Inform 和 Request 也是为了为 Select 提供辅助信息,Inform 排在 Request 之后,也是这个道理
        if len(select_domains) > 0:
            self.state['cur_domain'] = select_domains.most_common(1)[0][0]
        elif len(request_domains) > 0:
            self.state['cur_domain'] = request_domains.most_common(1)[0][0]
        elif len(inform_domains) > 0:
            self.state['cur_domain'] = inform_domains.most_common(1)[0][0]
        elif len(sys_domains) > 0:
            self.state['cur_domain'] = sys_domains.most_common(1)[0][0]
        else:
            self.state['cur_domain'] = None

        # 当 system action 中没有 inform 的意图且存在 NoOffer 的意图,判定为确实没 offer
        # 要满足没有 Inform,是因为如果存在 Inform,意味着 system 基于之前的约束信息可能提出了新的考虑意见,
        # 所以只有当 system 基于当前约束完全无法给出建议的时候,当前 domain 的约束才算失效
        no_offer = 'NoOffer' in [x[0] for x in sys_da
                                 ] and 'Inform' not in [x[0] for x in sys_da]
        # DONE: clean cur domain constraints because no offer

        if no_offer:
            if self.state[
                    'cur_domain']:  # 没有 offer 则清空对应 domain 的 state,即 slot 对应的 value
                self.state['belief_state'][
                    self.state['cur_domain']] = deepcopy(
                        default_state()['belief_state'][
                            self.state['cur_domain']])

        # DONE: clean request slot
        # 上一个 system action 中的 inform 和 recommend 表示已经完成的 request,所以从 request_slots 移除
        for domain, slot in deepcopy(self.state['request_slots']):
            if [domain, slot] in [
                    x[1:3] for x in sys_da if x[0] in ['Inform', 'Recommend']
            ]:
                self.state['request_slots'].remove([domain, slot])

        # DONE: domain switch
        for intent, domain, slot, value in usr_da:
            # ["Select", "酒店", "源领域", "餐馆"] 找餐馆附近的酒店,所以餐馆是 from_domain
            if intent == 'Select':
                from_domain = value
                name = self.state['belief_state'][from_domain]['名称']
                if name:
                    # 如果当前 domain 等于前一个 domain 那么信息清零再更新,
                    # 与源领域一致说明同样的 domain 要换个具体值了,所以前面的信息没用了
                    if domain == from_domain:
                        self.state['belief_state'][domain] = deepcopy(
                            default_state()['belief_state'][domain])
                    self.state['belief_state'][domain]['周边{}'.format(
                        from_domain)] = name

        for intent, domain, slot, value in usr_da:
            if intent == 'Inform':
                if slot in [
                        '名称', '游玩时间', '酒店类型', '出发地', '目的地', '评分', '门票', '价格',
                        '人均消费'
                ]:
                    self.state['belief_state'][domain][slot] = value
                elif slot == '推荐菜':
                    # 推荐菜可以有多个
                    if not self.state['belief_state'][domain][slot]:
                        self.state['belief_state'][domain][slot] = value
                    else:
                        self.state['belief_state'][domain][slot] += ' ' + value
                elif '酒店设施' in slot:  # ["Inform", "酒店", "酒店设施-吹风机", "是"]
                    if value == '是':
                        facility = slot.split('-')[1]
                        if not self.state['belief_state'][domain]['酒店设施']:
                            self.state['belief_state'][domain][
                                '酒店设施'] = facility
                        else:
                            self.state['belief_state'][domain][
                                '酒店设施'] += ' ' + facility
            elif intent == 'Request':  # 存入新增的 domain-slot
                self.state['request_slots'].append([domain, slot])

        return self.state
Ejemplo n.º 10
0
 def init_session(self, state=None):
     """Initialize ``self.state`` with a default state.
     :state: see xbot.util.state.default_state
     """
     self.state = default_state() if not state else deepcopy(state)
Ejemplo n.º 11
0
 def __init__(self):
     super(RuleDST, self).__init__()
     self.state = default_state()
Ejemplo n.º 12
0
    def __init__(self):
        super(TradeDST, self).__init__()
        # load config
        common_config_path = os.path.join(get_config_path(),
                                          TradeDST.common_config_name)
        common_config = json.load(open(common_config_path))
        model_config_path = os.path.join(get_config_path(),
                                         TradeDST.model_config_name)
        model_config = json.load(open(model_config_path))
        model_config.update(common_config)
        self.model_config = model_config
        self.model_config["data_path"] = os.path.join(
            get_data_path(), "crosswoz/dst_trade_data")
        self.model_config["n_gpus"] = (0 if self.model_config["device"]
                                       == "cpu" else torch.cuda.device_count())
        self.model_config["device"] = torch.device(self.model_config["device"])
        if model_config["load_embedding"]:
            model_config["hidden_size"] = 300

        # download data
        for model_key, url in TradeDST.model_urls.items():
            dst = os.path.join(self.model_config["data_path"], model_key)
            if model_key.endswith("pth"):
                file_name = "trained_model_path"
            elif model_key.endswith("pkl"):
                file_name = model_key.rsplit("-", maxsplit=1)[0]
            else:
                file_name = model_key.split(".")[0]  # ontology
            self.model_config[file_name] = dst
            if not os.path.exists(dst) or not self.model_config["use_cache"]:
                download_from_url(url, dst)

        # load date & model
        ontology = json.load(
            open(self.model_config["ontology"], "r", encoding="utf8"))
        self.all_slots = get_slot_information(ontology)
        self.gate2id = {"ptr": 0, "none": 1}
        self.id2gate = {id_: gate for gate, id_ in self.gate2id.items()}
        self.lang = pickle.load(open(self.model_config["lang"], "rb"))
        self.mem_lang = pickle.load(open(self.model_config["mem-lang"], "rb"))

        model = Trade(
            lang=self.lang,
            vocab_size=len(self.lang.index2word),
            hidden_size=self.model_config["hidden_size"],
            dropout=self.model_config["dropout"],
            num_encoder_layers=self.model_config["num_encoder_layers"],
            num_decoder_layers=self.model_config["num_decoder_layers"],
            pad_id=self.model_config["pad_id"],
            slots=self.all_slots,
            num_gates=len(self.gate2id),
            unk_mask=self.model_config["unk_mask"],
        )

        model.load_state_dict(
            torch.load(self.model_config["trained_model_path"]))

        self.model = model.to(self.model_config["device"]).eval()
        print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
        self.state = default_state()
        print(">>> State initialized ...")
Ejemplo n.º 13
0
    def update(self, usr_da=None):
        """update belief_state, cur_domain, request_slot
        :param usr_da: List[List[intent, domain, slot, value]]
        :return: state
        """
        sys_da = self.state["system_action"]

        # 统计各个意图下的 domain
        select_domains = Counter([x[1] for x in usr_da if x[0] == "Select"])
        request_domains = Counter([x[1] for x in usr_da if x[0] == "Request"])
        inform_domains = Counter([x[1] for x in usr_da if x[0] == "Inform"])
        sys_domains = Counter(
            [x[1] for x in sys_da if x[0] in ["Inform", "Recommend"]])

        # 确定domain
        # 为什么首选 select_domain
        # 观察数据集可以发现,Select 意图出现的时候是主导整句话的,即便出现了
        # Inform 和 Request 也是为了为 Select 提供辅助信息,Inform 排在 Request 之后,也是这个道理
        if len(select_domains) > 0:
            self.state["cur_domain"] = select_domains.most_common(1)[0][0]
        elif len(request_domains) > 0:
            self.state["cur_domain"] = request_domains.most_common(1)[0][0]
        elif len(inform_domains) > 0:
            self.state["cur_domain"] = inform_domains.most_common(1)[0][0]
        elif len(sys_domains) > 0:
            self.state["cur_domain"] = sys_domains.most_common(1)[0][0]
        else:
            self.state["cur_domain"] = None

        # 当 system action 中没有 inform 的意图且存在 NoOffer 的意图,判定为确实没 offer
        # 要满足没有 Inform,是因为如果存在 Inform,意味着 system 基于之前的约束信息可能提出了新的考虑意见,
        # 所以只有当 system 基于当前约束完全无法给出建议的时候,当前 domain 的约束才算失效
        no_offer = "NoOffer" in [x[0] for x in sys_da
                                 ] and "Inform" not in [x[0] for x in sys_da]
        # DONE: clean cur domain constraints because no offer
        if no_offer:
            # ISSUE: 过于暴力的做法
            if self.state[
                    "cur_domain"]:  # 没有 offer 则清空对应 domain 的 state,即 slot 对应的 value
                self.state["belief_state"][
                    self.state["cur_domain"]] = deepcopy(
                        default_state()["belief_state"][
                            self.state["cur_domain"]])

        # DONE: clean request slot
        # 上一个 system action 中的 inform 和 recommend 表示已经完成的 request,所以从 request_slots 移除
        for domain, slot in deepcopy(self.state["request_slots"]):
            # 已解决的slot全部清除过于暴力
            if [domain, slot] in [
                    x[1:3] for x in sys_da if x[0] in ["Inform", "Recommend"]
            ]:
                self.state["request_slots"].remove([domain, slot])

        # DONE: domain switch
        # 非常强依赖前面的NLU
        for intent, domain, slot, value in usr_da:
            # ["Select", "酒店", "源领域", "餐馆"] 找餐馆附近的酒店,所以餐馆是 from_domain
            if intent == "Select":
                from_domain = value
                name = self.state["belief_state"][from_domain]["名称"]
                if name:
                    # 如果当前 domain 等于前一个 domain 那么信息清零再更新,
                    # 与源领域一致说明同样的 domain 要换个具体值了,所以前面的信息没用了
                    if domain == from_domain:
                        self.state["belief_state"][domain] = deepcopy(
                            default_state()["belief_state"][domain])
                    self.state["belief_state"][domain]["周边{}".format(
                        from_domain)] = name

        for intent, domain, slot, value in usr_da:
            if intent == "Inform":
                if slot in [
                        "名称",
                        "游玩时间",
                        "酒店类型",
                        "出发地",
                        "目的地",
                        "评分",
                        "门票",
                        "价格",
                        "人均消费",
                ]:
                    self.state["belief_state"][domain][slot] = value
                elif slot == "推荐菜":
                    # 推荐菜可以有多个
                    if not self.state["belief_state"][domain][slot]:
                        self.state["belief_state"][domain][slot] = value
                    else:
                        self.state["belief_state"][domain][slot] += " " + value
                elif "酒店设施" in slot:  # ["Inform", "酒店", "酒店设施-吹风机", "是"]
                    if value == "是":
                        facility = slot.split("-")[1]
                        if not self.state["belief_state"][domain]["酒店设施"]:
                            self.state["belief_state"][domain][
                                "酒店设施"] = facility
                        else:
                            self.state["belief_state"][domain][
                                "酒店设施"] += " " + facility
            elif intent == "Request":  # 存入新增的 domain-slot
                self.state["request_slots"].append([domain, slot])

        # ISSUE: 词槽澄清
        # ISSUE: Repeat Intent
        # ISSUE: System State
        # ISSUE: Hidden Slot
        # ISSUE: 平级槽和依赖槽

        return self.state
Ejemplo n.º 14
0
 def init_session(self) -> None:
     """Initiate state of one session. """
     self.state = default_state()