def get_air_score(self, encoder_states, expected_action, kb): intent_out = encoder_states[3] intent_pred = intent_out.max(-1)[1].item() name_out = encoder_states[4] name_pred_vec = name_out.max(-1)[1] name_pred_vec_stop = 0 for i in name_pred_vec: if i == 0: break else: name_pred_vec_stop += 1 name_pred_vec = name_pred_vec[:name_pred_vec_stop] name_pred = self._v2t(name_pred_vec) name_pred = ' '.join([i.capitalize() for i in name_pred.split(' ')]) ticket_attn = encoder_states[2] ticket_pred = ticket_attn.max(-1)[1].item() status_pred, flight_pred = intent_to_status(ticket_pred, kb['reservation'] > 0, intent_pred) if flight_pred == 0: flight_pred = [] else: flight_pred = [flight_pred - 1 + 1000] pred = { 'status': STATUS_DICT[status_pred], 'flight': flight_pred, 'name': name_pred, } reward, name_score, flight_score, status_score = compute_reward( action_obj_to_str(pred), action_obj_to_str(expected_action), tokenize_kb(kb), ) pred['reward'] = reward pred['name_score'] = name_score pred['flight_score'] = flight_score pred['status_score'] = status_score pred['intent'] = INTENT_DICT[intent_pred] return pred
def score_human_data(flags): assert flags.true_data and flags.true_kb scores = [] expanded_kb = expanduser(flags.true_kb) expanded_data = expanduser(flags.true_data) f2 = gfile.Open(expanded_kb) with gfile.Open(expanded_data) as f: for line in tqdm(f): a = json.loads(line) kb_line = f2.readline() if a['correct_sample'] == False: pred_action = action_obj_to_str(a['action']) true_action = action_obj_to_str(a['expected_action']) kb = tokenize_kb(json.loads(kb_line)) ss = compute_reward(pred_action, true_action, kb) scores.append(ss) else: scores.append([1, 1, 1, 1]) sn = np.array(scores) # np.mean(sn[:,0]), np.mean(sn[:,1]),np.mean(sn[:,2]),np.mean(sn[:,3]) score = np.mean(sn[:, 0]) print('final score', score) return {'score': score}
def score_selfplay(flags): assert flags.true_data and flags.true_kb and flags.pred_data # check output all_score = [] bleu_scores = [] with tf.gfile.GFile(flags.pred_data) as f: with tf.gfile.GFile(flags.true_data) as t: with tf.gfile.GFile(flags.true_kb) as kb: for pred_line, true_line, kb_line in tqdm(list(zip(f, t, kb))): pred_json_obj = json.loads(pred_line) true_json_obj = json.loads(true_line) kb = tokenize_kb(json.loads(kb_line)) pred_action = '' if 'action' not in pred_json_obj: pred_action = '<unk> <unk> <unk> <unk>'.split(' ') else: pred_action = action_obj_to_str( pred_json_obj['action']) true_action = action_obj_to_str( true_json_obj['expected_action']) score = compute_reward(pred_action, true_action, kb) all_score.append(score) pred_raw_text = json_obj_to_tokens(pred_json_obj) true_raw_text = json_obj_to_tokens(true_json_obj) b = compute_bleu([[true_raw_text]], [pred_raw_text]) bleu_scores.append(b[0] * 100) avg_score = np.mean(all_score) avg_bleu = np.mean(bleu_scores) print('score=', avg_score) print('bleu=', avg_bleu) return {'score': avg_score, 'bleu': avg_bleu}
def main(): parser = argparse.ArgumentParser() parser.add_argument('--ref_file', required=True, type=str, help='path to reference file') parser.add_argument('--ref_kb', required=True, type=str, help='path to reference kb file') parser.add_argument('--output_dir', required=True, type=str, help='path to output dir') args = parser.parse_args() print(args) ref_data = [] ref_score = 0 with open(args.ref_file, 'r') as ref_file: with open(args.ref_kb, 'r') as kb_file: for l, kb_line in zip(ref_file, kb_file): _json = json.loads(l) kb = tokenize_kb(json.loads(kb_line)) action = action_obj_to_str(_json['action']) expected_action = action_obj_to_str(_json['expected_action']) if 'reward' not in _json: score = compute_reward(action, expected_action, kb) _json['reward'] = score[0] ref_score += _json['reward'] ref_data.append(_json) print('# of reference dialogue: ', len(ref_data)) print('avg ref reward: ', ref_score / len(ref_data)) result_data = [] for s in ref_data: dia = s['dialogue'] new_s = copy.deepcopy(s) customer_lines = len([l for l in dia if l.startswith('customer: ')]) agent_lines = len([l for l in dia if l.startswith('agent: ')]) gen_lines = len([l for l in dia if l.startswith('agent_tgt: ')]) assert agent_lines + customer_lines + gen_lines == len(dia) if not dia[0].startswith('customer: '): new_s['dialogue'] = ['customer: '] + new_s['dialogue'] if dia[-1].startswith('customer: '): new_s['dialogue'] = new_s['dialogue'][:-1] assert len(new_s['dialogue']) % 3 == 0 new_s['ref_customer_response'] = [ l.replace('customer:', '').lstrip().rstrip() for l in new_s['dialogue'][::3] ] new_s['ref_agent_response'] = [ l.replace('agent:', '').lstrip().rstrip() for l in new_s['dialogue'][1::3] ] new_s['gen_agent_response'] = [ l.replace('agent_tgt:', '').lstrip().rstrip() for l in new_s['dialogue'][2::3] ] # print(len(new_s['ref_agent_response'])) result_data.append(new_s) os.makedirs(args.output_dir, exist_ok=True) print('saving to :', args.output_dir) with open(os.path.join(args.output_dir, 'data.json'), 'w') as save_file: for i, r in enumerate(result_data): save_file.write(json.dumps(r)) save_file.write('\n')