コード例 #1
0
ファイル: eval_tools.py プロジェクト: dhuruvasaditya/Nlc2cmd
def manual_eval(prediction_path,
                dataset,
                FLAGS,
                top_k,
                num_examples=-1,
                interactive=True,
                verbose=True):
    """
    Conduct dev/test set evaluation.

    Evaluation metrics:
        1) full command accuracy;
        2) command template accuracy. 

    :param interactive:
        - If set, prompt the user to enter judgement if a prediction does not
            match any of the groundtruths and the correctness of the prediction
            has not been pre-determined;
          Otherwise, all predictions that does not match any of the groundtruths are counted as wrong.
    """
    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset)

    # Load model prediction
    prediction_list = load_predictions(prediction_path, top_k)

    metrics = get_manual_evaluation_metrics(grouped_dataset,
                                            prediction_list,
                                            FLAGS,
                                            num_examples=num_examples,
                                            interactive=interactive,
                                            verbose=verbose)

    return metrics
コード例 #2
0
ファイル: eval_tools.py プロジェクト: dhuruvasaditya/Nlc2cmd
def automatic_eval(prediction_path,
                   dataset,
                   FLAGS,
                   top_k,
                   num_samples=-1,
                   verbose=False):
    """
    Generate automatic evaluation metrics on dev/test set.
    The following metrics are computed:
        Top 1,3,5,10
            1. Structure accuracy
            2. Full command accuracy
            3. Command keyword overlap
            4. BLEU
    """
    grouped_dataset = data_utils.group_parallel_data(dataset)
    try:
        vocabs = data_utils.load_vocabulary(FLAGS)
    except ValueError:
        vocabs = None

    # Load predictions
    prediction_list = load_predictions(prediction_path, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    metrics = get_automatic_evaluation_metrics(grouped_dataset,
                                               prediction_list, vocabs, FLAGS,
                                               top_k, num_samples, verbose)
    return metrics
コード例 #3
0
def automatic_eval(model_dir,
                   decode_sig,
                   dataset,
                   FLAGS,
                   top_k,
                   num_samples=-1,
                   verbose=False):
    """
    Generate automatic evaluation metrics on a dev/test set.
    The following metrics are computed:
        Top 1,3,5,10
            1. Structure accuracy
            2. Full command accuracy
            3. Command keyword overlap
            4. BLEU
    """
    use_bucket = False if "knn" in model_dir else True

    grouped_dataset = data_utils.group_parallel_data(dataset,
                                                     use_bucket=use_bucket)
    vocabs = data_utils.load_vocabulary(FLAGS)

    # Load predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    M = get_automatic_evaluation_metrics(grouped_dataset, prediction_list,
                                         vocabs, FLAGS, top_k, num_samples,
                                         verbose)
    return M
コード例 #4
0
ファイル: error_analysis.py プロジェクト: syzer/nl2bash
def gen_error_analysis_csv_by_utility(model_dir,
                                      decode_sig,
                                      dataset,
                                      FLAGS,
                                      top_k=10):
    """
    Generate error analysis evaluation sheet grouped by utility.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError(
            "ground truth and predictions length must be equal: {} vs. {}".
            format(len(grouped_dataset), len(prediction_list)))

    # Load cached evaluation results
    cached_evaluation_results = load_cached_evaluations(model_dir)

    # Convert the predictions into csv format
    grammar_errors, argument_errors = print_error_analysis_csv(
        grouped_dataset,
        prediction_list,
        FLAGS,
        cached_evaluation_results,
        group_by_utility=True,
        error_predictions_only=False)

    error_by_utility_path = \
        os.path.join(model_dir, 'error.analysis.by.utility.csv')
    print("Saving grammar errors to {}".format(error_by_utility_path))
    with open(error_by_utility_path, 'w') as error_by_utility_file:
        # print csv file header
        error_by_utility_file.write(
            'utility, example_id, description, groundtruth, prediction, '
            'correct template, correct command\n')
        for line in bash.utility_stats.split('\n'):
            utility = line.split(',')[1]
            error_examples = grammar_errors[utility]
            if len(error_examples) <= 5:
                for example in error_examples:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
            else:
                random.shuffle(error_examples)
                for example in error_examples[:5]:
                    for l in example:
                        error_by_utility_file.write('{},{}\n'.format(
                            utility, l))
コード例 #5
0
ファイル: error_analysis.py プロジェクト: syzer/nl2bash
def gen_error_analysis_csv(model_dir, decode_sig, dataset, FLAGS, top_k=3):
    """
    Generate error analysis evaluation spreadsheet.
        - grammar error analysis
        - argument error analysis
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    # Convert the predictions to csv format
    grammar_errors, argument_errors = print_error_analysis_csv(
        grouped_dataset, prediction_list, FLAGS)

    grammar_error_path = os.path.join(model_dir, 'grammar.error.analysis.csv')
    random.shuffle(grammar_errors)
    with open(grammar_error_path, 'w') as grammar_error_file:
        print("Saving grammar errors to {}".format(grammar_error_path))
        # print csv file header
        grammar_error_file.write(
            'example_id, description, ground_truth, prediction, ' +
            'correct template, correct command\n')
        for example in grammar_errors[:100]:
            for line in example:
                grammar_error_file.write('{}\n'.format(line))

    arg_error_path = os.path.join(model_dir, 'argument.error.analysis.csv')
    random.shuffle(argument_errors)
    with open(arg_error_path, 'w') as arg_error_file:
        print("Saving argument errors to {}".format(arg_error_path))
        # print csv file header
        arg_error_file.write(
            'example_id, description, ground_truth, prediction, ' +
            'correct template, correct command\n')
        for example in argument_errors[:100]:
            for line in example:
                arg_error_file.write('{}\n'.format(line))
コード例 #6
0
ファイル: eval_tools.py プロジェクト: stanstarks/nl2bash
def gen_automatic_evaluation_table(dataset, FLAGS):
    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True)
    vocabs = data_utils.load_vocabulary(FLAGS)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    auto_evaluation_metrics = {}
    for model_id, model_name in enumerate(model_names):
        prediction_list = model_predictions[model_id]
        M = get_automatic_evaluation_metrics(
            grouped_dataset, prediction_list, vocabs, FLAGS, top_k=3)
        auto_evaluation_metrics[model_name] = \
            [M['top_bleu'][0], M['top_bleu'][1], M['top_cms'][0], M['top_cms'][1]]

    metrics_names = ['BLEU1', 'BLEU3', 'TM1', 'TM3']
    print_table(model_names, metrics_names, auto_evaluation_metrics)
コード例 #7
0
ファイル: eval_tools.py プロジェクト: dhuruvasaditya/Nlc2cmd
def gen_manual_evaluation_table(dataset,
                                FLAGS,
                                num_examples=-1,
                                interactive=True):
    """
    Conduct dev/test set evaluation. The results of multiple pre-specified models are tabulated in the same table.

    Evaluation metrics:
        1) full command accuracy;
        2) command template accuracy.
        
    :param interactive:
        - If set, prompt the user to enter judgement if a prediction does not
            match any of the groundtruths and the correctness of the prediction
            has not been pre-determined;
          Otherwise, all predictions that does not match any of the groundtruths are counted as wrong.
    """
    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset)

    # Load all model predictions
    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    manual_eval_metrics = {}
    for model_id, model_name in enumerate(model_names):
        prediction_list = model_predictions[model_names]
        M = get_manual_evaluation_metrics(grouped_dataset,
                                          prediction_list,
                                          FLAGS,
                                          num_examples=num_examples,
                                          interactive=interactive,
                                          verbose=False)
        manual_eval_metrics[model_name] = [
            M['acc_f'][0], M['acc_f'[1]], M['acc_t'][0], M['acc_t'][1]
        ]

    metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3']
    print_eval_table(model_names, metrics_names, manual_eval_metrics)
コード例 #8
0
ファイル: eval_tools.py プロジェクト: dhuruvasaditya/Nlc2cmd
def gen_automatic_evaluation_table(dataset, FLAGS):
    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset)
    vocabs = data_utils.load_vocabulary(FLAGS)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)
    auto_eval_metrics = {}
    for model_id, model_name in enumerate(model_names):
        prediction_list = model_predictions[model_id]
        if prediction_list is not None:
            M = get_automatic_evaluation_metrics(grouped_dataset,
                                                 prediction_list,
                                                 vocabs,
                                                 FLAGS,
                                                 top_k=3)
            auto_eval_metrics[model_name] = [
                M['bleu'][0], M['bleu'][1], M['cms'][0], M['cms'][1]
            ]
        else:
            print('Model {} skipped in evaluation'.format(model_name))
    metrics_names = ['BLEU1', 'BLEU3', 'TM1', 'TM3']
    print_eval_table(model_names, metrics_names, auto_eval_metrics)
コード例 #9
0
ファイル: decode_tools.py プロジェクト: hpplinux/nl2bash
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=False):
    """
    Compute top-k predictions on the dev/test dataset and write the predictions
    to disk.

    :param sess: A TensorFlow session.
    :param model: Prediction model object.
    :param top_k: Number of top predictions to compute.
    :param FLAGS: Training/testing hyperparameter settings.
    :param verbose: If set, also print decoding results to screen.
    """
    nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_parallel_data(
        dataset, okenizer_selector=tokenizer_selector)
    vocabs = data_utils.load_vocabulary(FLAGS)
    rev_sc_vocab = vocabs.rev_sc_vocab

    ts = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H%M%S')
    pred_file_path = os.path.join(model.model_dir, 'predictions.{}.{}'.format(
        model.decode_sig, ts))
    pred_file = open(pred_file_path, 'w')
    eval_file_path = os.path.join(model.model_dir, 'predictions.{}.{}.csv'.format(
        model.decode_sig, ts))
    eval_file = open(eval_file_path, 'w')
    eval_file.write('example_id, description, ground_truth, prediction, ' +
                    'correct template, correct command\n')
    for example_id in xrange(len(grouped_dataset)):
        key, data_group = grouped_dataset[example_id]

        sc_txt = data_group[0].sc_txt.strip()
        sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]
        if FLAGS.channel == 'char':
            sc_temp = ''.join(sc_tokens)
            sc_temp = sc_temp.replace(constants._SPACE, ' ')
        else:
            sc_temp = ' '.join(sc_tokens)
        tg_txts = [dp.tg_txt for dp in data_group]
        tg_asts = [data_tools.bash_parser(tg_txt) for tg_txt in tg_txts]
        if verbose:
            print('\nExample {}:'.format(example_id))
            print('Original Source: {}'.format(sc_txt.encode('utf-8')))
            print('Source: {}'.format(sc_temp.encode('utf-8')))
            for j in xrange(len(data_group)):
                print('GT Target {}: {}'.format(j+1, data_group[j].tg_txt.encode('utf-8')))

        if FLAGS.fill_argument_slots:
            slot_filling_classifier = get_slot_filling_classifer(FLAGS)
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS, slot_filling_classifier=slot_filling_classifier)
        else:
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS)
        if FLAGS.tg_char:
            batch_outputs, batch_char_outputs = batch_outputs

        eval_row = '{},"{}",'.format(example_id, sc_txt.replace('"', '""'))
        if batch_outputs:
            if FLAGS.token_decoding_algorithm == 'greedy':
                tree, pred_cmd = batch_outputs[0]
                if nl2bash:
                    pred_cmd = data_tools.ast2command(
                        tree, loose_constraints=True)
                score = sequence_logits[0]
                if verbose:
                    print('Prediction: {} ({})'.format(pred_cmd, score))
                pred_file.write('{}\n'.format(pred_cmd))
            elif FLAGS.token_decoding_algorithm == 'beam_search':
                top_k_predictions = batch_outputs[0]
                if FLAGS.tg_char:
                    top_k_char_predictions = batch_char_outputs[0]
                top_k_scores = sequence_logits[0]
                num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions))
                for j in xrange(num_preds):
                    if j > 0:
                        eval_row = ',,'
                    if j < len(tg_txts):
                        eval_row += '"{}",'.format(tg_txts[j].strip().replace('"', '""'))
                    else:
                        eval_row += ','
                    top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j]
                    if nl2bash:
                        pred_cmd = data_tools.ast2command(
                            top_k_pred_tree, loose_constraints=True)
                    else:
                        pred_cmd = top_k_pred_cmd
                    pred_file.write('{}|||'.format(pred_cmd.encode('utf-8')))
                    eval_row += '"{}",'.format(pred_cmd.replace('"', '""'))
                    temp_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=True)
                    str_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=False)
                    if temp_match:
                        eval_row += 'y,'
                    if str_match:
                        eval_row += 'y'
                    eval_file.write('{}\n'.format(eval_row.encode('utf-8')))
                    if verbose:
                        print('Prediction {}: {} ({})'.format(
                            j+1, pred_cmd.encode('utf-8'), top_k_scores[j]))
                        if FLAGS.tg_char:
                            print('Character-based prediction {}: {}'.format(
                                j+1, top_k_char_predictions[j].encode('utf-8')))
                pred_file.write('\n')
        else:
            print(APOLOGY_MSG)
            pred_file.write('\n')
            eval_file.write('{}\n'.format(eval_row))
            eval_file.write('\n')
            eval_file.write('\n')
    pred_file.close()
    eval_file.close()
    shutil.copyfile(pred_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest'.format(model.decode_sig)))
    shutil.copyfile(eval_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest.csv'.format(model.decode_sig)))
コード例 #10
0
ファイル: error_analysis.py プロジェクト: syzer/nl2bash
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))
コード例 #11
0
ファイル: error_analysis.py プロジェクト: syzer/nl2bash
def tabulate_example_predictions(dataset, FLAGS, num_examples=100):
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=1)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[:num_examples]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    model_name_pt = {
        'token-seq2seq': 'T-Seq2Seq',
        'tellina': 'Tellina',
        'token-copynet': 'T-CopyNet',
        'partial.token-seq2seq': 'ST-Seq2Seq',
        'partial.token-copynet': 'ST-CopyNet',
        'char-seq2seq': 'C-Seq2Seq',
        'char-copynet': 'C-CopyNet'
    }

    for example_id in sample_ids:
        print('Example {}'.format(example_id))
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        output_strs = {}
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)

                output_str = '& \\<{}> & {}'.format(
                    pred_cmd.replace('__SP__', '').replace('_', '\\_').replace(
                        '$', '\\$').replace('%',
                                            '\\%').replace('{{}}', '\\ttcbs'),
                    model_name_pt[model_name])

                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                output_str += ', {},{} \\\\'.format(structure_eval,
                                                    command_eval)
            output_strs[model_name] = output_str
        for model_name in [
                'char-seq2seq', 'char-copynet', 'token-seq2seq',
                'token-copynet', 'partial.token-seq2seq',
                'partial.token-copynet', 'tellina'
        ]:
            if model_name == 'char-seq2seq':
                print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format(
                    sc_txt) + output_strs[model_name])
            else:
                print(output_strs[model_name])
        output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace(
            '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace(
                '%', '\\%').replace('{{}}', '\\ttcbs'))
        print(output_str)
        print()
コード例 #12
0
ファイル: error_analysis.py プロジェクト: syzer/nl2bash
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100):
    """
    Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev
    examples, predictions of different models are listed side-by-side.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[num_examples:num_examples + 100]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv')
    with open(output_path, 'w') as o_f:
        o_f.write('example_id, description, ground_truth, model, prediction, '
                  'correct template, correct command\n')
        for example_id in sample_ids:
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gt_asts = [
                data_tools.bash_parser(gt) for gt in command_gts
            ]
            for model_id, model_name in enumerate(model_names):
                predictions = model_predictions[model_id][example_id]
                for i in xrange(min(3, len(predictions))):
                    if model_id == 0 and i == 0:
                        output_str = '{},"{}",'.format(
                            example_id, sc_txt.replace('"', '""'))
                    else:
                        output_str = ',,'
                    pred_cmd = predictions[i]
                    pred_tree = cmd_parser(pred_cmd)
                    pred_temp = data_tools.ast2template(pred_tree,
                                                        loose_constraints=True)
                    temp_match = tree_dist.one_match(command_gt_asts,
                                                     pred_tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(command_gt_asts,
                                                    pred_tree,
                                                    ignore_arg_value=False)
                    if (model_id * min(3, len(predictions)) +
                            i) < len(command_gts):
                        output_str += '"{}",'.format(
                            command_gts[model_id * min(3, len(predictions)) +
                                        i].strip().replace('"', '""'))
                    else:
                        output_str += ','
                    output_str += '{},"{}",'.format(
                        model_name, pred_cmd.replace('"', '""'))

                    command_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_cmd)
                    structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_temp)
                    command_eval, structure_eval = '', ''
                    if str_match:
                        command_eval = 'y'
                        structure_eval = 'y'
                    elif temp_match:
                        structure_eval = 'y'
                    if command_eval_cache and \
                            command_example_sig in command_eval_cache:
                        command_eval = command_eval_cache[command_example_sig]
                    if structure_eval_cache and \
                            structure_example_sig in structure_eval_cache:
                        structure_eval = structure_eval_cache[
                            structure_example_sig]
                    output_str += '{},{}'.format(structure_eval, command_eval)
                    o_f.write('{}\n'.format(output_str))

    print('Manual evaluation results saved to {}'.format(output_path))
コード例 #13
0
def gen_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True):
    """
    Generate structure and full command accuracy results on a fixed set of dev/test
        set samples, with the results of multiple models tabulated in the same table.
    :param interactive:
        - If set, prompt the user to enter judgement if a prediction does not
            match any of the groundtruths and the correctness of the prediction
            has not been pre-determined;
          Otherwise, count all predictions that does not match any of the
          groundtruths as wrong.
    """
    def add_judgement(data_dir,
                      nl,
                      command,
                      correct_template='',
                      correct_command=''):
        """
        Append a new judgement
        """
        data_dir = os.path.join(data_dir, 'manual_judgements')
        manual_judgement_path = os.path.join(data_dir,
                                             'manual.evaluations.additional')
        if not os.path.exists(manual_judgement_path):
            with open(manual_judgement_path, 'w') as o_f:
                o_f.write(
                    'description,prediction,template,correct template,correct command\n'
                )
        with open(manual_judgement_path, 'a') as o_f:
            temp = data_tools.cmd2template(command, loose_constraints=True)
            if not correct_template:
                correct_template = 'n'
            if not correct_command:
                correct_command = 'n'
            o_f.write('"{}","{}","{}","{}","{}"\n'.format(
                nl.replace('"', '""'), command.replace('"', '""'),
                temp.replace('"', '""'), correct_template.replace('"', '""'),
                correct_command.replace('"', '""')))
        print('new judgement added to {}'.format(manual_judgement_path))

    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True)

    if FLAGS.test:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset,
            FLAGS,
            top_k=3,
            tellina=True,
            partial_token_copynet=True,
            token_seq2seq=False,
            token_copynet=False,
            char_seq2seq=False,
            char_copynet=False,
            partial_token_seq2seq=False)
    else:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation interface
    num_s_correct = collections.defaultdict(int)
    num_f_correct = collections.defaultdict(int)
    num_s_top_3_correct = collections.defaultdict(int)
    num_f_top_3_correct = collections.defaultdict(int)
    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            top_3_s_correct_marked = False
            top_3_f_correct_marked = False
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_ast = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_ast,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_ast,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_ast,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and command_example_key in command_eval_cache:
                    command_eval = command_eval_cache[command_example_key]
                if structure_eval_cache and structure_example_key in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_key]
                # Prompt for new judgements
                if command_eval != 'y':
                    if structure_eval == 'y':
                        if not command_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    else:
                        if not structure_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            structure_eval = input(
                                'CORRECT STRUCTURE? [y/reason] ')
                            if structure_eval == 'y':
                                command_eval = input(
                                    'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    structure_eval_cache[
                        structure_example_key] = structure_eval
                    command_eval_cache[command_example_key] = command_eval
                if structure_eval == 'y':
                    if i == 0:
                        num_s_correct[model_name] += 1
                    if not top_3_s_correct_marked:
                        num_s_top_3_correct[model_name] += 1
                        top_3_s_correct_marked = True
                if command_eval == 'y':
                    if i == 0:
                        num_f_correct[model_name] += 1
                    if not top_3_f_correct_marked:
                        num_f_top_3_correct[model_name] += 1
                        top_3_f_correct_marked = True
    metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3']
    model_metrics = {}
    for model_name in model_names:
        metrics = [
            num_f_correct[model_name] / len(sample_ids),
            num_f_top_3_correct[model_name] / len(sample_ids),
            num_s_correct[model_name] / len(sample_ids),
            num_s_top_3_correct[model_name] / len(sample_ids)
        ]
        model_metrics[model_name] = metrics
    print_table(model_names, metrics_names, model_metrics)