Esempio n. 1
0
def gen_error_analysis_csv_by_utility(model_dir,
                                      decode_sig,
                                      dataset,
                                      FLAGS,
                                      top_k=10):
    """
    Generate error analysis evaluation sheet grouped by utility.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError(
            "ground truth and predictions length must be equal: {} vs. {}".
            format(len(grouped_dataset), len(prediction_list)))

    # Load cached evaluation results
    cached_evaluation_results = load_cached_evaluations(model_dir)

    # Convert the predictions into csv format
    grammar_errors, argument_errors = print_error_analysis_csv(
        grouped_dataset,
        prediction_list,
        FLAGS,
        cached_evaluation_results,
        group_by_utility=True,
        error_predictions_only=False)

    error_by_utility_path = \
        os.path.join(model_dir, 'error.analysis.by.utility.csv')
    print("Saving grammar errors to {}".format(error_by_utility_path))
    with open(error_by_utility_path, 'w') as error_by_utility_file:
        # print csv file header
        error_by_utility_file.write(
            'utility, example_id, description, groundtruth, prediction, '
            'correct template, correct command\n')
        for line in bash.utility_stats.split('\n'):
            utility = line.split(',')[1]
            error_examples = grammar_errors[utility]
            if len(error_examples) <= 5:
                for example in error_examples:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
            else:
                random.shuffle(error_examples)
                for example in error_examples[:5]:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
Esempio n. 2
0
def gen_error_analysis_csv(model_dir, decode_sig, dataset, FLAGS, top_k=3):
    """
    Generate error analysis evaluation spreadsheet.
        - grammar error analysis
        - argument error analysis
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    # Convert the predictions to csv format
    grammar_errors, argument_errors = print_error_analysis_csv(
        grouped_dataset, prediction_list, FLAGS)

    grammar_error_path = os.path.join(model_dir, 'grammar.error.analysis.csv')
    random.shuffle(grammar_errors)
    with open(grammar_error_path, 'w') as grammar_error_file:
        print("Saving grammar errors to {}".format(grammar_error_path))
        # print csv file header
        grammar_error_file.write(
            'example_id, description, ground_truth, prediction, ' +
            'correct template, correct command\n')
        for example in grammar_errors[:100]:
            for line in example:
                grammar_error_file.write('{}\n'.format(line))

    arg_error_path = os.path.join(model_dir, 'argument.error.analysis.csv')
    random.shuffle(argument_errors)
    with open(arg_error_path, 'w') as arg_error_file:
        print("Saving argument errors to {}".format(arg_error_path))
        # print csv file header
        arg_error_file.write(
            'example_id, description, ground_truth, prediction, ' +
            'correct template, correct command\n')
        for example in argument_errors[:100]:
            for line in example:
                arg_error_file.write('{}\n'.format(line))
Esempio n. 3
0
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))