Exemple #1
0
def one_match(asts, ast2, rewrite=False, ignore_arg_value=False):
    if rewrite:
        raise NotImplementedError
    else:
        ast_rewrites = asts
    cmd2 = data_tools.ast2template(ast2, loose_constraints=True,
                                   arg_type_only=ignore_arg_value)
    for ast1 in ast_rewrites:
        cmd1 = data_tools.ast2template(ast1, loose_constraints=True,
                                       arg_type_only=ignore_arg_value)
        if cmd1 == cmd2:
            return True
    return False
Exemple #2
0
def one_match(asts, ast2, rewrite=True, ignore_arg_value=False):
    if rewrite:
        with rewrites.DBConnection() as db:
            ast_rewrites = get_rewrites(asts, db)
    else:
        ast_rewrites = asts

    cmd2 = ignore_differences(data_tools.ast2template(
        ast2, loose_constraints=True, arg_type_only=ignore_arg_value))
    for ast1 in ast_rewrites:
        cmd1 = data_tools.ast2template(
            ast1, loose_constraints=True, arg_type_only=ignore_arg_value)
        cmd1 = ignore_differences(cmd1)
        if cmd1 == cmd2:
            return True
    return False
Exemple #3
0
def combine_annotations_multi_files():
    """
    Combine multiple annotations files and discard the annotations that has a conflict.
    """

    input_dir = sys.argv[1]

    template_evals = {}
    command_evals = {}
    discarded_keys = set({})

    for in_csv in os.listdir(input_dir):
        in_csv_path = os.path.join(input_dir, in_csv)
        with open(in_csv_path) as f:
            reader = csv.DictReader(f)
            current_description = ''
            for row in reader:
                template_eval = normalize_judgement(row['correct template'])
                command_eval = normalize_judgement(row['correct command'])
                description = get_example_nl_key(row['description'])
                if description.strip():
                    current_description = description
                else:
                    description = current_description
                prediction = row['prediction']
                example_key = '{}<NL_PREDICTION>{}'.format(
                    description, prediction)
                if example_key in template_evals and template_evals[
                        example_key] != template_eval:
                    discarded_keys.add(example_key)
                    continue
                if example_key in command_evals and command_evals[
                        example_key] != command_eval:
                    discarded_keys.add(example_key)
                    continue
                template_evals[example_key] = template_eval
                command_evals[example_key] = command_eval
            print('{} read ({} manually annotated examples, {} discarded)'.
                  format(in_csv_path, len(template_evals),
                         len(discarded_keys)))

    # Write to new file
    assert (len(template_evals) == len(command_evals))
    with open('manual_annotations.additional', 'w') as o_f:
        o_f.write(
            'description,prediction,template,correct template,correct comand\n'
        )
        for key in sorted(template_evals.keys()):
            if key in discarded_keys:
                continue
            description, prediction = key.split('<NL_PREDICTION>')
            template_eval = template_evals[example_key]
            command_eval = command_evals[example_key]
            pred_tree = data_tools.bash_parser(prediction)
            pred_temp = data_tools.ast2template(pred_tree,
                                                loose_constraints=True)
            o_f.write('"{}","{}","{}",{},{}\n'.format(
                description.replace('"', '""'), prediction.replace('"', '""'),
                pred_temp.replace('"', '""'), template_eval, command_eval))
Exemple #4
0
def populate_command_template():
    for cmd in Command.objects.all():
        if len(cmd.str) > 600:
            cmd.delete()
        else:
            ast = data_tools.bash_parser(cmd.str)
            template = data_tools.ast2template(ast, loose_constraints=True)
            cmd.template = template
            cmd.save()
Exemple #5
0
def run():
    sqlite_filename = sys.argv[1]
    url_prefix = 'https://stackoverflow.com/questions/'

    urls = {}
    commands = {}

    with sqlite3.connect(sqlite_filename, detect_types=sqlite3.PARSE_DECLTYPES) as db:
        count = 0
        for post_id, answer_body in db.cursor().execute("""
                SELECT questions.Id, answers.Body FROM questions, answers
                WHERE questions.Id = answers.ParentId
                ORDER BY questions.Score DESC"""):
            print(post_id)
            for code_block in extract_code(answer_body):
                for cmd in extract_oneliner_from_code(code_block):
                    print('command string: {}'.format(cmd))
                    ast = data_tools.bash_parser(cmd)
                    if not ast:
                        continue
                    utilities = data_tools.get_utilities(ast)
                    for utility in utilities:
                        if utility in bash.top_100_utilities:
                            print('extracted: {}, {}'.format(utility, cmd))
                            temp = data_tools.ast2template(ast, loose_constraints=True)
                            if not utility in commands:
                                commands[utility] = {}
                                commands[utility][temp] = cmd
                                urls[utility] = {'{}{}'.format(url_prefix, post_id)}
                            else:
                                if len(commands[utility]) >= NUM_COMMAND_THRESHOLD:
                                    continue
                                if not temp in commands[utility]:
                                    commands[utility][temp] = cmd
                                    urls[utility].add('{}{}'.format(url_prefix, post_id))
            count += 1
            if count % 1000 == 0:
                completed = False
                for utility in bash.top_100_utilities:
                    if not utility in commands or len(commands[utility]) < NUM_COMMAND_THRESHOLD:
                        completed = False
                    else:
                        print('{} collection done.'.format(utility))

                if completed:
                    break

    with open('stackoverflow.urls', 'wb') as o_f:
        pickle.dump(urls, o_f)
    with open('stackoverflow.commands', 'wb') as o_f:
        pickle.dump(commands, o_f)

    for utility in commands:
        print('{} ({})'.format(utility, len(commands[utility])))
        for cmd in commands[utility]:
            print(cmd)
Exemple #6
0
 def get_rewrites(self, ast):
     rewrites = set([ast])
     s1 = data_tools.ast2template(ast, loose_constraints=True)
     c = self.cursor
     for s1, s2 in c.execute("SELECT s1, s2 FROM Rewrites WHERE s1 = ?",
                             (s1, )):
         rw = rewrite(ast, s2)
         if not rw is None:
             rewrites.add(rw)
     return rewrites
Exemple #7
0
def Cust_Cmd_Tokenizer(String, parse="Template"):
    """As per our need Custom CMD Tokenizer"""
    if parse == "Norm":
        Command = cm_to_partial_tokens(String,
                                       tokenizer=data_tools.bash_tokenizer)
    elif parse == "Template":
        AST = data_tools.bash_parser(String)
        Template = data_tools.ast2template(AST, ignore_flag_order=False)
    Template_Tokens_List = Template.split(" ")
    return Template_Tokens_List
Exemple #8
0
def get_command(command_str):
    command_str = command_str.strip()
    if Command.objects.filter(str=command_str).exists():
        cmd = Command.objects.get(str=command_str)
    else:
        cmd = Command.objects.create(str=command_str)
        ast = data_tools.bash_parser(command_str)
        for utility in data_tools.get_utilities(ast):
            cmd.tags.add(get_tag(utility))
        template = data_tools.ast2template(ast, loose_constraints=True)
        cmd.template = template
        cmd.save()
    return cmd
Exemple #9
0
def get_manual_evaluation_metrics(grouped_dataset,
                                  prediction_list,
                                  FLAGS,
                                  num_examples=-1,
                                  interactive=True,
                                  verbose=True):

    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    # Get dev set samples (fixed)
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'), verbose=True)

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation
    num_t_top_1_correct = 0.0
    num_f_top_1_correct = 0.0
    num_t_top_3_correct = 0.0
    num_f_top_3_correct = 0.0

    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        predictions = prediction_list[example_id]
        top_3_s_correct_marked = False
        top_3_f_correct_marked = False
        for i in xrange(min(3, len(predictions))):
            pred_cmd = predictions[i]
            pred_ast = cmd_parser(pred_cmd)
            pred_temp = data_tools.ast2template(pred_ast,
                                                loose_constraints=True)
            temp_match = tree_dist.one_match(command_gt_asts,
                                             pred_ast,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(command_gt_asts,
                                            pred_ast,
                                            ignore_arg_value=False)
            # Match ground truths & exisitng judgements
            command_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_cmd)
            structure_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_temp)
            command_eval, structure_eval = '', ''
            if str_match:
                command_eval = 'y'
                structure_eval = 'y'
            elif temp_match:
                structure_eval = 'y'
            if command_eval_cache and command_example_key in command_eval_cache:
                command_eval = command_eval_cache[command_example_key]
            if structure_eval_cache and structure_example_key in structure_eval_cache:
                structure_eval = structure_eval_cache[structure_example_key]
            # Prompt for new judgements
            if command_eval != 'y':
                if structure_eval == 'y':
                    if not command_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        command_eval = input('CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                else:
                    if not structure_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        structure_eval = input(
                            'CORRECT STRUCTURE? [y/reason] ')
                        if structure_eval == 'y':
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                structure_eval_cache[structure_example_key] = structure_eval
                command_eval_cache[command_example_key] = command_eval
            if structure_eval == 'y':
                if i == 0:
                    num_t_top_1_correct += 1
                if not top_3_s_correct_marked:
                    num_t_top_3_correct += 1
                    top_3_s_correct_marked = True
            if command_eval == 'y':
                if i == 0:
                    num_f_top_1_correct += 1
                if not top_3_f_correct_marked:
                    num_f_top_3_correct += 1
                    top_3_f_correct_marked = True

    metrics = {}
    acc_f_1 = num_f_top_1_correct / len(sample_ids)
    acc_f_3 = num_f_top_3_correct / len(sample_ids)
    acc_t_1 = num_t_top_1_correct / len(sample_ids)
    acc_t_3 = num_t_top_3_correct / len(sample_ids)
    metrics['acc_f'] = [acc_f_1, acc_f_3]
    metrics['acc_t'] = [acc_t_1, acc_t_3]

    if verbose:
        print('{} examples evaluated'.format(len(sample_ids)))
        print('Top 1 Command Acc = {:.3f}'.format(acc_f_1))
        print('Top 3 Command Acc = {:.3f}'.format(acc_f_3))
        print('Top 1 Template Acc = {:.3f}'.format(acc_t_1))
        print('Top 3 Template Acc = {:.3f}'.format(acc_t_3))
    return metrics
Exemple #10
0
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))
Exemple #11
0
def template_match(ast1, ast2):
    temp1 = ignore_differences(
        data_tools.ast2template(ast1, loose_constraints=True))
    temp2 = ignore_differences(
        data_tools.ast2template(ast2, loose_constraints=True))
    return temp1 == temp2
Exemple #12
0
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100):
    """
    Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev
    examples, predictions of different models are listed side-by-side.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[num_examples:num_examples + 100]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv')
    with open(output_path, 'w') as o_f:
        o_f.write('example_id, description, ground_truth, model, prediction, '
                  'correct template, correct command\n')
        for example_id in sample_ids:
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gt_asts = [
                data_tools.bash_parser(gt) for gt in command_gts
            ]
            for model_id, model_name in enumerate(model_names):
                predictions = model_predictions[model_id][example_id]
                for i in xrange(min(3, len(predictions))):
                    if model_id == 0 and i == 0:
                        output_str = '{},"{}",'.format(
                            example_id, sc_txt.replace('"', '""'))
                    else:
                        output_str = ',,'
                    pred_cmd = predictions[i]
                    pred_tree = cmd_parser(pred_cmd)
                    pred_temp = data_tools.ast2template(pred_tree,
                                                        loose_constraints=True)
                    temp_match = tree_dist.one_match(command_gt_asts,
                                                     pred_tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(command_gt_asts,
                                                    pred_tree,
                                                    ignore_arg_value=False)
                    if (model_id * min(3, len(predictions)) +
                            i) < len(command_gts):
                        output_str += '"{}",'.format(
                            command_gts[model_id * min(3, len(predictions)) +
                                        i].strip().replace('"', '""'))
                    else:
                        output_str += ','
                    output_str += '{},"{}",'.format(
                        model_name, pred_cmd.replace('"', '""'))

                    command_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_cmd)
                    structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_temp)
                    command_eval, structure_eval = '', ''
                    if str_match:
                        command_eval = 'y'
                        structure_eval = 'y'
                    elif temp_match:
                        structure_eval = 'y'
                    if command_eval_cache and \
                            command_example_sig in command_eval_cache:
                        command_eval = command_eval_cache[command_example_sig]
                    if structure_eval_cache and \
                            structure_example_sig in structure_eval_cache:
                        structure_eval = structure_eval_cache[
                            structure_example_sig]
                    output_str += '{},{}'.format(structure_eval, command_eval)
                    o_f.write('{}\n'.format(output_str))

    print('Manual evaluation results saved to {}'.format(output_path))
Exemple #13
0
def tabulate_example_predictions(dataset, FLAGS, num_examples=100):
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=1)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[:num_examples]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    model_name_pt = {
        'token-seq2seq': 'T-Seq2Seq',
        'tellina': 'Tellina',
        'token-copynet': 'T-CopyNet',
        'partial.token-seq2seq': 'ST-Seq2Seq',
        'partial.token-copynet': 'ST-CopyNet',
        'char-seq2seq': 'C-Seq2Seq',
        'char-copynet': 'C-CopyNet'
    }

    for example_id in sample_ids:
        print('Example {}'.format(example_id))
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        output_strs = {}
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)

                output_str = '& \\<{}> & {}'.format(
                    pred_cmd.replace('__SP__', '').replace('_', '\\_').replace(
                        '$', '\\$').replace('%',
                                            '\\%').replace('{{}}', '\\ttcbs'),
                    model_name_pt[model_name])

                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                output_str += ', {},{} \\\\'.format(structure_eval,
                                                    command_eval)
            output_strs[model_name] = output_str
        for model_name in [
                'char-seq2seq', 'char-copynet', 'token-seq2seq',
                'token-copynet', 'partial.token-seq2seq',
                'partial.token-copynet', 'tellina'
        ]:
            if model_name == 'char-seq2seq':
                print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format(
                    sc_txt) + output_strs[model_name])
            else:
                print(output_strs[model_name])
        output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace(
            '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace(
                '%', '\\%').replace('{{}}', '\\ttcbs'))
        print(output_str)
        print()
Exemple #14
0
def gen_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True):
    """
    Generate structure and full command accuracy results on a fixed set of dev/test
        set samples, with the results of multiple models tabulated in the same table.
    :param interactive:
        - If set, prompt the user to enter judgement if a prediction does not
            match any of the groundtruths and the correctness of the prediction
            has not been pre-determined;
          Otherwise, count all predictions that does not match any of the
          groundtruths as wrong.
    """
    def add_judgement(data_dir,
                      nl,
                      command,
                      correct_template='',
                      correct_command=''):
        """
        Append a new judgement
        """
        data_dir = os.path.join(data_dir, 'manual_judgements')
        manual_judgement_path = os.path.join(data_dir,
                                             'manual.evaluations.additional')
        if not os.path.exists(manual_judgement_path):
            with open(manual_judgement_path, 'w') as o_f:
                o_f.write(
                    'description,prediction,template,correct template,correct command\n'
                )
        with open(manual_judgement_path, 'a') as o_f:
            temp = data_tools.cmd2template(command, loose_constraints=True)
            if not correct_template:
                correct_template = 'n'
            if not correct_command:
                correct_command = 'n'
            o_f.write('"{}","{}","{}","{}","{}"\n'.format(
                nl.replace('"', '""'), command.replace('"', '""'),
                temp.replace('"', '""'), correct_template.replace('"', '""'),
                correct_command.replace('"', '""')))
        print('new judgement added to {}'.format(manual_judgement_path))

    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True)

    if FLAGS.test:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset,
            FLAGS,
            top_k=3,
            tellina=True,
            partial_token_copynet=True,
            token_seq2seq=False,
            token_copynet=False,
            char_seq2seq=False,
            char_copynet=False,
            partial_token_seq2seq=False)
    else:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation interface
    num_s_correct = collections.defaultdict(int)
    num_f_correct = collections.defaultdict(int)
    num_s_top_3_correct = collections.defaultdict(int)
    num_f_top_3_correct = collections.defaultdict(int)
    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            top_3_s_correct_marked = False
            top_3_f_correct_marked = False
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_ast = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_ast,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_ast,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_ast,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and command_example_key in command_eval_cache:
                    command_eval = command_eval_cache[command_example_key]
                if structure_eval_cache and structure_example_key in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_key]
                # Prompt for new judgements
                if command_eval != 'y':
                    if structure_eval == 'y':
                        if not command_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    else:
                        if not structure_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            structure_eval = input(
                                'CORRECT STRUCTURE? [y/reason] ')
                            if structure_eval == 'y':
                                command_eval = input(
                                    'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    structure_eval_cache[
                        structure_example_key] = structure_eval
                    command_eval_cache[command_example_key] = command_eval
                if structure_eval == 'y':
                    if i == 0:
                        num_s_correct[model_name] += 1
                    if not top_3_s_correct_marked:
                        num_s_top_3_correct[model_name] += 1
                        top_3_s_correct_marked = True
                if command_eval == 'y':
                    if i == 0:
                        num_f_correct[model_name] += 1
                    if not top_3_f_correct_marked:
                        num_f_top_3_correct[model_name] += 1
                        top_3_f_correct_marked = True
    metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3']
    model_metrics = {}
    for model_name in model_names:
        metrics = [
            num_f_correct[model_name] / len(sample_ids),
            num_f_top_3_correct[model_name] / len(sample_ids),
            num_s_correct[model_name] / len(sample_ids),
            num_s_top_3_correct[model_name] / len(sample_ids)
        ]
        model_metrics[model_name] = metrics
    print_table(model_names, metrics_names, model_metrics)
Exemple #15
0
def string_match(ast1, ast2):
    str1 = ignore_differences(
        data_tools.ast2template(ast1, loose_constraints=True, arg_type_only=False))
    str2 = ignore_differences(
        data_tools.ast2template(ast2, loose_constraints=True, arg_type_only=False))
    return str1 == str2
Exemple #16
0
def template_match(ast1, ast2):
    temp1 = data_tools.ast2template(ast1, loose_constraints=True)
    temp2 = data_tools.ast2template(ast2, loose_constraints=True)
    return temp1 == temp2
Exemple #17
0
def manual_eval(model, dataset, FLAGS, output_dir, num_eval=None):
    num_top1_correct_temp = 0.0
    num_top3_correct_temp = 0.0
    num_top5_correct_temp = 0.0
    num_top10_correct_temp = 0.0
    num_top1_correct = 0.0
    num_top3_correct = 0.0
    num_top5_correct = 0.0
    num_top10_correct = 0.0
    num_evaled = 0

    eval_bash = FLAGS.dataset.startswith("bash")
    use_bucket = False if model == "knn" else True
    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_data(
        dataset,
        use_bucket=use_bucket,
        use_temp=eval_bash,
        tokenizer_selector=tokenizer_selector)

    if num_eval is None:
        num_eval = len(grouped_dataset)

    cmd_parser = data_tools.bash_parser if FLAGS.dataset.startswith("bash") \
        else data_tools.paren_parser

    o_f = open(os.path.join(output_dir, "manual.eval.results"), 'w')

    rejudge = False

    with DBConnection() as db:
        db.create_schema()
        while num_evaled < len(grouped_dataset):
            nl_strs, tg_strs, nls, search_historys = grouped_dataset[
                num_evaled]
            nl_str = nl_strs[0].decode('utf-8')

            if num_evaled == num_eval:
                break

            gt_trees = [cmd_parser(cmd) for cmd in tg_strs]

            predictions = db.get_top_k_predictions(model, nl_str, k=10)

            top1_correct_temp, top3_correct_temp, top5_correct_temp, \
                top10_correct_temp = False, False, False, False
            top1_correct, top3_correct, top5_correct, top10_correct = \
                False, False, False, False

            # evaluation ignoring ordering of flags
            print("Example %d (%d)" % (num_evaled + 1, len(tg_strs)))
            o_f.write("Example %d (%d)" % (num_evaled + 1, len(tg_strs)) +
                      "\n")
            print("English: " + nl_str.strip())
            o_f.write("English: " + nl_str.encode('utf-8'))
            for j in xrange(len(tg_strs)):
                print("GT Command %d: " % (j + 1) + tg_strs[j].strip())
                o_f.write("GT Command %d: " % (j + 1) + tg_strs[j].strip() +
                          "\n")

            pred_id = 0
            while pred_id < min(3, len(predictions)):
                pred_cmd, score = predictions[pred_id]
                tree = cmd_parser(pred_cmd)
                print("Prediction {}: {} ({})".format(pred_id + 1, pred_cmd,
                                                      score))
                o_f.write("Prediction {}: {} ({})\n".format(
                    pred_id + 1, pred_cmd, score))
                print()
                pred_temp = data_tools.ast2template(tree,
                                                    loose_constraints=True)
                str_judge = db.get_str_judgement((nl_str, pred_cmd))
                temp_judge = db.get_temp_judgement((nl_str, pred_temp))
                if temp_judge is not None and not rejudge:
                    judgement_str = "y" if temp_judge == 1 \
                        else "n ({})".format(error_types[temp_judge])
                    print("Is the command template correct [y/n]? %s" %
                          judgement_str)
                else:
                    temp_judge = tree_dist.one_match(gt_trees,
                                                     tree,
                                                     rewrite=False,
                                                     ignore_arg_value=True)
                    if not temp_judge:
                        inp = raw_input(
                            "Is the command template correct [y/n]? ")
                        if inp == "REVERSE":
                            rejudge = True
                        else:
                            if inp == "y":
                                temp_judge = True
                                db.add_temp_judgement((nl_str, pred_temp, 1))
                            else:
                                temp_judge = False
                                error_type = raw_input(
                                    "Error type: \n"
                                    "(2) extra utility \n"
                                    "(3) missing utility \n"
                                    "(4) confused utility \n"
                                    "(5) extra flag \n"
                                    "(6) missing flag \n"
                                    "(7) confused flag \n"
                                    "(8) logic error\n"
                                    "(9) count error\n")
                                db.add_temp_judgement(
                                    (nl_str, pred_temp, int(error_type)))
                            rejudge = False
                    else:
                        print("Is the command template correct [y/n]? y")
                if temp_judge == 1:
                    if pred_id < 1:
                        top1_correct_temp = True
                        top3_correct_temp = True
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 3:
                        top3_correct_temp = True
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 5:
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 10:
                        top10_correct_temp = True
                    o_f.write("C")
                    if str_judge is not None and not rejudge:
                        judgement_str = "y" if str_judge == 1 \
                            else "n ({})".format(error_types[str_judge])
                        print("Is the complete command correct [y/n]? %s" %
                              judgement_str)
                    else:
                        str_judge = tree_dist.one_match(gt_trees,
                                                        tree,
                                                        rewrite=False,
                                                        ignore_arg_value=False)
                        if not str_judge:
                            inp = raw_input(
                                "Is the complete command correct [y/n]? ")
                            if inp == "REVERSE":
                                rejudge = True
                                continue
                            elif inp == "y":
                                str_judge = True
                                o_f.write("C")
                                db.add_str_judgement((nl_str, pred_cmd, 1))
                            else:
                                str_judge = False
                                o_f.write("W")
                                db.add_str_judgement((nl_str, pred_cmd, 0))
                        else:
                            print("Is the complete command correct [y/n]? y")
                    if str_judge == 1:
                        if pred_id < 1:
                            top1_correct = True
                            top3_correct = True
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 3:
                            top3_correct = True
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 5:
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 10:
                            top10_correct = True
                        o_f.write("C")
                    else:
                        o_f.write("W")
                else:
                    o_f.write("WW")

                o_f.write("\n")
                o_f.write("\n")

                pred_id += 1

            if rejudge:
                num_evaled -= 1
            else:
                num_evaled += 1
                if top1_correct_temp:
                    num_top1_correct_temp += 1
                if top3_correct_temp:
                    num_top3_correct_temp += 1
                if top5_correct_temp:
                    num_top5_correct_temp += 1
                if top10_correct_temp:
                    num_top10_correct_temp += 1
                if top1_correct:
                    num_top1_correct += 1
                if top3_correct:
                    num_top3_correct += 1
                if top5_correct:
                    num_top5_correct += 1
                if top10_correct:
                    num_top10_correct += 1

            rejudge = False

            print()

    print("%d examples evaluated" % num_eval)
    print("Top 1 Template Match Score = %.2f" %
          (num_top1_correct_temp / num_eval))
    print("Top 1 String Match Score = %.2f" % (num_top1_correct / num_eval))
    if len(predictions) > 3:
        print("Top 5 Template Match Score = %.2f" %
              (num_top5_correct_temp / num_eval))
        print("Top 5 String Match Score = %.2f" %
              (num_top5_correct / num_eval))
        print("Top 10 Template Match Score = %.2f" %
              (num_top10_correct_temp / num_eval))
        print("Top 10 String Match Score = %.2f" %
              (num_top10_correct / num_eval))
    print()

    o_f.write("%d examples evaluated" % num_eval + "\n")
    o_f.write("Top 1 Template MatchScore = %.2f" %
              (num_top1_correct_temp / num_eval) + "\n")
    o_f.write("Top 1 String Match Score = %.2f" %
              (num_top1_correct / num_eval) + "\n")
    if len(predictions) > 1:
        o_f.write("Top 5 Template Match Score = %.2f" %
                  (num_top5_correct_temp / num_eval) + "\n")
        o_f.write("Top 5 String Match Score = %.2f" %
                  (num_top5_correct / num_eval) + "\n")
        o_f.write("Top 10 Template Match Score = %.2f" %
                  (num_top10_correct_temp / num_eval) + "\n")
        o_f.write("Top 10 String Match Score = %.2f" %
                  (num_top10_correct / num_eval) + "\n")
    o_f.write("\n")
Exemple #18
0
def string_match(ast1, ast2):
    str1 = data_tools.ast2template(ast1, loose_constraints=True, arg_type_only=False)
    str2 = data_tools.ast2template(ast2, loose_constraints=True, arg_type_only=False)
    return str1 == str2