Пример #1
0
def combine_annotations_multi_files():
    """
    Combine multiple annotations files and discard the annotations that has a conflict.
    """

    input_dir = sys.argv[1]

    template_evals = {}
    command_evals = {}
    discarded_keys = set({})

    for in_csv in os.listdir(input_dir):
        in_csv_path = os.path.join(input_dir, in_csv)
        with open(in_csv_path) as f:
            reader = csv.DictReader(f)
            current_description = ''
            for row in reader:
                template_eval = normalize_judgement(row['correct template'])
                command_eval = normalize_judgement(row['correct command'])
                description = get_example_nl_key(row['description'])
                if description.strip():
                    current_description = description
                else:
                    description = current_description
                prediction = row['prediction']
                example_key = '{}<NL_PREDICTION>{}'.format(
                    description, prediction)
                if example_key in template_evals and template_evals[
                        example_key] != template_eval:
                    discarded_keys.add(example_key)
                    continue
                if example_key in command_evals and command_evals[
                        example_key] != command_eval:
                    discarded_keys.add(example_key)
                    continue
                template_evals[example_key] = template_eval
                command_evals[example_key] = command_eval
            print('{} read ({} manually annotated examples, {} discarded)'.
                  format(in_csv_path, len(template_evals),
                         len(discarded_keys)))

    # Write to new file
    assert (len(template_evals) == len(command_evals))
    with open('manual_annotations.additional', 'w') as o_f:
        o_f.write(
            'description,prediction,template,correct template,correct comand\n'
        )
        for key in sorted(template_evals.keys()):
            if key in discarded_keys:
                continue
            description, prediction = key.split('<NL_PREDICTION>')
            template_eval = template_evals[example_key]
            command_eval = command_evals[example_key]
            pred_tree = data_tools.bash_parser(prediction)
            pred_temp = data_tools.ast2template(pred_tree,
                                                loose_constraints=True)
            o_f.write('"{}","{}","{}",{},{}\n'.format(
                description.replace('"', '""'), prediction.replace('"', '""'),
                pred_temp.replace('"', '""'), template_eval, command_eval))
def combine_annotations():
    """
    Combine the annotations input by three annotators.

    :param input_file1: main annotation file 1.
    :param input_file2: main annotation file 2 (should contain the same number of
        lines as input_file1).
    :param input_file3: supplementary annotation file which contains annotations
        of lines in input_file1 and input_file2 that contain a disagreement.
    :param output_file: file that contains the combined annotations.
    """
    input_file1 = sys.argv[1]
    input_file2 = sys.argv[2]
    input_file3 = sys.argv[3]
    output_file = sys.argv[4]
    o_f = open(output_file, 'w')
    o_f.write('description,prediction,template,correct template,correct command,'
              'correct template A,correct command A,'
              'correct template B,correct command B,'
              'correct template C,correct command C\n')
    sup_structure_eval, sup_command_eval = load_cached_evaluations_from_file(
        input_file3, treat_empty_as_correct=True)
    # for key in sup_structure_eval:
    #     print(key)
    # print('------------------')
    with open(input_file1) as f1:
        with open(input_file2) as f2:
            reader1 = csv.DictReader(f1)
            reader2 = csv.DictReader(f2)
            current_desp = ''
            for row1, row2 in zip(reader1, reader2):
                row1_template_eval = normalize_judgement(row1['correct template'].strip())
                row1_command_eval = normalize_judgement(row1['correct command'].strip())
                row2_template_eval = normalize_judgement(row2['correct template'].strip())
                row2_command_eval = normalize_judgement(row2['correct command'].strip())
                if row1['description']:
                    current_desp = row1['description'].strip()
                sc_key = get_example_nl_key(current_desp)
                pred_cmd = row1['prediction'].strip()
                if not pred_cmd:
                    row1_template_eval, row1_command_eval = 'n', 'n'
                    row2_template_eval, row2_command_eval = 'n', 'n'
                pred_temp = data_tools.cmd2template(pred_cmd, loose_constraints=True)
                structure_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_temp)
                command_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_cmd)
                row3_template_eval, row3_command_eval = None, None
                if structure_example_key in sup_structure_eval:
                    row3_template_eval = sup_structure_eval[structure_example_key]
                if command_example_key in sup_command_eval:
                    row3_command_eval = sup_command_eval[command_example_key]
                if row1_template_eval != row2_template_eval or row1_command_eval != row2_command_eval:
                    if row1_template_eval != row2_template_eval:
                        if row3_template_eval is None:
                            print(structure_example_key)
                        assert(row3_template_eval is not None)
                        template_eval = row3_template_eval
                    else:
                        template_eval = row1_template_eval
                    if row1_command_eval != row2_command_eval:
                        # if row3_command_eval is None:
                        #     print(command_example_key)
                        assert(row3_command_eval is not None)
                        command_eval = row3_command_eval
                    else:
                        command_eval = row1_command_eval
                else:
                    template_eval = row1_template_eval
                    command_eval = row1_command_eval
                if row3_template_eval is None:
                    row3_template_eval = ''
                if row3_command_eval is None:
                    row3_command_eval = ''
                o_f.write('"{}","{}","{}",{},{},{},{},{},{},{},{}\n'.format(
                    current_desp.replace('"', '""'), pred_cmd.replace('"', '""'), pred_temp.replace('"', '""'),
                    template_eval, command_eval,
                    row1_template_eval, row1_command_eval,
                    row2_template_eval, row2_command_eval,
                    row3_template_eval, row3_command_eval))
    o_f.close()
def print_error_analysis_sheet():
    input_file1 = sys.argv[1]
    input_file2 = sys.argv[2]
    input_file3 = sys.argv[3]
    output_file = sys.argv[4]
    o_f = open(output_file, 'w')
    o_f.write('description,model,prediction,correct template,correct command,'
              'correct template A,correct command A,'
              'correct template B,correct command B,'
              'correct template C,correct command C\n')
    sup_structure_eval, sup_command_eval = load_cached_evaluations_from_file(
        input_file3, treat_empty_as_correct=True)
    # for key in sup_structure_eval:
    #     print(key)
    # print('------------------')
    with open(input_file1) as f1:
        with open(input_file2) as f2:
            reader1 = csv.DictReader(f1)
            reader2 = csv.DictReader(f2)
            current_desp = ''
            for row_id, (row1, row2) in enumerate(zip(reader1, reader2)):
                if row1['description']:
                    current_desp = row1['description'].strip()
                model_name = row2['model']
                if not model_name in ['partial.token-copynet', 'tellina']:
                    continue
                if row_id % 3 != 0:
                    continue
                row1_template_eval = normalize_judgement(row1['correct template'].strip())
                row1_command_eval = normalize_judgement(row1['correct command'].strip())
                row2_template_eval = normalize_judgement(row2['correct template'].strip())
                row2_command_eval = normalize_judgement(row2['correct command'].strip())
                sc_key = get_example_nl_key(current_desp)
                pred_cmd = row1['prediction'].strip()
                if not pred_cmd:
                    row1_template_eval, row1_command_eval = 'n', 'n'
                    row2_template_eval, row2_command_eval = 'n', 'n'
                pred_temp = data_tools.cmd2template(pred_cmd, loose_constraints=True)
                structure_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_temp)
                command_example_key = '{}<NL_PREDICTION>{}'.format(sc_key, pred_cmd)
                row3_template_eval, row3_command_eval = None, None
                if structure_example_key in sup_structure_eval:
                    row3_template_eval = sup_structure_eval[structure_example_key]
                if command_example_key in sup_command_eval:
                    row3_command_eval = sup_command_eval[command_example_key]
                if row1_template_eval != row2_template_eval or row1_command_eval != row2_command_eval:
                    if row1_template_eval != row2_template_eval:
                        if row3_template_eval is None:
                            print(pred_cmd_key, structure_example_key)
                        assert (row3_template_eval is not None)
                        template_eval = row3_template_eval
                    else:
                        template_eval = row1_template_eval
                    if row1_command_eval != row2_command_eval:
                        # if row3_command_eval is None:
                        #     print(command_example_key)
                        assert (row3_command_eval is not None)
                        command_eval = row3_command_eval
                    else:
                        command_eval = row1_command_eval
                else:
                    template_eval = row1_template_eval
                    command_eval = row1_command_eval
                if row3_template_eval is None:
                    row3_template_eval = ''
                if row3_command_eval is None:
                    row3_command_eval = ''
                o_f.write('"{}","{}","{}",{},{},{},{},{},{},{},{}\n'.format(
                    current_desp.replace('"', '""'), model_name, pred_cmd.replace('"', '""'),
                    template_eval, command_eval,
                    row1_template_eval, row1_command_eval,
                    row2_template_eval, row2_command_eval,
                    row3_template_eval, row3_command_eval))
    o_f.close()
Пример #4
0
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))
Пример #5
0
def print_error_analysis_csv(grouped_dataset,
                             prediction_list,
                             FLAGS,
                             cached_evaluation_results=None,
                             group_by_utility=False,
                             error_predictions_only=True):
    """
    Convert dev/test set examples to csv format so as to make it easier for
    human annotators to enter their judgements.

    :param grouped_dataset: dev/test set grouped by natural language.
    :param prediction_list: model predictions.
    :param FLAGS: experiment hyperparameters.
    :param cached_evaluation_results: cached evaluation results from previous
        rounds.
    :param group_by_utility: if set, group the error examples by the utilities
        used in the ground truth.
    """
    def mark_example(error_list, example, gt_utility=None):
        if gt_utility:
            error_list[gt_utility].append(example)
        else:
            error_list.append(example)

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser
    if group_by_utility:
        utility_index = {}
        for line in bash.utility_stats.split('\n'):
            ind, utility, _, _ = line.split(',')
            utility_index[utility] = ind

    grammar_errors = collections.defaultdict(list) if group_by_utility else []
    argument_errors = collections.defaultdict(list) if group_by_utility else []
    example_id = 0
    for nl_temp, data_group in grouped_dataset:
        sc_txt = data_group[0].sc_txt.strip()
        sc_temp = get_example_nl_key(sc_txt)
        tg_strs = [dp.tg_txt for dp in data_group]
        gt_trees = [cmd_parser(cm_str) for cm_str in tg_strs]
        if group_by_utility:
            gt_utilities = functools.reduce(
                lambda x, y: x | y,
                [data_tools.get_utilities(gt) for gt in gt_trees])
            gt_utility = sorted(list(gt_utilities),
                                key=lambda x: int(utility_index[x]))[-1]
        else:
            gt_utility = None
        predictions = prediction_list[example_id]
        example_id += 1
        example = []
        grammar_error, argument_error = False, False
        for i in xrange(min(3, len(predictions))):
            if i == 0:
                output_str = '{},"{}",'.format(example_id,
                                               sc_txt.replace('"', '""'))
            else:
                output_str = ',,'
            pred_cmd = predictions[i]
            tree = cmd_parser(pred_cmd)

            # evaluation ignoring flag orders
            temp_match = tree_dist.one_match(gt_trees,
                                             tree,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(gt_trees,
                                            tree,
                                            ignore_arg_value=False)
            if i < len(tg_strs):
                output_str += '"{}",'.format(tg_strs[i].strip().replace(
                    '"', '""'))
            else:
                output_str += ','
            output_str += '"{}",'.format(pred_cmd.replace('"', '""'))
            if not str_match:
                if temp_match:
                    if i == 0:
                        argument_error = True
                        grammar_error = True
                else:
                    if i == 0:
                        grammar_error = True

            example_sig = '{}<NL_PREDICTION>{}'.format(sc_temp, pred_cmd)
            if cached_evaluation_results and \
                    example_sig in cached_evaluation_results:
                output_str += cached_evaluation_results[example_sig]
            else:
                if str_match:
                    output_str += 'y,y'
                elif temp_match:
                    output_str += 'y,'
            example.append(output_str)
        if error_predictions_only:
            if grammar_error:
                mark_example(grammar_errors, example, gt_utility)
            elif argument_error:
                mark_example(argument_errors, example, gt_utility)
        else:
            mark_example(grammar_errors, example, gt_utility)

    return grammar_errors, argument_errors
Пример #6
0
def tabulate_example_predictions(dataset, FLAGS, num_examples=100):
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=1)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[:num_examples]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    model_name_pt = {
        'token-seq2seq': 'T-Seq2Seq',
        'tellina': 'Tellina',
        'token-copynet': 'T-CopyNet',
        'partial.token-seq2seq': 'ST-Seq2Seq',
        'partial.token-copynet': 'ST-CopyNet',
        'char-seq2seq': 'C-Seq2Seq',
        'char-copynet': 'C-CopyNet'
    }

    for example_id in sample_ids:
        print('Example {}'.format(example_id))
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        output_strs = {}
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)

                output_str = '& \\<{}> & {}'.format(
                    pred_cmd.replace('__SP__', '').replace('_', '\\_').replace(
                        '$', '\\$').replace('%',
                                            '\\%').replace('{{}}', '\\ttcbs'),
                    model_name_pt[model_name])

                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                output_str += ', {},{} \\\\'.format(structure_eval,
                                                    command_eval)
            output_strs[model_name] = output_str
        for model_name in [
                'char-seq2seq', 'char-copynet', 'token-seq2seq',
                'token-copynet', 'partial.token-seq2seq',
                'partial.token-copynet', 'tellina'
        ]:
            if model_name == 'char-seq2seq':
                print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format(
                    sc_txt) + output_strs[model_name])
            else:
                print(output_strs[model_name])
        output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace(
            '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace(
                '%', '\\%').replace('{{}}', '\\ttcbs'))
        print(output_str)
        print()
Пример #7
0
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100):
    """
    Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev
    examples, predictions of different models are listed side-by-side.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[num_examples:num_examples + 100]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv')
    with open(output_path, 'w') as o_f:
        o_f.write('example_id, description, ground_truth, model, prediction, '
                  'correct template, correct command\n')
        for example_id in sample_ids:
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gt_asts = [
                data_tools.bash_parser(gt) for gt in command_gts
            ]
            for model_id, model_name in enumerate(model_names):
                predictions = model_predictions[model_id][example_id]
                for i in xrange(min(3, len(predictions))):
                    if model_id == 0 and i == 0:
                        output_str = '{},"{}",'.format(
                            example_id, sc_txt.replace('"', '""'))
                    else:
                        output_str = ',,'
                    pred_cmd = predictions[i]
                    pred_tree = cmd_parser(pred_cmd)
                    pred_temp = data_tools.ast2template(pred_tree,
                                                        loose_constraints=True)
                    temp_match = tree_dist.one_match(command_gt_asts,
                                                     pred_tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(command_gt_asts,
                                                    pred_tree,
                                                    ignore_arg_value=False)
                    if (model_id * min(3, len(predictions)) +
                            i) < len(command_gts):
                        output_str += '"{}",'.format(
                            command_gts[model_id * min(3, len(predictions)) +
                                        i].strip().replace('"', '""'))
                    else:
                        output_str += ','
                    output_str += '{},"{}",'.format(
                        model_name, pred_cmd.replace('"', '""'))

                    command_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_cmd)
                    structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_temp)
                    command_eval, structure_eval = '', ''
                    if str_match:
                        command_eval = 'y'
                        structure_eval = 'y'
                    elif temp_match:
                        structure_eval = 'y'
                    if command_eval_cache and \
                            command_example_sig in command_eval_cache:
                        command_eval = command_eval_cache[command_example_sig]
                    if structure_eval_cache and \
                            structure_example_sig in structure_eval_cache:
                        structure_eval = structure_eval_cache[
                            structure_example_sig]
                    output_str += '{},{}'.format(structure_eval, command_eval)
                    o_f.write('{}\n'.format(output_str))

    print('Manual evaluation results saved to {}'.format(output_path))