def action_vectorize(self, da): da = delexicalize_da(da) sys_act_vec = np.zeros(self.sys_da_dim) for a in da: if a in self.sys_da2id: sys_act_vec[self.sys_da2id[a]] = 1. return sys_act_vec
def state_vectorize(self, state): self.belief_state = state['belief_state'] self.cur_domain = state['cur_domain'] da = state['user_action'] da = delexicalize_da(da) usr_act_vec = np.zeros(self.usr_da_dim) for a in da: if a in self.usr_da2id: usr_act_vec[self.usr_da2id[a]] = 1. da = state['system_action'] da = delexicalize_da(da) sys_act_vec = np.zeros(self.sys_da_dim) for a in da: if a in self.sys_da2id: sys_act_vec[self.sys_da2id[a]] = 1. belief_state_vec = np.zeros(self.belief_state_dim) i = 0 for domain, svs in state['belief_state'].items(): for slot, value in svs.items(): if value: belief_state_vec[i] = 1. i += 1 self.db_res = self.database.query(state['belief_state'], state['cur_domain']) db_res_num = len(self.db_res) db_res_vec = np.zeros(4) if db_res_num == 0: db_res_vec[0] = 1. elif db_res_num == 1: db_res_vec[1] = 1. elif 1 < db_res_num < 5: db_res_vec[2] = 1. else: db_res_vec[3] = 1. terminated = 1. if state['terminated'] else 0. # print('state dim', self.state_dim) state_vec = np.r_[usr_act_vec, sys_act_vec, belief_state_vec, db_res_vec, terminated] # print('actual state vec dim', len(state_vec)) return state_vec
def gen_da_voc(data): usr_da_voc, sys_da_voc = {}, {} for task_id, item in data.items(): for i, turn in enumerate(item['messages']): if turn['role'] == 'usr': da_voc = usr_da_voc else: da_voc = sys_da_voc for da in delexicalize_da(turn['dialog_act']): da_voc[da] = 0 return sorted(usr_da_voc.keys()), sorted(sys_da_voc.keys())
def evaluate_corpus_f1(policy, data, goal_type=None): dst = RuleDST() da_predict_golden = [] delex_da_predict_golden = [] for task_id, sess in data.items(): if goal_type and sess['type'] != goal_type: continue dst.init_session() for i, turn in enumerate(sess['messages']): if turn['role'] == 'usr': dst.update(usr_da=turn['dialog_act']) if i + 2 == len(sess): dst.state['terminated'] = True else: for domain, svs in turn['sys_state'].items(): for slot, value in svs.items(): if slot != 'selectedResults': dst.state['belief_state'][domain][slot] = value golden_da = turn['dialog_act'] predict_da = policy.predict(deepcopy(dst.state)) # print(golden_da) # print(predict_da) # print() # if 'Select' in [x[0] for x in sess['messages'][i - 1]['dialog_act']]: da_predict_golden.append({ 'predict': predict_da, 'golden': golden_da }) delex_da_predict_golden.append({ 'predict': delexicalize_da(predict_da), 'golden': delexicalize_da(golden_da) }) # print(delex_da_predict_golden[-1]) dst.state['system_action'] = golden_da # break print('origin precision/recall/f1:', calculateF1(da_predict_golden)) print('delex precision/recall/f1:', calculateF1(delex_da_predict_golden))