Exemplo n.º 1
0
def load_data(except_domain, keep_pct=0):
    filename = 'data/train_dials.json'

    filtered_domains = []
    in_domain_data = []

    with open(filename) as fp:
        data = json.load(fp)

        for dialogue in data:
            dialogue['dialogue'].sort(key=lambda x: int(x['turn_idx']))

            all_domains = set(dialogue['domains'])
            # add sometimes missing domains to annotation
            for turn in dialogue['dialogue']:
                turn_belief_dict = fix_general_label_error(
                    turn["belief_state"], False, ALL_SLOTS)
                for slot_key, slot_value in turn_belief_dict.items():
                    if slot_value == 'none':
                        continue
                    domain, slot_name = slot_key.split('-', maxsplit=1)
                    all_domains.add(domain)
            dialogue['domains'] = list(all_domains)
            dialogue['domains'].sort()

            is_good_domain = True
            is_in_domain = False
            for domain in dialogue['domains']:
                if domain not in EXPERIMENT_DOMAINS:
                    is_good_domain = False
                if domain == except_domain:
                    is_in_domain = True
            if not is_good_domain:
                continue

            if is_in_domain:
                in_domain_data.append(dialogue)
            else:
                filtered_domains.append(dialogue)

    if keep_pct > 0:
        random.shuffle(in_domain_data)
        to_keep = int(keep_pct * len(in_domain_data))
        return filtered_domains + in_domain_data[:to_keep]
    return filtered_domains
def main():
    data = json.load(sys.stdin)

    new_data = []
    for dialogue in data:
        all_domains = set(dialogue['domains'])
        # add sometimes missing domains to annotation
        for turn in dialogue['dialogue']:
            turn_belief_dict = fix_general_label_error(turn["belief_state"],
                                                       False, ALL_SLOTS)
            for slot_key, slot_value in turn_belief_dict:
                if slot_value == 'none':
                    continue
                domain, slot_name = slot_key.split('-', maxsplit=1)
                all_domains.add(domain)

        if sys.argv[1] in all_domains:
            continue
        new_data.append(dialogue)

    print(len(data), len(new_data), file=sys.stderr)
    json.dump(new_data, sys.stdout, indent=2)
Exemplo n.º 3
0
def compute_templates(data):
    templates = Counter()
    distinct = set()

    for dialogue in data:
        dialogue['dialogue'].sort(key=lambda x: x['turn_idx'])

        raw_history = []
        history = []
        label_dict = dict()

        for turn in dialogue['dialogue']:
            turn_idx = turn['turn_idx']

            label_dict = fix_general_label_error(turn['belief_state'], False,
                                                 ALL_SLOTS)

            if turn_idx > 0:
                raw_history.append(turn['system_transcript'])
                system_transcript = replace_with_system_act(
                    turn['system_transcript'], turn['system_acts'])
                history.append(system_transcript)

            raw_history.append(turn['transcript'])
            transcript = replace_with_slots(turn['transcript'], label_dict)
            history.append(transcript)

            if len(label_dict) > 0:
                break

        raw_history = ' <sep> '.join(raw_history)
        history = ' <sep> '.join(history)
        templates[(history, belief_state_to_string(label_dict))] += 1
        distinct.add(raw_history)

    print(len(distinct), file=sys.stderr)
    return templates
Exemplo n.º 4
0
def get_full_belief(turn):
    dict_of_slots = fix_general_label_error(turn['belief_state'], False,
                                            ALL_SLOTS)
    return ['-'.join((el, dict_of_slots[el]))
            for el in dict_of_slots], list(dict_of_slots.keys())
Exemplo n.º 5
0
def read_langs(file_name, gating_dict, SLOTS, dataset, lang, mem_lang, sequicity, training, max_line=None):
    print(("Reading from {}".format(file_name)))
    data = []
    max_resp_len, max_value_len = 0, 0
    domain_counter = {}
    with open(file_name) as f:
        dials = json.load(f)
        # create vocab first 
        for dial_dict in dials:
            if (args.all_vocab or dataset == "train") and training:
                for ti, turn in enumerate(dial_dict["dialogue"]):
                    lang.index_words(turn["system_transcript"], 'utter')
                    lang.index_words(turn["transcript"], 'utter')
        # determine training data ratio, default is 100%
        if training and dataset == "train" and args.data_ratio != 100:
            random.Random(10).shuffle(dials)
            dials = dials[:int(len(dials) * 0.01 * args.data_ratio)]

        cnt_lin = 1
        for dial_dict in dials:
            dialog_history = ""
            last_belief_dict = {}
            # Filtering and counting domains
            for domain in dial_dict["domains"]:
                if domain not in EXPERIMENT_DOMAINS:
                    continue
                if domain not in domain_counter.keys():
                    domain_counter[domain] = 0
                domain_counter[domain] += 1

            # Unseen domain setting
            if args.only_domain != "" and args.only_domain not in dial_dict["domains"]:
                continue
            if args.except_domain != "":
                if (dataset == "test" and args.except_domain not in dial_dict["domains"]) or \
                        (dataset != "test" and [args.except_domain] == dial_dict["domains"]):
                    continue

            # Reading data
            for ti, turn in enumerate(dial_dict["dialogue"]):
                turn_domain = turn["domain"]
                turn_id = turn["turn_idx"]
                turn_uttr = turn["system_transcript"] + " ; " + turn["transcript"]
                turn_uttr_strip = turn_uttr.strip()
                dialog_history += (turn["system_transcript"] + " ; " + turn["transcript"] + " ; ")
                source_text = dialog_history.strip()
                turn_belief_dict = fix_general_label_error(turn["belief_state"], False, SLOTS)

                # Generate domain-dependent slot list
                slot_temp = SLOTS
                if dataset == "train" or dataset == "dev":
                    if args.except_domain != "":
                        slot_temp = [k for k in SLOTS if args.except_domain not in k]
                        turn_belief_dict = OrderedDict(
                            [(k, v) for k, v in turn_belief_dict.items() if args.except_domain not in k])
                    elif args.only_domain != "":
                        slot_temp = [k for k in SLOTS if args.only_domain in k]
                        turn_belief_dict = OrderedDict(
                            [(k, v) for k, v in turn_belief_dict.items() if args.only_domain in k])
                else:
                    if args.except_domain != "":
                        slot_temp = [k for k in SLOTS if args.except_domain in k]
                        turn_belief_dict = OrderedDict(
                            [(k, v) for k, v in turn_belief_dict.items() if args.except_domain in k])
                    elif args.only_domain != "":
                        slot_temp = [k for k in SLOTS if args.only_domain in k]
                        turn_belief_dict = OrderedDict(
                            [(k, v) for k, v in turn_belief_dict.items() if args.only_domain in k])

                turn_belief_list = [str(k) + '-' + str(v) for k, v in turn_belief_dict.items()]

                if (args.all_vocab or dataset == "train") and training:
                    mem_lang.index_words(turn_belief_dict, 'belief')

                class_label, generate_y, slot_mask, gating_label = [], [], [], []
                start_ptr_label, end_ptr_label = [], []
                for slot in slot_temp:
                    if slot in turn_belief_dict.keys():
                        generate_y.append(turn_belief_dict[slot])

                        if turn_belief_dict[slot] == "dontcare":
                            gating_label.append(gating_dict["dontcare"])
                        elif turn_belief_dict[slot] == "none":
                            gating_label.append(gating_dict["none"])
                        else:
                            gating_label.append(gating_dict["ptr"])

                        if max_value_len < len(turn_belief_dict[slot]):
                            max_value_len = len(turn_belief_dict[slot])

                    else:
                        generate_y.append("none")
                        gating_label.append(gating_dict["none"])

                data_detail = {
                    "ID": dial_dict["dialogue_idx"],
                    "domains": dial_dict["domains"],
                    "turn_domain": turn_domain,
                    "turn_id": turn_id,
                    "dialog_history": source_text,
                    "turn_belief": turn_belief_list,
                    "gating_label": gating_label,
                    "turn_uttr": turn_uttr_strip,
                    'generate_y': generate_y
                }
                data.append(data_detail)

                if max_resp_len < len(source_text.split()):
                    max_resp_len = len(source_text.split())

            cnt_lin += 1
            if max_line and cnt_lin >= max_line:
                break

    # add t{} to the lang file
    if "t{}".format(max_value_len - 1) not in mem_lang.word2index.keys() and training:
        for time_i in range(max_value_len):
            mem_lang.index_words("t{}".format(time_i), 'utter')

    print("domain_counter", domain_counter)
    return data, max_resp_len, slot_temp
Exemplo n.º 6
0
def get_data(file_data):
    print(("Reading from {}".format(file_data)))
    with open(file_data) as f:
        dials = json.load(f)  # len(dials) = 8420

    data = []  # data是全部dialogue的全部turn,逐turn封装的训练数据。
    max_value_len, max_input_len, max_value, max_input = 0, 0, '', ''
    dials = tqdm(enumerate(dials), total=len(dials))  # 13746
    for i, dial_dict in dials:
        #for dial_dict in dials:
        '''
        for domain in dial_dict["domains"]: # 放在外循环内!
            if domain not in EXPERIMENT_DOMAINS:
                continue
        '''
        if not set(dial_dict["domains"]) < set(EXPERIMENT_DOMAINS):
            continue
        #dialog_history = ''
        previous_utterances = A_token + " " + U_token
        previous_generate_y = [NULL_token] * len(ALL_SLOTS)
        previous_belief = _convert_bs({}).strip()
        previous_gating_label = [gating_dict["carryover"]] * len(ALL_SLOTS)
        # 整个对话的开始之前的state B_0,has only NULL as the value of all slots
        # 所以B_1对应的gate是Carryover
        #previous_generate_y = ['' for s in ALL_SLOTS]

        for ti, turn in enumerate(dial_dict["dialogue"]):
            # 1 基本
            turn_domain = turn["domain"]
            turn_id = turn["turn_idx"]

            turn_belief_dict = fix_general_label_error(
                turn["belief_state"], False, ALL_SLOTS)  # 修正错误,取value

            # Generate domain-dependent slot list(target)
            if turn_belief_dict:
                turn_belief_list = [
                    str(k) + '-' + fix(v) for k, v in turn_belief_dict.items()
                    if v != "none"
                ]
                turn_belief_dict = {
                    k: fix(v)
                    for k, v in turn_belief_dict.items()
                }
            else:
                turn_belief_list = []
            # 2 生成自定义label
            # 2.1 生成generate_y:按ALL_SLOTS排列的value列表。
            if turn_belief_dict:
                generate_y = []
                for slot in ALL_SLOTS:
                    if slot in turn_belief_dict.keys():
                        value = turn_belief_dict[slot]
                        value = fix(value)
                        value_seq = tokenizer.tokenize(turn_belief_dict[slot])
                        if max_value_len < len(value_seq):
                            max_value_len = len(value_seq)
                            max_value = value_seq
                    else:
                        value = NULL_token
                    generate_y.append(value)
                #print('generate_y:\n',generate_y)
            else:
                generate_y = [NULL_token] * len(ALL_SLOTS)
            # 2.2 用相邻轮的generate_y生成gate_label:按ALL_SLOTS排列的operation列表
            gating_label = []
            for sidx in range(len(ALL_SLOTS)):
                p_v = previous_generate_y[sidx]
                n_v = generate_y[sidx]
                if n_v == p_v:
                    operation = gating_dict["carryover"]
                # yes/no/dontcare
                elif n_v in list(sub_gating_dict.keys()):
                    operation = gating_dict["confirm"]
                else:
                    operation = gating_dict["update"]
                gating_label.append(operation)
            #print(gating_label)

            # 2.3 生成input_seq & previous_belief:
            #print('previous_utterances:\n', previous_utterances)
            #print('current_utterances:\n', current_utterances)
            # input = CLS_token + previous_utterances + SEP_token + current_utterances + previous_belief

            #tokenizer or split:用tokenizer
            '''
            三个seq:
               story:包括D_t-1和D_t,以及两个SEP
               previous_belief:包括B_t-1
               input_seq: story和previous_belief拼接,同时前加一个CLS,尾加一个SEP
            '''
            # 2.3.1 story
            current_utterances = A_token + " " + turn[
                "system_transcript"] + " " + U_token + " " + turn["transcript"]
            current_utterances = current_utterances.strip()
            story = previous_utterances + " " + SEP_token + " " + current_utterances + " " + SEP_token
            #story = tokenizer.tokenize(previous_utterances) + [SEP_token] + tokenizer.tokenize(current_utterances) + [SEP_token]

            # 2.3.2 previous_belief(应该是30个value都放入input吧,包括[NULL] value)
            current_belief = _convert_bs(turn_belief_dict).strip()

            # 2.3.3 input_seq
            input_seq = [CLS_token
                         ] + tokenizer.tokenize(story + " " +
                                                previous_belief) + [EOS_token]

            if max_input_len < len(input_seq):
                max_input_len = len(input_seq)
                max_input = input_seq

            # 2.3.4 三元组
            #SEP_indices = tuple([i for i, x in enumerate(input_seq) if x == SEP_token])
            #assert len(SEP_indices) == 2
        # 2.4 生成domain_focus:0表示领域完全地不关注;1表示领域中存在至少一个slot是非carryover的
        # domain_focus
            domain_focus = get_domain_focus(generate_y, gating_label)

            if use2turn:
                pre_domain_focus = get_domain_focus(previous_generate_y,
                                                    previous_gating_label)
                domain_focus = list(
                    map(lambda x: (x[0] | x[1]),
                        zip(domain_focus, pre_domain_focus)))

            # 3. 用一个容器装起来:
            data_detail = {
                "ID": dial_dict["dialogue_idx"],
                "domains": dial_dict["domains"],  # 这个turn所属的对话共涉及哪些领域
                "turn_domain": turn_domain,
                "turn_id": turn_id,
                "turn_belief": turn_belief_list,
                'previous_generate_y': previous_generate_y,
                'generate_y': generate_y,
                "gating_label": gating_label,
                "previous_utterances": previous_utterances,
                "current_utterances": current_utterances,
                'domain_focus': domain_focus,
                #'story':story, # D_{t-1}+Dt. 仅当前轮和上一轮的对话原文 (str)
                #'previous_belief':previous_belief, # B_t (str)
                'input_seq': input_seq,  # D_{t-1}+Dt+Bt, 已经分词 (list)
                #'SEP_indices':SEP_indices, # Dt-1,Dt,Bt-1的结束位置 (tuple)
            }
            data.append(data_detail)

            previous_utterances, previous_generate_y, previous_belief, previous_gating_label = current_utterances, generate_y, current_belief, gating_label
            #previous_turn_domain = turn_domain]
    #print('max_value:\n', max_value) #  london liverpool street=23(char级别),3(tk级别)
    #print(max_value_len)
    #break
    return data, (max_input_len, max_input), (max_value_len, max_value)
Exemplo n.º 7
0
def transfer_data(original_data, from_domain, to_domain):
    new_data = []

    transfer_phrases_from = TRANSFER_PHRASES[from_domain]
    transfer_phrases_to = TRANSFER_PHRASES[to_domain]
    assert len(transfer_phrases_from) > 0
    assert len(transfer_phrases_to) > 0

    for dialogue in original_data:
        new_data.append(dialogue)
        if not from_domain in dialogue['domains']:
            continue

        new_dialogue = copy.deepcopy(dialogue)

        transfer_replace_bag = ReplaceBag()
        for phrase in transfer_phrases_from:
            transfer_replace_bag.add(
                phrase.split(' '),
                random.choice(transfer_phrases_to).split(' '))

        good_dialogue = True
        for turn in new_dialogue['dialogue']:
            turn_idx = int(turn['turn_idx'])

            if turn_idx > 0:
                turn['original_system_transcript'] = turn['system_transcript']
                turn['system_transcript'] = ' '.join(
                    apply_replacement(turn['system_transcript'],
                                      transfer_replace_bag))
            turn['original_transcript'] = turn['transcript']
            turn['transcript'] = ' '.join(
                apply_replacement(turn['transcript'], transfer_replace_bag))
            found_transfer_phrase = transfer_replace_bag.used > 0

            label_dict = fix_general_label_error(turn['belief_state'], False,
                                                 ALL_SLOTS)
            label_dict = remove_none_slots(label_dict)

            found_transfer_slot = False
            found_bad_slot = False
            new_label_dict = dict()
            for slot_key, slot_value in label_dict.items():
                domain, slot_name = slot_key.split('-', maxsplit=1)

                # we have removed all "to_domain" data so if we see this it's a mislabel, ignore as bad
                if domain == to_domain:
                    good_dialogue = False
                    break
                if domain != from_domain:
                    new_label_dict[slot_key] = slot_value
                    continue

                found_transfer_slot = True
                new_slot_key = to_domain + '-' + slot_name
                if new_slot_key not in ALL_SLOTS:
                    found_bad_slot = True
                    break
                new_label_dict[new_slot_key] = slot_value

            turn['belief_state'] = belief_to_json(new_label_dict)
            turn['turn_label'] = [
                (slot_key, slot_value)
                for slot_key, slot_value in new_label_dict.items()
            ]
            turn['domain'] = to_domain if turn[
                'domain'] == from_domain else turn['domain']

            if found_bad_slot or (found_transfer_slot
                                  and not found_transfer_phrase):
                good_dialogue = False
                break

        if good_dialogue:
            new_dialogue['dialogue_idx'] = new_dialogue[
                'dialogue_idx'] + '/' + from_domain + '->' + to_domain
            new_dialogue['domains'] = [
                x for x in new_dialogue['domains'] if x != from_domain
            ] + [to_domain]

            # replace the values in the new dialogue with values that make sense for the domain
            augmenter = Augmenter(only_domain=to_domain)
            augmenter.augment(new_dialogue['dialogue'])

            new_data.append(new_dialogue)

            #print(json.dumps(dialogue['dialogue'], indent=2))
            #sys.exit(0)

    return new_data
Exemplo n.º 8
0
    dials_v2 = []
    for dial_dict in dials:
        dial_domains = dial_dict["domains"]
        prev_turn_state = {}
        for slot in slot_meta:
            prev_turn_state[slot] = "none"

        for ti, turn in enumerate(dial_dict["dialogue"]):
            dial_dict["dialogue"][ti]["system_transcript"] = normalize_text(
                turn["system_transcript"])
            dial_dict["dialogue"][ti]["transcript"] = normalize_text(
                turn["transcript"])

            # state
            turn_dialog_state = fix_general_label_error(
                turn["belief_state"], False, slot_meta)
            for slot in slot_meta:
                if slot not in turn_dialog_state or slot.split(
                        '-')[0] not in dial_domains:
                    turn_dialog_state[slot] = "none"
                else:
                    turn_dialog_state[slot] = normalize_label(
                        slot, turn_dialog_state[slot])

                if turn_dialog_state[slot] == "dontcare":
                    turn_dialog_state[slot] = "do not care"

                ontology_modified[slot].append(turn_dialog_state[slot])

            dial_dict["dialogue"][ti]["belief_state"] = []