def gen_error_analysis_csv_by_utility(model_dir, decode_sig, dataset, FLAGS, top_k=10): """ Generate error analysis evaluation sheet grouped by utility. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions prediction_list = load_predictions(model_dir, decode_sig, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError( "ground truth and predictions length must be equal: {} vs. {}". format(len(grouped_dataset), len(prediction_list))) # Load cached evaluation results cached_evaluation_results = load_cached_evaluations(model_dir) # Convert the predictions into csv format grammar_errors, argument_errors = print_error_analysis_csv( grouped_dataset, prediction_list, FLAGS, cached_evaluation_results, group_by_utility=True, error_predictions_only=False) error_by_utility_path = \ os.path.join(model_dir, 'error.analysis.by.utility.csv') print("Saving grammar errors to {}".format(error_by_utility_path)) with open(error_by_utility_path, 'w') as error_by_utility_file: # print csv file header error_by_utility_file.write( 'utility, example_id, description, groundtruth, prediction, ' 'correct template, correct command\n') for line in bash.utility_stats.split('\n'): utility = line.split(',')[1] error_examples = grammar_errors[utility] if len(error_examples) <= 5: for example in error_examples: for l in example: error_by_utility_file.write('{},{}\n'.format( utility, l)) else: random.shuffle(error_examples) for example in error_examples[:5]: for l in example: error_by_utility_file.write('{},{}\n'.format( utility, l))
def gen_error_analysis_csv(model_dir, decode_sig, dataset, FLAGS, top_k=3): """ Generate error analysis evaluation spreadsheet. - grammar error analysis - argument error analysis """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions prediction_list = load_predictions(model_dir, decode_sig, top_k) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth and predictions length must be equal: " "{} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Convert the predictions to csv format grammar_errors, argument_errors = print_error_analysis_csv( grouped_dataset, prediction_list, FLAGS) grammar_error_path = os.path.join(model_dir, 'grammar.error.analysis.csv') random.shuffle(grammar_errors) with open(grammar_error_path, 'w') as grammar_error_file: print("Saving grammar errors to {}".format(grammar_error_path)) # print csv file header grammar_error_file.write( 'example_id, description, ground_truth, prediction, ' + 'correct template, correct command\n') for example in grammar_errors[:100]: for line in example: grammar_error_file.write('{}\n'.format(line)) arg_error_path = os.path.join(model_dir, 'argument.error.analysis.csv') random.shuffle(argument_errors) with open(arg_error_path, 'w') as arg_error_file: print("Saving argument errors to {}".format(arg_error_path)) # print csv file header arg_error_file.write( 'example_id, description, ground_truth, prediction, ' + 'correct template, correct command\n') for example in argument_errors[:100]: for line in example: arg_error_file.write('{}\n'.format(line))
def gen_manual_evaluation_csv_single_model(dataset, FLAGS): """ Generate .csv spreadsheet for manual evaluation on dev/test set examples for a specific model. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS) model_dir = os.path.join(FLAGS.model_root_dir, model_subdir) prediction_list = load_predictions(model_dir, decode_sig, top_k=3) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth list and prediction list length must " "be equal: {} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Load additional ground truths template_translations, command_translations = load_cached_correct_translations( FLAGS.data_dir) # Load cached evaluation results structure_eval_cache, command_eval_cache = load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser output_path = os.path.join(model_dir, 'manual.evaluations.single.model') with open(output_path, 'w') as o_f: # write spreadsheet header o_f.write('id,description,command,correct template,correct command\n') for example_id in range(len(grouped_dataset)): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gts = set(command_gts + command_translations[sc_key]) command_gt_asts = [ data_tools.bash_parser(cmd) for cmd in command_gts ] template_gts = [ data_tools.cmd2template(cmd, loose_constraints=True) for cmd in command_gts ] template_gts = set(template_gts + template_translations[sc_key]) template_gt_asts = [ data_tools.bash_parser(temp) for temp in template_gts ] predictions = prediction_list[example_id] for i in xrange(3): if i >= len(predictions): o_f.write(',,,n,n\n') continue pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(template_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] if i == 0: o_f.write('{},"{}","{}",{},{}\n'.format( example_id, sc_txt.replace('"', '""'), pred_cmd.replace('"', '""'), structure_eval, command_eval)) else: o_f.write(',,"{}",{},{}\n'.format( pred_cmd.replace('"', '""'), structure_eval, command_eval)) print('manual evaluation spreadsheet saved to {}'.format(output_path))