示例#1
0
    def _build_data(self, root_dir, processed_dir):
        raw_data = {}
        for part in ['train', 'val', 'test']:
            archive = zipfile.ZipFile(os.path.join(root_dir, 'data/crosswoz/{}.json.zip'.format(part)), 'r')
            with archive.open('{}.json'.format(part), 'r') as f:
                raw_data[part] = json.load(f)

        self.data = {}
        # for cur domain update
        dst = RuleDST()
        for part in ['train', 'val', 'test']:
            self.data[part] = []

            for key in raw_data[part]:
                sess = raw_data[part][key]['messages']
                dst.init_session()
                for i, turn in enumerate(sess):
                    if turn['role'] == 'usr':
                        dst.update(usr_da=turn['dialog_act'])
                        if i + 2 == len(sess):
                            dst.state['terminated'] = True
                    else:
                        for domain, svs in turn['sys_state'].items():
                            for slot, value in svs.items():
                                if slot != 'selectedResults':
                                    dst.state['belief_state'][domain][slot] = value
                        action = turn['dialog_act']
                        self.data[part].append([self.vector.state_vectorize(deepcopy(dst.state)),
                                                self.vector.action_vectorize(action)])
                        dst.state['system_action'] = turn['dialog_act']

        os.makedirs(processed_dir)
        for part in ['train', 'val', 'test']:
            with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f:
                pickle.dump(self.data[part], f)
示例#2
0
def test_sys_state(data, goal_type):
    ruleDST = RuleDST()
    state_predict_golden = []
    for task_id, item in data.items():
        if goal_type and item['type'] != goal_type:
            continue
        ruleDST.init_session()
        for i, turn in enumerate(item['messages']):
            if turn['role'] == 'sys':
                usr_da = item['messages'][i - 1]['dialog_act']
                if i > 2:
                    for domain, svs in item['messages'][
                            i - 2]['sys_state'].items():
                        for slot, value in svs.items():
                            if slot != 'selectedResults':
                                ruleDST.state['belief_state'][domain][
                                    slot] = value
                ruleDST.update(usr_da)
                new_state = deepcopy(ruleDST.state['belief_state'])
                golden_state = deepcopy(turn['sys_state_init'])
                for x in golden_state:
                    golden_state[x].pop('selectedResults')
                state_predict_golden.append({
                    'predict': new_state,
                    'golden': golden_state
                })
    print('joint state', calculateJointState(state_predict_golden))
    print('slot state', calculateSlotState(state_predict_golden))
示例#3
0
def evaluate_corpus_f1(policy, data, goal_type=None):
    dst = RuleDST()
    da_predict_golden = []
    delex_da_predict_golden = []
    for task_id, sess in data.items():
        if goal_type and sess['type'] != goal_type:
            continue
        dst.init_session()
        for i, turn in enumerate(sess['messages']):
            if turn['role'] == 'usr':
                dst.update(usr_da=turn['dialog_act'])
                if i + 2 == len(sess):
                    dst.state['terminated'] = True
            else:
                for domain, svs in turn['sys_state'].items():
                    for slot, value in svs.items():
                        if slot != 'selectedResults':
                            dst.state['belief_state'][domain][slot] = value
                golden_da = turn['dialog_act']

                predict_da = policy.predict(deepcopy(dst.state))
                # print(golden_da)
                # print(predict_da)
                # print()
                # if 'Select' in [x[0] for x in sess['messages'][i - 1]['dialog_act']]:
                da_predict_golden.append({
                    'predict': predict_da,
                    'golden': golden_da
                })
                delex_da_predict_golden.append({
                    'predict':
                    delexicalize_da(predict_da),
                    'golden':
                    delexicalize_da(golden_da)
                })
                # print(delex_da_predict_golden[-1])
                dst.state['system_action'] = golden_da
        # break
    print('origin precision/recall/f1:', calculateF1(da_predict_golden))
    print('delex precision/recall/f1:', calculateF1(delex_da_predict_golden))
示例#4
0
def end2end_evaluate_simulation(policy):
    nlu = BERTNLU('all', 'crosswoz_all_context.json', None)
    nlg_usr = TemplateNLG(is_user=True, mode='auto_manual')
    nlg_sys = TemplateNLG(is_user=False, mode='auto_manual')
    # nlg_usr = SCLSTM(is_user=True, use_cuda=False)
    # nlg_sys = SCLSTM(is_user=False, use_cuda=False)
    usr_policy = Simulator()
    usr_agent = PipelineAgent(nlu, None, usr_policy, nlg_usr, name='usr')
    sys_policy = policy
    sys_dst = RuleDST()
    sys_agent = PipelineAgent(nlu, sys_dst, sys_policy, nlg_sys, name='sys')
    sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent)

    task_success = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(),
                    '不独立多领域+交通': list()}
    simulate_sess_num = 100
    repeat = 10
    random_seed = 2019
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)]
    while True:
        sys_response = ''
        random_seed = random_seeds[0]
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        random_seeds.pop(0)
        sess.init_session()
        # print(usr_policy.goal_type)
        if len(task_success[usr_policy.goal_type]) == simulate_sess_num*repeat:
            continue
        for i in range(15):
            sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
            # print('user:'******'sys:', sys_response)
            # print(session_over, reward)
            # print()
            if session_over is True:
                task_success['All'].append(1)
                task_success[usr_policy.goal_type].append(1)
                break
        else:
            task_success['All'].append(0)
            task_success[usr_policy.goal_type].append(0)
        print([len(x) for x in task_success.values()])
        # print(min([len(x) for x in task_success.values()]))
        if len(task_success['All']) % 100 == 0:
            for k, v in task_success.items():
                print(k)
                all_samples = []
                for i in range(repeat):
                    samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
                    all_samples += samples
                    print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
                print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
        if min([len(x) for x in task_success.values()]) == simulate_sess_num*repeat:
            break
        # pprint(usr_policy.original_goal)
        # pprint(task_success)
    print('task_success')
    for k, v in task_success.items():
        print(k)
        all_samples = []
        for i in range(repeat):
            samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
            all_samples += samples
            print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
        print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
示例#5
0
def da_evaluate_simulation(policy):
    usr_policy = Simulator()
    usr_agent = PipelineAgent(None, None, usr_policy, None, name='usr')
    sys_policy = policy
    sys_dst = RuleDST()
    sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys')
    sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent)

    task_success = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(),
                    '不独立多领域+交通': list()}
    simulate_sess_num = 100
    repeat = 10
    random_seed = 2019
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)]
    while True:
        sys_response = []
        random_seed = random_seeds[0]
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        random_seeds.pop(0)
        sess.init_session()
        # print(usr_policy.goal_type)
        if len(task_success[usr_policy.goal_type]) == simulate_sess_num*repeat:
            continue
        for i in range(15):
            sys_response, user_response, session_over, reward = sess.next_turn(sys_response)
            # print('user:'******'sys:', sys_response)
            # print(session_over, reward)
            # print()
            if session_over is True:
                # pprint(sys_agent.tracker.state)
                task_success['All'].append(1)
                task_success[usr_policy.goal_type].append(1)
                break
        else:
            task_success['All'].append(0)
            task_success[usr_policy.goal_type].append(0)
        print([len(x) for x in task_success.values()])
        # print(min([len(x) for x in task_success.values()]))
        if len(task_success['All']) % 100 == 0:
            for k, v in task_success.items():
                print(k)
                all_samples = []
                for i in range(repeat):
                    samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
                    all_samples += samples
                    print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
                print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
        if min([len(x) for x in task_success.values()]) == simulate_sess_num*repeat:
            break
        # pprint(usr_policy.original_goal)
        # pprint(task_success)
    print('task_success')
    for k, v in task_success.items():
        print(k)
        all_samples = []
        for i in range(repeat):
            samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num]
            all_samples += samples
            print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0)
        print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)