def manual_eval(prediction_path, dataset, FLAGS, top_k, num_examples=-1, interactive=True, verbose=True): """ Conduct dev/test set evaluation. Evaluation metrics: 1) full command accuracy; 2) command template accuracy. :param interactive: - If set, prompt the user to enter judgement if a prediction does not match any of the groundtruths and the correctness of the prediction has not been pre-determined; Otherwise, all predictions that does not match any of the groundtruths are counted as wrong. """ # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset) # Load model prediction prediction_list = load_predictions(prediction_path, top_k) metrics = get_manual_evaluation_metrics(grouped_dataset, prediction_list, FLAGS, num_examples=num_examples, interactive=interactive, verbose=verbose) return metrics
def automatic_eval(prediction_path, dataset, FLAGS, top_k, num_samples=-1, verbose=False): """ Generate automatic evaluation metrics on dev/test set. The following metrics are computed: Top 1,3,5,10 1. Structure accuracy 2. Full command accuracy 3. Command keyword overlap 4. BLEU """ grouped_dataset = data_utils.group_parallel_data(dataset) try: vocabs = data_utils.load_vocabulary(FLAGS) except ValueError: vocabs = None # Load predictions prediction_list = load_predictions(prediction_path, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth and predictions length must be equal: " "{} vs. {}".format(len(grouped_dataset), len(prediction_list))) metrics = get_automatic_evaluation_metrics(grouped_dataset, prediction_list, vocabs, FLAGS, top_k, num_samples, verbose) return metrics
def automatic_eval(model_dir, decode_sig, dataset, FLAGS, top_k, num_samples=-1, verbose=False): """ Generate automatic evaluation metrics on a dev/test set. The following metrics are computed: Top 1,3,5,10 1. Structure accuracy 2. Full command accuracy 3. Command keyword overlap 4. BLEU """ use_bucket = False if "knn" in model_dir else True grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=use_bucket) vocabs = data_utils.load_vocabulary(FLAGS) # Load predictions prediction_list = load_predictions(model_dir, decode_sig, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth and predictions length must be equal: " "{} vs. {}".format(len(grouped_dataset), len(prediction_list))) M = get_automatic_evaluation_metrics(grouped_dataset, prediction_list, vocabs, FLAGS, top_k, num_samples, verbose) return M
def gen_error_analysis_csv_by_utility(model_dir, decode_sig, dataset, FLAGS, top_k=10): """ Generate error analysis evaluation sheet grouped by utility. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions prediction_list = load_predictions(model_dir, decode_sig, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError( "ground truth and predictions length must be equal: {} vs. {}". format(len(grouped_dataset), len(prediction_list))) # Load cached evaluation results cached_evaluation_results = load_cached_evaluations(model_dir) # Convert the predictions into csv format grammar_errors, argument_errors = print_error_analysis_csv( grouped_dataset, prediction_list, FLAGS, cached_evaluation_results, group_by_utility=True, error_predictions_only=False) error_by_utility_path = \ os.path.join(model_dir, 'error.analysis.by.utility.csv') print("Saving grammar errors to {}".format(error_by_utility_path)) with open(error_by_utility_path, 'w') as error_by_utility_file: # print csv file header error_by_utility_file.write( 'utility, example_id, description, groundtruth, prediction, ' 'correct template, correct command\n') for line in bash.utility_stats.split('\n'): utility = line.split(',')[1] error_examples = grammar_errors[utility] if len(error_examples) <= 5: for example in error_examples: for l in example: error_by_utility_file.write('{},{}\n'.format( utility, l)) else: random.shuffle(error_examples) for example in error_examples[:5]: for l in example: error_by_utility_file.write('{},{}\n'.format( utility, l))
def gen_error_analysis_csv(model_dir, decode_sig, dataset, FLAGS, top_k=3): """ Generate error analysis evaluation spreadsheet. - grammar error analysis - argument error analysis """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions prediction_list = load_predictions(model_dir, decode_sig, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth and predictions length must be equal: " "{} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Convert the predictions to csv format grammar_errors, argument_errors = print_error_analysis_csv( grouped_dataset, prediction_list, FLAGS) grammar_error_path = os.path.join(model_dir, 'grammar.error.analysis.csv') random.shuffle(grammar_errors) with open(grammar_error_path, 'w') as grammar_error_file: print("Saving grammar errors to {}".format(grammar_error_path)) # print csv file header grammar_error_file.write( 'example_id, description, ground_truth, prediction, ' + 'correct template, correct command\n') for example in grammar_errors[:100]: for line in example: grammar_error_file.write('{}\n'.format(line)) arg_error_path = os.path.join(model_dir, 'argument.error.analysis.csv') random.shuffle(argument_errors) with open(arg_error_path, 'w') as arg_error_file: print("Saving argument errors to {}".format(arg_error_path)) # print csv file header arg_error_file.write( 'example_id, description, ground_truth, prediction, ' + 'correct template, correct command\n') for example in argument_errors[:100]: for line in example: arg_error_file.write('{}\n'.format(line))
def gen_automatic_evaluation_table(dataset, FLAGS): # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True) vocabs = data_utils.load_vocabulary(FLAGS) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) auto_evaluation_metrics = {} for model_id, model_name in enumerate(model_names): prediction_list = model_predictions[model_id] M = get_automatic_evaluation_metrics( grouped_dataset, prediction_list, vocabs, FLAGS, top_k=3) auto_evaluation_metrics[model_name] = \ [M['top_bleu'][0], M['top_bleu'][1], M['top_cms'][0], M['top_cms'][1]] metrics_names = ['BLEU1', 'BLEU3', 'TM1', 'TM3'] print_table(model_names, metrics_names, auto_evaluation_metrics)
def gen_manual_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True): """ Conduct dev/test set evaluation. The results of multiple pre-specified models are tabulated in the same table. Evaluation metrics: 1) full command accuracy; 2) command template accuracy. :param interactive: - If set, prompt the user to enter judgement if a prediction does not match any of the groundtruths and the correctness of the prediction has not been pre-determined; Otherwise, all predictions that does not match any of the groundtruths are counted as wrong. """ # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset) # Load all model predictions model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) manual_eval_metrics = {} for model_id, model_name in enumerate(model_names): prediction_list = model_predictions[model_names] M = get_manual_evaluation_metrics(grouped_dataset, prediction_list, FLAGS, num_examples=num_examples, interactive=interactive, verbose=False) manual_eval_metrics[model_name] = [ M['acc_f'][0], M['acc_f'[1]], M['acc_t'][0], M['acc_t'][1] ] metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3'] print_eval_table(model_names, metrics_names, manual_eval_metrics)
def gen_automatic_evaluation_table(dataset, FLAGS): # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset) vocabs = data_utils.load_vocabulary(FLAGS) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) auto_eval_metrics = {} for model_id, model_name in enumerate(model_names): prediction_list = model_predictions[model_id] if prediction_list is not None: M = get_automatic_evaluation_metrics(grouped_dataset, prediction_list, vocabs, FLAGS, top_k=3) auto_eval_metrics[model_name] = [ M['bleu'][0], M['bleu'][1], M['cms'][0], M['cms'][1] ] else: print('Model {} skipped in evaluation'.format(model_name)) metrics_names = ['BLEU1', 'BLEU3', 'TM1', 'TM3'] print_eval_table(model_names, metrics_names, auto_eval_metrics)
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=False): """ Compute top-k predictions on the dev/test dataset and write the predictions to disk. :param sess: A TensorFlow session. :param model: Prediction model object. :param top_k: Number of top predictions to compute. :param FLAGS: Training/testing hyperparameter settings. :param verbose: If set, also print decoding results to screen. """ nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain tokenizer_selector = 'cm' if FLAGS.explain else 'nl' grouped_dataset = data_utils.group_parallel_data( dataset, okenizer_selector=tokenizer_selector) vocabs = data_utils.load_vocabulary(FLAGS) rev_sc_vocab = vocabs.rev_sc_vocab ts = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H%M%S') pred_file_path = os.path.join(model.model_dir, 'predictions.{}.{}'.format( model.decode_sig, ts)) pred_file = open(pred_file_path, 'w') eval_file_path = os.path.join(model.model_dir, 'predictions.{}.{}.csv'.format( model.decode_sig, ts)) eval_file = open(eval_file_path, 'w') eval_file.write('example_id, description, ground_truth, prediction, ' + 'correct template, correct command\n') for example_id in xrange(len(grouped_dataset)): key, data_group = grouped_dataset[example_id] sc_txt = data_group[0].sc_txt.strip() sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids] if FLAGS.channel == 'char': sc_temp = ''.join(sc_tokens) sc_temp = sc_temp.replace(constants._SPACE, ' ') else: sc_temp = ' '.join(sc_tokens) tg_txts = [dp.tg_txt for dp in data_group] tg_asts = [data_tools.bash_parser(tg_txt) for tg_txt in tg_txts] if verbose: print('\nExample {}:'.format(example_id)) print('Original Source: {}'.format(sc_txt.encode('utf-8'))) print('Source: {}'.format(sc_temp.encode('utf-8'))) for j in xrange(len(data_group)): print('GT Target {}: {}'.format(j+1, data_group[j].tg_txt.encode('utf-8'))) if FLAGS.fill_argument_slots: slot_filling_classifier = get_slot_filling_classifer(FLAGS) batch_outputs, sequence_logits = translate_fun(data_group, sess, model, vocabs, FLAGS, slot_filling_classifier=slot_filling_classifier) else: batch_outputs, sequence_logits = translate_fun(data_group, sess, model, vocabs, FLAGS) if FLAGS.tg_char: batch_outputs, batch_char_outputs = batch_outputs eval_row = '{},"{}",'.format(example_id, sc_txt.replace('"', '""')) if batch_outputs: if FLAGS.token_decoding_algorithm == 'greedy': tree, pred_cmd = batch_outputs[0] if nl2bash: pred_cmd = data_tools.ast2command( tree, loose_constraints=True) score = sequence_logits[0] if verbose: print('Prediction: {} ({})'.format(pred_cmd, score)) pred_file.write('{}\n'.format(pred_cmd)) elif FLAGS.token_decoding_algorithm == 'beam_search': top_k_predictions = batch_outputs[0] if FLAGS.tg_char: top_k_char_predictions = batch_char_outputs[0] top_k_scores = sequence_logits[0] num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions)) for j in xrange(num_preds): if j > 0: eval_row = ',,' if j < len(tg_txts): eval_row += '"{}",'.format(tg_txts[j].strip().replace('"', '""')) else: eval_row += ',' top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j] if nl2bash: pred_cmd = data_tools.ast2command( top_k_pred_tree, loose_constraints=True) else: pred_cmd = top_k_pred_cmd pred_file.write('{}|||'.format(pred_cmd.encode('utf-8'))) eval_row += '"{}",'.format(pred_cmd.replace('"', '""')) temp_match = tree_dist.one_match( tg_asts, top_k_pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match( tg_asts, top_k_pred_tree, ignore_arg_value=False) if temp_match: eval_row += 'y,' if str_match: eval_row += 'y' eval_file.write('{}\n'.format(eval_row.encode('utf-8'))) if verbose: print('Prediction {}: {} ({})'.format( j+1, pred_cmd.encode('utf-8'), top_k_scores[j])) if FLAGS.tg_char: print('Character-based prediction {}: {}'.format( j+1, top_k_char_predictions[j].encode('utf-8'))) pred_file.write('\n') else: print(APOLOGY_MSG) pred_file.write('\n') eval_file.write('{}\n'.format(eval_row)) eval_file.write('\n') eval_file.write('\n') pred_file.close() eval_file.close() shutil.copyfile(pred_file_path, os.path.join(FLAGS.model_dir, 'predictions.{}.latest'.format(model.decode_sig))) shutil.copyfile(eval_file_path, os.path.join(FLAGS.model_dir, 'predictions.{}.latest.csv'.format(model.decode_sig)))
def gen_manual_evaluation_csv_single_model(dataset, FLAGS): """ Generate .csv spreadsheet for manual evaluation on dev/test set examples for a specific model. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS) model_dir = os.path.join(FLAGS.model_root_dir, model_subdir) prediction_list = load_predictions(model_dir, decode_sig, top_k=3) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth list and prediction list length must " "be equal: {} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Load additional ground truths template_translations, command_translations = load_cached_correct_translations( FLAGS.data_dir) # Load cached evaluation results structure_eval_cache, command_eval_cache = load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser output_path = os.path.join(model_dir, 'manual.evaluations.single.model') with open(output_path, 'w') as o_f: # write spreadsheet header o_f.write('id,description,command,correct template,correct command\n') for example_id in range(len(grouped_dataset)): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gts = set(command_gts + command_translations[sc_key]) command_gt_asts = [ data_tools.bash_parser(cmd) for cmd in command_gts ] template_gts = [ data_tools.cmd2template(cmd, loose_constraints=True) for cmd in command_gts ] template_gts = set(template_gts + template_translations[sc_key]) template_gt_asts = [ data_tools.bash_parser(temp) for temp in template_gts ] predictions = prediction_list[example_id] for i in xrange(3): if i >= len(predictions): o_f.write(',,,n,n\n') continue pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(template_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] if i == 0: o_f.write('{},"{}","{}",{},{}\n'.format( example_id, sc_txt.replace('"', '""'), pred_cmd.replace('"', '""'), structure_eval, command_eval)) else: o_f.write(',,"{}",{},{}\n'.format( pred_cmd.replace('"', '""'), structure_eval, command_eval)) print('manual evaluation spreadsheet saved to {}'.format(output_path))
def tabulate_example_predictions(dataset, FLAGS, num_examples=100): # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=1) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) sample_ids = example_ids[:num_examples] # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser model_name_pt = { 'token-seq2seq': 'T-Seq2Seq', 'tellina': 'Tellina', 'token-copynet': 'T-CopyNet', 'partial.token-seq2seq': 'ST-Seq2Seq', 'partial.token-copynet': 'ST-CopyNet', 'char-seq2seq': 'C-Seq2Seq', 'char-copynet': 'C-CopyNet' } for example_id in sample_ids: print('Example {}'.format(example_id)) data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts] output_strs = {} for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] for i in xrange(min(3, len(predictions))): pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) output_str = '& \\<{}> & {}'.format( pred_cmd.replace('__SP__', '').replace('_', '\\_').replace( '$', '\\$').replace('%', '\\%').replace('{{}}', '\\ttcbs'), model_name_pt[model_name]) command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] output_str += ', {},{} \\\\'.format(structure_eval, command_eval) output_strs[model_name] = output_str for model_name in [ 'char-seq2seq', 'char-copynet', 'token-seq2seq', 'token-copynet', 'partial.token-seq2seq', 'partial.token-copynet', 'tellina' ]: if model_name == 'char-seq2seq': print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format( sc_txt) + output_strs[model_name]) else: print(output_strs[model_name]) output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace( '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace( '%', '\\%').replace('{{}}', '\\ttcbs')) print(output_str) print()
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100): """ Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev examples, predictions of different models are listed side-by-side. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) sample_ids = example_ids[num_examples:num_examples + 100] # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv') with open(output_path, 'w') as o_f: o_f.write('example_id, description, ground_truth, model, prediction, ' 'correct template, correct command\n') for example_id in sample_ids: data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [ data_tools.bash_parser(gt) for gt in command_gts ] for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] for i in xrange(min(3, len(predictions))): if model_id == 0 and i == 0: output_str = '{},"{}",'.format( example_id, sc_txt.replace('"', '""')) else: output_str = ',,' pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) if (model_id * min(3, len(predictions)) + i) < len(command_gts): output_str += '"{}",'.format( command_gts[model_id * min(3, len(predictions)) + i].strip().replace('"', '""')) else: output_str += ',' output_str += '{},"{}",'.format( model_name, pred_cmd.replace('"', '""')) command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] output_str += '{},{}'.format(structure_eval, command_eval) o_f.write('{}\n'.format(output_str)) print('Manual evaluation results saved to {}'.format(output_path))
def gen_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True): """ Generate structure and full command accuracy results on a fixed set of dev/test set samples, with the results of multiple models tabulated in the same table. :param interactive: - If set, prompt the user to enter judgement if a prediction does not match any of the groundtruths and the correctness of the prediction has not been pre-determined; Otherwise, count all predictions that does not match any of the groundtruths as wrong. """ def add_judgement(data_dir, nl, command, correct_template='', correct_command=''): """ Append a new judgement """ data_dir = os.path.join(data_dir, 'manual_judgements') manual_judgement_path = os.path.join(data_dir, 'manual.evaluations.additional') if not os.path.exists(manual_judgement_path): with open(manual_judgement_path, 'w') as o_f: o_f.write( 'description,prediction,template,correct template,correct command\n' ) with open(manual_judgement_path, 'a') as o_f: temp = data_tools.cmd2template(command, loose_constraints=True) if not correct_template: correct_template = 'n' if not correct_command: correct_command = 'n' o_f.write('"{}","{}","{}","{}","{}"\n'.format( nl.replace('"', '""'), command.replace('"', '""'), temp.replace('"', '""'), correct_template.replace('"', '""'), correct_command.replace('"', '""'))) print('new judgement added to {}'.format(manual_judgement_path)) # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True) if FLAGS.test: model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3, tellina=True, partial_token_copynet=True, token_seq2seq=False, token_copynet=False, char_seq2seq=False, char_copynet=False, partial_token_seq2seq=False) else: model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) if num_examples > 0: sample_ids = example_ids[:num_examples] else: sample_ids = example_ids # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser # Interactive manual evaluation interface num_s_correct = collections.defaultdict(int) num_f_correct = collections.defaultdict(int) num_s_top_3_correct = collections.defaultdict(int) num_f_top_3_correct = collections.defaultdict(int) for exam_id, example_id in enumerate(sample_ids): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts] for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] top_3_s_correct_marked = False top_3_f_correct_marked = False for i in xrange(min(3, len(predictions))): pred_cmd = predictions[i] pred_ast = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_ast, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and command_example_key in command_eval_cache: command_eval = command_eval_cache[command_example_key] if structure_eval_cache and structure_example_key in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_key] # Prompt for new judgements if command_eval != 'y': if structure_eval == 'y': if not command_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) command_eval = input( 'CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_key, pred_cmd, structure_eval, command_eval) print() else: if not structure_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) structure_eval = input( 'CORRECT STRUCTURE? [y/reason] ') if structure_eval == 'y': command_eval = input( 'CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_key, pred_cmd, structure_eval, command_eval) print() structure_eval_cache[ structure_example_key] = structure_eval command_eval_cache[command_example_key] = command_eval if structure_eval == 'y': if i == 0: num_s_correct[model_name] += 1 if not top_3_s_correct_marked: num_s_top_3_correct[model_name] += 1 top_3_s_correct_marked = True if command_eval == 'y': if i == 0: num_f_correct[model_name] += 1 if not top_3_f_correct_marked: num_f_top_3_correct[model_name] += 1 top_3_f_correct_marked = True metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3'] model_metrics = {} for model_name in model_names: metrics = [ num_f_correct[model_name] / len(sample_ids), num_f_top_3_correct[model_name] / len(sample_ids), num_s_correct[model_name] / len(sample_ids), num_s_top_3_correct[model_name] / len(sample_ids) ] model_metrics[model_name] = metrics print_table(model_names, metrics_names, model_metrics)