def one_match(asts, ast2, rewrite=False, ignore_arg_value=False): if rewrite: raise NotImplementedError else: ast_rewrites = asts cmd2 = data_tools.ast2template(ast2, loose_constraints=True, arg_type_only=ignore_arg_value) for ast1 in ast_rewrites: cmd1 = data_tools.ast2template(ast1, loose_constraints=True, arg_type_only=ignore_arg_value) if cmd1 == cmd2: return True return False
def one_match(asts, ast2, rewrite=True, ignore_arg_value=False): if rewrite: with rewrites.DBConnection() as db: ast_rewrites = get_rewrites(asts, db) else: ast_rewrites = asts cmd2 = ignore_differences(data_tools.ast2template( ast2, loose_constraints=True, arg_type_only=ignore_arg_value)) for ast1 in ast_rewrites: cmd1 = data_tools.ast2template( ast1, loose_constraints=True, arg_type_only=ignore_arg_value) cmd1 = ignore_differences(cmd1) if cmd1 == cmd2: return True return False
def combine_annotations_multi_files(): """ Combine multiple annotations files and discard the annotations that has a conflict. """ input_dir = sys.argv[1] template_evals = {} command_evals = {} discarded_keys = set({}) for in_csv in os.listdir(input_dir): in_csv_path = os.path.join(input_dir, in_csv) with open(in_csv_path) as f: reader = csv.DictReader(f) current_description = '' for row in reader: template_eval = normalize_judgement(row['correct template']) command_eval = normalize_judgement(row['correct command']) description = get_example_nl_key(row['description']) if description.strip(): current_description = description else: description = current_description prediction = row['prediction'] example_key = '{}<NL_PREDICTION>{}'.format( description, prediction) if example_key in template_evals and template_evals[ example_key] != template_eval: discarded_keys.add(example_key) continue if example_key in command_evals and command_evals[ example_key] != command_eval: discarded_keys.add(example_key) continue template_evals[example_key] = template_eval command_evals[example_key] = command_eval print('{} read ({} manually annotated examples, {} discarded)'. format(in_csv_path, len(template_evals), len(discarded_keys))) # Write to new file assert (len(template_evals) == len(command_evals)) with open('manual_annotations.additional', 'w') as o_f: o_f.write( 'description,prediction,template,correct template,correct comand\n' ) for key in sorted(template_evals.keys()): if key in discarded_keys: continue description, prediction = key.split('<NL_PREDICTION>') template_eval = template_evals[example_key] command_eval = command_evals[example_key] pred_tree = data_tools.bash_parser(prediction) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) o_f.write('"{}","{}","{}",{},{}\n'.format( description.replace('"', '""'), prediction.replace('"', '""'), pred_temp.replace('"', '""'), template_eval, command_eval))
def populate_command_template(): for cmd in Command.objects.all(): if len(cmd.str) > 600: cmd.delete() else: ast = data_tools.bash_parser(cmd.str) template = data_tools.ast2template(ast, loose_constraints=True) cmd.template = template cmd.save()
def run(): sqlite_filename = sys.argv[1] url_prefix = 'https://stackoverflow.com/questions/' urls = {} commands = {} with sqlite3.connect(sqlite_filename, detect_types=sqlite3.PARSE_DECLTYPES) as db: count = 0 for post_id, answer_body in db.cursor().execute(""" SELECT questions.Id, answers.Body FROM questions, answers WHERE questions.Id = answers.ParentId ORDER BY questions.Score DESC"""): print(post_id) for code_block in extract_code(answer_body): for cmd in extract_oneliner_from_code(code_block): print('command string: {}'.format(cmd)) ast = data_tools.bash_parser(cmd) if not ast: continue utilities = data_tools.get_utilities(ast) for utility in utilities: if utility in bash.top_100_utilities: print('extracted: {}, {}'.format(utility, cmd)) temp = data_tools.ast2template(ast, loose_constraints=True) if not utility in commands: commands[utility] = {} commands[utility][temp] = cmd urls[utility] = {'{}{}'.format(url_prefix, post_id)} else: if len(commands[utility]) >= NUM_COMMAND_THRESHOLD: continue if not temp in commands[utility]: commands[utility][temp] = cmd urls[utility].add('{}{}'.format(url_prefix, post_id)) count += 1 if count % 1000 == 0: completed = False for utility in bash.top_100_utilities: if not utility in commands or len(commands[utility]) < NUM_COMMAND_THRESHOLD: completed = False else: print('{} collection done.'.format(utility)) if completed: break with open('stackoverflow.urls', 'wb') as o_f: pickle.dump(urls, o_f) with open('stackoverflow.commands', 'wb') as o_f: pickle.dump(commands, o_f) for utility in commands: print('{} ({})'.format(utility, len(commands[utility]))) for cmd in commands[utility]: print(cmd)
def get_rewrites(self, ast): rewrites = set([ast]) s1 = data_tools.ast2template(ast, loose_constraints=True) c = self.cursor for s1, s2 in c.execute("SELECT s1, s2 FROM Rewrites WHERE s1 = ?", (s1, )): rw = rewrite(ast, s2) if not rw is None: rewrites.add(rw) return rewrites
def Cust_Cmd_Tokenizer(String, parse="Template"): """As per our need Custom CMD Tokenizer""" if parse == "Norm": Command = cm_to_partial_tokens(String, tokenizer=data_tools.bash_tokenizer) elif parse == "Template": AST = data_tools.bash_parser(String) Template = data_tools.ast2template(AST, ignore_flag_order=False) Template_Tokens_List = Template.split(" ") return Template_Tokens_List
def get_command(command_str): command_str = command_str.strip() if Command.objects.filter(str=command_str).exists(): cmd = Command.objects.get(str=command_str) else: cmd = Command.objects.create(str=command_str) ast = data_tools.bash_parser(command_str) for utility in data_tools.get_utilities(ast): cmd.tags.add(get_tag(utility)) template = data_tools.ast2template(ast, loose_constraints=True) cmd.template = template cmd.save() return cmd
def get_manual_evaluation_metrics(grouped_dataset, prediction_list, FLAGS, num_examples=-1, interactive=True, verbose=True): if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth and predictions length must be equal: " "{} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Get dev set samples (fixed) random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) if num_examples > 0: sample_ids = example_ids[:num_examples] else: sample_ids = example_ids # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements'), verbose=True) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser # Interactive manual evaluation num_t_top_1_correct = 0.0 num_f_top_1_correct = 0.0 num_t_top_3_correct = 0.0 num_f_top_3_correct = 0.0 for exam_id, example_id in enumerate(sample_ids): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts] predictions = prediction_list[example_id] top_3_s_correct_marked = False top_3_f_correct_marked = False for i in xrange(min(3, len(predictions))): pred_cmd = predictions[i] pred_ast = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_ast, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and command_example_key in command_eval_cache: command_eval = command_eval_cache[command_example_key] if structure_eval_cache and structure_example_key in structure_eval_cache: structure_eval = structure_eval_cache[structure_example_key] # Prompt for new judgements if command_eval != 'y': if structure_eval == 'y': if not command_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) command_eval = input('CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_txt, pred_cmd, structure_eval, command_eval) print() else: if not structure_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) structure_eval = input( 'CORRECT STRUCTURE? [y/reason] ') if structure_eval == 'y': command_eval = input( 'CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_txt, pred_cmd, structure_eval, command_eval) print() structure_eval_cache[structure_example_key] = structure_eval command_eval_cache[command_example_key] = command_eval if structure_eval == 'y': if i == 0: num_t_top_1_correct += 1 if not top_3_s_correct_marked: num_t_top_3_correct += 1 top_3_s_correct_marked = True if command_eval == 'y': if i == 0: num_f_top_1_correct += 1 if not top_3_f_correct_marked: num_f_top_3_correct += 1 top_3_f_correct_marked = True metrics = {} acc_f_1 = num_f_top_1_correct / len(sample_ids) acc_f_3 = num_f_top_3_correct / len(sample_ids) acc_t_1 = num_t_top_1_correct / len(sample_ids) acc_t_3 = num_t_top_3_correct / len(sample_ids) metrics['acc_f'] = [acc_f_1, acc_f_3] metrics['acc_t'] = [acc_t_1, acc_t_3] if verbose: print('{} examples evaluated'.format(len(sample_ids))) print('Top 1 Command Acc = {:.3f}'.format(acc_f_1)) print('Top 3 Command Acc = {:.3f}'.format(acc_f_3)) print('Top 1 Template Acc = {:.3f}'.format(acc_t_1)) print('Top 3 Template Acc = {:.3f}'.format(acc_t_3)) return metrics
def gen_manual_evaluation_csv_single_model(dataset, FLAGS): """ Generate .csv spreadsheet for manual evaluation on dev/test set examples for a specific model. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) # Load model predictions model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS) model_dir = os.path.join(FLAGS.model_root_dir, model_subdir) prediction_list = load_predictions(model_dir, decode_sig, top_k=3) if len(grouped_dataset) != len(prediction_list): raise ValueError("ground truth list and prediction list length must " "be equal: {} vs. {}".format(len(grouped_dataset), len(prediction_list))) # Load additional ground truths template_translations, command_translations = load_cached_correct_translations( FLAGS.data_dir) # Load cached evaluation results structure_eval_cache, command_eval_cache = load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser output_path = os.path.join(model_dir, 'manual.evaluations.single.model') with open(output_path, 'w') as o_f: # write spreadsheet header o_f.write('id,description,command,correct template,correct command\n') for example_id in range(len(grouped_dataset)): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gts = set(command_gts + command_translations[sc_key]) command_gt_asts = [ data_tools.bash_parser(cmd) for cmd in command_gts ] template_gts = [ data_tools.cmd2template(cmd, loose_constraints=True) for cmd in command_gts ] template_gts = set(template_gts + template_translations[sc_key]) template_gt_asts = [ data_tools.bash_parser(temp) for temp in template_gts ] predictions = prediction_list[example_id] for i in xrange(3): if i >= len(predictions): o_f.write(',,,n,n\n') continue pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(template_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] if i == 0: o_f.write('{},"{}","{}",{},{}\n'.format( example_id, sc_txt.replace('"', '""'), pred_cmd.replace('"', '""'), structure_eval, command_eval)) else: o_f.write(',,"{}",{},{}\n'.format( pred_cmd.replace('"', '""'), structure_eval, command_eval)) print('manual evaluation spreadsheet saved to {}'.format(output_path))
def template_match(ast1, ast2): temp1 = ignore_differences( data_tools.ast2template(ast1, loose_constraints=True)) temp2 = ignore_differences( data_tools.ast2template(ast2, loose_constraints=True)) return temp1 == temp2
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100): """ Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev examples, predictions of different models are listed side-by-side. """ # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) sample_ids = example_ids[num_examples:num_examples + 100] # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv') with open(output_path, 'w') as o_f: o_f.write('example_id, description, ground_truth, model, prediction, ' 'correct template, correct command\n') for example_id in sample_ids: data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [ data_tools.bash_parser(gt) for gt in command_gts ] for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] for i in xrange(min(3, len(predictions))): if model_id == 0 and i == 0: output_str = '{},"{}",'.format( example_id, sc_txt.replace('"', '""')) else: output_str = ',,' pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) if (model_id * min(3, len(predictions)) + i) < len(command_gts): output_str += '"{}",'.format( command_gts[model_id * min(3, len(predictions)) + i].strip().replace('"', '""')) else: output_str += ',' output_str += '{},"{}",'.format( model_name, pred_cmd.replace('"', '""')) command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] output_str += '{},{}'.format(structure_eval, command_eval) o_f.write('{}\n'.format(output_str)) print('Manual evaluation results saved to {}'.format(output_path))
def tabulate_example_predictions(dataset, FLAGS, num_examples=100): # Group dataset tokenizer_selector = "cm" if FLAGS.explain else "nl" grouped_dataset = data_utils.group_parallel_data( dataset, use_bucket=True, tokenizer_selector=tokenizer_selector) model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=1) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) sample_ids = example_ids[:num_examples] # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser model_name_pt = { 'token-seq2seq': 'T-Seq2Seq', 'tellina': 'Tellina', 'token-copynet': 'T-CopyNet', 'partial.token-seq2seq': 'ST-Seq2Seq', 'partial.token-copynet': 'ST-CopyNet', 'char-seq2seq': 'C-Seq2Seq', 'char-copynet': 'C-CopyNet' } for example_id in sample_ids: print('Example {}'.format(example_id)) data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts] output_strs = {} for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] for i in xrange(min(3, len(predictions))): pred_cmd = predictions[i] pred_tree = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_tree, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_tree, ignore_arg_value=False) output_str = '& \\<{}> & {}'.format( pred_cmd.replace('__SP__', '').replace('_', '\\_').replace( '$', '\\$').replace('%', '\\%').replace('{{}}', '\\ttcbs'), model_name_pt[model_name]) command_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_sig = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and \ command_example_sig in command_eval_cache: command_eval = command_eval_cache[command_example_sig] if structure_eval_cache and \ structure_example_sig in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_sig] output_str += ', {},{} \\\\'.format(structure_eval, command_eval) output_strs[model_name] = output_str for model_name in [ 'char-seq2seq', 'char-copynet', 'token-seq2seq', 'token-copynet', 'partial.token-seq2seq', 'partial.token-copynet', 'tellina' ]: if model_name == 'char-seq2seq': print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format( sc_txt) + output_strs[model_name]) else: print(output_strs[model_name]) output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace( '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace( '%', '\\%').replace('{{}}', '\\ttcbs')) print(output_str) print()
def gen_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True): """ Generate structure and full command accuracy results on a fixed set of dev/test set samples, with the results of multiple models tabulated in the same table. :param interactive: - If set, prompt the user to enter judgement if a prediction does not match any of the groundtruths and the correctness of the prediction has not been pre-determined; Otherwise, count all predictions that does not match any of the groundtruths as wrong. """ def add_judgement(data_dir, nl, command, correct_template='', correct_command=''): """ Append a new judgement """ data_dir = os.path.join(data_dir, 'manual_judgements') manual_judgement_path = os.path.join(data_dir, 'manual.evaluations.additional') if not os.path.exists(manual_judgement_path): with open(manual_judgement_path, 'w') as o_f: o_f.write( 'description,prediction,template,correct template,correct command\n' ) with open(manual_judgement_path, 'a') as o_f: temp = data_tools.cmd2template(command, loose_constraints=True) if not correct_template: correct_template = 'n' if not correct_command: correct_command = 'n' o_f.write('"{}","{}","{}","{}","{}"\n'.format( nl.replace('"', '""'), command.replace('"', '""'), temp.replace('"', '""'), correct_template.replace('"', '""'), correct_command.replace('"', '""'))) print('new judgement added to {}'.format(manual_judgement_path)) # Group dataset grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True) if FLAGS.test: model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3, tellina=True, partial_token_copynet=True, token_seq2seq=False, token_copynet=False, char_seq2seq=False, char_copynet=False, partial_token_seq2seq=False) else: model_names, model_predictions = load_all_model_predictions( grouped_dataset, FLAGS, top_k=3) # Get FIXED dev set samples random.seed(100) example_ids = list(range(len(grouped_dataset))) random.shuffle(example_ids) if num_examples > 0: sample_ids = example_ids[:num_examples] else: sample_ids = example_ids # Load cached evaluation results structure_eval_cache, command_eval_cache = \ load_cached_evaluations( os.path.join(FLAGS.data_dir, 'manual_judgements')) eval_bash = FLAGS.dataset.startswith("bash") cmd_parser = data_tools.bash_parser if eval_bash \ else data_tools.paren_parser # Interactive manual evaluation interface num_s_correct = collections.defaultdict(int) num_f_correct = collections.defaultdict(int) num_s_top_3_correct = collections.defaultdict(int) num_f_top_3_correct = collections.defaultdict(int) for exam_id, example_id in enumerate(sample_ids): data_group = grouped_dataset[example_id][1] sc_txt = data_group[0].sc_txt.strip() sc_key = get_example_nl_key(sc_txt) command_gts = [dp.tg_txt for dp in data_group] command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts] for model_id, model_name in enumerate(model_names): predictions = model_predictions[model_id][example_id] top_3_s_correct_marked = False top_3_f_correct_marked = False for i in xrange(min(3, len(predictions))): pred_cmd = predictions[i] pred_ast = cmd_parser(pred_cmd) pred_temp = data_tools.ast2template(pred_ast, loose_constraints=True) temp_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=True) str_match = tree_dist.one_match(command_gt_asts, pred_ast, ignore_arg_value=False) # Match ground truths & exisitng judgements command_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_cmd) structure_example_key = '{}<NL_PREDICTION>{}'.format( sc_key, pred_temp) command_eval, structure_eval = '', '' if str_match: command_eval = 'y' structure_eval = 'y' elif temp_match: structure_eval = 'y' if command_eval_cache and command_example_key in command_eval_cache: command_eval = command_eval_cache[command_example_key] if structure_eval_cache and structure_example_key in structure_eval_cache: structure_eval = structure_eval_cache[ structure_example_key] # Prompt for new judgements if command_eval != 'y': if structure_eval == 'y': if not command_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) command_eval = input( 'CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_key, pred_cmd, structure_eval, command_eval) print() else: if not structure_eval and interactive: print('#{}. {}'.format(exam_id, sc_txt)) for j, gt in enumerate(command_gts): print('- GT{}: {}'.format(j, gt)) print('> {}'.format(pred_cmd)) structure_eval = input( 'CORRECT STRUCTURE? [y/reason] ') if structure_eval == 'y': command_eval = input( 'CORRECT COMMAND? [y/reason] ') add_judgement(FLAGS.data_dir, sc_key, pred_cmd, structure_eval, command_eval) print() structure_eval_cache[ structure_example_key] = structure_eval command_eval_cache[command_example_key] = command_eval if structure_eval == 'y': if i == 0: num_s_correct[model_name] += 1 if not top_3_s_correct_marked: num_s_top_3_correct[model_name] += 1 top_3_s_correct_marked = True if command_eval == 'y': if i == 0: num_f_correct[model_name] += 1 if not top_3_f_correct_marked: num_f_top_3_correct[model_name] += 1 top_3_f_correct_marked = True metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3'] model_metrics = {} for model_name in model_names: metrics = [ num_f_correct[model_name] / len(sample_ids), num_f_top_3_correct[model_name] / len(sample_ids), num_s_correct[model_name] / len(sample_ids), num_s_top_3_correct[model_name] / len(sample_ids) ] model_metrics[model_name] = metrics print_table(model_names, metrics_names, model_metrics)
def string_match(ast1, ast2): str1 = ignore_differences( data_tools.ast2template(ast1, loose_constraints=True, arg_type_only=False)) str2 = ignore_differences( data_tools.ast2template(ast2, loose_constraints=True, arg_type_only=False)) return str1 == str2
def template_match(ast1, ast2): temp1 = data_tools.ast2template(ast1, loose_constraints=True) temp2 = data_tools.ast2template(ast2, loose_constraints=True) return temp1 == temp2
def manual_eval(model, dataset, FLAGS, output_dir, num_eval=None): num_top1_correct_temp = 0.0 num_top3_correct_temp = 0.0 num_top5_correct_temp = 0.0 num_top10_correct_temp = 0.0 num_top1_correct = 0.0 num_top3_correct = 0.0 num_top5_correct = 0.0 num_top10_correct = 0.0 num_evaled = 0 eval_bash = FLAGS.dataset.startswith("bash") use_bucket = False if model == "knn" else True tokenizer_selector = 'cm' if FLAGS.explain else 'nl' grouped_dataset = data_utils.group_data( dataset, use_bucket=use_bucket, use_temp=eval_bash, tokenizer_selector=tokenizer_selector) if num_eval is None: num_eval = len(grouped_dataset) cmd_parser = data_tools.bash_parser if FLAGS.dataset.startswith("bash") \ else data_tools.paren_parser o_f = open(os.path.join(output_dir, "manual.eval.results"), 'w') rejudge = False with DBConnection() as db: db.create_schema() while num_evaled < len(grouped_dataset): nl_strs, tg_strs, nls, search_historys = grouped_dataset[ num_evaled] nl_str = nl_strs[0].decode('utf-8') if num_evaled == num_eval: break gt_trees = [cmd_parser(cmd) for cmd in tg_strs] predictions = db.get_top_k_predictions(model, nl_str, k=10) top1_correct_temp, top3_correct_temp, top5_correct_temp, \ top10_correct_temp = False, False, False, False top1_correct, top3_correct, top5_correct, top10_correct = \ False, False, False, False # evaluation ignoring ordering of flags print("Example %d (%d)" % (num_evaled + 1, len(tg_strs))) o_f.write("Example %d (%d)" % (num_evaled + 1, len(tg_strs)) + "\n") print("English: " + nl_str.strip()) o_f.write("English: " + nl_str.encode('utf-8')) for j in xrange(len(tg_strs)): print("GT Command %d: " % (j + 1) + tg_strs[j].strip()) o_f.write("GT Command %d: " % (j + 1) + tg_strs[j].strip() + "\n") pred_id = 0 while pred_id < min(3, len(predictions)): pred_cmd, score = predictions[pred_id] tree = cmd_parser(pred_cmd) print("Prediction {}: {} ({})".format(pred_id + 1, pred_cmd, score)) o_f.write("Prediction {}: {} ({})\n".format( pred_id + 1, pred_cmd, score)) print() pred_temp = data_tools.ast2template(tree, loose_constraints=True) str_judge = db.get_str_judgement((nl_str, pred_cmd)) temp_judge = db.get_temp_judgement((nl_str, pred_temp)) if temp_judge is not None and not rejudge: judgement_str = "y" if temp_judge == 1 \ else "n ({})".format(error_types[temp_judge]) print("Is the command template correct [y/n]? %s" % judgement_str) else: temp_judge = tree_dist.one_match(gt_trees, tree, rewrite=False, ignore_arg_value=True) if not temp_judge: inp = raw_input( "Is the command template correct [y/n]? ") if inp == "REVERSE": rejudge = True else: if inp == "y": temp_judge = True db.add_temp_judgement((nl_str, pred_temp, 1)) else: temp_judge = False error_type = raw_input( "Error type: \n" "(2) extra utility \n" "(3) missing utility \n" "(4) confused utility \n" "(5) extra flag \n" "(6) missing flag \n" "(7) confused flag \n" "(8) logic error\n" "(9) count error\n") db.add_temp_judgement( (nl_str, pred_temp, int(error_type))) rejudge = False else: print("Is the command template correct [y/n]? y") if temp_judge == 1: if pred_id < 1: top1_correct_temp = True top3_correct_temp = True top5_correct_temp = True top10_correct_temp = True elif pred_id < 3: top3_correct_temp = True top5_correct_temp = True top10_correct_temp = True elif pred_id < 5: top5_correct_temp = True top10_correct_temp = True elif pred_id < 10: top10_correct_temp = True o_f.write("C") if str_judge is not None and not rejudge: judgement_str = "y" if str_judge == 1 \ else "n ({})".format(error_types[str_judge]) print("Is the complete command correct [y/n]? %s" % judgement_str) else: str_judge = tree_dist.one_match(gt_trees, tree, rewrite=False, ignore_arg_value=False) if not str_judge: inp = raw_input( "Is the complete command correct [y/n]? ") if inp == "REVERSE": rejudge = True continue elif inp == "y": str_judge = True o_f.write("C") db.add_str_judgement((nl_str, pred_cmd, 1)) else: str_judge = False o_f.write("W") db.add_str_judgement((nl_str, pred_cmd, 0)) else: print("Is the complete command correct [y/n]? y") if str_judge == 1: if pred_id < 1: top1_correct = True top3_correct = True top5_correct = True top10_correct = True elif pred_id < 3: top3_correct = True top5_correct = True top10_correct = True elif pred_id < 5: top5_correct = True top10_correct = True elif pred_id < 10: top10_correct = True o_f.write("C") else: o_f.write("W") else: o_f.write("WW") o_f.write("\n") o_f.write("\n") pred_id += 1 if rejudge: num_evaled -= 1 else: num_evaled += 1 if top1_correct_temp: num_top1_correct_temp += 1 if top3_correct_temp: num_top3_correct_temp += 1 if top5_correct_temp: num_top5_correct_temp += 1 if top10_correct_temp: num_top10_correct_temp += 1 if top1_correct: num_top1_correct += 1 if top3_correct: num_top3_correct += 1 if top5_correct: num_top5_correct += 1 if top10_correct: num_top10_correct += 1 rejudge = False print() print("%d examples evaluated" % num_eval) print("Top 1 Template Match Score = %.2f" % (num_top1_correct_temp / num_eval)) print("Top 1 String Match Score = %.2f" % (num_top1_correct / num_eval)) if len(predictions) > 3: print("Top 5 Template Match Score = %.2f" % (num_top5_correct_temp / num_eval)) print("Top 5 String Match Score = %.2f" % (num_top5_correct / num_eval)) print("Top 10 Template Match Score = %.2f" % (num_top10_correct_temp / num_eval)) print("Top 10 String Match Score = %.2f" % (num_top10_correct / num_eval)) print() o_f.write("%d examples evaluated" % num_eval + "\n") o_f.write("Top 1 Template MatchScore = %.2f" % (num_top1_correct_temp / num_eval) + "\n") o_f.write("Top 1 String Match Score = %.2f" % (num_top1_correct / num_eval) + "\n") if len(predictions) > 1: o_f.write("Top 5 Template Match Score = %.2f" % (num_top5_correct_temp / num_eval) + "\n") o_f.write("Top 5 String Match Score = %.2f" % (num_top5_correct / num_eval) + "\n") o_f.write("Top 10 Template Match Score = %.2f" % (num_top10_correct_temp / num_eval) + "\n") o_f.write("Top 10 String Match Score = %.2f" % (num_top10_correct / num_eval) + "\n") o_f.write("\n")
def string_match(ast1, ast2): str1 = data_tools.ast2template(ast1, loose_constraints=True, arg_type_only=False) str2 = data_tools.ast2template(ast2, loose_constraints=True, arg_type_only=False) return str1 == str2