def eval_state_predict(data): def state_update(prev_state, cur_state): update = [] for prev_ele, cur_ele in zip(prev_state, cur_state): if cur_ele != prev_ele: update.append(cur_ele) id = 1 for ele in cur_state[::-1]: if ele[-1]: id = ele[0] break return update, id simulator = Simulator() for task_id, item in data.items(): for i, turn in enumerate(item['messages']): if turn['role']=='usr' and i > 0: last_turn = item['messages'][i-2] usr_da = item['messages'][i-2]['dialog_act'] sys_da = item['messages'][i-1]['dialog_act'] simulator.init_session(goal=item['goal'],state=deepcopy(last_turn['user_state'])) simulator.state_update(prev_user_da=usr_da, prev_sys_da=sys_da) cur_da = simulator.state_predict() new_state = simulator.state # print('old state:') # pprint(last_turn['user_state']) # if 'NoOffer' in [x[0] for x in item['messages'][i-1]['dialog_act']]: print(item['messages'][i-2]['content']) print(item['messages'][i-1]['content']) print(turn['content']) print('usr da') pprint(usr_da) print('sys da') pprint(sys_da) print('predict state update:') pprint(state_update(last_turn['user_state'], new_state)) print('golden state:') pprint(state_update(last_turn['user_state'], turn['user_state'])) print('predict usr da') pprint(cur_da) print('golden usr da') pprint(turn['dialog_act']) print('-'*100)
def eval_simulator_performance(data, goal_type=None): begin_da_predict_golden = [] state_da_predict_golden = [] state_predict_golden = [] simulator = Simulator() for task_id, item in data.items(): if goal_type and item['type']!=goal_type: continue for i, turn in enumerate(item['messages']): if turn['role']=='usr': if i==0: simulator.init_session(goal=item['goal']) begin_da_predict_golden.append({ 'predict': simulator.begin_da(), 'golden': turn['dialog_act'] }) else: last_turn = item['messages'][i - 2] usr_da = item['messages'][i - 2]['dialog_act'] sys_da = item['messages'][i - 1]['dialog_act'] simulator.init_session(goal=item['goal'], state=deepcopy(last_turn['user_state'])) simulator.state_update(prev_user_da=usr_da, prev_sys_da=sys_da) cur_da = simulator.state_predict() new_state = deepcopy(simulator.state) state_da_predict_golden.append({ 'predict': cur_da, 'golden': turn['dialog_act'] }) state_predict_golden.append({ 'predict': new_state, 'golden': turn['user_state'] }) print('begin da', calculateF1(begin_da_predict_golden)) print('state da', calculateF1(state_da_predict_golden)) print('all da', calculateF1(begin_da_predict_golden+state_da_predict_golden)) print('joint state', calculateJointState(state_predict_golden)) print('slot state', calculateSlotState(state_predict_golden))