Ejemplo n.º 1
0
class MultiWozVector(Vector):
    def __init__(self,
                 voc_file,
                 voc_opp_file,
                 character='sys',
                 intent_file=os.path.join(
                     os.path.dirname(
                         os.path.dirname(
                             os.path.dirname(
                                 os.path.dirname(os.path.abspath(__file__))))),
                     'data/multiwoz/trackable_intent.json')):

        self.belief_domains = [
            'Attraction', 'Restaurant', 'Train', 'Hotel', 'Taxi', 'Hospital',
            'Police'
        ]
        self.db_domains = ['Attraction', 'Restaurant', 'Train', 'Hotel']

        with open(intent_file) as f:
            intents = json.load(f)
        self.informable = intents['informable']
        self.requestable = intents['requestable']
        self.db = Database()

        with open(voc_file) as f:
            self.da_voc = f.read().splitlines()
        with open(voc_opp_file) as f:
            self.da_voc_opp = f.read().splitlines()
        self.character = character
        self.generate_dict()

    def generate_dict(self):
        """
        init the dict for mapping state/action into vector
        """
        self.act2vec = dict((a, i) for i, a in enumerate(self.da_voc))
        self.vec2act = dict((v, k) for k, v in self.act2vec.items())
        self.da_dim = len(self.da_voc)
        self.opp2vec = dict((a, i) for i, a in enumerate(self.da_voc_opp))
        self.da_opp_dim = len(self.da_voc_opp)

        self.belief_state_dim = 0
        for domain in self.belief_domains:
            for slot, value in default_state()['belief_state'][
                    domain.lower()]['semi'].items():
                self.belief_state_dim += 1

        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
                         len(self.db_domains) + 6 * len(self.db_domains) + 1

    def pointer(self, turn):
        pointer_vector = np.zeros(6 * len(self.db_domains))
        for domain in self.db_domains:
            constraint = []
            for k, v in turn[domain.lower()]['semi'].items():
                if k in mapping[domain.lower()]:
                    constraint.append((mapping[domain.lower()][k], v))
            entities = self.db.query(domain.lower(), constraint)
            pointer_vector = self.one_hot_vector(len(entities), domain,
                                                 pointer_vector)

        return pointer_vector

    def one_hot_vector(self, num, domain, vector):
        """Return number of available entities for particular domain."""
        if domain != 'train':
            idx = self.db_domains.index(domain)
            if num == 0:
                vector[idx * 6:idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0])
            elif num == 1:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0])
            elif num == 2:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0])
            elif num == 3:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0])
            elif num == 4:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0])
            elif num >= 5:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1])
        else:
            idx = self.db_domains.index(domain)
            if num == 0:
                vector[idx * 6:idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0])
            elif num <= 2:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0])
            elif num <= 5:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0])
            elif num <= 10:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0])
            elif num <= 40:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0])
            elif num > 40:
                vector[idx * 6:idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1])

        return vector

    def state_vectorize(self, state):
        """vectorize a state

        Args:
            state (dict):
                Dialog state
            action (tuple):
                Dialog act
        Returns:
            state_vec (np.array):
                Dialog state vector
        """
        self.state = state['belief_state']

        action = state['user_action'] if self.character == 'sys' else state[
            'system_action']
        opp_action = delexicalize_da(action, self.requestable)
        opp_action = flat_da(opp_action)
        opp_act_vec = np.zeros(self.da_opp_dim)
        for da in opp_action:
            if da in self.opp2vec:
                opp_act_vec[self.opp2vec[da]] = 1.

        action = state['system_action'] if self.character == 'sys' else state[
            'user_action']
        action = delexicalize_da(action, self.requestable)
        action = flat_da(action)
        last_act_vec = np.zeros(self.da_dim)
        for da in action:
            if da in self.act2vec:
                last_act_vec[self.act2vec[da]] = 1.

        belief_state = np.zeros(self.belief_state_dim)
        i = 0
        for domain in self.belief_domains:
            for slot, value in state['belief_state'][
                    domain.lower()]['semi'].items():
                if value:
                    belief_state[i] = 1.
                i += 1

        book = np.zeros(len(self.db_domains))
        for i, domain in enumerate(self.db_domains):
            if state['belief_state'][domain.lower()]['book']['booked']:
                book[i] = 1.

        degree = self.pointer(state['belief_state'])

        final = 1. if state['terminated'] else 0.

        state_vec = np.r_[opp_act_vec, last_act_vec, belief_state, book,
                          degree, final]
        assert len(state_vec) == self.state_dim
        return state_vec

    def action_devectorize(self, action_vec):
        """
        recover an action
        Args:
            action_vec (np.array):
                Dialog act vector
        Returns:
            action (tuple):
                Dialog act
        """
        act_array = []
        for i, idx in enumerate(action_vec):
            if idx == 1:
                act_array.append(self.vec2act[i])
        action = deflat_da(act_array)
        entities = {}
        for domint in action:
            domain, intent = domint.split('-')
            if domain not in entities and domain.lower() not in [
                    'general', 'booking'
            ]:
                constraint = []
                for k, v in self.state[domain.lower()]['semi'].items():
                    if k in mapping[domain.lower()]:
                        constraint.append((mapping[domain.lower()][k], v))
                entities[domain] = self.db.query(domain.lower(), constraint)
        action = lexicalize_da(action, entities, self.state, self.requestable)
        return action

    def action_vectorize(self, action):
        action = delexicalize_da(action, self.requestable)
        action = flat_da(action)
        act_vec = np.zeros(self.da_dim)
        for da in action:
            if da in self.act2vec:
                act_vec[self.act2vec[da]] = 1.
        return act_vec
Ejemplo n.º 2
0
class GoalGenerator:
    """User goal generator."""
    def __init__(self,
                 goal_model_path=os.path.join(
                     get_root_path(), 'data/multiwoz/goal/new_goal_model.pkl'),
                 corpus_path=None,
                 boldify=False,
                 sample_info_from_trainset=True,
                 sample_reqt_from_trainset=False):
        """
        Args:
            goal_model_path: path to a goal model
            corpus_path: path to a dialog corpus to build a goal model
            boldify: highlight some information in the goal message
            sample_info_from_trainset: if True, sample info slots combination from train set, else sample each slot independently
            sample_reqt_from_trainset: if True, sample reqt slots combination from train set, else sample each slot independently
        """
        self.goal_model_path = goal_model_path
        self.corpus_path = corpus_path
        self.db = Database()
        self.boldify = do_boldify if boldify else null_boldify
        self.sample_info_from_trainset = sample_info_from_trainset
        self.sample_reqt_from_trainset = sample_reqt_from_trainset
        self.train_database = self.db.query('train', [])
        if os.path.exists(self.goal_model_path):
            self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist, self.slots_num_dist, self.slots_combination_dist = pickle.load(
                open(self.goal_model_path, 'rb'))
            print('Loading goal model is done')
        else:
            self._build_goal_model()
            print('Building goal model is done')

        # remove some slot
        del self.ind_slot_dist['police']['reqt']['postcode']
        del self.ind_slot_value_dist['police']['reqt']['postcode']
        del self.ind_slot_dist['hospital']['reqt']['postcode']
        del self.ind_slot_value_dist['hospital']['reqt']['postcode']
        del self.ind_slot_dist['hospital']['reqt']['address']
        del self.ind_slot_value_dist['hospital']['reqt']['address']

        # print(self.slots_combination_dist['police'])
        # print(self.slots_combination_dist['hospital'])
        # pprint(self.ind_slot_dist)
        # pprint(self.slots_num_dist)
        # pprint(self.slots_combination_dist)

    def _build_goal_model(self):
        dialogs = json.load(open(self.corpus_path))

        # domain ordering
        def _get_dialog_domains(dialog):
            return list(
                filter(lambda x: x in domains and len(dialog['goal'][x]) > 0,
                       dialog['goal']))

        domain_orderings = []
        for d in dialogs:
            d_domains = _get_dialog_domains(dialogs[d])
            first_index = []
            for domain in d_domains:
                message = [dialogs[d]['goal']['message']] if type(dialogs[d]['goal']['message']) == str else \
                dialogs[d]['goal']['message']
                for i, m in enumerate(message):
                    if domain_keywords[domain].lower() in m.lower(
                    ) or domain.lower() in m.lower():
                        first_index.append(i)
                        break
            domain_orderings.append(
                tuple(
                    map(
                        lambda x: x[1],
                        sorted(zip(first_index, d_domains),
                               key=lambda x: x[0]))))
        domain_ordering_cnt = Counter(domain_orderings)
        self.domain_ordering_dist = deepcopy(domain_ordering_cnt)
        for order in domain_ordering_cnt.keys():
            self.domain_ordering_dist[order] = domain_ordering_cnt[
                order] / sum(domain_ordering_cnt.values())

        # independent goal slot distribution
        ind_slot_value_cnt = dict([(domain, {}) for domain in domains])
        domain_cnt = Counter()
        book_cnt = Counter()
        self.slots_combination_dist = {domain: {} for domain in domains}
        self.slots_num_dist = {domain: {} for domain in domains}

        for d in dialogs:
            for domain in domains:
                if dialogs[d]['goal'][domain] != {}:
                    domain_cnt[domain] += 1
                if 'info' in dialogs[d]['goal'][domain]:
                    if 'info' not in self.slots_combination_dist[domain]:
                        self.slots_combination_dist[domain]['info'] = {}
                        self.slots_num_dist[domain]['info'] = {}

                    slots = sorted(
                        list(dialogs[d]['goal'][domain]['info'].keys()))
                    self.slots_combination_dist[domain]['info'].setdefault(
                        tuple(slots), 0)
                    self.slots_combination_dist[domain]['info'][tuple(
                        slots)] += 1
                    self.slots_num_dist[domain]['info'].setdefault(
                        len(slots), 0)
                    self.slots_num_dist[domain]['info'][len(slots)] += 1

                    for slot in dialogs[d]['goal'][domain]['info']:
                        if 'invalid' in slot:
                            continue
                        if 'info' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['info'] = {}
                        if slot not in ind_slot_value_cnt[domain]['info']:
                            ind_slot_value_cnt[domain]['info'][slot] = Counter(
                            )
                        if 'care' in dialogs[d]['goal'][domain]['info'][slot]:
                            continue
                        ind_slot_value_cnt[domain]['info'][slot][
                            dialogs[d]['goal'][domain]['info'][slot]] += 1
                if 'reqt' in dialogs[d]['goal'][domain]:
                    if 'reqt' not in self.slots_combination_dist[domain]:
                        self.slots_combination_dist[domain]['reqt'] = {}
                        self.slots_num_dist[domain]['reqt'] = {}
                    slots = sorted(dialogs[d]['goal'][domain]['reqt'])
                    if domain in ['police', 'hospital'
                                  ] and 'postcode' in slots:
                        slots.remove('postcode')
                    else:
                        assert len(slots) > 0, print(
                            sorted(dialogs[d]['goal'][domain]['reqt']),
                            [slots])
                    if len(slots) > 0:
                        self.slots_combination_dist[domain]['reqt'].setdefault(
                            tuple(slots), 0)
                        self.slots_combination_dist[domain]['reqt'][tuple(
                            slots)] += 1
                        self.slots_num_dist[domain]['reqt'].setdefault(
                            len(slots), 0)
                        self.slots_num_dist[domain]['reqt'][len(slots)] += 1

                    for slot in dialogs[d]['goal'][domain]['reqt']:
                        if 'reqt' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['reqt'] = Counter()
                        ind_slot_value_cnt[domain]['reqt'][slot] += 1
                if 'book' in dialogs[d]['goal'][domain]:
                    book_cnt[domain] += 1
                    for slot in dialogs[d]['goal'][domain]['book']:
                        if 'invalid' in slot:
                            continue
                        if 'book' not in ind_slot_value_cnt[domain]:
                            ind_slot_value_cnt[domain]['book'] = {}
                        if slot not in ind_slot_value_cnt[domain]['book']:
                            ind_slot_value_cnt[domain]['book'][slot] = Counter(
                            )
                        if 'care' in dialogs[d]['goal'][domain]['book'][slot]:
                            continue
                        ind_slot_value_cnt[domain]['book'][slot][
                            dialogs[d]['goal'][domain]['book'][slot]] += 1

        # pprint(self.slots_num_dist)
        # pprint(self.slots_combination_dist)
        # for domain in domains:
        #     print(domain, len(self.slots_combination_dist[domain]['info']))
        self.ind_slot_value_dist = deepcopy(ind_slot_value_cnt)
        self.ind_slot_dist = dict([(domain, {}) for domain in domains])
        self.book_dist = {}
        for domain in domains:
            if 'info' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['info']:
                    if 'info' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['info'] = {}
                    if slot not in self.ind_slot_dist[domain]['info']:
                        self.ind_slot_dist[domain]['info'][slot] = {}
                    self.ind_slot_dist[domain]['info'][slot] = sum(ind_slot_value_cnt[domain]['info'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(
                        ind_slot_value_cnt[domain]['info'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['info'][slot]:
                        self.ind_slot_value_dist[domain]['info'][slot][
                            val] = ind_slot_value_cnt[domain]['info'][slot][
                                val] / slot_total
            if 'reqt' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['reqt']:
                    if 'reqt' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['reqt'] = {}
                    self.ind_slot_dist[domain]['reqt'][
                        slot] = ind_slot_value_cnt[domain]['reqt'][
                            slot] / domain_cnt[domain]
                    self.ind_slot_value_dist[domain]['reqt'][slot] = ind_slot_value_cnt[domain]['reqt'][slot] / \
                                                                     domain_cnt[domain]
            if 'book' in ind_slot_value_cnt[domain]:
                for slot in ind_slot_value_cnt[domain]['book']:
                    if 'book' not in self.ind_slot_dist[domain]:
                        self.ind_slot_dist[domain]['book'] = {}
                    if slot not in self.ind_slot_dist[domain]['book']:
                        self.ind_slot_dist[domain]['book'][slot] = {}
                    self.ind_slot_dist[domain]['book'][slot] = sum(ind_slot_value_cnt[domain]['book'][slot].values()) / \
                                                               domain_cnt[domain]
                    slot_total = sum(
                        ind_slot_value_cnt[domain]['book'][slot].values())
                    for val in self.ind_slot_value_dist[domain]['book'][slot]:
                        self.ind_slot_value_dist[domain]['book'][slot][
                            val] = ind_slot_value_cnt[domain]['book'][slot][
                                val] / slot_total
            self.book_dist[domain] = book_cnt[domain] / len(dialogs)

        pickle.dump((self.ind_slot_dist, self.ind_slot_value_dist,
                     self.domain_ordering_dist, self.book_dist,
                     self.slots_num_dist, self.slots_combination_dist),
                    open(self.goal_model_path, 'wb'))

    def _get_domain_goal(self, domain):
        cnt_slot = self.ind_slot_dist[domain]
        cnt_slot_value = self.ind_slot_value_dist[domain]
        pro_book = self.book_dist[domain]

        while True:
            # domain_goal = defaultdict(lambda: {})
            # domain_goal = {'info': {}, 'fail_info': {}, 'reqt': {}, 'book': {}, 'fail_book': {}}
            domain_goal = {'info': {}}
            # inform
            if 'info' in cnt_slot:
                if self.sample_info_from_trainset:
                    slots = random.choices(
                        list(self.slots_combination_dist[domain]
                             ['info'].keys()),
                        list(self.slots_combination_dist[domain]
                             ['info'].values()))[0]
                    for slot in slots:
                        domain_goal['info'][slot] = nomial_sample(
                            cnt_slot_value['info'][slot])
                else:
                    for slot in cnt_slot['info']:
                        if random.random(
                        ) < cnt_slot['info'][slot] + pro_correction['info']:
                            domain_goal['info'][slot] = nomial_sample(
                                cnt_slot_value['info'][slot])

                if domain in ['hotel', 'restaurant', 'attraction'
                              ] and 'name' in domain_goal['info'] and len(
                                  domain_goal['info']) > 1:
                    if random.random() < cnt_slot['info']['name']:
                        domain_goal['info'] = {
                            'name': domain_goal['info']['name']
                        }
                    else:
                        del domain_goal['info']['name']

                if domain in ['taxi', 'train'] and 'arriveBy' in domain_goal[
                        'info'] and 'leaveAt' in domain_goal['info']:
                    if random.random() < (cnt_slot['info']['leaveAt'] /
                                          (cnt_slot['info']['arriveBy'] +
                                           cnt_slot['info']['leaveAt'])):
                        del domain_goal['info']['arriveBy']
                    else:
                        del domain_goal['info']['leaveAt']

                if domain in ['taxi', 'train'] and 'arriveBy' not in domain_goal['info'] and 'leaveAt' not in \
                        domain_goal['info']:
                    if random.random() < (cnt_slot['info']['arriveBy'] /
                                          (cnt_slot['info']['arriveBy'] +
                                           cnt_slot['info']['leaveAt'])):
                        domain_goal['info']['arriveBy'] = nomial_sample(
                            cnt_slot_value['info']['arriveBy'])
                    else:
                        domain_goal['info']['leaveAt'] = nomial_sample(
                            cnt_slot_value['info']['leaveAt'])

                if domain in ['train']:
                    random_train = random.choice(self.train_database)
                    domain_goal['info']['departure'] = random_train[
                        'departure']
                    domain_goal['info']['destination'] = random_train[
                        'destination']

                if domain in ['taxi'
                              ] and 'departure' not in domain_goal['info']:
                    domain_goal['info']['departure'] = nomial_sample(
                        cnt_slot_value['info']['departure'])

                if domain in ['taxi'
                              ] and 'destination' not in domain_goal['info']:
                    domain_goal['info']['destination'] = nomial_sample(
                        cnt_slot_value['info']['destination'])

                if domain in ['taxi'] and \
                        'departure' in domain_goal['info'] and \
                        'destination' in domain_goal['info'] and \
                        domain_goal['info']['departure'] == domain_goal['info']['destination']:
                    if random.random() < (cnt_slot['info']['departure'] /
                                          (cnt_slot['info']['departure'] +
                                           cnt_slot['info']['destination'])):
                        domain_goal['info']['departure'] = nomial_sample(
                            cnt_slot_value['info']['departure'])
                    else:
                        domain_goal['info']['destination'] = nomial_sample(
                            cnt_slot_value['info']['destination'])
                if domain_goal['info'] == {}:
                    continue
            # request
            if 'reqt' in cnt_slot:
                if self.sample_reqt_from_trainset:
                    not_in_info_slots = {}
                    for slots in self.slots_combination_dist[domain]['reqt']:
                        for slot in slots:
                            if slot in domain_goal['info']:
                                break
                        else:
                            not_in_info_slots[
                                slots] = self.slots_combination_dist[domain][
                                    'reqt'][slots]
                    pprint(not_in_info_slots)
                    reqt = list(
                        random.choices(list(not_in_info_slots.keys()),
                                       list(not_in_info_slots.values()))[0])
                else:
                    reqt = [
                        slot for slot in cnt_slot['reqt']
                        if random.random() < cnt_slot['reqt'][slot] +
                        pro_correction['reqt']
                        and slot not in domain_goal['info']
                    ]
                if len(reqt) > 0:
                    domain_goal['reqt'] = reqt

            # book
            if 'book' in cnt_slot and random.random(
            ) < pro_book + pro_correction['book']:
                if 'book' not in domain_goal:
                    domain_goal['book'] = {}

                for slot in cnt_slot['book']:
                    if random.random(
                    ) < cnt_slot['book'][slot] + pro_correction['book']:
                        domain_goal['book'][slot] = nomial_sample(
                            cnt_slot_value['book'][slot])

                # makes sure that there are all necessary slots for booking
                if domain == 'restaurant' and 'time' not in domain_goal['book']:
                    domain_goal['book']['time'] = nomial_sample(
                        cnt_slot_value['book']['time'])

                if domain == 'hotel' and 'stay' not in domain_goal['book']:
                    domain_goal['book']['stay'] = nomial_sample(
                        cnt_slot_value['book']['stay'])

                if domain in ['hotel', 'restaurant'
                              ] and 'day' not in domain_goal['book']:
                    domain_goal['book']['day'] = nomial_sample(
                        cnt_slot_value['book']['day'])

                if domain in ['hotel', 'restaurant'
                              ] and 'people' not in domain_goal['book']:
                    domain_goal['book']['people'] = nomial_sample(
                        cnt_slot_value['book']['people'])

                if domain == 'train' and len(domain_goal['book']) <= 0:
                    domain_goal['book']['people'] = nomial_sample(
                        cnt_slot_value['book']['people'])

            # fail_book: not use any more since 2020.8.18
            # if 'book' in domain_goal and random.random() < 0.5:
            #     if domain == 'hotel':
            #         domain_goal['fail_book'] = deepcopy(domain_goal['book'])
            #         if 'stay' in domain_goal['book'] and random.random() < 0.5:
            #             # increase hotel-stay
            #             domain_goal['fail_book']['stay'] = str(int(domain_goal['book']['stay']) + 1)
            #         elif 'day' in domain_goal['book']:
            #             # push back hotel-day by a day
            #             domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7]
            #
            #     elif domain == 'restaurant':
            #         domain_goal['fail_book'] = deepcopy(domain_goal['book'])
            #         if 'time' in domain_goal['book'] and random.random() < 0.5:
            #             hour, minute = domain_goal['book']['time'].split(':')
            #             domain_goal['fail_book']['time'] = str((int(hour) + 1) % 24) + ':' + minute
            #         elif 'day' in domain_goal['book']:
            #             if random.random() < 0.5:
            #                 domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) - 1) % 7]
            #             else:
            #                 domain_goal['fail_book']['day'] = days[(days.index(domain_goal['book']['day']) + 1) % 7]

            # fail_info
            if 'info' in domain_goal and len(
                    self.db.query(domain, domain_goal['info'].items())) == 0:
                num_trial = 0
                while num_trial < 100:
                    adjusted_info = self._adjust_info(domain,
                                                      domain_goal['info'])
                    if len(self.db.query(domain, adjusted_info.items())) > 0:
                        if domain == 'train':
                            domain_goal['info'] = adjusted_info
                        else:
                            # first ask fail_info which return no result then ask info
                            domain_goal['fail_info'] = domain_goal['info']
                            domain_goal['info'] = adjusted_info

                        break
                    num_trial += 1

                if num_trial >= 100:
                    continue

            # at least there is one request and book
            if 'reqt' in domain_goal or 'book' in domain_goal:
                break

        return domain_goal

    def get_user_goal(self):
        domain_ordering = ()
        while len(domain_ordering) <= 0:
            domain_ordering = nomial_sample(self.domain_ordering_dist)
        # domain_ordering = ('restaurant',)

        user_goal = {
            dom: self._get_domain_goal(dom)
            for dom in domain_ordering
        }
        assert len(user_goal.keys()) > 0

        # using taxi to communte between places, removing destination and departure.
        if 'taxi' in domain_ordering:
            places = [
                dom for dom in domain_ordering[:domain_ordering.index('taxi')]
                if dom in
                ['attraction', 'hotel', 'restaurant', 'police', 'hospital']
            ]
            if len(places) >= 1:
                del user_goal['taxi']['info']['destination']
                # if 'reqt' not in user_goal[places[-1]]:
                #     user_goal[places[-1]]['reqt'] = []
                # if 'address' not in user_goal[places[-1]]['reqt']:
                #     user_goal[places[-1]]['reqt'].append('address')
                # the line below introduce randomness by `union`
                # user_goal[places[-1]]['reqt'] = list(set(user_goal[places[-1]].get('reqt', [])).union({'address'}))
                if places[-1] == 'restaurant' and 'book' in user_goal[
                        'restaurant']:
                    user_goal['taxi']['info']['arriveBy'] = user_goal[
                        'restaurant']['book']['time']
                    if 'leaveAt' in user_goal['taxi']['info']:
                        del user_goal['taxi']['info']['leaveAt']
            if len(places) >= 2:
                del user_goal['taxi']['info']['departure']
                # if 'reqt' not in user_goal[places[-2]]:
                #     user_goal[places[-2]]['reqt'] = []
                # if 'address' not in user_goal[places[-2]]['reqt']:
                #     user_goal[places[-2]]['reqt'].append('address')
                # the line below introduce randomness by `union`
                # user_goal[places[-2]]['reqt'] = list(set(user_goal[places[-2]].get('reqt', [])).union({'address'}))

        # match area of attraction and restaurant
        if 'restaurant' in domain_ordering and \
                'attraction' in domain_ordering and \
                'fail_info' not in user_goal['restaurant'] and \
                domain_ordering.index('restaurant') > domain_ordering.index('attraction') and \
                'area' in user_goal['restaurant']['info'] and 'area' in user_goal['attraction']['info']:
            adjusted_restaurant_goal = deepcopy(
                user_goal['restaurant']['info'])
            adjusted_restaurant_goal['area'] = user_goal['attraction']['info'][
                'area']
            if len(
                    self.db.query('restaurant', adjusted_restaurant_goal.items(
                    ))) > 0 and random.random() < 0.5:
                user_goal['restaurant']['info']['area'] = user_goal[
                    'attraction']['info']['area']

        # match day and people of restaurant and hotel
        if 'restaurant' in domain_ordering and 'hotel' in domain_ordering and \
                'book' in user_goal['restaurant'] and 'book' in user_goal['hotel']:
            if random.random() < 0.5:
                user_goal['restaurant']['book']['people'] = user_goal['hotel'][
                    'book']['people']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['people'] = user_goal[
                        'hotel']['book']['people']
            if random.random() < 1.0:
                user_goal['restaurant']['book']['day'] = user_goal['hotel'][
                    'book']['day']
                if 'fail_book' in user_goal['restaurant']:
                    user_goal['restaurant']['fail_book']['day'] = user_goal[
                        'hotel']['book']['day']
                    if user_goal['restaurant']['book']['day'] == user_goal['restaurant']['fail_book']['day'] and \
                            user_goal['restaurant']['book']['time'] == user_goal['restaurant']['fail_book']['time'] and \
                            user_goal['restaurant']['book']['people'] == user_goal['restaurant']['fail_book']['people']:
                        del user_goal['restaurant']['fail_book']

        # match day and people of hotel and train
        if 'hotel' in domain_ordering and 'train' in domain_ordering and \
                'book' in user_goal['hotel'] and 'info' in user_goal['train']:
            if user_goal['train']['info']['destination'] == 'cambridge' and \
                'day' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = user_goal['hotel']['book'][
                    'day']
            elif user_goal['train']['info']['departure'] == 'cambridge' and \
                'day' in user_goal['hotel']['book'] and 'stay' in user_goal['hotel']['book']:
                user_goal['train']['info']['day'] = days[
                    (days.index(user_goal['hotel']['book']['day']) +
                     int(user_goal['hotel']['book']['stay'])) % 7]
            # In case, we have no query results with adjusted train goal, we simply drop the train goal.
            if len(self.db.query('train',
                                 user_goal['train']['info'].items())) == 0:
                del user_goal['train']
                domain_ordering = tuple(list(domain_ordering).remove('train'))

        for domain in user_goal:
            if not user_goal[domain]['info']:
                user_goal[domain]['info'] = {'none': 'none'}

        user_goal['domain_ordering'] = domain_ordering

        return user_goal

    def _adjust_info(self, domain, info):
        # adjust one of the slots of the info
        adjusted_info = deepcopy(info)
        slot = random.choice(list(info.keys()))
        adjusted_info[slot] = random.choice(
            list(self.ind_slot_value_dist[domain]['info'][slot].keys()))
        return adjusted_info

    def build_message(self, user_goal, boldify=null_boldify):
        message = []
        message_by_domain = []
        mess_ptr4domain = 0
        state = deepcopy(user_goal)

        for dom in user_goal['domain_ordering']:
            dom_msg = []
            state = deepcopy(user_goal[dom])
            num_acts_in_unit = 0

            if not (dom == 'taxi' and len(state['info']) == 1):
                # intro
                m = [templates[dom]['intro']]

            # info
            def fill_info_template(user_goal, domain, slot, info):
                if slot != 'area' or not (
                        'restaurant' in user_goal and 'attraction' in user_goal
                        and info in user_goal['restaurant'].keys()
                        and info in user_goal['attraction'].keys()
                        and 'area' in user_goal['restaurant'][info]
                        and 'area' in user_goal['attraction'][info]
                        and user_goal['restaurant'][info]['area']
                        == user_goal['attraction'][info]['area']):
                    return templates[domain][slot].format(
                        self.boldify(user_goal[domain][info][slot]))
                else:
                    restaurant_index = user_goal['domain_ordering'].index(
                        'restaurant')
                    attraction_index = user_goal['domain_ordering'].index(
                        'attraction')
                    if restaurant_index > attraction_index and domain == 'restaurant':
                        return templates[domain][slot].format(
                            self.boldify('same area as the attraction'))
                    elif attraction_index > restaurant_index and domain == 'attraction':
                        return templates[domain][slot].format(
                            self.boldify('same area as the restaurant'))
                return templates[domain][slot].format(
                    self.boldify(user_goal[domain][info][slot]))

            info = 'info'
            if 'fail_info' in user_goal[dom]:
                info = 'fail_info'
            if dom == 'taxi' and len(state[info]) == 1:
                taxi_index = user_goal['domain_ordering'].index('taxi')
                places = [
                    dom for dom in user_goal['domain_ordering'][:taxi_index]
                    if dom in ['attraction', 'hotel', 'restaurant']
                ]
                if len(places) >= 2:
                    random.shuffle(places)
                    m.append(templates['taxi']['commute'])
                    if 'arriveBy' in state[info]:
                        m.append(
                            'The taxi should arrive at the {} from the {} by {}.'
                            .format(self.boldify(places[0]),
                                    self.boldify(places[1]),
                                    self.boldify(state[info]['arriveBy'])))
                    elif 'leaveAt' in state[info]:
                        m.append(
                            'The taxi should leave from the {} to the {} after {}.'
                            .format(self.boldify(places[0]),
                                    self.boldify(places[1]),
                                    self.boldify(state[info]['leaveAt'])))
                    message.append(' '.join(m))
            else:
                while len(state[info]) > 0:
                    num_acts = random.randint(1, min(len(state[info]), 3))
                    slots = random.sample(list(state[info].keys()), num_acts)
                    sents = [
                        fill_info_template(user_goal, dom, slot, info)
                        for slot in slots
                        if slot not in ['parking', 'internet']
                    ]
                    if 'parking' in slots:
                        sents.append(templates[dom]['parking ' +
                                                    state[info]['parking']])
                    if 'internet' in slots:
                        sents.append(templates[dom]['internet ' +
                                                    state[info]['internet']])
                    m.extend(sents)
                    message.append(' '.join(m))
                    m = []
                    for slot in slots:
                        del state[info][slot]

            # fail_info
            if 'fail_info' in user_goal[dom]:
                # if 'fail_info' in user_goal[dom]:
                adjusted_slot = list(
                    filter(
                        lambda x: x[0][1] != x[1][1],
                        zip(user_goal[dom]['info'].items(),
                            user_goal[dom]['fail_info'].items())))[0][0][0]
                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_info ' + adjusted_slot + ' ' +
                                       user_goal[dom]['info'][adjusted_slot]])
                else:
                    message.append(
                        templates[dom]['fail_info ' + adjusted_slot].format(
                            self.boldify(
                                user_goal[dom]['info'][adjusted_slot])))

            # reqt
            if 'reqt' in state:
                slot_strings = []
                for slot in state['reqt']:
                    if slot in ['internet', 'parking', 'food']:
                        continue
                    slot_strings.append(slot if slot not in
                                        request_slot_string_map else
                                        request_slot_string_map[slot])
                if len(slot_strings) > 0:
                    message.append(templates[dom]['request'].format(
                        self.boldify(', '.join(slot_strings))))
                if 'internet' in state['reqt']:
                    message.append(
                        'Make sure to ask if the hotel includes free wifi.')
                if 'parking' in state['reqt']:
                    message.append(
                        'Make sure to ask if the hotel includes free parking.')
                if 'food' in state['reqt']:
                    message.append(
                        'Make sure to ask about what food it serves.')

            def get_same_people_domain(user_goal, domain, slot):
                if slot not in ['day', 'people']:
                    return None
                domain_index = user_goal['domain_ordering'].index(domain)
                previous_domains = user_goal['domain_ordering'][:domain_index]
                for prev in previous_domains:
                    if prev in ['restaurant', 'hotel', 'train'] and 'book' in user_goal[prev] and \
                                    slot in user_goal[prev]['book'] and user_goal[prev]['book'][slot] == \
                            user_goal[domain]['book'][slot]:
                        return prev
                return None

            # book
            book = 'book'
            if 'fail_book' in user_goal[dom]:
                book = 'fail_book'
            if 'book' in state:
                slot_strings = []
                for slot in ['people', 'time', 'day', 'stay']:
                    if slot in state[book]:
                        if slot == 'people':
                            same_people_domain = get_same_people_domain(
                                user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('for {} people'.format(
                                    self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(
                                    self.boldify(
                                        'for the same group of people as the {} booking'
                                        .format(same_people_domain)))
                        elif slot == 'time':
                            slot_strings.append('at {}'.format(
                                self.boldify(state[book][slot])))
                        elif slot == 'day':
                            same_people_domain = get_same_people_domain(
                                user_goal, dom, slot)
                            if same_people_domain is None:
                                slot_strings.append('on {}'.format(
                                    self.boldify(state[book][slot])))
                            else:
                                slot_strings.append(
                                    self.boldify(
                                        'on the same day as the {} booking'.
                                        format(same_people_domain)))
                        elif slot == 'stay':
                            slot_strings.append('for {} nights'.format(
                                self.boldify(state[book][slot])))
                        del state[book][slot]

                assert len(state[book]) <= 0, state[book]

                if len(slot_strings) > 0:
                    message.append(templates[dom]['book'].format(
                        ' '.join(slot_strings)))

            # fail_book
            if 'fail_book' in user_goal[dom]:
                adjusted_slot = list(
                    filter(
                        lambda x: x[0][1] != x[1][1],
                        zip(user_goal[dom]['book'].items(),
                            user_goal[dom]['fail_book'].items())))[0][0][0]

                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot + ' ' +
                                       user_goal[dom]['book'][adjusted_slot]])
                else:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot].format(
                            self.boldify(
                                user_goal[dom]['book'][adjusted_slot])))

            dm = message[mess_ptr4domain:]
            mess_ptr4domain = len(message)
            message_by_domain.append(' '.join(dm))

        if boldify == do_boldify:
            for i, m in enumerate(message):
                message[i] = message[i].replace('wifi', "<b>wifi</b>")
                message[i] = message[i].replace('internet', "<b>internet</b>")
                message[i] = message[i].replace('parking', "<b>parking</b>")

        return message, message_by_domain
Ejemplo n.º 3
0
class HDSA_predictor():
    def __init__(self, archive_file, model_file=None, use_cuda=False):
        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for DA-predictor is specified!")
            archive_file = cached_path(model_file)
        model_dir = os.path.dirname(os.path.abspath(__file__))
        if not os.path.exists(os.path.join(model_dir, 'checkpoints')):
            archive = zipfile.ZipFile(archive_file, 'r')
            archive.extractall(model_dir)

        load_dir = os.path.join(model_dir,
                                "checkpoints/predictor/save_step_23926")
        self.db = Database()
        if not os.path.exists(load_dir):
            archive = zipfile.ZipFile('{}.zip'.format(load_dir), 'r')
            archive.extractall(os.path.dirname(load_dir))

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                       do_lower_case=True)
        self.max_seq_length = 256
        self.domain = 'restaurant'
        self.model = BertForSequenceClassification.from_pretrained(
            load_dir,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(-1)),
            num_labels=44)
        self.device = 'cuda' if use_cuda else 'cpu'
        self.model.to(self.device)

    def gen_example(self, state):
        file = ''
        turn = 0
        guid = 'infer'

        act = state['user_action']
        for w in act:
            d = w[1]
            if Constants.domains.index(d.lower()) < 8:
                self.domain = d.lower()
        hierarchical_act_vecs = [0 for _ in range(44)]  # fake target

        meta = state['belief_state']
        constraints = []
        if self.domain != 'bus':
            for slot in meta[self.domain]['semi']:
                if meta[self.domain]['semi'][slot] != "":
                    constraints.append([slot, meta[self.domain]['semi'][slot]])
        query_result = self.db.query(self.domain, constraints)
        if not query_result:
            kb = {'count': '0'}
            src = "no information"
        else:
            kb = query_result[0]
            kb['count'] = str(len(query_result))
            src = []
            for k, v in kb.items():
                k = examine(self.domain, k.lower())
                if k != 'illegal' and isinstance(v, str):
                    src.extend([k, 'is', v])
            src = " ".join(src)

        usr = state['history'][-1][-1]
        sys = state['history'][-2][-1] if len(state['history']) > 1 else None

        example = InputExample(file, turn, guid, src, usr, sys,
                               hierarchical_act_vecs)
        kb['domain'] = self.domain
        return example, kb

    def gen_feature(self, example):
        tokens_a = self.tokenizer.tokenize(example.text_a)
        tokens_b = self.tokenizer.tokenize(example.text_b)
        tokens_m = self.tokenizer.tokenize(example.text_m)
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with "- 3"
        truncate_seq_pair(tokens_a, tokens_b, self.max_seq_length - 3)

        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0] * (len(tokens_a) + 2)

        assert len(tokens) == len(segment_ids)

        tokens += tokens_b + ["[SEP]"]
        segment_ids += [1] * (len(tokens_b) + 1)

        if len(tokens) < self.max_seq_length:
            if len(tokens_m) > self.max_seq_length - len(tokens) - 1:
                tokens_m = tokens_m[:self.max_seq_length - len(tokens) - 1]

            tokens += tokens_m + ['[SEP]']
            segment_ids += [0] * (len(tokens_m) + 1)

        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)
        # Zero-pad up to the sequence length.
        padding = [0] * (self.max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == self.max_seq_length
        assert len(input_mask) == self.max_seq_length
        assert len(segment_ids) == self.max_seq_length

        feature = InputFeatures(file=example.file,
                                turn=example.turn,
                                input_ids=input_ids,
                                input_mask=input_mask,
                                segment_ids=segment_ids,
                                label_id=example.label)
        return feature

    def predict(self, state):

        example, kb = self.gen_example(state)
        feature = self.gen_feature(example)

        input_ids = torch.tensor([feature.input_ids],
                                 dtype=torch.long).to(self.device)
        input_masks = torch.tensor([feature.input_mask],
                                   dtype=torch.long).to(self.device)
        segment_ids = torch.tensor([feature.segment_ids],
                                   dtype=torch.long).to(self.device)

        with torch.no_grad():
            logits = self.model(input_ids,
                                segment_ids,
                                input_masks,
                                labels=None)
            logits = torch.sigmoid(logits)
        preds = (logits > 0.4).float()
        preds_numpy = preds.cpu().nonzero().squeeze().numpy()

        #        for i in preds_numpy:
        #            if i < 10:
        #                print(Constants.domains[i], end=' ')
        #            elif i < 17:
        #                print(Constants.functions[i-10], end=' ')
        #            else:
        #                print(Constants.arguments[i-17], end=' ')
        #        print()

        return preds, kb
class RuleBasedMultiwozBot(Policy):
    ''' Rule-based bot. Implemented for Multiwoz dataset. '''

    recommend_flag = -1
    choice = ""

    def __init__(self):
        Policy.__init__(self)
        self.last_state = {}
        self.db = Database()

    def init_session(self):
        self.last_state = {}

    def predict(self, state):
        """
        Args:
            State, please refer to util/state.py
        Output:
            DA(Dialog Act), in the form of {act_type1: [[slot_name_1, value_1], [slot_name_2, value_2], ...], ...}
        """
        # print('policy received state: {}'.format(state))

        if self.recommend_flag != -1:
            self.recommend_flag += 1

        self.kb_result = {}

        DA = {}

        if 'user_action' in state and (len(state['user_action']) > 0):
            user_action = {}
            for da in state['user_action']:
                i, d, s, v = da
                k = '-'.join((d, i))
                if k not in user_action:
                    user_action[k] = []
                    user_action[k].append([s, v])
        else:
            user_action = check_diff(self.last_state, state)

        # Debug info for check_diff function

        last_state_cpy = copy.deepcopy(self.last_state)
        state_cpy = copy.deepcopy(state)

        try:
            del last_state_cpy['history']
        except:
            pass

        try:
            del state_cpy['history']
        except:
            pass
        '''
        if last_state_cpy != state_cpy:
            print("Last state: ", last_state_cpy)
            print("State: ", state_cpy)
            print("Predicted action: ", user_action)
        '''

        self.last_state = state

        for user_act in user_action:
            domain, intent_type = user_act.split('-')

            # Respond to general greetings
            if domain == 'general':
                self._update_greeting(user_act, state, DA)

            # Book taxi for user
            elif domain == 'Taxi':
                self._book_taxi(user_act, state, DA)

            elif domain == 'Booking':
                self._update_booking(user_act, state, DA)

            # User's talking about other domain
            elif domain != "Train":
                self._update_DA(user_act, user_action, state, DA)

            # Info about train
            else:
                self._update_train(user_act, user_action, state, DA)

            # Judge if user want to book
            self._judge_booking(user_act, user_action, DA)

            if 'Booking-Book' in DA:
                if random.random() < 0.5:
                    DA['general-reqmore'] = []
                user_acts = []
                for user_act in DA:
                    if user_act != 'Booking-Book':
                        user_acts.append(user_act)
                for user_act in user_acts:
                    del DA[user_act]

        # print("Sys action: ", DA)

        if DA == {}:
            DA = {'general-greet': [['none', 'none']]}
        tuples = []
        for domain_intent, svs in DA.items():
            domain, intent = domain_intent.split('-')
            if not svs:
                tuples.append([intent, domain, 'none', 'none'])
            else:
                for slot, value in svs:
                    tuples.append([intent, domain, slot, value])
        state['system_action'] = tuples
        return tuples

    def _update_greeting(self, user_act, state, DA):
        """ General request / inform. """
        _, intent_type = user_act.split('-')

        # Respond to goodbye
        if intent_type == 'bye':
            if 'general-bye' not in DA:
                DA['general-bye'] = []
            if random.random() < 0.3:
                if 'general-welcome' not in DA:
                    DA['general-welcome'] = []
        elif intent_type == 'thank':
            DA['general-welcome'] = []

    def _book_taxi(self, user_act, state, DA):
        """ Book a taxi for user. """

        blank_info = []
        for info in ['departure', 'destination']:
            if state['belief_state']['taxi']['semi'] == "":
                info = REF_USR_DA['Taxi'].get(info, info)
                blank_info.append(info)
        if state['belief_state']['taxi']['semi']['leaveAt'] == "" and state[
                'belief_state']['taxi']['semi']['arriveBy'] == "":
            blank_info += ['Leave', 'Arrive']

        # Finish booking, tell user car type and phone number
        if len(blank_info) == 0:
            if 'Taxi-Inform' not in DA:
                DA['Taxi-Inform'] = []
            car = generate_car()
            phone_num = generate_phone_num(11)
            DA['Taxi-Inform'].append(['Car', car])
            DA['Taxi-Inform'].append(['Phone', phone_num])
            return

        # Need essential info to finish booking
        request_num = random.randint(0, 999999) % len(blank_info) + 1
        if 'Taxi-Request' not in DA:
            DA['Taxi-Request'] = []
        for i in range(request_num):
            slot = REF_USR_DA.get(blank_info[i], blank_info[i])
            DA['Taxi-Request'].append([slot, '?'])

    def _update_booking(self, user_act, state, DA):
        pass

    def _update_DA(self, user_act, user_action, state, DA):
        """ Answer user's utterance about any domain other than taxi or train. """

        domain, intent_type = user_act.split('-')

        constraints = []
        for slot in state['belief_state'][domain.lower()]['semi']:
            if state['belief_state'][domain.lower()]['semi'][slot] != "":
                constraints.append([
                    slot, state['belief_state'][domain.lower()]['semi'][slot]
                ])

        kb_result = self.db.query(domain.lower(), constraints)
        self.kb_result[domain] = deepcopy(kb_result)

        # print("\tConstraint: " + "{}".format(constraints))
        # print("\tCandidate Count: " + "{}".format(len(kb_result)))
        # if len(kb_result) > 0:
        #     print("Candidate: " + "{}".format(kb_result[0]))

        # print(state['user_action'])
        # Respond to user's request
        if intent_type == 'Request':
            if self.recommend_flag > 1:
                self.recommend_flag = -1
                self.choice = ""
            elif self.recommend_flag == 1:
                self.recommend_flag == 0
            if (domain + "-Inform") not in DA:
                DA[domain + "-Inform"] = []
            for slot in user_action[user_act]:
                if len(kb_result) > 0:
                    kb_slot_name = REF_SYS_DA[domain].get(slot[0], slot[0])
                    if kb_slot_name in kb_result[0]:
                        DA[domain + "-Inform"].append(
                            [slot[0], kb_result[0][kb_slot_name]])
                    else:
                        DA[domain + "-Inform"].append([slot[0], "unknown"])
                        # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][slot[0].lower()]])

        else:
            # There's no result matching user's constraint
            # if len(state['kb_results_dict']) == 0:
            if len(kb_result) == 0:
                if (domain + "-NoOffer") not in DA:
                    DA[domain + "-NoOffer"] = []

                for slot in state['belief_state'][domain.lower()]['semi']:
                    if state['belief_state'][domain.lower()]['semi'][slot] != "" and \
                                    state['belief_state'][domain.lower()]['semi'][slot] != "do n't care":
                        slot_name = REF_USR_DA[domain].get(slot, slot)
                        DA[domain + "-NoOffer"].append([
                            slot_name,
                            state['belief_state'][domain.lower()]['semi'][slot]
                        ])

                p = random.random()

                # Ask user if he wants to change constraint
                if p < 0.3:
                    req_num = min(
                        random.randint(0, 999999) %
                        len(DA[domain + "-NoOffer"]) + 1, 3)
                    if domain + "-Request" not in DA:
                        DA[domain + "-Request"] = []
                    for i in range(req_num):
                        slot_name = REF_USR_DA[domain].get(
                            DA[domain + "-NoOffer"][i][0],
                            DA[domain + "-NoOffer"][i][0])
                        DA[domain + "-Request"].append([slot_name, "?"])

            # There's exactly one result matching user's constraint
            # elif len(state['kb_results_dict']) == 1:
            elif len(kb_result) == 1:

                # Inform user about this result
                if (domain + "-Inform") not in DA:
                    DA[domain + "-Inform"] = []
                props = []
                for prop in state['belief_state'][domain.lower()]['semi']:
                    props.append(prop)
                property_num = len(props)
                if property_num > 0:
                    info_num = random.randint(0, 999999) % property_num + 1
                    random.shuffle(props)
                    for i in range(info_num):
                        slot_name = REF_USR_DA[domain].get(props[i], props[i])
                        # DA[domain + "-Inform"].append([slot_name, state['kb_results_dict'][0][props[i]]])
                        DA[domain + "-Inform"].append(
                            [slot_name, kb_result[0][props[i]]])

            # There are multiple resultes matching user's constraint
            else:
                p = random.random()

                # Recommend a choice from kb_list
                if True:  # p < 0.3:
                    if (domain + "-Inform") not in DA:
                        DA[domain + "-Inform"] = []
                    if (domain + "-Recommend") not in DA:
                        DA[domain + "-Recommend"] = []
                    DA[domain + "-Inform"].append(
                        ["Choice", str(len(kb_result))])
                    idx = random.randint(0, 999999) % len(kb_result)
                    # idx = 0
                    choice = kb_result[idx]
                    if domain in [
                            "Hotel", "Attraction", "Police", "Restaurant"
                    ]:
                        DA[domain + "-Recommend"].append(
                            ['Name', choice['name']])
                    self.recommend_flag = 0
                    self.candidate = choice
                    props = []
                    for prop in choice:
                        props.append([prop, choice[prop]])
                    prop_num = min(random.randint(0, 999999) % 3, len(props))
                    # prop_num = min(2, len(props))
                    random.shuffle(props)
                    for i in range(prop_num):
                        slot = props[i][0]
                        string = REF_USR_DA[domain].get(slot, slot)
                        if string in INFORMABLE_SLOTS:
                            DA[domain + "-Recommend"].append(
                                [string, str(props[i][1])])

                # Ask user to choose a candidate.
                elif p < 0.5:
                    prop_values = []
                    props = []
                    # for prop in state['kb_results_dict'][0]:
                    for prop in kb_result[0]:
                        # for candidate in state['kb_results_dict']:
                        for candidate in kb_result:
                            if prop not in candidate:
                                continue
                            if candidate[prop] not in prop_values:
                                prop_values.append(candidate[prop])
                        if len(prop_values) > 1:
                            props.append([prop, prop_values])
                        prop_values = []
                    random.shuffle(props)
                    idx = 0
                    while idx < len(props):
                        if props[idx][0] not in SELECTABLE_SLOTS[domain]:
                            props.pop(idx)
                            idx -= 1
                        idx += 1
                    if domain + "-Select" not in DA:
                        DA[domain + "-Select"] = []
                    for i in range(min(len(props[0][1]), 5)):
                        prop_value = REF_USR_DA[domain].get(
                            props[0][0], props[0][0])
                        DA[domain + "-Select"].append(
                            [prop_value, props[0][1][i]])

                # Ask user for more constraint
                else:
                    reqs = []
                    for prop in state['belief_state'][domain.lower()]['semi']:
                        if state['belief_state'][
                                domain.lower()]['semi'][prop] == "":
                            prop_value = REF_USR_DA[domain].get(prop, prop)
                            reqs.append([prop_value, "?"])
                    i = 0
                    while i < len(reqs):
                        if reqs[i][0] not in REQUESTABLE_SLOTS:
                            reqs.pop(i)
                            i -= 1
                        i += 1
                    random.shuffle(reqs)
                    if len(reqs) == 0:
                        return
                    req_num = min(random.randint(0, 999999) % len(reqs) + 1, 2)
                    if (domain + "-Request") not in DA:
                        DA[domain + "-Request"] = []
                    for i in range(req_num):
                        req = reqs[i]
                        req[0] = REF_USR_DA[domain].get(req[0], req[0])
                        DA[domain + "-Request"].append(req)

    def _update_train(self, user_act, user_action, state, DA):
        trans = {
            'day': 'Day',
            'destination': 'Destination',
            'departure': 'Departure'
        }
        constraints = []
        for time in ['leaveAt', 'arriveBy']:
            if state['belief_state']['train']['semi'][time] != "":
                constraints.append(
                    [time, state['belief_state']['train']['semi'][time]])

        if len(constraints) == 0:
            p = random.random()
            if 'Train-Request' not in DA:
                DA['Train-Request'] = []
            if p < 0.33:
                DA['Train-Request'].append(['Leave', '?'])
            elif p < 0.66:
                DA['Train-Request'].append(['Arrive', '?'])
            else:
                DA['Train-Request'].append(['Leave', '?'])
                DA['Train-Request'].append(['Arrive', '?'])

        if 'Train-Request' not in DA:
            DA['Train-Request'] = []
        for prop in ['day', 'destination', 'departure']:
            if state['belief_state']['train']['semi'][prop] == "":
                slot = REF_USR_DA['Train'].get(prop, prop)
                DA["Train-Request"].append([slot, '?'])
            else:
                constraints.append(
                    [prop, state['belief_state']['train']['semi'][prop]])

        kb_result = self.db.query('train', constraints)
        self.kb_result['Train'] = deepcopy(kb_result)

        # print(constraints)
        # print(len(kb_result))
        if user_act == 'Train-Request':
            del (DA['Train-Request'])
            if 'Train-Inform' not in DA:
                DA['Train-Inform'] = []
            for slot in user_action[user_act]:
                # Train_DA_MAP = {'Duration': "Time", 'Price': 'Ticket', 'TrainID': 'Id'}
                # slot[0] = Train_DA_MAP.get(slot[0], slot[0])
                slot_name = REF_SYS_DA['Train'].get(slot[0], slot[0])
                try:
                    DA['Train-Inform'].append(
                        [slot[0], kb_result[0][slot_name]])
                except:
                    pass
            return
        if len(kb_result) == 0:
            if 'Train-NoOffer' not in DA:
                DA['Train-NoOffer'] = []
            for prop in constraints:
                DA['Train-NoOffer'].append(
                    [REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]])
            if 'Train-Request' in DA:
                del DA['Train-Request']
        elif len(kb_result) >= 1:
            if len(constraints) < 4:
                return
            if 'Train-Request' in DA:
                del DA['Train-Request']
            if 'Train-OfferBook' not in DA:
                DA['Train-OfferBook'] = []
            for prop in constraints:
                DA['Train-OfferBook'].append(
                    [REF_USR_DA['Train'].get(prop[0], prop[0]), prop[1]])

    def _judge_booking(self, user_act, user_action, DA):
        """ If user want to book, return a ref number. """
        if self.recommend_flag > 1:
            self.recommend_flag = -1
            self.choice = ""
        elif self.recommend_flag == 1:
            self.recommend_flag == 0
        domain, _ = user_act.split('-')
        for slot in user_action[user_act]:
            if domain in booking_info and slot[0] in booking_info[domain]:
                if 'Booking-Book' not in DA:
                    if domain in self.kb_result and len(
                            self.kb_result[domain]) > 0:
                        if 'Ref' in self.kb_result[domain][0]:
                            DA['Booking-Book'] = [[
                                "Ref", self.kb_result[domain][0]['Ref']
                            ]]
                        else:
                            DA['Booking-Book'] = [["Ref", "N/A"]]
Ejemplo n.º 5
0
class MultiWozEvaluator(Evaluator):
    def __init__(self):
        self.sys_da_array = []
        self.usr_da_array = []
        self.goal = {}
        self.cur_domain = ''
        self.booked = {}
        self.database = Database()
        self.dbs = self.database.dbs

    def _init_dict(self):
        dic = {}
        for domain in belief_domains:
            dic[domain] = {'info': {}, 'book': {}, 'reqt': []}
        return dic

    def _init_dict_booked(self):
        dic = {}
        for domain in belief_domains:
            dic[domain] = None
        return dic

    def _expand(self, _goal):
        goal = deepcopy(_goal)
        for domain in belief_domains:
            if domain not in goal:
                goal[domain] = {'info': {}, 'book': {}, 'reqt': []}
                continue
            if 'info' not in goal[domain]:
                goal[domain]['info'] = {}
            if 'book' not in goal[domain]:
                goal[domain]['book'] = {}
            if 'reqt' not in goal[domain]:
                goal[domain]['reqt'] = []
        return goal

    def add_goal(self, goal):
        """init goal and array

        args:
            goal:
                dict[domain] dict['info'/'book'/'reqt'] dict/dict/list[slot]
        """
        self.sys_da_array = []
        self.usr_da_array = []
        self.goal = goal
        self.cur_domain = ''
        self.booked = self._init_dict_booked()

    def add_sys_da(self, da_turn):
        """add sys_da into array

        args:
            da_turn:
                list[intent, domain, slot, value]
        """
        for intent, domain, slot, value in da_turn:
            dom_int = '-'.join([domain, intent])
            domain = dom_int.split('-')[0].lower()
            if domain in belief_domains and domain != self.cur_domain:
                self.cur_domain = domain
            da = (dom_int + '-' + slot).lower()
            value = str(value)
            self.sys_da_array.append(da + '-' + value)

            if da == 'booking-book-ref' and self.cur_domain in [
                    'hotel', 'restaurant', 'train'
            ]:
                if not self.booked[self.cur_domain] and re.match(r'^\d{8}$', value) and \
                        len(self.dbs[self.cur_domain]) > int(value):
                    self.booked[self.cur_domain] = self.dbs[self.cur_domain][
                        int(value)].copy()
                    self.booked[self.cur_domain]['Ref'] = value
            elif da == 'train-offerbooked-ref' or da == 'train-inform-ref':
                if not self.booked['train'] and re.match(
                        r'^\d{8}$',
                        value) and len(self.dbs['train']) > int(value):
                    self.booked['train'] = self.dbs['train'][int(value)].copy()
                    self.booked['train']['Ref'] = value
            elif da == 'taxi-inform-car':
                if not self.booked['taxi']:
                    self.booked['taxi'] = 'booked'

    def add_usr_da(self, da_turn):
        """add usr_da into array

        args:
            da_turn:
                list[intent, domain, slot, value]
        """
        for intent, domain, slot, value in da_turn:
            dom_int = '-'.join([domain, intent])
            domain = dom_int.split('-')[0].lower()
            if domain in belief_domains and domain != self.cur_domain:
                self.cur_domain = domain
            da = (dom_int + '-' + slot).lower()
            value = str(value)
            self.usr_da_array.append(da + '-' + value)

    def _book_rate_goal(self, goal, booked_entity, domains=None):
        """
        judge if the selected entity meets the constraint
        """
        if domains is None:
            domains = belief_domains
        score = []
        for domain in domains:
            if 'book' in goal[domain] and goal[domain]['book']:
                tot = len(goal[domain]['info'].keys())
                if tot == 0:
                    continue
                entity = booked_entity[domain]
                if entity is None:
                    score.append(0)
                    continue
                if domain == 'taxi':
                    score.append(1)
                    continue
                match = 0
                for k, v in goal[domain]['info'].items():
                    if k in ['destination', 'departure']:
                        tot -= 1
                    elif k == 'leaveAt':
                        try:
                            v_constraint = int(v.split(':')[0]) * 100 + int(
                                v.split(':')[1])
                            v_select = int(
                                entity['leaveAt'].split(':')[0]) * 100 + int(
                                    entity['leaveAt'].split(':')[1])
                            if v_constraint <= v_select:
                                match += 1
                        except (ValueError, IndexError):
                            match += 1
                    elif k == 'arriveBy':
                        try:
                            v_constraint = int(v.split(':')[0]) * 100 + int(
                                v.split(':')[1])
                            v_select = int(
                                entity['arriveBy'].split(':')[0]) * 100 + int(
                                    entity['arriveBy'].split(':')[1])
                            if v_constraint >= v_select:
                                match += 1
                        except (ValueError, IndexError):
                            match += 1
                    else:
                        if v.strip() == entity[k].strip():
                            match += 1
                if tot != 0:
                    score.append(match / tot)
        return score

    def _inform_F1_goal(self, goal, sys_history, domains=None):
        """
        judge if all the requested information is answered
        """
        if domains is None:
            domains = belief_domains
        inform_slot = {}
        for domain in domains:
            inform_slot[domain] = set()
        TP, FP, FN = 0, 0, 0

        inform_not_reqt = set()
        reqt_not_inform = set()
        bad_inform = set()

        for da in sys_history:
            domain, intent, slot, value = da.split('-', 3)
            if intent in ['inform', 'recommend', 'offerbook', 'offerbooked'] and \
                    domain in domains and slot in mapping[domain] and value.strip() not in NUL_VALUE:
                key = mapping[domain][slot]
                if self._check_value(domain, key, value):
                    # print('add key', key)
                    inform_slot[domain].add(key)
                else:
                    bad_inform.add((intent, domain, key))
                    FP += 1

        for domain in domains:
            for k in goal[domain]['reqt']:
                if k in inform_slot[domain]:
                    # print('k: ', k)
                    TP += 1
                else:
                    # print('FN + 1')
                    reqt_not_inform.add(('request', domain, k))
                    FN += 1
            for k in inform_slot[domain]:
                # exclude slots that are informed by users
                if k not in goal[domain]['reqt'] \
                        and k not in goal[domain]['info'] \
                        and k in requestable[domain]:
                    # print('FP + 1 @2', k)
                    inform_not_reqt.add((
                        'inform',
                        domain,
                        k,
                    ))
                    FP += 1
        return TP, FP, FN, bad_inform, reqt_not_inform, inform_not_reqt

    def _check_value(self, domain, key, value):
        if key == "area":
            return value.lower() in [
                "centre", "east", "south", "west", "north"
            ]
        elif key == "arriveBy" or key == "leaveAt":
            return time_re.match(value)
        elif key == "day":
            return value.lower() in [
                "monday", "tuesday", "wednesday", "thursday", "friday",
                "saturday", "sunday"
            ]
        elif key == "duration":
            return 'minute' in value
        elif key == "internet" or key == "parking":
            return value in ["yes", "no", "none"]
        elif key == "phone":
            return re.match(r'^\d{11}$', value) or domain == "restaurant"
        elif key == "price":
            return 'pound' in value
        elif key == "pricerange":
            return value in ["cheap", "expensive", "moderate", "free"
                             ] or domain == "attraction"
        elif key == "postcode":
            return re.match(r'^cb\d{1,3}[a-z]{2,3}$',
                            value) or value == 'pe296fl'
        elif key == "stars":
            return re.match(r'^\d$', value)
        elif key == "trainID":
            return re.match(r'^tr\d{4}$', value.lower())
        else:
            return True

    def book_rate(self, ref2goal=True, aggregate=True):
        if ref2goal:
            goal = self._expand(self.goal)
        else:
            goal = self._init_dict()
            for domain in belief_domains:
                if domain in self.goal and 'book' in self.goal[domain]:
                    goal[domain]['book'] = self.goal[domain]['book']
            for da in self.usr_da_array:
                d, i, s, v = da.split('-', 3)
                if i in ['inform', 'recommend', 'offerbook', 'offerbooked'
                         ] and s in mapping[d]:
                    goal[d]['info'][mapping[d][s]] = v
        score = self._book_rate_goal(goal, self.booked)
        if aggregate:
            return np.mean(score) if score else None
        else:
            return score

    def inform_F1(self, ref2goal=True, aggregate=True):
        if ref2goal:
            goal = self._expand(self.goal)
        else:
            goal = self._init_dict()
            for da in self.usr_da_array:
                d, i, s, v = da.split('-', 3)
                if i in ['inform', 'recommend', 'offerbook', 'offerbooked'
                         ] and s in mapping[d]:
                    goal[d]['info'][mapping[d][s]] = v
                elif i == 'request':
                    goal[d]['reqt'].append(s)
        TP, FP, FN, _, _, _ = self._inform_F1_goal(goal, self.sys_da_array)
        if aggregate:
            try:
                rec = TP / (TP + FN)
            except ZeroDivisionError:
                return None, None, None
            try:
                prec = TP / (TP + FP)
                F1 = 2 * prec * rec / (prec + rec)
            except ZeroDivisionError:
                return 0, rec, 0
            return prec, rec, F1
        else:
            return [TP, FP, FN]

    def task_success(self, ref2goal=True):
        """
        judge if all the domains are successfully completed
        """
        book_sess = self.book_rate(ref2goal)
        inform_sess = self.inform_F1(ref2goal)
        goal_sess = self.final_goal_analyze()
        # book rate == 1 & inform recall == 1
        if ((book_sess == 1 and inform_sess[1] == 1) \
            or (book_sess == 1 and inform_sess[1] is None) \
            or (book_sess is None and inform_sess[1] == 1)) \
                and goal_sess == 1:
            return 1
        else:
            return 0

    def domain_reqt_inform_analyze(self, domain, ref2goal=True):
        if domain not in self.goal:
            return None

        if ref2goal:
            goal = {}
            goal[domain] = self._expand(self.goal)[domain]
        else:
            goal = {}
            goal[domain] = {'info': {}, 'book': {}, 'reqt': []}
            if 'book' in self.goal[domain]:
                goal[domain]['book'] = self.goal[domain]['book']
            for da in self.usr_da_array:
                d, i, s, v = da.split('-', 3)
                if d != domain:
                    continue
                if i in ['inform', 'recommend', 'offerbook', 'offerbooked'
                         ] and s in mapping[d]:
                    goal[d]['info'][mapping[d][s]] = v
                elif i == 'request':
                    goal[d]['reqt'].append(s)

        inform = self._inform_F1_goal(goal, self.sys_da_array, [domain])
        return inform

    def domain_success(self, domain, ref2goal=True):
        """
        judge if the domain (subtask) is successfully completed
        """
        if domain not in self.goal:
            return None

        if ref2goal:
            goal = {}
            goal[domain] = self._expand(self.goal)[domain]
        else:
            goal = {}
            goal[domain] = {'info': {}, 'book': {}, 'reqt': []}
            if 'book' in self.goal[domain]:
                goal[domain]['book'] = self.goal[domain]['book']
            for da in self.usr_da_array:
                d, i, s, v = da.split('-', 3)
                if d != domain:
                    continue
                if i in ['inform', 'recommend', 'offerbook', 'offerbooked'
                         ] and s in mapping[d]:
                    goal[d]['info'][mapping[d][s]] = v
                elif i == 'request':
                    goal[d]['reqt'].append(s)

        book_rate = self._book_rate_goal(goal, self.booked, [domain])
        book_rate = np.mean(book_rate) if book_rate else None

        inform = self._inform_F1_goal(goal, self.sys_da_array, [domain])
        try:
            inform_rec = inform[0] / (inform[0] + inform[2])
        except ZeroDivisionError:
            inform_rec = None

        if (book_rate == 1 and inform_rec == 1) \
                or (book_rate == 1 and inform_rec is None) \
                or (book_rate is None and inform_rec == 1):
            return 1
        else:
            return 0

    def _final_goal_analyze(self):
        """whether the final goal satisfies constraints"""
        match = mismatch = 0
        for domain, dom_goal_dict in self.goal.items():
            constraints = []
            if 'reqt' in dom_goal_dict:
                reqt_constraints = list(dom_goal_dict['reqt'].items())
                constraints += reqt_constraints
            else:
                reqt_constraints = []
            if 'info' in dom_goal_dict:
                info_constraints = list(dom_goal_dict['info'].items())
                constraints += info_constraints
            else:
                info_constraints = []
            query_result = self.database.query(
                domain, info_constraints, soft_contraints=reqt_constraints)
            if not query_result:
                mismatch += 1
                continue

            booked = self.booked[domain]
            if not self.goal[domain].get('book'):
                match += 1
            elif isinstance(booked, dict):
                ref = booked['Ref']
                if any(found['Ref'] == ref for found in query_result):
                    match += 1
                else:
                    mismatch += 1
            else:
                match += 1
        return match, mismatch

    def final_goal_analyze(self):
        """percentage of domains, in which the final goal satisfies the database constraints.
        If there is no dialog action, returns 1."""
        match, mismatch = self._final_goal_analyze()
        if match == mismatch == 0:
            return 1
        else:
            return match / (match + mismatch)