コード例 #1
0
ファイル: preprocess.py プロジェクト: elnaz776655/PARG
    def __init__(self):
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)
        data_path = 'data/multi-woz/annotated_user_da_with_span_full.json'
        archive = zipfile.ZipFile(data_path + '.zip', 'r')
        self.convlab_data = json.loads(
            archive.open(data_path.split('/')[-1],
                         'r').read().decode('utf-8').lower())
        self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json'
        self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json'
        self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json'
        self.delex_refs_path = 'data/multi-woz-processed/reference_no.json'
        self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read())
        if not os.path.exists(self.delex_sg_valdict_path):
            self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict(
            )
        else:
            self.delex_sg_valdict = json.loads(
                open(self.delex_sg_valdict_path, 'r').read())
            self.delex_mt_valdict = json.loads(
                open(self.delex_mt_valdict_path, 'r').read())
            self.ambiguous_vals = json.loads(
                open(self.ambiguous_val_path, 'r').read())

        self.vocab = utils.Vocab(cfg.vocab_size)
コード例 #2
0
ファイル: preprocess.py プロジェクト: gusalsdmlwlq/DAMD
    def __init__(self):
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)
        data_path = 'data/multi-woz/annotated_user_da_with_span_full.json'
        archive = zipfile.ZipFile(data_path + '.zip', 'r')
        self.convlab_data = json.loads(archive.open(data_path.split('/')[-1], 'r').read().lower())
        self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json'
        self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json'
        self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json'
        self.delex_refs_path = 'data/multi-woz-processed/reference_no.json'
        self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read())
        if not os.path.exists(self.delex_sg_valdict_path):
            self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict()
        else:
            self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, 'r').read())
            self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, 'r').read())
            self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, 'r').read())

        self.vocab = utils.Vocab(cfg.vocab_size)

        self.slot_list = [
            'hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-stay', 'hotel-day', 'hotel-people', \
            'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arrive', \
            'train-people', 'train-leave', 'attraction-area', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', \
            'attraction-name', 'restaurant-name', 'attraction-type', 'hotel-name', 'taxi-leave', 'taxi-destination', 'taxi-departure', \
            'restaurant-time', 'restaurant-day', 'restaurant-people', 'taxi-arrive', "hospital-department"
        ]
        self.gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2}
コード例 #3
0
ファイル: reader.py プロジェクト: Verylovenlp/MinTL-SKKU
    def __init__(self, vocab=None):
        super().__init__()
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)

        self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read())
        self.slot_value_set = json.loads(
            open(cfg.slot_value_set_path, 'r').read())
        if cfg.multi_acts_training:
            self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read())

        test_list = [
            l.strip().lower() for l in open(cfg.test_list, 'r').readlines()
        ]
        dev_list = [
            l.strip().lower() for l in open(cfg.dev_list, 'r').readlines()
        ]
        self.dev_files, self.test_files = {}, {}
        for fn in test_list:
            self.test_files[fn.replace('.json', '')] = 1
        for fn in dev_list:
            self.dev_files[fn.replace('.json', '')] = 1

        self.exp_files = {}
        if 'all' not in cfg.exp_domains:
            for domain in cfg.exp_domains:
                fn_list = self.domain_files.get(domain)
                if not fn_list:
                    raise ValueError('[%s] is an invalid experiment setting' %
                                     domain)
                for fn in fn_list:
                    self.exp_files[fn.replace('.json', '')] = 1

        if vocab:
            self.vocab = vocab
            self.vocab_size = vocab.size

        else:
            self.vocab_size = self._build_vocab()
        self._load_data()

        if cfg.limit_bspn_vocab:
            self.bspn_masks = self._construct_bspn_constraint()
        if cfg.limit_aspn_vocab:
            self.aspn_masks = self._construct_aspn_constraint()

        self.multi_acts_record = None
コード例 #4
0
ファイル: preprocess.py プロジェクト: elnaz776655/PARG
class DataPreprocessor(object):
    def __init__(self):
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)
        data_path = 'data/multi-woz/annotated_user_da_with_span_full.json'
        archive = zipfile.ZipFile(data_path + '.zip', 'r')
        self.convlab_data = json.loads(
            archive.open(data_path.split('/')[-1],
                         'r').read().decode('utf-8').lower())
        self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json'
        self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json'
        self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json'
        self.delex_refs_path = 'data/multi-woz-processed/reference_no.json'
        self.delex_refs = json.loads(open(self.delex_refs_path, 'r').read())
        if not os.path.exists(self.delex_sg_valdict_path):
            self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals = self.get_delex_valdict(
            )
        else:
            self.delex_sg_valdict = json.loads(
                open(self.delex_sg_valdict_path, 'r').read())
            self.delex_mt_valdict = json.loads(
                open(self.delex_mt_valdict_path, 'r').read())
            self.ambiguous_vals = json.loads(
                open(self.ambiguous_val_path, 'r').read())

        self.vocab = utils.Vocab(cfg.vocab_size)

    def delex_by_annotation(self, dial_turn):
        u = dial_turn['text'].split()
        span = dial_turn['span_info']
        for s in span:
            slot = s[1]
            if slot == 'open':
                continue
            if ontology.da_abbr_to_slot_name.get(slot):
                slot = ontology.da_abbr_to_slot_name[slot]
            for idx in range(s[3], s[4] + 1):
                u[idx] = ''
            try:
                u[s[3]] = '[value_' + slot + ']'
            except:
                u[5] = '[value_' + slot + ']'
        u_delex = ' '.join([t for t in u if t is not ''])
        u_delex = u_delex.replace(
            '[value_address] , [value_address] , [value_address]',
            '[value_address]')
        u_delex = u_delex.replace('[value_address] , [value_address]',
                                  '[value_address]')
        u_delex = u_delex.replace('[value_name] [value_name]', '[value_name]')
        u_delex = u_delex.replace('[value_name]([value_phone] )',
                                  '[value_name] ( [value_phone] )')
        return u_delex

    def delex_by_valdict(self, text):
        text = clean_text(text)

        text = re.sub(r'\d{5}\s?\d{5,7}', '[value_phone]', text)
        text = re.sub(r'\d[\s-]stars?', '[value_stars]', text)
        text = re.sub(r'\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)',
                      '[value_price]', text)
        text = re.sub(r'tr[\d]{4}', '[value_id]', text)
        text = re.sub(
            r'([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})',
            '[value_postcode]', text)

        for value, slot in self.delex_mt_valdict.items():
            text = text.replace(value, '[value_%s]' % slot)

        for value, slot in self.delex_sg_valdict.items():
            tokens = text.split()
            for idx, tk in enumerate(tokens):
                if tk == value:
                    tokens[idx] = '[value_%s]' % slot
            text = ' '.join(tokens)

        for ambg_ent in self.ambiguous_vals:
            start_idx = text.find(
                ' ' + ambg_ent
            )  # ely is a place, but appears in words like moderately
            if start_idx == -1:
                continue
            front_words = text[:start_idx].split()
            ent_type = 'time' if ':' in ambg_ent else 'place'

            for fw in front_words[::-1]:
                if fw in [
                        'arrive', 'arrives', 'arrived', 'arriving', 'arrival',
                        'destination', 'there', 'reach', 'to', 'by', 'before'
                ]:
                    slot = '[value_arrive]' if ent_type == 'time' else '[value_destination]'
                    text = re.sub(' ' + ambg_ent, ' ' + slot, text)
                elif fw in [
                        'leave', 'leaves', 'leaving', 'depart', 'departs',
                        'departing', 'departure', 'from', 'after', 'pulls'
                ]:
                    slot = '[value_leave]' if ent_type == 'time' else '[value_departure]'
                    text = re.sub(' ' + ambg_ent, ' ' + slot, text)

        text = text.replace('[value_car] [value_car]', '[value_car]')
        return text

    def get_delex_valdict(self, ):
        skip_entry_type = {
            'taxi': ['taxi_phone'],
            'police': ['id'],
            'hospital': ['id'],
            'hotel': [
                'id', 'location', 'internet', 'parking', 'takesbookings',
                'stars', 'price', 'n', 'postcode', 'phone'
            ],
            'attraction': [
                'id', 'location', 'pricerange', 'price', 'openhours',
                'postcode', 'phone'
            ],
            'train': ['price', 'id'],
            'restaurant': [
                'id', 'location', 'introduction', 'signature', 'type',
                'postcode', 'phone'
            ],
        }
        entity_value_to_slot = {}
        ambiguous_entities = []
        for domain, db_data in self.db.dbs.items():
            print('Processing entity values in [%s]' % domain)
            if domain != 'taxi':
                for db_entry in db_data:
                    for slot, value in db_entry.items():
                        if slot not in skip_entry_type[domain]:
                            if type(value) is not str:
                                raise TypeError(
                                    "value '%s' in domain '%s' should be rechecked"
                                    % (slot, domain))
                            else:
                                slot, value = clean_slot_values(
                                    domain, slot, value)
                                value = ' '.join([
                                    token.text for token in self.nlp(value)
                                ]).strip()
                                if value in entity_value_to_slot and entity_value_to_slot[
                                        value] != slot:
                                    # print(value, ": ",entity_value_to_slot[value], slot)
                                    ambiguous_entities.append(value)
                                entity_value_to_slot[value] = slot
            else:  # taxi db specific
                db_entry = db_data[0]
                for slot, ent_list in db_entry.items():
                    if slot not in skip_entry_type[domain]:
                        for ent in ent_list:
                            entity_value_to_slot[ent] = 'car'
        ambiguous_entities = set(ambiguous_entities)
        ambiguous_entities.remove('cambridge')
        ambiguous_entities = list(ambiguous_entities)
        for amb_ent in ambiguous_entities:  # departure or destination? arrive time or leave time?
            entity_value_to_slot.pop(amb_ent)
        entity_value_to_slot['parkside'] = 'address'
        entity_value_to_slot['parkside, cambridge'] = 'address'
        entity_value_to_slot['cambridge belfry'] = 'name'
        entity_value_to_slot['hills road'] = 'address'
        entity_value_to_slot['hills rd'] = 'address'
        entity_value_to_slot['Parkside Police Station'] = 'name'

        single_token_values = {}
        multi_token_values = {}
        for val, slt in entity_value_to_slot.items():
            if val in ['cambridge']:
                continue
            if len(val.split()) > 1:
                multi_token_values[val] = slt
            else:
                single_token_values[val] = slt

        with open(self.delex_sg_valdict_path, 'w') as f:
            single_token_values = OrderedDict(
                sorted(single_token_values.items(),
                       key=lambda kv: len(kv[0]),
                       reverse=True))
            json.dump(single_token_values, f, indent=2)
            print('single delex value dict saved!')
        with open(self.delex_mt_valdict_path, 'w') as f:
            multi_token_values = OrderedDict(
                sorted(multi_token_values.items(),
                       key=lambda kv: len(kv[0]),
                       reverse=True))
            json.dump(multi_token_values, f, indent=2)
            print('multi delex value dict saved!')
        with open(self.ambiguous_val_path, 'w') as f:
            json.dump(ambiguous_entities, f, indent=2)
            print('ambiguous value dict saved!')

        return single_token_values, multi_token_values, ambiguous_entities

    def preprocess_main(self, save_path=None, is_test=False):
        """
        """
        data = {}
        count = 0
        self.unique_da = {}
        ordered_sysact_dict = {}
        for fn, raw_dial in tqdm(list(self.convlab_data.items())):
            count += 1
            # if count == 100:
            #     break

            compressed_goal = {}
            dial_domains, dial_reqs = [], []
            for dom, g in raw_dial['goal'].items():
                if dom != 'topic' and dom != 'message' and g:
                    if g.get('reqt'):
                        for i, req_slot in enumerate(g['reqt']):
                            if ontology.normlize_slot_names.get(req_slot):
                                g['reqt'][i] = ontology.normlize_slot_names[
                                    req_slot]
                                dial_reqs.append(g['reqt'][i])
                    compressed_goal[dom] = g
                    if dom in ontology.all_domains:
                        dial_domains.append(dom)

            dial_reqs = list(set(dial_reqs))

            dial = {'goal': compressed_goal, 'log': []}
            single_turn = {}
            constraint_dict = OrderedDict()
            prev_constraint_dict = {}
            prev_turn_domain = ['general']
            ordered_sysact_dict[fn] = {}

            for turn_num, dial_turn in enumerate(raw_dial['log']):

                dial_state = dial_turn['metadata']
                if not dial_state:  # user
                    u = ' '.join(clean_text(dial_turn['text']).split())
                    if dial_turn['span_info']:
                        u_delex = clean_text(
                            self.delex_by_annotation(dial_turn))
                    else:
                        u_delex = self.delex_by_valdict(dial_turn['text'])

                    single_turn['user'] = u
                    single_turn['user_delex'] = u_delex

                else:  #system
                    if dial_turn['span_info']:
                        s_delex = clean_text(
                            self.delex_by_annotation(dial_turn))
                    else:
                        if not dial_turn['text']:
                            print(fn)
                        s_delex = self.delex_by_valdict(dial_turn['text'])
                    single_turn['resp'] = s_delex

                    # get belief state
                    for domain in dial_domains:
                        if not constraint_dict.get(domain):
                            constraint_dict[domain] = OrderedDict()
                        info_sv = dial_state[domain]['semi']
                        for s, v in info_sv.items():
                            s, v = clean_slot_values(domain, s, v)
                            if len(v.split()) > 1:
                                v = ' '.join([
                                    token.text for token in self.nlp(v)
                                ]).strip()
                            if v != '':
                                constraint_dict[domain][s] = v
                        book_sv = dial_state[domain]['book']
                        for s, v in book_sv.items():
                            if s == 'booked':
                                continue
                            s, v = clean_slot_values(domain, s, v)
                            if len(v.split()) > 1:
                                v = ' '.join([
                                    token.text for token in self.nlp(v)
                                ]).strip()
                            if v != '':
                                constraint_dict[domain][s] = v

                    constraints = []
                    cons_delex = []
                    turn_dom_bs = []
                    for domain, info_slots in constraint_dict.items():
                        if info_slots:
                            constraints.append('[' + domain + ']')
                            cons_delex.append('[' + domain + ']')
                            for slot, value in info_slots.items():
                                constraints.append(slot)
                                constraints.extend(value.split())
                                cons_delex.append(slot)
                            if domain not in prev_constraint_dict:
                                turn_dom_bs.append(domain)
                            elif prev_constraint_dict[
                                    domain] != constraint_dict[domain]:
                                turn_dom_bs.append(domain)

                    sys_act_dict = {}
                    turn_dom_da = set()
                    for act in dial_turn['dialog_act']:
                        d, a = act.split('-')
                        turn_dom_da.add(d)
                    turn_dom_da = list(turn_dom_da)
                    if len(turn_dom_da) != 1 and 'general' in turn_dom_da:
                        turn_dom_da.remove('general')
                    if len(turn_dom_da) != 1 and 'booking' in turn_dom_da:
                        turn_dom_da.remove('booking')

                    # get turn domain
                    turn_domain = turn_dom_bs
                    for dom in turn_dom_da:
                        if dom != 'booking' and dom not in turn_domain:
                            turn_domain.append(dom)
                    if not turn_domain:
                        turn_domain = prev_turn_domain
                    if len(turn_domain) == 2 and 'general' in turn_domain:
                        turn_domain.remove('general')
                    if len(turn_domain) == 2:
                        if len(prev_turn_domain) == 1 and prev_turn_domain[
                                0] == turn_domain[1]:
                            turn_domain = turn_domain[::-1]

                    # get system action
                    for dom in turn_domain:
                        sys_act_dict[dom] = {}
                    add_to_last_collect = []
                    booking_act_map = {
                        'inform': 'offerbook',
                        'book': 'offerbooked'
                    }
                    for act, params in dial_turn['dialog_act'].items():
                        if act == 'general-greet':
                            continue
                        d, a = act.split('-')
                        if d == 'general' and d not in sys_act_dict:
                            sys_act_dict[d] = {}
                        if d == 'booking':
                            d = turn_domain[0]
                            a = booking_act_map.get(a, a)
                        add_p = []
                        for param in params:
                            p = param[0]
                            if p == 'none':
                                continue
                            elif ontology.da_abbr_to_slot_name.get(p):
                                p = ontology.da_abbr_to_slot_name[p]
                            if p not in add_p:
                                add_p.append(p)
                        add_to_last = True if a in [
                            'request', 'reqmore', 'bye', 'offerbook'
                        ] else False
                        if add_to_last:
                            add_to_last_collect.append((d, a, add_p))
                        else:
                            sys_act_dict[d][a] = add_p
                    for d, a, add_p in add_to_last_collect:
                        sys_act_dict[d][a] = add_p

                    for d in copy.copy(sys_act_dict):
                        acts = sys_act_dict[d]
                        if not acts:
                            del sys_act_dict[d]
                        if 'inform' in acts and 'offerbooked' in acts:
                            for s in sys_act_dict[d]['inform']:
                                sys_act_dict[d]['offerbooked'].append(s)
                            del sys_act_dict[d]['inform']

                    ordered_sysact_dict[fn][len(dial['log'])] = sys_act_dict

                    sys_act = []
                    if 'general-greet' in dial_turn['dialog_act']:
                        sys_act.extend(['[general]', '[greet]'])
                    for d, acts in sys_act_dict.items():
                        sys_act += ['[' + d + ']']
                        for a, slots in acts.items():
                            self.unique_da[d + '-' + a] = 1
                            sys_act += ['[' + a + ']']
                            sys_act += slots

                    # get db pointers
                    matnums = self.db.get_match_num(constraint_dict)
                    match_dom = turn_domain[0] if len(
                        turn_domain) == 1 else turn_domain[1]
                    match = matnums[match_dom]
                    dbvec = self.db.addDBPointer(match_dom, match)
                    bkvec = self.db.addBookingPointer(dial_turn['dialog_act'])

                    single_turn['pointer'] = ','.join(
                        [str(d) for d in dbvec + bkvec])
                    single_turn['match'] = str(match)
                    single_turn['constraint'] = ' '.join(constraints)
                    single_turn['cons_delex'] = ' '.join(cons_delex)
                    single_turn['sys_act'] = ' '.join(sys_act)
                    single_turn['turn_num'] = len(dial['log'])
                    single_turn['turn_domain'] = ' '.join(
                        ['[' + d + ']' for d in turn_domain])

                    prev_turn_domain = copy.deepcopy(turn_domain)
                    prev_constraint_dict = copy.deepcopy(constraint_dict)

                    if 'user' in single_turn:
                        dial['log'].append(single_turn)
                        for t in single_turn['user'].split() + single_turn[
                                'resp'].split() + constraints + sys_act:
                            self.vocab.add_word(t)
                        for t in single_turn['user_delex'].split():
                            if '[' in t and ']' in t and not t.startswith(
                                    '[') and not t.endswith(']'):
                                single_turn['user_delex'].replace(
                                    t, t[t.index('['):t.index(']') + 1])
                            elif not self.vocab.has_word(t):
                                self.vocab.add_word(t)

                    single_turn = {}

            data[fn] = dial
            # pprint(dial)
            # if count == 20:
            #     break
        self.vocab.construct()
        self.vocab.save_vocab('data/multi-woz-processed/vocab')
        with open('data/multi-woz-analysis/dialog_acts.json', 'w') as f:
            json.dump(ordered_sysact_dict, f, indent=2)
        with open('data/multi-woz-analysis/dialog_act_type.json', 'w') as f:
            json.dump(self.unique_da, f, indent=2)
        return data
コード例 #5
0
ファイル: reader.py プロジェクト: gusalsdmlwlq/DAMD
class MultiWozReader(_ReaderBase):
    def __init__(self):
        super().__init__()
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)
        self.vocab_size = self._build_vocab()
        self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read())
        self.slot_value_set = json.loads(
            open(cfg.slot_value_set_path, 'r').read())
        if cfg.multi_acts_training:
            self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read())

        self.gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2}

        test_list = [
            l.strip().lower() for l in open(cfg.test_list, 'r').readlines()
        ]
        dev_list = [
            l.strip().lower() for l in open(cfg.dev_list, 'r').readlines()
        ]
        self.dev_files, self.test_files = {}, {}
        for fn in test_list:
            self.test_files[fn.replace('.json', '')] = 1
        for fn in dev_list:
            self.dev_files[fn.replace('.json', '')] = 1

        self.exp_files = {}
        if 'all' not in cfg.exp_domains:
            for domain in cfg.exp_domains:
                fn_list = self.domain_files.get(domain)
                if not fn_list:
                    raise ValueError('[%s] is an invalid experiment setting' %
                                     domain)
                for fn in fn_list:
                    self.exp_files[fn.replace('.json', '')] = 1

        self._load_data()

        if cfg.limit_bspn_vocab:
            self.bspn_masks = self._construct_bspn_constraint()
        if cfg.limit_aspn_vocab:
            self.aspn_masks = self._construct_aspn_constraint()

        self.multi_acts_record = None

    def _build_vocab(self):
        self.vocab = utils.Vocab(cfg.vocab_size)
        vp = cfg.vocab_path_train if cfg.mode == 'train' or cfg.vocab_path_eval is None else cfg.vocab_path_eval
        # vp = cfg.vocab_path+'.json.freq.json'
        self.vocab.load_vocab(vp)
        return self.vocab.vocab_size

    def _construct_bspn_constraint(self):
        bspn_masks = {}
        valid_domains = [
            'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'
        ]
        all_dom_codes = [
            self.vocab.encode('[' + d + ']') for d in valid_domains
        ]
        all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots]
        bspn_masks[self.vocab.encode(
            '<go_b>')] = all_dom_codes + [self.vocab.encode('<eos_b>'), 0]
        bspn_masks[self.vocab.encode('<eos_b>')] = [self.vocab.encode('<pad>')]
        bspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')]
        for domain, slot_values in self.slot_value_set.items():
            if domain == 'police':
                continue
            dom_code = self.vocab.encode('[' + domain + ']')
            bspn_masks[dom_code] = []
            for slot, values in slot_values.items():
                slot_code = self.vocab.encode(slot)
                if slot_code not in bspn_masks:
                    bspn_masks[slot_code] = []
                if slot_code not in bspn_masks[dom_code]:
                    bspn_masks[dom_code].append(slot_code)
                for value in values:
                    for idx, v in enumerate(value.split()):
                        if not self.vocab.has_word(v):
                            continue
                        v_code = self.vocab.encode(v)
                        if v_code not in bspn_masks:
                            # print(self.vocab._word2idx)
                            bspn_masks[v_code] = []
                        if idx == 0 and v_code not in bspn_masks[slot_code]:
                            bspn_masks[slot_code].append(v_code)
                        if idx == (len(value.split()) - 1):
                            for w in all_dom_codes + all_slot_codes:
                                if self.vocab.encode(
                                        '<eos_b>') not in bspn_masks[v_code]:
                                    bspn_masks[v_code].append(
                                        self.vocab.encode('<eos_b>'))
                                if w not in bspn_masks[v_code]:
                                    bspn_masks[v_code].append(w)
                            break
                        if not self.vocab.has_word(value.split()[idx + 1]):
                            continue
                        next_v_code = self.vocab.encode(value.split()[idx + 1])
                        if next_v_code not in bspn_masks[v_code]:
                            bspn_masks[v_code].append(next_v_code)
        bspn_masks[self.vocab.encode('<unk>')] = list(bspn_masks.keys())

        with open('data/multi-woz-processed/bspn_masks.txt', 'w') as f:
            for i, j in bspn_masks.items():
                f.write(
                    self.vocab.decode(i) + ': ' +
                    ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n')
        return bspn_masks

    def _construct_aspn_constraint(self):
        aspn_masks = {}
        aspn_masks = {}
        all_dom_codes = [
            self.vocab.encode('[' + d + ']')
            for d in ontology.dialog_acts.keys()
        ]
        all_act_codes = [
            self.vocab.encode('[' + a + ']')
            for a in ontology.dialog_act_params
        ]
        all_slot_codes = [
            self.vocab.encode(s) for s in ontology.dialog_act_all_slots
        ]
        aspn_masks[self.vocab.encode(
            '<go_a>')] = all_dom_codes + [self.vocab.encode('<eos_a>'), 0]
        aspn_masks[self.vocab.encode('<eos_a>')] = [self.vocab.encode('<pad>')]
        aspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')]
        # for d in all_dom_codes:
        #     aspn_masks[d] = all_act_codes
        for a in all_act_codes:
            aspn_masks[a] = all_dom_codes + all_slot_codes + [
                self.vocab.encode('<eos_a>')
            ]
        for domain, acts in ontology.dialog_acts.items():
            dom_code = self.vocab.encode('[' + domain + ']')
            aspn_masks[dom_code] = []
            for a in acts:
                act_code = self.vocab.encode('[' + a + ']')
                if act_code not in aspn_masks[dom_code]:
                    aspn_masks[dom_code].append(act_code)
        # for a, slots in ontology.dialog_act_params.items():
        #     act_code = self.vocab.encode('['+a+']')
        #     slot_codes = [self.vocab.encode(s) for s in slots]
        #     aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')]
        for s in all_slot_codes:
            aspn_masks[s] = all_dom_codes + all_slot_codes + [
                self.vocab.encode('<eos_a>')
            ]
        aspn_masks[self.vocab.encode('<unk>')] = list(aspn_masks.keys())

        with open('data/multi-woz-processed/aspn_masks.txt', 'w') as f:
            for i, j in aspn_masks.items():
                f.write(
                    self.vocab.decode(i) + ': ' +
                    ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n')
        return aspn_masks

    def _load_data(self, save_temp=False):
        self.data = json.loads(
            open(cfg.data_path + cfg.data_file, 'r',
                 encoding='utf-8').read().lower())
        self.train, self.dev, self.test = [], [], []
        for fn, dial in self.data.items():
            if 'all' in cfg.exp_domains or self.exp_files.get(fn):
                if self.dev_files.get(fn):
                    self.dev.append(self._get_encoded_data(fn, dial))
                elif self.test_files.get(fn):
                    self.test.append(self._get_encoded_data(fn, dial))
                else:
                    self.train.append(self._get_encoded_data(fn, dial))
        if save_temp:
            json.dump(self.test,
                      open('data/multi-woz-analysis/test.encoded.json', 'w'),
                      indent=2)
            self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp')

        random.shuffle(self.train)
        random.shuffle(self.dev)
        random.shuffle(self.test)

    def _get_encoded_data(self, fn, dial):
        encoded_dial = []
        for idx, t in enumerate(dial['log']):
            enc = {}
            enc['dial_id'] = fn
            enc['user'] = self.vocab.sentence_encode(t['user'].split() +
                                                     ['<eos_u>'])
            enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() +
                                                     ['<eos_u>'])
            enc['resp'] = self.vocab.sentence_encode(t['resp'].split() +
                                                     ['<eos_r>'])
            enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() +
                                                     ['<eos_b>'])
            enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() +
                                                     ['<eos_b>'])
            enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() +
                                                     ['<eos_a>'])
            enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() +
                                                     ['<eos_d>'])
            enc['pointer'] = [int(i) for i in t['pointer'].split(',')]
            enc['turn_domain'] = t['turn_domain'].split()
            enc['turn_num'] = t['turn_num']

            # TRADE labels
            enc["gating_label"] = [
                self.gating_dict[label]
                for label in t["gating_label"].split(",")
            ]
            enc["ptr_label"] = []
            for labels in t["ptr_label"].split(","):
                temp = []
                for label in labels.split():
                    label_idx = self.vocab.encode(label)
                    if label_idx >= self.vocab.vocab_size:
                        temp.append(self.vocab.encode("<unk>"))
                    else:
                        temp.append(label_idx)
                temp.append(self.vocab.encode("<eos_b>"))
                enc["ptr_label"].append(temp)

            if cfg.multi_acts_training:
                enc['aspn_aug'] = []
                if fn in self.multi_acts:
                    turn_ma = self.multi_acts[fn].get(str(idx), {})
                    for act_type, act_spans in turn_ma.items():
                        enc['aspn_aug'].append([
                            self.vocab.sentence_encode(a.split() + ['<eos_a>'])
                            for a in act_spans
                        ])

            encoded_dial.append(enc)
        return encoded_dial

    def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'):
        bspan = bspan.split() if isinstance(bspan, str) else bspan
        constraint_dict = {}
        domain = None
        conslen = len(bspan)
        for idx, cons in enumerate(bspan):
            cons = self.vocab.decode(cons) if type(cons) is not str else cons
            if cons == '<eos_b>':
                break
            if '[' in cons:
                if cons[1:-1] not in ontology.all_domains:
                    continue
                domain = cons[1:-1]
            elif cons in ontology.get_slot:
                if domain is None:
                    continue
                if cons == 'people':
                    # handle confusion of value name "people's portraits..." and slot people
                    try:
                        ns = bspan[idx + 1]
                        ns = self.vocab.decode(ns) if type(
                            ns) is not str else ns
                        if ns == "'s":
                            continue
                    except:
                        continue
                if not constraint_dict.get(domain):
                    constraint_dict[domain] = {}
                if bspn_mode == 'bsdx':
                    constraint_dict[domain][cons] = 1
                    continue
                vidx = idx + 1
                if vidx == conslen:
                    break
                vt_collect = []
                vt = bspan[vidx]
                vt = self.vocab.decode(vt) if type(vt) is not str else vt
                while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot:
                    vt_collect.append(vt)
                    vidx += 1
                    if vidx == conslen:
                        break
                    vt = bspan[vidx]
                    vt = self.vocab.decode(vt) if type(vt) is not str else vt
                if vt_collect:
                    constraint_dict[domain][cons] = ' '.join(vt_collect)

        return constraint_dict

    def bspan_to_DBpointer(self, bspan, turn_domain):
        constraint_dict = self.bspan_to_constraint_dict(bspan)
        # print(constraint_dict)
        matnums = self.db.get_match_num(constraint_dict)
        match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
        match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom
        match = matnums[match_dom]
        vector = self.db.addDBPointer(match_dom, match)
        return vector

    def aspan_to_act_list(self, aspan):
        aspan = aspan.split() if isinstance(aspan, str) else aspan
        acts = []
        domain = None
        conslen = len(aspan)
        for idx, cons in enumerate(aspan):
            cons = self.vocab.decode(cons) if type(cons) is not str else cons
            if cons == '<eos_a>':
                break
            if '[' in cons and cons[1:-1] in ontology.dialog_acts:
                domain = cons[1:-1]

            elif '[' in cons and cons[1:-1] in ontology.dialog_act_params:
                if domain is None:
                    continue
                vidx = idx + 1
                if vidx == conslen:
                    acts.append(domain + '-' + cons[1:-1] + '-none')
                    break
                vt = aspan[vidx]
                vt = self.vocab.decode(vt) if type(vt) is not str else vt
                no_param_act = True
                while vidx < conslen and vt != '<eos_a>' and '[' not in vt:
                    no_param_act = False
                    acts.append(domain + '-' + cons[1:-1] + '-' + vt)
                    vidx += 1
                    if vidx == conslen:
                        break
                    vt = aspan[vidx]
                    vt = self.vocab.decode(vt) if type(vt) is not str else vt
                if no_param_act:
                    acts.append(domain + '-' + cons[1:-1] + '-none')

        return acts

    def dspan_to_domain(self, dspan):
        domains = {}
        dspan = dspan.split() if isinstance(dspan, str) else dspan
        for d in dspan:
            dom = self.vocab.decode(d) if type(d) is not str else d
            if dom != '<eos_d>':
                domains[dom] = 1
            else:
                break
        return domains

    def convert_batch(self, py_batch, py_prev, first_turn=False):
        inputs = {}
        if first_turn:
            for item, py_list in py_prev.items():
                batch_size = len(py_batch['user'])
                inputs[item + '_np'] = np.array([[1]] * batch_size)
                inputs[item + '_unk_np'] = np.array([[1]] * batch_size)
        else:
            for item, py_list in py_prev.items():
                if py_list is None:
                    continue
                if not cfg.enable_aspn and 'aspn' in item:
                    continue
                if not cfg.enable_bspn and 'bspn' in item:
                    continue
                if not cfg.enable_dspn and 'dspn' in item:
                    continue
                prev_np = utils.padSeqs(py_list,
                                        truncated=cfg.truncated,
                                        trunc_method='pre')
                inputs[item + '_np'] = prev_np
                if item in ['pv_resp', 'pv_bspn', "pv_aspn"]:
                    inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np'])
                    inputs[item + '_unk_np'][inputs[item + '_unk_np'] >=
                                             self.vocab_size] = 2  # <unk>
                else:
                    inputs[item + '_unk_np'] = inputs[item + '_np']

        for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
            if not cfg.enable_aspn and item == 'aspn':
                continue
            if not cfg.enable_bspn and item == 'bspn':
                continue

            if not cfg.enable_dspn and item == 'dspn':
                continue
            py_list = py_batch[item]
            trunc_method = 'post' if item == 'resp' else 'pre'
            # max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length
            inputs[item + '_np'] = utils.padSeqs(py_list,
                                                 truncated=cfg.truncated,
                                                 trunc_method=trunc_method)
            if item in ['user', 'usdx', 'resp', 'bspn', "aspn"]:
                inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np'])
                inputs[item + '_unk_np'][
                    inputs[item + '_unk_np'] >= self.vocab_size] = 2  # <unk>
            else:
                inputs[item + '_unk_np'] = inputs[item + '_np']

        if cfg.multi_acts_training and cfg.mode == 'train':
            inputs['aspn_bidx'], multi_aspn = [], []
            for bidx, aspn_type_list in enumerate(py_batch['aspn_aug']):
                if aspn_type_list:
                    for aspn_list in aspn_type_list:
                        random.shuffle(aspn_list)
                        aspn = aspn_list[
                            0]  #choose one random act span in each act type
                        multi_aspn.append(aspn)
                        inputs['aspn_bidx'].append(bidx)
                        if cfg.multi_act_sampling_num > 1:
                            for i in range(cfg.multi_act_sampling_num):
                                if len(aspn_list) >= i + 2:
                                    aspn = aspn_list[
                                        i +
                                        1]  #choose one random act span in each act type
                                    multi_aspn.append(aspn)
                                    inputs['aspn_bidx'].append(bidx)

            if multi_aspn:
                inputs['aspn_aug_np'] = utils.padSeqs(multi_aspn,
                                                      truncated=cfg.truncated,
                                                      trunc_method='pre')
                inputs['aspn_aug_unk_np'] = inputs[
                    'aspn_aug_np']  # [all available aspn num in the batch, T]

        inputs['db_np'] = np.array(py_batch['pointer'])
        inputs['turn_domain'] = py_batch['turn_domain']

        # TRADE
        inputs["gating_label_np"] = np.array(py_batch["gating_label"])
        batch_size = len(py_batch["ptr_label"])
        slot_num = inputs["gating_label_np"].shape[1]
        ptr_max_len = max(
            [len(label) for batch in py_batch["ptr_label"] for label in batch])
        inputs["ptr_label_np"] = np.zeros((batch_size, slot_num, ptr_max_len))
        for bidx, batch in enumerate(py_batch["ptr_label"]):
            for idx, label in enumerate(batch):
                label_len = len(label)
                inputs["ptr_label_np"][bidx, idx, :label_len] = np.array(label)

        return inputs

    def wrap_result(self, result_dict, eos_syntax=None):
        decode_fn = self.vocab.sentence_decode
        results = []
        eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax

        gating_dict = {'ptr': 0, 'dontcare': 1, 'none': 2}
        slot_list = [
            'hotel-pricerange', 'hotel-type', 'hotel-parking', 'hotel-stay', 'hotel-day', 'hotel-people', \
            'hotel-area', 'hotel-stars', 'hotel-internet', 'train-destination', 'train-day', 'train-departure', 'train-arriveby', \
            'train-people', 'train-leaveat', 'attraction-area', 'restaurant-food', 'restaurant-pricerange', 'restaurant-area', \
            'attraction-name', 'restaurant-name', 'attraction-type', 'hotel-name', 'taxi-leaveat', 'taxi-destination', 'taxi-departure', \
            'restaurant-time', 'restaurant-day', 'restaurant-people', 'taxi-arriveby', "hospital-department"
        ]

        if cfg.enable_trade:
            field = [
                'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer',
                "gating_label", "trade_gate", "ptr_label", "trade_ptr"
            ]
        elif cfg.bspn_mode == 'bspn':
            field = [
                'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer'
            ]
        elif not cfg.enable_dst:
            field = [
                'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn',
                'pointer'
            ]
        else:
            field = [
                'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn_gen',
                'bspn', 'pointer'
            ]
        if self.multi_acts_record is not None:
            field.insert(7, 'multi_act_gen')

        for dial_id, turns in result_dict.items():
            entry = {'dial_id': dial_id, 'turn_num': len(turns)}
            for prop in field[2:]:
                entry[prop] = ''
            results.append(entry)  # info of dialogue
            for turn_no, turn in enumerate(turns):
                entry = {'dial_id': dial_id}
                trade_gate = []
                for key in field:
                    if key in ['dial_id']:
                        continue
                    v = turn.get(key, '')
                    if key == "trade_gate":
                        trade_gate = v
                        entry[key] = v
                        continue
                    elif key == "trade_ptr":
                        trade_ptr = []
                        for idx, slot in enumerate(v):
                            if trade_gate[idx].item() == 0:
                                trade_ptr.append(
                                    decode_fn(slot.tolist(),
                                              eos=eos_syntax["trade"]))
                            elif trade_gate[idx].item() == 1:
                                trade_ptr.append("do n't care")
                            else:
                                trade_ptr.append("none")
                        entry[key] = trade_ptr
                        continue
                    elif key == "ptr_label":
                        ptr_label = []
                        for slot in v:
                            ptr_label.append(
                                decode_fn(slot, eos=eos_syntax["trade"]))
                        entry[key] = ptr_label
                        continue

                    if key == 'turn_domain':
                        v = ' '.join(v)
                    entry[key] = decode_fn(
                        v, eos=eos_syntax[key]
                    ) if key in eos_syntax and v != '' else v
                results.append(entry)
        return results, field

    def restore(self, resp, domain, constraint_dict, mat_ents):
        restored = resp

        restored = restored.replace('[value_reference]', '53022')
        restored = restored.replace('[value_car]', 'BMW')

        # restored.replace('[value_phone]', '830-430-6666')
        for d in domain:
            constraint = constraint_dict.get(d, None)
            if constraint:
                if 'stay' in constraint:
                    restored = restored.replace('[value_stay]',
                                                constraint['stay'])
                if 'day' in constraint:
                    restored = restored.replace('[value_day]',
                                                constraint['day'])
                if 'people' in constraint:
                    restored = restored.replace('[value_people]',
                                                constraint['people'])
                if 'time' in constraint:
                    restored = restored.replace('[value_time]',
                                                constraint['time'])
                if 'type' in constraint:
                    restored = restored.replace('[value_type]',
                                                constraint['type'])
                if d in mat_ents and len(mat_ents[d]) == 0:
                    for s in constraint:
                        if s == 'pricerange' and d in [
                                'hotel', 'restaurant'
                        ] and 'price]' in restored:
                            restored = restored.replace(
                                '[value_price]', constraint['pricerange'])
                        if s + ']' in restored:
                            restored = restored.replace(
                                '[value_%s]' % s, constraint[s])

            if '[value_choice' in restored and mat_ents.get(d):
                restored = restored.replace('[value_choice]',
                                            str(len(mat_ents[d])))
        if '[value_choice' in restored:
            restored = restored.replace('[value_choice]', '3')

        # restored.replace('[value_car]', 'BMW')

        try:
            ent = mat_ents.get(domain[-1], [])
            if ent:
                ent = ent[0]

                for t in restored.split():
                    if '[value' in t:
                        slot = t[7:-1]
                        if ent.get(slot):
                            if domain[-1] == 'hotel' and slot == 'price':
                                slot = 'pricerange'
                            restored = restored.replace(t, ent[slot])
                        elif slot == 'price':
                            if ent.get('pricerange'):
                                restored = restored.replace(
                                    t, ent['pricerange'])
                            else:
                                print(restored, domain)
        except:
            print(resp)
            print(restored)
            quit()

        restored = restored.replace('[value_phone]', '62781111')
        restored = restored.replace('[value_postcode]', 'CG9566')
        restored = restored.replace('[value_address]', 'Parkside, Cambridge')

        # if '[value_' in restored:

        #     print(domain)
        #     # print(mat_ents)
        #     print(resp)
        #     print(restored)
        return restored

    def record_utterance(self, result_dict):
        decode_fn = self.vocab.sentence_decode

        ordered_dial = {}
        for dial_id, turns in result_dict.items():
            diverse = 0
            turn_count = 0
            for turn_no, turn in enumerate(turns):
                act_collect = {}
                act_type_collect = {}
                slot_score = 0
                for i in range(cfg.nbest):
                    aspn = decode_fn(turn['multi_act'][i],
                                     eos=ontology.eos_tokens['aspn'])
                    pred_acts = self.aspan_to_act_list(' '.join(aspn))
                    act_type = ''
                    for act in pred_acts:
                        d, a, s = act.split('-')
                        if d + '-' + a not in act_collect:
                            act_collect[d + '-' + a] = {s: 1}
                            slot_score += 1
                            act_type += d + '-' + a + ';'
                        elif s not in act_collect:
                            act_collect[d + '-' + a][s] = 1
                            slot_score += 1
                    act_type_collect[act_type] = 1
                turn_count += 1
                diverse += len(act_collect) * 3 + slot_score
            ordered_dial[dial_id] = diverse / turn_count

        ordered_dial = sorted(ordered_dial.keys(),
                              key=lambda x: -ordered_dial[x])

        dialog_record = {}

        with open(cfg.eval_load_path + '/dialogue_record.csv', 'w') as rf:
            writer = csv.writer(rf)

            for dial_id in ordered_dial:
                dialog_record[dial_id] = []
                turns = result_dict[dial_id]
                writer.writerow([dial_id])
                for turn_no, turn in enumerate(turns):
                    user = decode_fn(turn['user'],
                                     eos=ontology.eos_tokens['user'])
                    bspn = decode_fn(turn['bspn'],
                                     eos=ontology.eos_tokens['bspn'])
                    aspn = decode_fn(turn['aspn'],
                                     eos=ontology.eos_tokens['aspn'])
                    resp = decode_fn(turn['resp'],
                                     eos=ontology.eos_tokens['resp'])
                    constraint_dict = self.bspan_to_constraint_dict(bspn)
                    # print(constraint_dict)
                    mat_ents = self.db.get_match_num(constraint_dict, True)
                    domain = [
                        i[1:-1]
                        for i in self.dspan_to_domain(turn['dspn']).keys()
                    ]
                    restored = self.restore(resp, domain, constraint_dict,
                                            mat_ents)
                    writer.writerow([
                        turn_no, user, turn['pointer'], domain, restored, resp
                    ])
                    turn_record = {
                        'user': user,
                        'bspn': bspn,
                        'aspn': aspn,
                        'dom': domain,
                        'resp': resp,
                        'resp_res': restored
                    }

                    resp_col = []
                    aspn_col = []
                    resp_restore_col = []
                    for i in range(cfg.nbest):
                        aspn = decode_fn(turn['multi_act'][i],
                                         eos=ontology.eos_tokens['aspn'])
                        resp = decode_fn(turn['multi_resp'][i],
                                         eos=ontology.eos_tokens['resp'])

                        restored = self.restore(resp, domain, constraint_dict,
                                                mat_ents)
                        resp_col.append(resp)
                        resp_restore_col.append(restored)
                        aspn_col.append(aspn)

                    zipped = list(zip(resp_restore_col, resp_col, aspn_col))
                    zipped.sort(key=lambda s: len(s[0]))
                    resp_restore_col = list(list(zip(*zipped))[0])
                    aspn_col = list(list(zip(*zipped))[2])
                    resp_col = list(list(zip(*zipped))[1])
                    turn_record['aspn_col'] = aspn_col
                    turn_record['resp_col'] = resp_col
                    turn_record['resp_res_col'] = resp_restore_col
                    for i in range(cfg.nbest):
                        # aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn'])
                        resp = resp_col[i]
                        aspn = aspn_col[i]
                        resp_restore = resp_restore_col[i]

                        writer.writerow(['', resp_restore, resp, aspn])

                    dialog_record[dial_id].append(turn_record)
コード例 #6
0
ファイル: reader.py プロジェクト: Verylovenlp/MinTL-SKKU
class MultiWozReader(_ReaderBase):
    def __init__(self, vocab=None):
        super().__init__()
        self.nlp = spacy.load('en_core_web_sm')
        self.db = MultiWozDB(cfg.dbs)

        self.domain_files = json.loads(open(cfg.domain_file_path, 'r').read())
        self.slot_value_set = json.loads(
            open(cfg.slot_value_set_path, 'r').read())
        if cfg.multi_acts_training:
            self.multi_acts = json.loads(open(cfg.multi_acts_path, 'r').read())

        test_list = [
            l.strip().lower() for l in open(cfg.test_list, 'r').readlines()
        ]
        dev_list = [
            l.strip().lower() for l in open(cfg.dev_list, 'r').readlines()
        ]
        self.dev_files, self.test_files = {}, {}
        for fn in test_list:
            self.test_files[fn.replace('.json', '')] = 1
        for fn in dev_list:
            self.dev_files[fn.replace('.json', '')] = 1

        self.exp_files = {}
        if 'all' not in cfg.exp_domains:
            for domain in cfg.exp_domains:
                fn_list = self.domain_files.get(domain)
                if not fn_list:
                    raise ValueError('[%s] is an invalid experiment setting' %
                                     domain)
                for fn in fn_list:
                    self.exp_files[fn.replace('.json', '')] = 1

        if vocab:
            self.vocab = vocab
            self.vocab_size = vocab.size

        else:
            self.vocab_size = self._build_vocab()
        self._load_data()

        if cfg.limit_bspn_vocab:
            self.bspn_masks = self._construct_bspn_constraint()
        if cfg.limit_aspn_vocab:
            self.aspn_masks = self._construct_aspn_constraint()

        self.multi_acts_record = None

    def _build_vocab(self):
        self.vocab = utils.Vocab(cfg.vocab_size)
        vp = cfg.vocab_path_train if cfg.mode == 'train' or cfg.vocab_path_eval is None else cfg.vocab_path_eval
        # vp = cfg.vocab_path+'.json.freq.json'
        self.vocab.load_vocab(vp)
        return self.vocab.vocab_size

    def _construct_bspn_constraint(self):
        bspn_masks = {}
        valid_domains = [
            'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'
        ]
        all_dom_codes = [
            self.vocab.encode('[' + d + ']') for d in valid_domains
        ]
        all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots]
        bspn_masks[self.vocab.encode(
            '<go_b>')] = all_dom_codes + [self.vocab.encode('<eos_b>'), 0]
        bspn_masks[self.vocab.encode('<eos_b>')] = [self.vocab.encode('<pad>')]
        bspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')]
        for domain, slot_values in self.slot_value_set.items():
            if domain == 'police':
                continue
            dom_code = self.vocab.encode('[' + domain + ']')
            bspn_masks[dom_code] = []
            for slot, values in slot_values.items():
                slot_code = self.vocab.encode(slot)
                if slot_code not in bspn_masks:
                    bspn_masks[slot_code] = []
                if slot_code not in bspn_masks[dom_code]:
                    bspn_masks[dom_code].append(slot_code)
                for value in values:
                    for idx, v in enumerate(value.split()):
                        if not self.vocab.has_word(v):
                            continue
                        v_code = self.vocab.encode(v)
                        if v_code not in bspn_masks:
                            # print(self.vocab._word2idx)
                            bspn_masks[v_code] = []
                        if idx == 0 and v_code not in bspn_masks[slot_code]:
                            bspn_masks[slot_code].append(v_code)
                        if idx == (len(value.split()) - 1):
                            for w in all_dom_codes + all_slot_codes:
                                if self.vocab.encode(
                                        '<eos_b>') not in bspn_masks[v_code]:
                                    bspn_masks[v_code].append(
                                        self.vocab.encode('<eos_b>'))
                                if w not in bspn_masks[v_code]:
                                    bspn_masks[v_code].append(w)
                            break
                        if not self.vocab.has_word(value.split()[idx + 1]):
                            continue
                        next_v_code = self.vocab.encode(value.split()[idx + 1])
                        if next_v_code not in bspn_masks[v_code]:
                            bspn_masks[v_code].append(next_v_code)
        bspn_masks[self.vocab.encode('<unk>')] = list(bspn_masks.keys())

        with open('data/multi-woz-processed/bspn_masks.txt', 'w') as f:
            for i, j in bspn_masks.items():
                f.write(
                    self.vocab.decode(i) + ': ' +
                    ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n')
        return bspn_masks

    def _construct_aspn_constraint(self):
        aspn_masks = {}
        aspn_masks = {}
        all_dom_codes = [
            self.vocab.encode('[' + d + ']')
            for d in ontology.dialog_acts.keys()
        ]
        all_act_codes = [
            self.vocab.encode('[' + a + ']')
            for a in ontology.dialog_act_params
        ]
        all_slot_codes = [
            self.vocab.encode(s) for s in ontology.dialog_act_all_slots
        ]
        aspn_masks[self.vocab.encode(
            '<go_a>')] = all_dom_codes + [self.vocab.encode('<eos_a>'), 0]
        aspn_masks[self.vocab.encode('<eos_a>')] = [self.vocab.encode('<pad>')]
        aspn_masks[self.vocab.encode('<pad>')] = [self.vocab.encode('<pad>')]
        # for d in all_dom_codes:
        #     aspn_masks[d] = all_act_codes
        for a in all_act_codes:
            aspn_masks[a] = all_dom_codes + all_slot_codes + [
                self.vocab.encode('<eos_a>')
            ]
        for domain, acts in ontology.dialog_acts.items():
            dom_code = self.vocab.encode('[' + domain + ']')
            aspn_masks[dom_code] = []
            for a in acts:
                act_code = self.vocab.encode('[' + a + ']')
                if act_code not in aspn_masks[dom_code]:
                    aspn_masks[dom_code].append(act_code)
        # for a, slots in ontology.dialog_act_params.items():
        #     act_code = self.vocab.encode('['+a+']')
        #     slot_codes = [self.vocab.encode(s) for s in slots]
        #     aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')]
        for s in all_slot_codes:
            aspn_masks[s] = all_dom_codes + all_slot_codes + [
                self.vocab.encode('<eos_a>')
            ]
        aspn_masks[self.vocab.encode('<unk>')] = list(aspn_masks.keys())

        with open('data/multi-woz-processed/aspn_masks.txt', 'w') as f:
            for i, j in aspn_masks.items():
                f.write(
                    self.vocab.decode(i) + ': ' +
                    ' '.join([self.vocab.decode(int(m)) for m in j]) + '\n')
        return aspn_masks

    def _load_data(self, save_temp=False):
        self.data = json.loads(
            open(cfg.data_path + cfg.data_file, 'r',
                 encoding='utf-8').read().lower())
        self.train, self.dev, self.test = [], [], []

        # data_fraction = 0.05
        # train_count = 0
        for fn, dial in self.data.items():
            #print(fn)
            if 'all' in cfg.exp_domains or self.exp_files.get(fn):
                if self.dev_files.get(fn):
                    self.dev.append(self._get_encoded_data(fn, dial))
                elif self.test_files.get(fn):
                    self.test.append(self._get_encoded_data(fn, dial))
                else:
                    # if train_count>round(data_fraction*8438):
                    #     continue
                    self.train.append(self._get_encoded_data(fn, dial))
                    # train_count+=1
        if save_temp:
            json.dump(self.test,
                      open('data/multi-woz-analysis/test.encoded.json', 'w'),
                      indent=2)
            self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp')

        random.shuffle(self.train)
        random.shuffle(self.dev)
        random.shuffle(self.test)

    def _get_encoded_data(self, fn, dial):
        encoded_dial = []
        for idx, t in enumerate(dial['log']):
            enc = {}
            enc['dial_id'] = fn
            enc['user'] = self.vocab.sentence_encode(t['user'].split() +
                                                     ['<eos_u>'])
            enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() +
                                                     ['<eos_u>'])
            enc['resp'] = self.vocab.sentence_encode(t['resp'].split() +
                                                     ['<eos_r>'])
            enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() +
                                                     ['<eos_b>'])
            enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() +
                                                     ['<eos_b>'])
            enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() +
                                                     ['<eos_a>'])
            enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() +
                                                     ['<eos_d>'])
            enc['pointer'] = [int(i) for i in t['pointer'].split(',')]
            enc['input_pointer'] = enc['pointer']
            enc['turn_domain'] = t['turn_domain'].split()
            enc['turn_num'] = t['turn_num']
            if cfg.multi_acts_training:
                enc['aspn_aug'] = []
                if fn in self.multi_acts:
                    turn_ma = self.multi_acts[fn].get(str(idx), {})
                    for act_type, act_spans in turn_ma.items():
                        enc['aspn_aug'].append([
                            self.vocab.sentence_encode(a.split() + ['<eos_a>'])
                            for a in act_spans
                        ])

            encoded_dial.append(enc)
        return encoded_dial

    def bspan_to_constraint_dict(self, bspan, bspn_mode='bspn'):
        bspan = bspan.split() if isinstance(bspan, str) else bspan
        constraint_dict = {}
        domain = None
        conslen = len(bspan)
        for idx, cons in enumerate(bspan):
            cons = self.vocab.decode(cons) if type(cons) is not str else cons
            if cons == '<eos_b>':
                break
            if '[' in cons:
                if cons[1:-1] not in ontology.all_domains:
                    continue
                domain = cons[1:-1]
            elif cons in ontology.get_slot:
                if domain is None:
                    continue
                if cons == 'people':
                    # handle confusion of value name "people's portraits..." and slot people
                    try:
                        ns = bspan[idx + 1]
                        ns = self.vocab.decode(ns) if type(
                            ns) is not str else ns
                        if ns == "'s":
                            continue
                    except:
                        continue
                if not constraint_dict.get(domain):
                    constraint_dict[domain] = {}
                if bspn_mode == 'bsdx':
                    constraint_dict[domain][cons] = 1
                    continue
                vidx = idx + 1
                if vidx == conslen:
                    break
                vt_collect = []
                vt = bspan[vidx]
                vt = self.vocab.decode(vt) if type(vt) is not str else vt
                while vidx < conslen and vt != '<eos_b>' and '[' not in vt and vt not in ontology.get_slot:
                    vt_collect.append(vt)
                    vidx += 1
                    if vidx == conslen:
                        break
                    vt = bspan[vidx]
                    vt = self.vocab.decode(vt) if type(vt) is not str else vt
                if vt_collect:
                    constraint_dict[domain][cons] = ' '.join(vt_collect)

        return constraint_dict

    def bspan_to_DBpointer(self, bspan, turn_domain):
        constraint_dict = self.bspan_to_constraint_dict(bspan)
        # print(constraint_dict)
        matnums = self.db.get_match_num(constraint_dict)
        match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
        match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom
        match = matnums[match_dom]
        vector = self.db.addDBPointer(match_dom, match)
        return vector

    def aspan_to_act_list(self, aspan):
        aspan = aspan.split() if isinstance(aspan, str) else aspan
        acts = []
        domain = None
        conslen = len(aspan)
        for idx, cons in enumerate(aspan):
            cons = self.vocab.decode(cons) if type(cons) is not str else cons
            if cons == '<eos_a>':
                break
            if '[' in cons and cons[1:-1] in ontology.dialog_acts:
                domain = cons[1:-1]

            elif '[' in cons and cons[1:-1] in ontology.dialog_act_params:
                if domain is None:
                    continue
                vidx = idx + 1
                if vidx == conslen:
                    acts.append(domain + '-' + cons[1:-1] + '-none')
                    break
                vt = aspan[vidx]
                vt = self.vocab.decode(vt) if type(vt) is not str else vt
                no_param_act = True
                while vidx < conslen and vt != '<eos_a>' and '[' not in vt:
                    no_param_act = False
                    acts.append(domain + '-' + cons[1:-1] + '-' + vt)
                    vidx += 1
                    if vidx == conslen:
                        break
                    vt = aspan[vidx]
                    vt = self.vocab.decode(vt) if type(vt) is not str else vt
                if no_param_act:
                    acts.append(domain + '-' + cons[1:-1] + '-none')

        return acts

    def dspan_to_domain(self, dspan):
        domains = {}
        dspan = dspan.split() if isinstance(dspan, str) else dspan
        for d in dspan:
            dom = self.vocab.decode(d) if type(d) is not str else d
            if dom != '<eos_d>':
                domains[dom] = 1
            else:
                break
        return domains

    def convert_batch(self, py_batch, py_prev, first_turn=False):
        inputs = {}
        if first_turn:
            for item, py_list in py_prev.items():
                batch_size = len(py_batch['user'])
                inputs[item + '_np'] = np.array([[1]] * batch_size)
                inputs[item + '_unk_np'] = np.array([[1]] * batch_size)
        else:
            for item, py_list in py_prev.items():
                if py_list is None:
                    continue
                if not cfg.enable_aspn and 'aspn' in item:
                    continue
                if not cfg.enable_bspn and 'bspn' in item:
                    continue
                if not cfg.enable_dspn and 'dspn' in item:
                    continue
                prev_np = utils.padSeqs(py_list,
                                        truncated=cfg.truncated,
                                        trunc_method='pre')
                inputs[item + '_np'] = prev_np
                if item in ['pv_resp', 'pv_bspn']:
                    inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np'])
                    inputs[item + '_unk_np'][inputs[item + '_unk_np'] >=
                                             self.vocab_size] = 2  # <unk>
                else:
                    inputs[item + '_unk_np'] = inputs[item + '_np']

        for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
            if not cfg.enable_aspn and item == 'aspn':
                continue
            if not cfg.enable_bspn and item == 'bspn':
                continue

            if not cfg.enable_dspn and item == 'dspn':
                continue
            py_list = py_batch[item]
            trunc_method = 'post' if item == 'resp' else 'pre'
            # max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length
            inputs[item + '_np'] = utils.padSeqs(py_list,
                                                 truncated=cfg.truncated,
                                                 trunc_method=trunc_method)
            if item in ['user', 'usdx', 'resp', 'bspn']:
                inputs[item + '_unk_np'] = deepcopy(inputs[item + '_np'])
                inputs[item + '_unk_np'][
                    inputs[item + '_unk_np'] >= self.vocab_size] = 2  # <unk>
            else:
                inputs[item + '_unk_np'] = inputs[item + '_np']

        if cfg.multi_acts_training and cfg.mode == 'train':
            inputs['aspn_bidx'], multi_aspn = [], []
            for bidx, aspn_type_list in enumerate(py_batch['aspn_aug']):
                if aspn_type_list:
                    for aspn_list in aspn_type_list:
                        random.shuffle(aspn_list)
                        aspn = aspn_list[
                            0]  #choose one random act span in each act type
                        multi_aspn.append(aspn)
                        inputs['aspn_bidx'].append(bidx)
                        if cfg.multi_act_sampling_num > 1:
                            for i in range(cfg.multi_act_sampling_num):
                                if len(aspn_list) >= i + 2:
                                    aspn = aspn_list[
                                        i +
                                        1]  #choose one random act span in each act type
                                    multi_aspn.append(aspn)
                                    inputs['aspn_bidx'].append(bidx)

            if multi_aspn:
                inputs['aspn_aug_np'] = utils.padSeqs(multi_aspn,
                                                      truncated=cfg.truncated,
                                                      trunc_method='pre')
                inputs['aspn_aug_unk_np'] = inputs[
                    'aspn_aug_np']  # [all available aspn num in the batch, T]

        inputs['db_np'] = np.array(py_batch['input_pointer'])
        inputs['turn_domain'] = py_batch['turn_domain']

        return inputs

    def wrap_result(self, result_dict, eos_syntax=None):
        decode_fn = self.vocab.sentence_decode
        results = []
        eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax

        if cfg.bspn_mode == 'bspn':
            field = [
                'dial_id', 'turn_num', 'user', 'bspn_gen', 'bspn', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'pointer'
            ]
        elif not cfg.enable_dst:
            field = [
                'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn',
                'pointer'
            ]
        else:
            field = [
                'dial_id', 'turn_num', 'user', 'bsdx_gen', 'bsdx', 'resp_gen',
                'resp', 'aspn_gen', 'aspn', 'dspn_gen', 'dspn', 'bspn_gen',
                'bspn', 'pointer'
            ]
        if self.multi_acts_record is not None:
            field.insert(7, 'multi_act_gen')

        for dial_id, turns in result_dict.items():
            entry = {'dial_id': dial_id, 'turn_num': len(turns)}
            for prop in field[2:]:
                entry[prop] = ''
            results.append(entry)
            for turn_no, turn in enumerate(turns):
                entry = {'dial_id': dial_id}
                for key in field:
                    if key in ['dial_id']:
                        continue
                    v = turn.get(key, '')
                    if key == 'turn_domain':
                        v = ' '.join(v)
                    entry[key] = decode_fn(
                        v, eos=eos_syntax[key]
                    ) if key in eos_syntax and v != '' else v
                results.append(entry)
        return results, field

    def restore(self, resp, domain, constraint_dict):
        restored = resp
        restored = restored.capitalize()
        restored = restored.replace(' -s', 's')
        restored = restored.replace(' -ly', 'ly')
        restored = restored.replace(' -er', 'er')

        mat_ents = self.db.get_match_num(constraint_dict, True)

        ref = random.choice(self.delex_refs)
        restored = restored.replace('[value_reference]', ref.upper())
        restored = restored.replace('[value_car]', 'BMW')

        # restored.replace('[value_phone]', '830-430-6666')
        for d in domain:
            constraint = constraint_dict.get(d, None)
            if constraint:
                if 'stay' in constraint:
                    restored = restored.replace('[value_stay]',
                                                constraint['stay'])
                if 'day' in constraint:
                    restored = restored.replace('[value_day]',
                                                constraint['day'])
                if 'people' in constraint:
                    restored = restored.replace('[value_people]',
                                                constraint['people'])
                if 'time' in constraint:
                    restored = restored.replace('[value_time]',
                                                constraint['time'])
                if 'type' in constraint:
                    restored = restored.replace('[value_type]',
                                                constraint['type'])
                if d in mat_ents and len(mat_ents[d]) == 0:
                    for s in constraint:
                        if s == 'pricerange' and d in [
                                'hotel', 'restaurant'
                        ] and 'price]' in restored:
                            restored = restored.replace(
                                '[value_price]', constraint['pricerange'])
                        if s + ']' in restored:
                            restored = restored.replace(
                                '[value_%s]' % s, constraint[s])

            if '[value_choice' in restored and mat_ents.get(d):
                restored = restored.replace('[value_choice]',
                                            str(len(mat_ents[d])))
        if '[value_choice' in restored:
            restored = restored.replace('[value_choice]',
                                        str(random.choice([1, 2, 3, 4, 5])))

        stopwords = [
            "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you",
            "your", "yours", "yourself", "yourselves", "he", "him", "his",
            "himself", "she", "her", "hers", "herself", "it", "its", "itself",
            "they", "them", "their", "theirs", "themselves", "what", "which",
            "who", "whom", "this", "that", "these", "those", "am", "is", "are",
            "was", "were", "be", "been", "being", "have", "has", "had",
            "having", "do", "does", "did", "doing", "a", "an", "the", "and",
            "but", "if", "or", "because", "as", "until", "while", "of", "at",
            "by", "for", "with", "about", "against", "between", "into",
            "through", "during", "before", "after", "above", "below", "to",
            "from", "up", "down", "in", "out", "on", "off", "over", "under",
            "again", "further", "then", "once", "here", "there", "when",
            "where", "why", "how", "all", "any", "both", "each", "few", "more",
            "most", "other", "some", "such", "no", "nor", "not", "only", "own",
            "same", "so", "than", "too", "very", "s", "t", "can", "will",
            "just", "don", "should", "now"
        ]
        # restored.replace('[value_car]', 'BMW')

        ent = mat_ents.get(domain[-1], [])
        if ent:
            # handle multiple [value_xxx] tokens first
            restored_split = restored.split()
            token_count = Counter(restored_split)
            for idx, t in enumerate(restored_split):
                if '[value' in t and token_count[t] > 1 and token_count[
                        t] <= len(ent):
                    slot = t[7:-1]
                    pattern = r'\[' + t[1:-1] + r'\]'
                    for e in ent:
                        if e.get(slot):
                            if domain[-1] == 'hotel' and slot == 'price':
                                slot = 'pricerange'
                            if slot in ['name', 'address']:
                                rep = ' '.join([
                                    i.capitalize() if i not in stopwords else i
                                    for i in e[slot].split()
                                ])
                            elif slot in ['id', 'postcode']:
                                rep = e[slot].upper()
                            else:
                                rep = e[slot]
                            restored = re.sub(pattern, rep, restored, 1)
                        elif slot == 'price' and e.get('pricerange'):
                            restored = re.sub(pattern, e['pricerange'],
                                              restored, 1)

            # handle normal 1 entity case
            ent = ent[0]
            for t in restored.split():
                if '[value' in t:
                    slot = t[7:-1]
                    if ent.get(slot):
                        if domain[-1] == 'hotel' and slot == 'price':
                            slot = 'pricerange'
                        if slot in ['name', 'address']:
                            rep = ' '.join([
                                i.capitalize() if i not in stopwords else i
                                for i in ent[slot].split()
                            ])
                        elif slot in ['id', 'postcode']:
                            rep = ent[slot].upper()
                        else:
                            rep = ent[slot]
                        # rep = ent[slot]
                        restored = restored.replace(t, rep)
                        # restored = restored.replace(t, ent[slot])
                    elif slot == 'price' and ent.get('pricerange'):
                        restored = restored.replace(t, ent['pricerange'])
                        # else:
                        #     print(restored, domain)

        for t in restored.split():
            if '[value' in t:
                restored = restored.replace(t, 'UNKNOWN')

        restored = restored.split()
        for idx, w in enumerate(restored):
            if idx > 0 and restored[idx - 1] in ['.', '?', '!']:
                restored[idx] = restored[idx].capitalize()
        restored = ' '.join(restored)
        return restored

    # def restore(self, resp, domain, constraint_dict, mat_ents):
    #     restored = resp

    #     restored = restored.replace('[value_reference]', '53022')
    #     restored = restored.replace('[value_car]', 'BMW')

    #     # restored.replace('[value_phone]', '830-430-6666')
    #     for d in domain:
    #         constraint = constraint_dict.get(d,None)
    #         if constraint:
    #             if 'stay' in constraint:
    #                 restored = restored.replace('[value_stay]', constraint['stay'])
    #             if 'day' in constraint:
    #                 restored = restored.replace('[value_day]', constraint['day'])
    #             if 'people' in constraint:
    #                 restored = restored.replace('[value_people]', constraint['people'])
    #             if 'time' in constraint:
    #                 restored = restored.replace('[value_time]', constraint['time'])
    #             if 'type' in constraint:
    #                 restored = restored.replace('[value_type]', constraint['type'])
    #             if d in mat_ents and len(mat_ents[d])==0:
    #                 for s in constraint:
    #                     if s == 'pricerange' and d in ['hotel', 'restaurant'] and 'price]' in restored:
    #                         restored = restored.replace('[value_price]', constraint['pricerange'])
    #                     if s+']' in restored:
    #                         restored = restored.replace('[value_%s]'%s, constraint[s])

    #         if '[value_choice' in restored and mat_ents.get(d):
    #             restored = restored.replace('[value_choice]', str(len(mat_ents[d])))
    #     if '[value_choice' in restored:
    #         restored = restored.replace('[value_choice]', '3')

    #     # restored.replace('[value_car]', 'BMW')

    #     try:
    #         ent = mat_ents.get(domain[-1], [])
    #         if ent:
    #             ent = ent[0]

    #             for t in restored.split():
    #                 if '[value' in t:
    #                     slot = t[7:-1]
    #                     if ent.get(slot):
    #                         if domain[-1] == 'hotel' and slot == 'price':
    #                             slot = 'pricerange'
    #                         restored = restored.replace(t, ent[slot])
    #                     elif slot == 'price':
    #                         if ent.get('pricerange'):
    #                             restored = restored.replace(t, ent['pricerange'])
    #                         else:
    #                             print(restored, domain)
    #     except:
    #         print(resp)
    #         print(restored)
    #         quit()

    #     restored = restored.replace('[value_phone]', '62781111')
    #     restored = restored.replace('[value_postcode]', 'CG9566')
    #     restored = restored.replace('[value_address]', 'Parkside, Cambridge')

    #     return restored

    def record_utterance(self, result_dict):
        decode_fn = self.vocab.sentence_decode

        ordered_dial = {}
        for dial_id, turns in result_dict.items():
            diverse = 0
            turn_count = 0
            for turn_no, turn in enumerate(turns):
                act_collect = {}
                act_type_collect = {}
                slot_score = 0
                for i in range(cfg.nbest):
                    aspn = decode_fn(turn['multi_act'][i],
                                     eos=ontology.eos_tokens['aspn'])
                    pred_acts = self.aspan_to_act_list(' '.join(aspn))
                    act_type = ''
                    for act in pred_acts:
                        d, a, s = act.split('-')
                        if d + '-' + a not in act_collect:
                            act_collect[d + '-' + a] = {s: 1}
                            slot_score += 1
                            act_type += d + '-' + a + ';'
                        elif s not in act_collect:
                            act_collect[d + '-' + a][s] = 1
                            slot_score += 1
                    act_type_collect[act_type] = 1
                turn_count += 1
                diverse += len(act_collect) * 3 + slot_score
            ordered_dial[dial_id] = diverse / turn_count

        ordered_dial = sorted(ordered_dial.keys(),
                              key=lambda x: -ordered_dial[x])

        dialog_record = {}

        with open(cfg.eval_load_path + '/dialogue_record.csv', 'w') as rf:
            writer = csv.writer(rf)

            for dial_id in ordered_dial:
                dialog_record[dial_id] = []
                turns = result_dict[dial_id]
                writer.writerow([dial_id])
                for turn_no, turn in enumerate(turns):
                    user = decode_fn(turn['user'],
                                     eos=ontology.eos_tokens['user'])
                    bspn = decode_fn(turn['bspn'],
                                     eos=ontology.eos_tokens['bspn'])
                    aspn = decode_fn(turn['aspn'],
                                     eos=ontology.eos_tokens['aspn'])
                    resp = decode_fn(turn['resp'],
                                     eos=ontology.eos_tokens['resp'])
                    constraint_dict = self.bspan_to_constraint_dict(bspn)
                    # print(constraint_dict)
                    mat_ents = self.db.get_match_num(constraint_dict, True)
                    domain = [
                        i[1:-1]
                        for i in self.dspan_to_domain(turn['dspn']).keys()
                    ]
                    restored = self.restore(resp, domain, constraint_dict,
                                            mat_ents)
                    writer.writerow([
                        turn_no, user, turn['pointer'], domain, restored, resp
                    ])
                    turn_record = {
                        'user': user,
                        'bspn': bspn,
                        'aspn': aspn,
                        'dom': domain,
                        'resp': resp,
                        'resp_res': restored
                    }

                    resp_col = []
                    aspn_col = []
                    resp_restore_col = []
                    for i in range(cfg.nbest):
                        aspn = decode_fn(turn['multi_act'][i],
                                         eos=ontology.eos_tokens['aspn'])
                        resp = decode_fn(turn['multi_resp'][i],
                                         eos=ontology.eos_tokens['resp'])

                        restored = self.restore(resp, domain, constraint_dict,
                                                mat_ents)
                        resp_col.append(resp)
                        resp_restore_col.append(restored)
                        aspn_col.append(aspn)

                    zipped = list(zip(resp_restore_col, resp_col, aspn_col))
                    zipped.sort(key=lambda s: len(s[0]))
                    resp_restore_col = list(list(zip(*zipped))[0])
                    aspn_col = list(list(zip(*zipped))[2])
                    resp_col = list(list(zip(*zipped))[1])
                    turn_record['aspn_col'] = aspn_col
                    turn_record['resp_col'] = resp_col
                    turn_record['resp_res_col'] = resp_restore_col
                    for i in range(cfg.nbest):
                        # aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn'])
                        resp = resp_col[i]
                        aspn = aspn_col[i]
                        resp_restore = resp_restore_col[i]

                        writer.writerow(['', resp_restore, resp, aspn])

                    dialog_record[dial_id].append(turn_record)