def _build_data(self, root_dir, processed_dir): raw_data = {} for part in ['train', 'val', 'test']: archive = zipfile.ZipFile(os.path.join(root_dir, 'data/crosswoz/{}.json.zip'.format(part)), 'r') with archive.open('{}.json'.format(part), 'r') as f: raw_data[part] = json.load(f) self.data = {} # for cur domain update dst = RuleDST() for part in ['train', 'val', 'test']: self.data[part] = [] for key in raw_data[part]: sess = raw_data[part][key]['messages'] dst.init_session() for i, turn in enumerate(sess): if turn['role'] == 'usr': dst.update(usr_da=turn['dialog_act']) if i + 2 == len(sess): dst.state['terminated'] = True else: for domain, svs in turn['sys_state'].items(): for slot, value in svs.items(): if slot != 'selectedResults': dst.state['belief_state'][domain][slot] = value action = turn['dialog_act'] self.data[part].append([self.vector.state_vectorize(deepcopy(dst.state)), self.vector.action_vectorize(action)]) dst.state['system_action'] = turn['dialog_act'] os.makedirs(processed_dir) for part in ['train', 'val', 'test']: with open(os.path.join(processed_dir, '{}.pkl'.format(part)), 'wb') as f: pickle.dump(self.data[part], f)
def test_sys_state(data, goal_type): ruleDST = RuleDST() state_predict_golden = [] for task_id, item in data.items(): if goal_type and item['type'] != goal_type: continue ruleDST.init_session() for i, turn in enumerate(item['messages']): if turn['role'] == 'sys': usr_da = item['messages'][i - 1]['dialog_act'] if i > 2: for domain, svs in item['messages'][ i - 2]['sys_state'].items(): for slot, value in svs.items(): if slot != 'selectedResults': ruleDST.state['belief_state'][domain][ slot] = value ruleDST.update(usr_da) new_state = deepcopy(ruleDST.state['belief_state']) golden_state = deepcopy(turn['sys_state_init']) for x in golden_state: golden_state[x].pop('selectedResults') state_predict_golden.append({ 'predict': new_state, 'golden': golden_state }) print('joint state', calculateJointState(state_predict_golden)) print('slot state', calculateSlotState(state_predict_golden))
def evaluate_corpus_f1(policy, data, goal_type=None): dst = RuleDST() da_predict_golden = [] delex_da_predict_golden = [] for task_id, sess in data.items(): if goal_type and sess['type'] != goal_type: continue dst.init_session() for i, turn in enumerate(sess['messages']): if turn['role'] == 'usr': dst.update(usr_da=turn['dialog_act']) if i + 2 == len(sess): dst.state['terminated'] = True else: for domain, svs in turn['sys_state'].items(): for slot, value in svs.items(): if slot != 'selectedResults': dst.state['belief_state'][domain][slot] = value golden_da = turn['dialog_act'] predict_da = policy.predict(deepcopy(dst.state)) # print(golden_da) # print(predict_da) # print() # if 'Select' in [x[0] for x in sess['messages'][i - 1]['dialog_act']]: da_predict_golden.append({ 'predict': predict_da, 'golden': golden_da }) delex_da_predict_golden.append({ 'predict': delexicalize_da(predict_da), 'golden': delexicalize_da(golden_da) }) # print(delex_da_predict_golden[-1]) dst.state['system_action'] = golden_da # break print('origin precision/recall/f1:', calculateF1(da_predict_golden)) print('delex precision/recall/f1:', calculateF1(delex_da_predict_golden))
def end2end_evaluate_simulation(policy): nlu = BERTNLU('all', 'crosswoz_all_context.json', None) nlg_usr = TemplateNLG(is_user=True, mode='auto_manual') nlg_sys = TemplateNLG(is_user=False, mode='auto_manual') # nlg_usr = SCLSTM(is_user=True, use_cuda=False) # nlg_sys = SCLSTM(is_user=False, use_cuda=False) usr_policy = Simulator() usr_agent = PipelineAgent(nlu, None, usr_policy, nlg_usr, name='usr') sys_policy = policy sys_dst = RuleDST() sys_agent = PipelineAgent(nlu, sys_dst, sys_policy, nlg_sys, name='sys') sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent) task_success = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(), '不独立多领域+交通': list()} simulate_sess_num = 100 repeat = 10 random_seed = 2019 random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)] while True: sys_response = '' random_seed = random_seeds[0] random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) random_seeds.pop(0) sess.init_session() # print(usr_policy.goal_type) if len(task_success[usr_policy.goal_type]) == simulate_sess_num*repeat: continue for i in range(15): sys_response, user_response, session_over, reward = sess.next_turn(sys_response) # print('user:'******'sys:', sys_response) # print(session_over, reward) # print() if session_over is True: task_success['All'].append(1) task_success[usr_policy.goal_type].append(1) break else: task_success['All'].append(0) task_success[usr_policy.goal_type].append(0) print([len(x) for x in task_success.values()]) # print(min([len(x) for x in task_success.values()])) if len(task_success['All']) % 100 == 0: for k, v in task_success.items(): print(k) all_samples = [] for i in range(repeat): samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num] all_samples += samples print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0) print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0) if min([len(x) for x in task_success.values()]) == simulate_sess_num*repeat: break # pprint(usr_policy.original_goal) # pprint(task_success) print('task_success') for k, v in task_success.items(): print(k) all_samples = [] for i in range(repeat): samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num] all_samples += samples print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0) print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)
def da_evaluate_simulation(policy): usr_policy = Simulator() usr_agent = PipelineAgent(None, None, usr_policy, None, name='usr') sys_policy = policy sys_dst = RuleDST() sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, name='sys') sess = BiSession(sys_agent=sys_agent, user_agent=usr_agent) task_success = {'All': list(), '单领域': list(), '独立多领域': list(), '独立多领域+交通': list(), '不独立多领域': list(), '不独立多领域+交通': list()} simulate_sess_num = 100 repeat = 10 random_seed = 2019 random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) random_seeds = [random.randint(1, 2**32-1) for _ in range(simulate_sess_num * repeat * 10000)] while True: sys_response = [] random_seed = random_seeds[0] random.seed(random_seed) np.random.seed(random_seed) torch.manual_seed(random_seed) random_seeds.pop(0) sess.init_session() # print(usr_policy.goal_type) if len(task_success[usr_policy.goal_type]) == simulate_sess_num*repeat: continue for i in range(15): sys_response, user_response, session_over, reward = sess.next_turn(sys_response) # print('user:'******'sys:', sys_response) # print(session_over, reward) # print() if session_over is True: # pprint(sys_agent.tracker.state) task_success['All'].append(1) task_success[usr_policy.goal_type].append(1) break else: task_success['All'].append(0) task_success[usr_policy.goal_type].append(0) print([len(x) for x in task_success.values()]) # print(min([len(x) for x in task_success.values()])) if len(task_success['All']) % 100 == 0: for k, v in task_success.items(): print(k) all_samples = [] for i in range(repeat): samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num] all_samples += samples print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0) print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0) if min([len(x) for x in task_success.values()]) == simulate_sess_num*repeat: break # pprint(usr_policy.original_goal) # pprint(task_success) print('task_success') for k, v in task_success.items(): print(k) all_samples = [] for i in range(repeat): samples = v[i * simulate_sess_num:(i + 1) * simulate_sess_num] all_samples += samples print(sum(samples), len(samples), (sum(samples) / len(samples)) if len(samples) else 0) print('avg', (sum(all_samples) / len(all_samples)) if len(all_samples) else 0)