示例#1
0
def test_generate_overlap(total_num=1000, seed=42, output_file='goal.json'):
    train_data = read_zipped_json('../../../data/multiwoz/train.json.zip',
                                  'train.json')
    train_serialized_goals = []
    for d in train_data:
        train_serialized_goals.append(
            extract_slot_combination_from_goal(train_data[d]['goal']))

    test_data = read_zipped_json('../../../data/multiwoz/test.json.zip',
                                 'test.json')
    test_serialized_goals = []
    for d in test_data:
        test_serialized_goals.append(
            extract_slot_combination_from_goal(test_data[d]['goal']))

    overlap = 0
    for serialized_goal in test_serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(test_serialized_goals),
          overlap)  # 8434 1000 430

    random.seed(seed)
    np.random.seed(seed)
    goal_generator = GoalGenerator()
    goals = []
    avg_domains = []
    serialized_goals = []
    while len(goals) < total_num:
        goal = goal_generator.get_user_goal()
        # pprint(goal)
        if 'police' in goal['domain_ordering']:
            no_police = list(goal['domain_ordering'])
            no_police.remove('police')
            goal['domain_ordering'] = tuple(no_police)
            del goal['police']
        try:
            message = goal_generator.build_message(goal)[1]
        except:
            continue
        # print(message)
        avg_domains.append(len(goal['domain_ordering']))
        goals.append({
            "goals": [],
            "ori_goals": goal,
            "description": message,
            "timestamp": str(datetime.datetime.now()),
            "ID": len(goals)
        })
        serialized_goals.append(extract_slot_combination_from_goal(goal))
        if len(serialized_goals) == 1:
            print(serialized_goals)
    overlap = 0
    for serialized_goal in serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(serialized_goals),
          overlap)  # 8434 1000 199
示例#2
0
    def org_data_loader(self):
        if self.__org_goals is None or self.__org_usr_dass is None or self.__org_sys_dass is None:
            zip_path = os.path.join(
                os.path.dirname(
                    os.path.dirname(
                        os.path.dirname(
                            os.path.dirname(
                                os.path.dirname(os.path.abspath(__file__)))))),
                'data/multiwoz/train.json.zip')
            file_path = 'train.json'
            full_data = read_zipped_json(zip_path, file_path)
            goals = []
            usr_dass = []
            sys_dass = []
            for session in full_data.values():
                goal = session.get('goal', {})
                logs = session.get('log', [])
                usr_das, sys_das = [], []
                for turn in range(len(logs) // 2):
                    # <class 'dict'>: {'Hotel-Inform': [['Price', 'cheap'], ['Type', 'hotel']]}
                    usr_da = self.ref_data2stand(logs[turn * 2].get(
                        'dialog_act', {}))
                    sys_da = self.ref_data2stand(logs[turn * 2 + 1].get(
                        'dialog_act', {}))
                    usr_das.append(usr_da)
                    sys_das.append(sys_da)
                    if len(usr_das[-1]) <= 0 or len(sys_das[-1]) <= 0:
                        break
                else:
                    goals.append(goal)
                    usr_dass.append(usr_das)
                    sys_dass.append(sys_das)

            self.__org_goals = [
                UserDataManager.usrgoal2seq(goal) for goal in goals
            ]
            self.__org_usr_dass = [[
                UserDataManager.usrda2seq(usr_da, goal) for usr_da in usr_das
            ] for (usr_das, goal) in zip(usr_dass, goals)]
            self.__org_sys_dass = [[
                UserDataManager.sysda2seq(sys_da, goal) for sys_da in sys_das
            ] for (sys_das, goal) in zip(sys_dass, goals)]

        return self.__org_goals, self.__org_usr_dass, self.__org_sys_dass
    def load_data(self,
                  data_dir=None,
                  data_key='all',
                  role='all',
                  utterance=False,
                  dialog_act=False,
                  context=False,
                  context_window_size=0,
                  context_dialog_act=False,
                  user_state=False,
                  sys_state=False,
                  sys_state_init=False,
                  last_opponent_utterance=False,
                  last_self_utterance=False,
                  session_id=False,
                  terminated=False,
                  goal=False,
                  final_goal=False,
                  task_description=False):
        if data_dir is None:
            data_dir = os.path.join(DATA_ROOT,
                                    'crosswoz' + ('_en' if self.en else ''))

        def da2tuples(dialog_act):
            tuples = []
            for act in dialog_act:
                tuples.append([act[0], act[1], act[2], act[3]])
            return tuples

        assert role in ['sys', 'usr', 'all']
        info_list = list(
            filter(eval, [
                'utterance', 'dialog_act', 'context', 'context_dialog_act',
                'user_state', 'sys_state', 'sys_state_init',
                'last_opponent_utterance', 'last_self_utterance', 'session_id',
                'terminated', 'goal', 'final_goal', 'task_description'
            ]))
        self.data = {
            'train': {},
            'val': {},
            'test': {},
            'role': role,
            'human_val': {}
        }
        if data_key == 'all':
            data_key_list = ['train', 'val', 'test']
        else:
            data_key_list = [data_key]
        for data_key in data_key_list:
            data = read_zipped_json(
                os.path.join(data_dir, '{}.json.zip'.format(data_key)),
                '{}.json'.format(data_key))
            print('loaded {}, size {}'.format(data_key, len(data)))
            for x in info_list:
                self.data[data_key][x] = []
            for sess_id, sess in data.items():
                cur_context = []
                cur_context_dialog_act = []
                for i, turn in enumerate(sess['messages']):
                    text = turn['content']
                    da = da2tuples(turn['dialog_act'])
                    if {role, turn['role']} == {'usr', 'sys'}:
                        cur_context.append(text)
                        cur_context_dialog_act.append(da)
                        continue
                    if utterance:
                        self.data[data_key]['utterance'].append(text)
                    if dialog_act:
                        self.data[data_key]['dialog_act'].append(da)
                    if context and context_window_size:
                        self.data[data_key]['context'].append(
                            cur_context[-context_window_size:])
                    if context_dialog_act and context_window_size:
                        self.data[data_key]['context_dialog_act'].append(
                            cur_context_dialog_act[-context_window_size:])
                    if role in ['usr', 'all'
                                ] and user_state and turn['role'] == 'usr':
                        self.data[data_key]['user_state'].append(
                            turn['user_state'])
                    if role in ['sys', 'all'
                                ] and sys_state and turn['role'] == 'sys':
                        self.data[data_key]['sys_state'].append(
                            turn['sys_state'])
                    if role in ['sys', 'all'] and sys_state_init:
                        if turn['role'] == 'sys':
                            self.data[data_key]['sys_state_init'].append(
                                turn['sys_state_init'])
                        else:
                            self.data[data_key]['sys_state_init'].append({})
                    if last_opponent_utterance:
                        self.data[data_key]['last_opponent_utterance'].append(
                            cur_context[-1] if len(cur_context) >= 1 else '')
                    if last_self_utterance:
                        self.data[data_key]['last_self_utterance'].append(
                            cur_context[-2] if len(cur_context) >= 2 else '')
                    if session_id:
                        self.data[data_key]['session_id'].append(sess_id)
                    if terminated:
                        self.data[data_key]['terminated'].append(
                            i + 2 >= len(sess['messages']))
                    if goal:
                        self.data[data_key]['goal'].append(sess['goal'])
                    if final_goal:
                        self.data[data_key]['final_goal'].append(
                            sess['final_goal'])
                    if task_description:
                        self.data[data_key]['task_description'].append(
                            sess['task description'])
                    cur_context.append(text)
                    cur_context_dialog_act.append(da)

        return self.data
    def load_data(self,
                  data_dir=os.path.abspath(
                      os.path.join(os.path.abspath(__file__),
                                   '../../../../data/camrest')),
                  data_key='all',
                  role='all',
                  utterance=False,
                  dialog_act=False,
                  context=False,
                  context_window_size=0,
                  context_dialog_act=False,
                  last_opponent_utterance=False,
                  last_self_utterance=False,
                  session_id=False,
                  terminated=False,
                  goal=False):
        def da2tuples(dialog_act):
            tuples = []
            for intent, svs in dialog_act.items():
                for slot, value in sorted(svs, key=lambda x: x[0]):
                    tuples.append([intent, slot, value])
            return tuples

        assert role in ['sys', 'usr', 'all']
        info_list = list(
            filter(eval, [
                'utterance', 'dialog_act', 'context', 'context_dialog_act',
                'last_opponent_utterance', 'last_self_utterance', 'session_id',
                'terminated', 'goal'
            ]))
        self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role}
        if data_key == 'all':
            data_key_list = ['train', 'val', 'test']
        else:
            data_key_list = [data_key]
        for data_key in data_key_list:
            data = read_zipped_json(
                os.path.join(data_dir, '{}.json.zip'.format(data_key)),
                '{}.json'.format(data_key))
            print('loaded {}, size {}'.format(data_key, len(data)))
            for x in info_list:
                self.data[data_key][x] = []
            for sess in data:
                cur_context = []
                cur_context_dialog_act = []
                for turn in sess['dial']:
                    turn_id = turn['turn']
                    for side_id in ['usr', 'sys']:
                        if side_id == 'usr':
                            text = turn[side_id]['transcript']
                        else:
                            text = turn[side_id]['sent']
                        da = da2tuples(turn[side_id]['dialog_act'])
                        if {role, side_id} == {'usr', 'sys'}:
                            cur_context.append(text)
                            cur_context_dialog_act.append(da)
                            continue
                        if utterance:
                            self.data[data_key]['utterance'].append(text)
                        if dialog_act:
                            self.data[data_key]['dialog_act'].append(da)
                        if context and context_window_size:
                            self.data[data_key]['context'].append(
                                cur_context[-context_window_size:])
                        if context_dialog_act and context_window_size:
                            self.data[data_key]['context_dialog_act'].append(
                                cur_context_dialog_act[-context_window_size:])
                        if last_opponent_utterance:
                            self.data[data_key][
                                'last_opponent_utterance'].append(
                                    cur_context[-1]
                                    if len(cur_context) >= 1 else '')
                        if last_self_utterance:
                            self.data[data_key]['last_self_utterance'].append(
                                cur_context[-2]
                                if len(cur_context) >= 2 else '')
                        if session_id:
                            self.data[data_key]['session_id'].append(
                                sess['dialogue_id'])
                        if terminated:
                            self.data[data_key]['terminated'].append(
                                turn_id >= len(sess['dial']))
                        if goal:
                            self.data[data_key]['goal'].append(sess['goal'])
                        cur_context.append(text)
                        cur_context_dialog_act.append(da)

        return self.data
示例#5
0
def preprocess():
    original_zipped_path = os.path.join(self_dir, 'original_data.zip')
    if not os.path.exists(original_zipped_path):
        raise FileNotFoundError(original_zipped_path)
    if not os.path.exists(os.path.join(
            self_dir, 'data.zip')) or not os.path.exists(
                os.path.join(self_dir, 'ontology.json')):
        # print('unzip to', new_dir)
        # print('This may take several minutes')
        archive = zipfile.ZipFile(original_zipped_path, 'r')
        archive.extractall(self_dir)

    all_data = []
    all_intent = []
    all_binary_das = []
    all_state_slots = ['pricerange', 'area', 'food']

    data_splits = ['train', 'val', 'test']
    extract_dir = os.path.join(self_dir, 'original_data')

    if not os.path.exists('data.zip') or not os.path.exists('ontology.json'):

        dialog_id = 1
        for data_split in data_splits:
            data = json.load(
                open(
                    os.path.join(self_dir, extract_dir,
                                 '{}.json'.format(data_split))))

            for i, d in enumerate(data):

                dialogue = d['dial']
                converted_dialogue = {
                    'dataset': 'camrest',
                    'data_split': data_split,
                    'dialogue_id': 'camrest_' + str(dialog_id),
                    'original_id': d['dialogue_id'],
                    'domains': ['restaurant'],
                    'turns': []
                }

                prev_state = {'restaurant': {}}
                for turn in dialogue:
                    usr_text = turn['usr']['transcript'].lower()
                    usr_da = turn['usr']['dialog_act']

                    sys_text = turn['sys']['sent'].lower()
                    sys_da = turn['sys']['dialog_act']

                    cur_state = convert_state(turn['usr']['slu'],
                                              all_state_slots)
                    cur_user_da = convert_da(usr_text, usr_da, all_intent,
                                             all_binary_das)

                    usr_turn = {
                        'utt_idx':
                        len(converted_dialogue['turns']),
                        'speaker':
                        'user',
                        'utterance':
                        usr_text,
                        'dialogue_act':
                        cur_user_da,
                        'state':
                        copy.deepcopy(cur_state),
                        'state_update':
                        get_state_update(prev_state['restaurant'],
                                         cur_state['restaurant'],
                                         converted_dialogue['turns'],
                                         cur_user_da['non-categorical'],
                                         converted_dialogue['dialogue_id'])
                    }

                    sys_turn = {
                        'utt_idx':
                        len(converted_dialogue['turns']) + 1,
                        'speaker':
                        'system',
                        'utterance':
                        sys_text,
                        'dialogue_act':
                        convert_da(sys_text, sys_da, all_intent,
                                   all_binary_das),
                    }

                    prev_state = copy.deepcopy(cur_state)

                    converted_dialogue['turns'].append(usr_turn)
                    converted_dialogue['turns'].append(sys_turn)
                if converted_dialogue['turns'][-1]['speaker'] == 'system':
                    converted_dialogue['turns'].pop(-1)
                all_data.append(converted_dialogue)
                dialog_id += 1

        json.dump(all_data, open('./data.json', 'w'), indent=4)
        write_zipped_json(os.path.join(self_dir, 'data.zip'), 'data.json')
        os.remove('data.json')

        new_ont = {
            'domains': {},
            'intents': {},
            'binary_dialogue_act': [],
            'state': {}
        }

        new_ont['state']['restaurant'] = {}
        for ss in all_state_slots:
            new_ont['state']['restaurant'][ss] = ''

        for b in all_binary_das:
            new_ont['binary_dialogue_act'].append(b)

        for i in all_intent:
            new_ont['intents'][i] = {'description': camrest_desc['intents'][i]}

        new_ont['domains']['restaurant'] = {
            'description': camrest_desc['restaurant']['domain'],
            'slots': {}
        }
        for s in all_slots:
            new_ont['domains']['restaurant']['slots'][s] = {
                "description":
                camrest_desc['restaurant'][s],
                "is_categorical":
                True if s in cat_slot_values else False,
                "possible_values":
                cat_slot_values[s] if s in cat_slot_values else []
            }
        json.dump(new_ont,
                  open(os.path.join(self_dir, './ontology.json'), 'w'),
                  indent=4)

    else:
        all_data = read_zipped_json(os.path.join(self_dir, './data.zip'),
                                    'data.json')
        new_ont = json.load(
            open(os.path.join(self_dir, './ontology.json'), 'r'))

    return all_data, new_ont
示例#6
0
def preprocess():
    processed_dialogue = []
    ontology = {
        'domains': {},
        'intents': {},
        'binary_dialogue_act': [],
        'state': {}
    }
    ontology['intents'].update(get_intent())
    numerical_slots = {}
    original_zipped_path = os.path.join(self_dir, 'original_data.zip')
    new_dir = os.path.join(self_dir, 'original_data')
    if not os.path.exists(original_zipped_path):
        raise FileNotFoundError(original_zipped_path)
    if not os.path.exists(os.path.join(
            self_dir, 'data.zip')) or not os.path.exists(
                os.path.join(self_dir, 'ontology.json')):
        print('unzip to', new_dir)
        print('This may take several minutes')
        archive = zipfile.ZipFile(original_zipped_path, 'r')
        archive.extractall(self_dir)
        cnt = 1
        non_cate_slot_update_cnt = 0
        non_cate_slot_update_fail_cnt = 0
        state_cnt = {}
        num_train_dialog = 0
        num_train_utt = 0
        for data_split in ['train', 'dev', 'test']:
            dataset_name = 'schema'
            data_dir = os.path.join(new_dir, data_split)
            # schema => ontology
            f = open(os.path.join(data_dir, 'schema.json'))
            data = json.load(f)
            for schema in data:
                domain = service2domain(schema['service_name'])
                ontology['domains'].setdefault(domain, {})
                ontology['domains'][domain]['description'] = schema[
                    'description']
                ontology['domains'][domain].setdefault('slots', {})
                ontology['state'].setdefault(domain, {})
                for slot in schema['slots']:
                    # numerical => non-categorical: not use
                    # is_numerical = slot['is_categorical']
                    # for value in slot['possible_values']:
                    #     if not value.isdigit():
                    #         is_numerical = False
                    #         break
                    # if is_numerical:
                    #     numerical_slots.setdefault(slot['name'].lower(), 1)
                    lower_values = [x.lower() for x in slot['possible_values']]
                    ontology['domains'][domain]['slots'][
                        slot['name'].lower()] = {
                            "description": slot['description'],
                            "is_categorical": slot['is_categorical'],
                            "possible_values": lower_values
                        }
                    ontology['state'][domain][slot['name'].lower()] = ''
                # add 'count' slot
                ontology['domains'][domain]['slots']['count'] = {
                    "description":
                    "the number of items found that satisfy the user's request.",
                    "is_categorical": False,
                    "possible_values": []
                }
                # ontology['state'][domain]['count'] = ''
            # pprint(numerical_slots)
            # dialog
            for root, dirs, files in os.walk(data_dir):
                fs = sorted([x for x in files if 'dialogues' in x])
                for f in tqdm(
                        fs,
                        desc='processing schema-guided-{}'.format(data_split)):
                    data = json.load(open(os.path.join(data_dir, f)))
                    if data_split == 'train':
                        num_train_dialog += len(data)
                    for d in data:
                        dialogue = {
                            "dataset": dataset_name,
                            "data_split":
                            data_split if data_split != 'dev' else 'val',
                            "dialogue_id": dataset_name + '_' + str(cnt),
                            "original_id": d['dialogue_id'],
                            "domains":
                            [service2domain(s) for s in d['services']],
                            "turns": []
                        }
                        # if d['dialogue_id'] != '84_00008':
                        #     continue
                        cnt += 1
                        prev_sys_frames = []
                        prev_user_frames = []
                        all_slot_spans_from_da = []
                        state = {}
                        for domain in dialogue['domains']:
                            state.setdefault(
                                domain, deepcopy(ontology['state'][domain]))
                        if data_split == 'train':
                            num_train_utt += len(d['turns'])
                        for utt_idx, t in enumerate(d['turns']):
                            speaker = t['speaker'].lower()
                            turn = {
                                'speaker': speaker,
                                'utterance': t['utterance'],
                                'utt_idx': utt_idx,
                                'dialogue_act': {
                                    'binary': [],
                                    'categorical': [],
                                    'non-categorical': [],
                                },
                            }
                            for i, frame in enumerate(t['frames']):
                                domain = service2domain(frame['service'])
                                for action in frame['actions']:
                                    intent = action['act'].lower()
                                    assert intent in ontology['intents'], [
                                        intent
                                    ]
                                    slot = action['slot'].lower()
                                    value_list = action['values']
                                    if action['act'] in [
                                            'REQ_MORE', 'AFFIRM', 'NEGATE',
                                            'THANK_YOU', 'GOODBYE'
                                    ]:
                                        turn['dialogue_act']['binary'].append({
                                            "intent":
                                            intent,
                                            "domain":
                                            '',
                                            "slot":
                                            '',
                                            "value":
                                            '',
                                        })
                                    elif action['act'] in [
                                            'NOTIFY_SUCCESS', 'NOTIFY_FAILURE',
                                            'REQUEST_ALTS', 'AFFIRM_INTENT',
                                            'NEGATE_INTENT'
                                    ]:
                                        # Slot and values are always empty
                                        turn['dialogue_act']['binary'].append({
                                            "intent":
                                            intent,
                                            "domain":
                                            domain,
                                            "slot":
                                            '',
                                            "value":
                                            '',
                                        })
                                    elif action['act'] in [
                                            'OFFER_INTENT', 'INFORM_INTENT'
                                    ]:
                                        # always has "intent" as the slot, and a single value containing the intent being offered.
                                        assert slot == 'intent'
                                        turn['dialogue_act']['binary'].append({
                                            "intent":
                                            intent,
                                            "domain":
                                            domain,
                                            "slot":
                                            slot,
                                            "value":
                                            value_list[0].lower(),
                                        })
                                    elif action['act'] in [
                                            'REQUEST', 'SELECT'
                                    ] and not value_list:
                                        # always contains a slot, but values are optional.
                                        # assert slot in ontology['domains'][domain]['slots']
                                        turn['dialogue_act']['binary'].append({
                                            "intent":
                                            intent,
                                            "domain":
                                            domain,
                                            "slot":
                                            slot,
                                            "value":
                                            '',
                                        })
                                    elif action['act'] in ['INFORM_COUNT']:
                                        # always has "count" as the slot, and a single element in values for the number of results obtained by the system.
                                        value = value_list[0]
                                        assert slot in ontology['domains'][
                                            domain]['slots']
                                        (start, end), num = pharse_in_sen(
                                            value, t['utterance'])
                                        if num:
                                            assert value.lower() == t['utterance'][start:end].lower() \
                                                   or digit2word[value].lower() == t['utterance'][start:end].lower()
                                            turn['dialogue_act'][
                                                'non-categorical'].append({
                                                    "intent":
                                                    intent,
                                                    "domain":
                                                    domain,
                                                    "slot":
                                                    slot.lower(),
                                                    "value":
                                                    t['utterance']
                                                    [start:end].lower(),
                                                    "start":
                                                    start,
                                                    "end":
                                                    end
                                                })
                                    else:
                                        # have slot & value
                                        if ontology['domains'][domain][
                                                'slots'][slot][
                                                    'is_categorical']:
                                            for value in value_list:
                                                value = value.lower()
                                                if value not in ontology['domains'][
                                                        domain]['slots'][slot][
                                                            'possible_values'] and value != 'dontcare':
                                                    ontology['domains'][
                                                        domain]['slots'][slot][
                                                            'possible_values'].append(
                                                                value)
                                                    print(
                                                        'add value to ontology',
                                                        domain, slot, value)
                                                assert value in ontology['domains'][
                                                    domain]['slots'][slot][
                                                        'possible_values'] or value == 'dontcare'
                                                turn['dialogue_act'][
                                                    'categorical'].append({
                                                        "intent":
                                                        intent,
                                                        "domain":
                                                        domain,
                                                        "slot":
                                                        slot,
                                                        "value":
                                                        value,
                                                    })
                                        elif slot in numerical_slots:
                                            value = value_list[-1]
                                            (start, end), num = pharse_in_sen(
                                                value, t['utterance'])
                                            if num:
                                                assert value.lower() == t['utterance'][start:end].lower() \
                                                       or digit2word[value].lower() == t['utterance'][start:end].lower()
                                                turn['dialogue_act'][
                                                    'non-categorical'].append({
                                                        "intent":
                                                        intent,
                                                        "domain":
                                                        domain,
                                                        "slot":
                                                        slot.lower(),
                                                        "value":
                                                        t['utterance']
                                                        [start:end].lower(),
                                                        "start":
                                                        start,
                                                        "end":
                                                        end
                                                    })
                                        else:
                                            # span info in frame['slots']
                                            for value in value_list:
                                                for slot_info in frame[
                                                        'slots']:
                                                    start = slot_info['start']
                                                    end = slot_info[
                                                        'exclusive_end']
                                                    if slot_info[
                                                            'slot'] == slot and t[
                                                                'utterance'][
                                                                    start:
                                                                    end] == value:
                                                        turn['dialogue_act'][
                                                            'non-categorical'].append(
                                                                {
                                                                    "intent":
                                                                    intent,
                                                                    "domain":
                                                                    domain,
                                                                    "slot":
                                                                    slot,
                                                                    "value":
                                                                    value.
                                                                    lower(),
                                                                    "start":
                                                                    start,
                                                                    "end":
                                                                    end
                                                                })
                                                        break
                            # add span da to all_slot_spans_from_da
                            for ele in turn['dialogue_act']['non-categorical']:
                                all_slot_spans_from_da.insert(
                                    0, {
                                        "domain": ele["domain"],
                                        "slot": ele["slot"],
                                        "value": ele["value"].lower(),
                                        "utt_idx": utt_idx,
                                        "start": ele["start"],
                                        "end": ele["end"]
                                    })
                            if speaker == 'user':
                                # DONE: record state update, may come from sys acts
                                # prev_state: state. update the state using current frames.
                                # candidate span info from prev frames and current frames
                                slot_spans = []
                                for frame in t['frames']:
                                    for ele in frame['slots']:
                                        slot, start, end = ele['slot'].lower(
                                        ), ele['start'], ele['exclusive_end']
                                        slot_spans.append({
                                            "domain":
                                            service2domain(frame['service']),
                                            "slot":
                                            slot,
                                            "value":
                                            t['utterance'][start:end].lower(),
                                            "utt_idx":
                                            utt_idx,
                                            "start":
                                            start,
                                            "end":
                                            end
                                        })
                                for frame in prev_sys_frames:
                                    for ele in frame['slots']:
                                        slot, start, end = ele['slot'].lower(
                                        ), ele['start'], ele['exclusive_end']
                                        slot_spans.append({
                                            "domain":
                                            service2domain(frame['service']),
                                            "slot":
                                            slot,
                                            "value":
                                            d['turns'][utt_idx - 1]
                                            ['utterance'][start:end].lower(),
                                            "utt_idx":
                                            utt_idx - 1,
                                            "start":
                                            start,
                                            "end":
                                            end
                                        })
                                # turn['slot_spans'] = slot_spans
                                # turn['all_slot_span'] = deepcopy(all_slot_spans_from_da)
                                state_update = {
                                    "categorical": [],
                                    "non-categorical": []
                                }
                                # print(utt_idx)
                                for frame in t['frames']:
                                    domain = service2domain(frame['service'])
                                    # print(domain)
                                    for slot, value_list in frame['state'][
                                            'slot_values'].items():
                                        # For categorical slots, this list contains a single value assigned to the slot.
                                        # For non-categorical slots, all the values in this list are spoken variations
                                        # of each other and are equivalent (e.g, "6 pm", "six in the evening",
                                        # "evening at 6" etc.).
                                        numerical_equal_values = []
                                        if slot in numerical_slots:
                                            for value in value_list:
                                                if value in digit2word:
                                                    numerical_equal_values.append(
                                                        digit2word[value])
                                        value_list += numerical_equal_values
                                        assert len(value_list) > 0, print(
                                            slot, value_list)
                                        assert slot in state[domain]
                                        value_list = list(
                                            set([
                                                x.lower() for x in value_list
                                            ]))
                                        if state[domain][slot] in value_list:
                                            continue
                                        # new value
                                        candidate_values = value_list
                                        for prev_user_frame in prev_user_frames:
                                            prev_domain = service2domain(
                                                prev_user_frame['service'])
                                            if prev_domain == domain and slot in prev_user_frame[
                                                    'state']['slot_values']:
                                                prev_value_list = [
                                                    x.lower() for x in
                                                    prev_user_frame['state']
                                                    ['slot_values'][slot]
                                                ]
                                                candidate_values = list(
                                                    set(value_list) -
                                                    set(prev_value_list))
                                        assert state[domain][
                                            slot] not in candidate_values
                                        assert candidate_values

                                        if ontology['domains'][domain][
                                                'slots'][slot][
                                                    'is_categorical']:
                                            state_cnt.setdefault(
                                                'cate_slot_update', 0)
                                            state_cnt['cate_slot_update'] += 1
                                            value = candidate_values[0]
                                            state_update['categorical'].append(
                                                {
                                                    "domain": domain,
                                                    "slot": slot,
                                                    "value": value
                                                })
                                            state[domain][slot] = value
                                        else:
                                            state_cnt.setdefault(
                                                'non_cate_slot_update', 0)
                                            state_cnt[
                                                'non_cate_slot_update'] += 1
                                            span_priority = []
                                            slot_spans_len = len(slot_spans)
                                            all_slot_spans = slot_spans + all_slot_spans_from_da
                                            for span_idx, slot_span in enumerate(
                                                    all_slot_spans):
                                                priority = 0
                                                span_domain = slot_span[
                                                    'domain']
                                                span_slot = slot_span['slot']
                                                span_value = slot_span['value']
                                                if domain == span_domain:
                                                    priority += 1
                                                if slot == span_slot:
                                                    priority += 10
                                                if span_value in candidate_values:
                                                    priority += 100
                                                if span_idx + 1 <= slot_spans_len:
                                                    priority += 0.5
                                                span_priority.append(priority)
                                                if span_idx + 1 <= slot_spans_len:
                                                    # slot_spans not run out
                                                    if max(span_priority
                                                           ) >= 111.5:
                                                        break
                                                else:
                                                    # search in previous da
                                                    if max(span_priority
                                                           ) >= 111:
                                                        break
                                            if span_priority and max(
                                                    span_priority) >= 100:
                                                # {111.5: 114255,
                                                #  111: 29591,
                                                #  100: 15208,
                                                #  110: 2159,
                                                #  100.5: 642,
                                                #  110.5: 125,
                                                #  101: 24}
                                                max_priority = max(
                                                    span_priority)
                                                state_cnt.setdefault(
                                                    'max_priority', Counter())
                                                state_cnt['max_priority'][
                                                    max_priority] += 1
                                                span_idx = np.argmax(
                                                    span_priority)
                                                ele = all_slot_spans[span_idx]
                                                state_update[
                                                    'non-categorical'].append({
                                                        "domain":
                                                        domain,
                                                        "slot":
                                                        slot,
                                                        "value":
                                                        ele['value'],
                                                        "utt_idx":
                                                        ele["utt_idx"],
                                                        "start":
                                                        ele["start"],
                                                        "end":
                                                        ele["end"]
                                                    })
                                                state[domain][slot] = ele[
                                                    'value']
                                            else:
                                                # not found
                                                value = candidate_values[0]
                                                state_update[
                                                    'non-categorical'].append({
                                                        "domain":
                                                        domain,
                                                        "slot":
                                                        slot,
                                                        "value":
                                                        value
                                                    })
                                                state[domain][slot] = value
                                                # print(t['utterance'])
                                                non_cate_slot_update_fail_cnt += 1
                                            non_cate_slot_update_cnt += 1
                                turn['state'] = deepcopy(state)
                                turn['state_update'] = state_update
                                prev_user_frames = deepcopy(t['frames'])
                            else:
                                prev_sys_frames = deepcopy(t['frames'])

                            for da in turn['dialogue_act']['binary']:
                                if da not in ontology['binary_dialogue_act']:
                                    ontology['binary_dialogue_act'].append(
                                        deepcopy(da))
                            dialogue['turns'].append(deepcopy(turn))
                        assert len(dialogue['turns']) % 2 == 0
                        dialogue['turns'].pop()
                        processed_dialogue.append(dialogue)
                        # break
        # sort ontology binary
        pprint(state_cnt)
        ontology['binary_dialogue_act'] = sorted(
            ontology['binary_dialogue_act'], key=lambda x: x['intent'])
        json.dump(ontology,
                  open(os.path.join(self_dir, 'ontology.json'), 'w'),
                  indent=2)
        json.dump(processed_dialogue, open('data.json', 'w'), indent=2)
        write_zipped_json(os.path.join(self_dir, 'data.zip'), 'data.json')
        os.remove('data.json')
        print('# train dialog: {}, # train utterance: {}'.format(
            num_train_dialog, num_train_utt))
        print(non_cate_slot_update_fail_cnt,
              non_cate_slot_update_cnt)  # 395 162399

    else:
        # read from file
        processed_dialogue = read_zipped_json(
            os.path.join(self_dir, 'data.zip'), 'data.json')
        ontology = json.load(open(os.path.join(self_dir, 'ontology.json')))
    return processed_dialogue, ontology
示例#7
0
    def load_data(self,
                  data_dir=None,
                  data_key='all',
                  role='all',
                  utterance=False,
                  dialog_act=False,
                  context=False,
                  context_window_size=0,
                  context_dialog_act=False,
                  belief_state=False,
                  last_opponent_utterance=False,
                  last_self_utterance=False,
                  ontology=False,
                  session_id=False,
                  span_info=False,
                  terminated=False,
                  goal=False
                  ):
        if data_dir is None:
            data_dir = os.path.join(DATA_ROOT, 'multiwoz' + ('_zh' if self.zh else ''))

        def da2tuples(dialog_act):
            tuples = []
            for domain_intent, svs in dialog_act.items():
                for slot, value in sorted(svs, key=lambda x: x[0]):
                    domain, intent = domain_intent.split('-')
                    tuples.append([intent, domain, slot, value])
            return tuples

        assert role in ['sys', 'usr', 'all']
        info_list = list(filter(eval, ['utterance', 'dialog_act', 'context', 'context_dialog_act', 'belief_state',
                                       'last_opponent_utterance', 'last_self_utterance', 'session_id', 'span_info',
                                       'terminated', 'goal']))
        self.data = {'train': {}, 'val': {}, 'test': {}, 'role': role, 'human_val': {}}
        if data_key == 'all':
            data_key_list = ['train', 'val', 'test']
        else:
            data_key_list = [data_key]
        for data_key in data_key_list:
            data = read_zipped_json(os.path.join(data_dir, '{}.json.zip'.format(data_key)), '{}.json'.format(data_key))
            print('loaded {}, size {}'.format(data_key, len(data)))
            for x in info_list:
                self.data[data_key][x] = []
            for sess_id, sess in data.items():
                cur_context = []
                cur_context_dialog_act = []
                for i, turn in enumerate(sess['log']):
                    text = turn['text']
                    da = da2tuples(turn['dialog_act'])
                    if role == 'sys' and i % 2 == 0:
                        cur_context.append(text)
                        cur_context_dialog_act.append(da)
                        continue
                    elif role == 'usr' and i % 2 == 1:
                        cur_context.append(text)
                        cur_context_dialog_act.append(da)
                        continue
                    if utterance:
                        self.data[data_key]['utterance'].append(text)
                    if dialog_act:
                        self.data[data_key]['dialog_act'].append(da)
                    if context:
                        self.data[data_key]['context'].append(cur_context[-context_window_size:])
                    if context_dialog_act:
                        self.data[data_key]['context_dialog_act'].append(cur_context_dialog_act[-context_window_size:])
                    if belief_state:
                        self.data[data_key]['belief_state'].append(turn['metadata'])
                    if last_opponent_utterance:
                        self.data[data_key]['last_opponent_utterance'].append(
                            cur_context[-1] if len(cur_context) >= 1 else '')
                    if last_self_utterance:
                        self.data[data_key]['last_self_utterance'].append(
                            cur_context[-2] if len(cur_context) >= 2 else '')
                    if session_id:
                        self.data[data_key]['session_id'].append(sess_id)
                    if span_info:
                        self.data[data_key]['span_info'].append(turn['span_info'])
                    if terminated:
                        self.data[data_key]['terminated'].append(i + 2 >= len(sess['log']))
                    if goal:
                        self.data[data_key]['goal'].append(sess['goal'])
                    cur_context.append(text)
                    cur_context_dialog_act.append(da)
        if ontology:
            ontology_path = os.path.join(data_dir, 'ontology.json')
            self.data['ontology'] = json.load(open(ontology_path))

        return self.data
示例#8
0
def preprocess():
    self_dir = os.path.dirname(os.path.abspath(__file__))
    processed_dialogue = []
    ontology = {
        'domains': {},
        'intents': {},
        'binary_dialogue_act': [],
        'state': {}
    }
    original_zipped_path = os.path.join(self_dir, 'original_data.zip')
    new_dir = os.path.join(self_dir, 'original_data')
    if not os.path.exists(os.path.join(
            self_dir, 'data.zip')) or not os.path.exists(
                os.path.join(self_dir, 'ontology.json')):
        print('unzip to', new_dir)
        print('This may take several minutes')
        archive = zipfile.ZipFile(original_zipped_path, 'r')
        archive.extractall(self_dir)
        files = [
            ('TM-1-2019/woz-dialogs.json', 'TM-1-2019/ontology.json'),
            ('TM-1-2019/self-dialogs.json', 'TM-1-2019/ontology.json'),
            ('TM-2-2020/data/flights.json', 'TM-2-2020/ontology/flights.json'),
            ('TM-2-2020/data/food-ordering.json',
             'TM-2-2020/ontology/food-ordering.json'),
            ('TM-2-2020/data/hotels.json', 'TM-2-2020/ontology/hotels.json'),
            ('TM-2-2020/data/movies.json', 'TM-2-2020/ontology/movies.json'),
            ('TM-2-2020/data/music.json', 'TM-2-2020/ontology/music.json'),
            ('TM-2-2020/data/restaurant-search.json',
             'TM-2-2020/ontology/restaurant-search.json'),
            ('TM-2-2020/data/sports.json', 'TM-2-2020/ontology/sports.json')
        ]
        idx_count = 1
        total = 0

        for filename, ontology_filename in files:
            data = json.load(open(os.path.join(new_dir, filename)))
            ori_ontology = {}
            if 'TM-1' in filename:
                for domain, item in json.load(
                        open(os.path.join(new_dir,
                                          ontology_filename))).items():
                    ori_ontology[item["id"]] = {}
                    for slot in item["required"] + item["optional"]:
                        ori_ontology[item["id"]][slot] = 0
            else:
                domain = normalize_domain_name(
                    filename.split('/')[-1].split('.')[0])
                ori_ontology[domain] = {}
                for _, item in json.load(
                        open(os.path.join(new_dir,
                                          ontology_filename))).items():
                    for group in item:
                        for anno in group["annotations"]:
                            ori_ontology[domain][anno] = 0
            for d in ori_ontology:
                if d not in ontology['domains']:
                    ontology['domains'][d] = {
                        'description': descriptions[d][d],
                        'slots': {}
                    }
                for s in ori_ontology[d]:
                    if s not in ontology['domains'][d]['slots']:
                        ontology['domains'][d]['slots'][s] = {
                            'description': descriptions[d][s],
                            'is_categorical': False,
                            'possible_values': [],
                            'count': 0,
                            'in original ontology': True
                        }
            # pprint(ori_ontology)
            for ori_sess in tqdm(
                    data, desc='processing taskmaster-{}'.format(filename)):
                total += 1
                turns = format_turns(ori_sess['utterances'])
                if not turns:
                    continue
                if 'TM-2' in filename:
                    dial_domain = normalize_domain_name(
                        filename.split('/')[-1].split('.')[0])
                else:
                    dial_domain = normalize_domain_name(
                        ori_sess['instruction_id'].split('-', 1)[0])
                dialogue = {
                    "dataset": "taskmaster",
                    "data_split": "train",
                    "dialogue_id": 'taskmaster_' + str(idx_count),
                    "original_id": ori_sess['conversation_id'],
                    "instruction_id": ori_sess['instruction_id'],
                    "domains": [dial_domain],
                    "turns": []
                }
                idx_count += 1
                assert turns[0]['speaker'] == 'user' and turns[-1][
                    'speaker'] == 'user', print(turns)
                for utt_idx, uttr in enumerate(turns):
                    speaker = uttr['speaker']
                    turn = {
                        'speaker': speaker,
                        'utterance': uttr['text'],
                        'utt_idx': utt_idx,
                        'dialogue_act': {
                            'binary': [],
                            'categorical': [],
                            'non-categorical': [],
                        },
                    }
                    if speaker == 'user':
                        turn['state'] = {}
                        turn['state_update'] = {
                            'categorical': [],
                            'non-categorical': []
                        }

                    if 'segments' in uttr:
                        for segment in uttr['segments']:
                            for item in segment['annotations']:
                                # domain = item['name'].split('.', 1)[0]
                                domain = dial_domain

                                # if domain != item['name'].split('.', 1)[0]:
                                #     print(domain, item['name'].split('.', 1), dialogue["original_id"])
                                #     assert domain in item['name'].split('.', 1)[0]

                                # if item['name'].split('.', 1)[0] != domain:
                                #     print(domain, item['name'].split('.', 1), dialogue["original_id"])
                                slot = item['name'].split('.', 1)[-1]
                                if slot.endswith('.accept') or slot.endswith(
                                        '.reject'):
                                    slot = slot[:-7]
                                if slot not in ori_ontology[domain]:
                                    # print(domain, item['name'].split('.', 1), dialogue["original_id"])
                                    continue
                                # if domain in ori_ontology:
                                #     ori_ontology[domain][slot] += 1
                                # else:
                                #     print(domain, item['name'].split('.', 1), dialogue["original_id"])
                                # assert domain in ori_ontology, print(domain, item['name'].split('.', 1), dialogue["original_id"])

                                if not segment['text']:
                                    print(slot)
                                    print(segment)
                                    print()
                                assert turn['utterance'][
                                    segment['start_index']:
                                    segment['end_index']] == segment['text']
                                turn['dialogue_act']['non-categorical'].append(
                                    {
                                        'intent': 'inform',
                                        'domain': domain,
                                        'slot': slot,
                                        'value': segment['text'].lower(),
                                        'start': segment['start_index'],
                                        'end': segment['end_index']
                                    })
                        log_ontology(turn['dialogue_act']['non-categorical'],
                                     ontology, ori_ontology)
                    dialogue['turns'].append(turn)
                processed_dialogue.append(dialogue)
            # pprint(ori_ontology)
        # save ontology json
        json.dump(ontology,
                  open(os.path.join(self_dir, 'ontology.json'), 'w'),
                  indent=2)
        json.dump(processed_dialogue, open('data.json', 'w'), indent=2)
        write_zipped_json(os.path.join(self_dir, 'data.zip'), 'data.json')
        os.remove('data.json')
    else:
        # read from file
        processed_dialogue = read_zipped_json(
            os.path.join(self_dir, 'data.zip'), 'data.json')
        ontology = json.load(open(os.path.join(self_dir, 'ontology.json')))
    return processed_dialogue, ontology
示例#9
0
def preprocess():
    processed_dialogue = []
    ontology = {
        'domains': {
            'travel': {
                "description":
                "Book a vacation package containing round-trip flights and a hotel.",
                "slots": {}
            }
        },
        'intents': {},
        'binary_dialogue_act': [],
        'state': {}
    }
    original_zipped_path = os.path.join(self_dir, 'original_data.zip')
    new_dir = os.path.join(self_dir, 'original_data')
    if not os.path.exists(original_zipped_path):
        raise FileNotFoundError(original_zipped_path)
    if not os.path.exists(os.path.join(
            self_dir, 'data.zip')) or not os.path.exists(
                os.path.join(self_dir, 'ontology.json')):
        print('unzip to', new_dir)
        print('This may take several minutes')
        archive = zipfile.ZipFile(original_zipped_path, 'r')
        archive.extractall(new_dir)
        data = json.load(open(os.path.join(new_dir, 'frames.json')))
        # json.dump(data, open(os.path.join(new_dir, 'original_data.json'), 'w'), indent=2)
        cnt = 1
        for d in tqdm(data, desc='dialogue'):
            dialogue = {
                "dataset":
                'frames',
                "data_split":
                'train',
                "dialogue_id":
                'frames_' + str(cnt),
                "original_id":
                d['id'],
                "user_id":
                d['user_id'],
                "system_id":
                d['wizard_id'],
                "userSurveyRating":
                d['labels']['userSurveyRating'],
                "wizardSurveyTaskSuccessful":
                d['labels']['wizardSurveyTaskSuccessful'],
                "domains": ['travel'],
                "turns": []
            }
            # state = deepcopy(ontology['state']['travel'])
            for utt_idx, t in enumerate(d['turns']):
                speaker = 'system' if t['author'] == 'wizard' else t['author']
                turn = {
                    'speaker': speaker,
                    'utterance': t['text'],
                    'utt_idx': utt_idx,
                    'dialogue_act': {
                        'binary': [],
                        'categorical': [],
                        'non-categorical': [],
                    },
                }
                for intent, slot, value in iter_over_acts(t['labels']['acts']):
                    da_type, da = normalize_da(intent, slot, value, t['text'])
                    if da is not None:
                        da['value'] = da['value'].lower()
                        turn['dialogue_act'][da_type].append(da)
                        slot = da['slot']
                        value = da['value']
                        if da_type == 'binary':
                            if da not in ontology['binary_dialogue_act']:
                                ontology['binary_dialogue_act'].append(da)
                        else:
                            ontology['domains']['travel']['slots'].setdefault(
                                slot, {
                                    "description": slot2des[slot],
                                    "is_categorical": da_type == 'categorical',
                                    "possible_values": []
                                })
                            if da_type == 'categorical' \
                                    and value not in ontology['domains']['travel']['slots'][slot]['possible_values']:
                                ontology['domains']['travel']['slots'][slot][
                                    'possible_values'].append(value)
                        ontology['intents'].setdefault(
                            intent, {"description": intent2des[intent]})
                # state
                if speaker == 'user':
                    turn['state'] = {}
                    turn['state_update'] = {
                        'categorical': [],
                        'non-categorical': [],
                    }
                dialogue['turns'].append(deepcopy(turn))
            cnt += 1
            if len(dialogue['turns']) % 2 == 0:
                dialogue['turns'] = dialogue['turns'][:-1]
            processed_dialogue.append(deepcopy(dialogue))
        ontology['binary_dialogue_act'] = sorted(
            ontology['binary_dialogue_act'], key=lambda x: x['intent'])
        json.dump(ontology,
                  open(os.path.join(self_dir, 'ontology.json'), 'w'),
                  indent=2)
        json.dump(processed_dialogue, open('data.json', 'w'), indent=2)
        write_zipped_json(os.path.join(self_dir, 'data.zip'), 'data.json')
        os.remove('data.json')
    else:
        # read from file
        processed_dialogue = read_zipped_json(
            os.path.join(self_dir, 'data.zip'), 'data.json')
        ontology = json.load(open(os.path.join(self_dir, 'ontology.json')))
    return processed_dialogue, ontology