Exemplo n.º 1
0
def gen_error_analysis_csv_by_utility(model_dir,
                                      decode_sig,
                                      dataset,
                                      FLAGS,
                                      top_k=10):
    """
    Generate error analysis evaluation sheet grouped by utility.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError(
            "ground truth and predictions length must be equal: {} vs. {}".
            format(len(grouped_dataset), len(prediction_list)))

    # Load cached evaluation results
    cached_evaluation_results = load_cached_evaluations(model_dir)

    # Convert the predictions into csv format
    grammar_errors, argument_errors = print_error_analysis_csv(
        grouped_dataset,
        prediction_list,
        FLAGS,
        cached_evaluation_results,
        group_by_utility=True,
        error_predictions_only=False)

    error_by_utility_path = \
        os.path.join(model_dir, 'error.analysis.by.utility.csv')
    print("Saving grammar errors to {}".format(error_by_utility_path))
    with open(error_by_utility_path, 'w') as error_by_utility_file:
        # print csv file header
        error_by_utility_file.write(
            'utility, example_id, description, groundtruth, prediction, '
            'correct template, correct command\n')
        for line in bash.utility_stats.split('\n'):
            utility = line.split(',')[1]
            error_examples = grammar_errors[utility]
            if len(error_examples) <= 5:
                for example in error_examples:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
            else:
                random.shuffle(error_examples)
                for example in error_examples[:5]:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
Exemplo n.º 2
0
def tabulate_example_predictions(dataset, FLAGS, num_examples=100):
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=1)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[:num_examples]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    model_name_pt = {
        'token-seq2seq': 'T-Seq2Seq',
        'tellina': 'Tellina',
        'token-copynet': 'T-CopyNet',
        'partial.token-seq2seq': 'ST-Seq2Seq',
        'partial.token-copynet': 'ST-CopyNet',
        'char-seq2seq': 'C-Seq2Seq',
        'char-copynet': 'C-CopyNet'
    }

    for example_id in sample_ids:
        print('Example {}'.format(example_id))
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        output_strs = {}
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)

                output_str = '& \\<{}> & {}'.format(
                    pred_cmd.replace('__SP__', '').replace('_', '\\_').replace(
                        '$', '\\$').replace('%',
                                            '\\%').replace('{{}}', '\\ttcbs'),
                    model_name_pt[model_name])

                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                output_str += ', {},{} \\\\'.format(structure_eval,
                                                    command_eval)
            output_strs[model_name] = output_str
        for model_name in [
                'char-seq2seq', 'char-copynet', 'token-seq2seq',
                'token-copynet', 'partial.token-seq2seq',
                'partial.token-copynet', 'tellina'
        ]:
            if model_name == 'char-seq2seq':
                print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format(
                    sc_txt) + output_strs[model_name])
            else:
                print(output_strs[model_name])
        output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace(
            '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace(
                '%', '\\%').replace('{{}}', '\\ttcbs'))
        print(output_str)
        print()
Exemplo n.º 3
0
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))
Exemplo n.º 4
0
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100):
    """
    Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev
    examples, predictions of different models are listed side-by-side.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[num_examples:num_examples + 100]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv')
    with open(output_path, 'w') as o_f:
        o_f.write('example_id, description, ground_truth, model, prediction, '
                  'correct template, correct command\n')
        for example_id in sample_ids:
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gt_asts = [
                data_tools.bash_parser(gt) for gt in command_gts
            ]
            for model_id, model_name in enumerate(model_names):
                predictions = model_predictions[model_id][example_id]
                for i in xrange(min(3, len(predictions))):
                    if model_id == 0 and i == 0:
                        output_str = '{},"{}",'.format(
                            example_id, sc_txt.replace('"', '""'))
                    else:
                        output_str = ',,'
                    pred_cmd = predictions[i]
                    pred_tree = cmd_parser(pred_cmd)
                    pred_temp = data_tools.ast2template(pred_tree,
                                                        loose_constraints=True)
                    temp_match = tree_dist.one_match(command_gt_asts,
                                                     pred_tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(command_gt_asts,
                                                    pred_tree,
                                                    ignore_arg_value=False)
                    if (model_id * min(3, len(predictions)) +
                            i) < len(command_gts):
                        output_str += '"{}",'.format(
                            command_gts[model_id * min(3, len(predictions)) +
                                        i].strip().replace('"', '""'))
                    else:
                        output_str += ','
                    output_str += '{},"{}",'.format(
                        model_name, pred_cmd.replace('"', '""'))

                    command_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_cmd)
                    structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_temp)
                    command_eval, structure_eval = '', ''
                    if str_match:
                        command_eval = 'y'
                        structure_eval = 'y'
                    elif temp_match:
                        structure_eval = 'y'
                    if command_eval_cache and \
                            command_example_sig in command_eval_cache:
                        command_eval = command_eval_cache[command_example_sig]
                    if structure_eval_cache and \
                            structure_example_sig in structure_eval_cache:
                        structure_eval = structure_eval_cache[
                            structure_example_sig]
                    output_str += '{},{}'.format(structure_eval, command_eval)
                    o_f.write('{}\n'.format(output_str))

    print('Manual evaluation results saved to {}'.format(output_path))