Beispiel #1
0
def get_manual_evaluation_metrics(grouped_dataset,
                                  prediction_list,
                                  FLAGS,
                                  num_examples=-1,
                                  interactive=True,
                                  verbose=True):

    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth and predictions length must be equal: "
                         "{} vs. {}".format(len(grouped_dataset),
                                            len(prediction_list)))

    # Get dev set samples (fixed)
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'), verbose=True)

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation
    num_t_top_1_correct = 0.0
    num_f_top_1_correct = 0.0
    num_t_top_3_correct = 0.0
    num_f_top_3_correct = 0.0

    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        predictions = prediction_list[example_id]
        top_3_s_correct_marked = False
        top_3_f_correct_marked = False
        for i in xrange(min(3, len(predictions))):
            pred_cmd = predictions[i]
            pred_ast = cmd_parser(pred_cmd)
            pred_temp = data_tools.ast2template(pred_ast,
                                                loose_constraints=True)
            temp_match = tree_dist.one_match(command_gt_asts,
                                             pred_ast,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(command_gt_asts,
                                            pred_ast,
                                            ignore_arg_value=False)
            # Match ground truths & exisitng judgements
            command_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_cmd)
            structure_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_temp)
            command_eval, structure_eval = '', ''
            if str_match:
                command_eval = 'y'
                structure_eval = 'y'
            elif temp_match:
                structure_eval = 'y'
            if command_eval_cache and command_example_key in command_eval_cache:
                command_eval = command_eval_cache[command_example_key]
            if structure_eval_cache and structure_example_key in structure_eval_cache:
                structure_eval = structure_eval_cache[structure_example_key]
            # Prompt for new judgements
            if command_eval != 'y':
                if structure_eval == 'y':
                    if not command_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        command_eval = input('CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                else:
                    if not structure_eval and interactive:
                        print('#{}. {}'.format(exam_id, sc_txt))
                        for j, gt in enumerate(command_gts):
                            print('- GT{}: {}'.format(j, gt))
                        print('> {}'.format(pred_cmd))
                        structure_eval = input(
                            'CORRECT STRUCTURE? [y/reason] ')
                        if structure_eval == 'y':
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                        add_judgement(FLAGS.data_dir, sc_txt, pred_cmd,
                                      structure_eval, command_eval)
                        print()
                structure_eval_cache[structure_example_key] = structure_eval
                command_eval_cache[command_example_key] = command_eval
            if structure_eval == 'y':
                if i == 0:
                    num_t_top_1_correct += 1
                if not top_3_s_correct_marked:
                    num_t_top_3_correct += 1
                    top_3_s_correct_marked = True
            if command_eval == 'y':
                if i == 0:
                    num_f_top_1_correct += 1
                if not top_3_f_correct_marked:
                    num_f_top_3_correct += 1
                    top_3_f_correct_marked = True

    metrics = {}
    acc_f_1 = num_f_top_1_correct / len(sample_ids)
    acc_f_3 = num_f_top_3_correct / len(sample_ids)
    acc_t_1 = num_t_top_1_correct / len(sample_ids)
    acc_t_3 = num_t_top_3_correct / len(sample_ids)
    metrics['acc_f'] = [acc_f_1, acc_f_3]
    metrics['acc_t'] = [acc_t_1, acc_t_3]

    if verbose:
        print('{} examples evaluated'.format(len(sample_ids)))
        print('Top 1 Command Acc = {:.3f}'.format(acc_f_1))
        print('Top 3 Command Acc = {:.3f}'.format(acc_f_3))
        print('Top 1 Template Acc = {:.3f}'.format(acc_t_1))
        print('Top 3 Template Acc = {:.3f}'.format(acc_t_3))
    return metrics
Beispiel #2
0
def get_automatic_evaluation_metrics(grouped_dataset,
                                     prediction_list,
                                     vocabs,
                                     FLAGS,
                                     top_k,
                                     num_samples=-1,
                                     verbose=False):
    cmd_parser = data_tools.bash_parser
    rev_sc_vocab = vocabs.rev_sc_vocab if vocabs is not None else None

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    # Compute manual evaluation scores on a subset of examples
    if num_samples > 0:
        # Get FIXED dev set samples
        random.seed(100)
        example_ids = list(range(len(grouped_dataset)))
        random.shuffle(example_ids)
        sample_ids = example_ids[:100]
        grouped_dataset = [grouped_dataset[i] for i in sample_ids]
        prediction_list = [prediction_list[i] for i in sample_ids]

    num_eval = 0
    top_k_temp_correct = np.zeros([len(grouped_dataset), top_k])
    top_k_str_correct = np.zeros([len(grouped_dataset), top_k])
    top_k_cms = np.zeros([len(grouped_dataset), top_k])
    top_k_bleu = np.zeros([len(grouped_dataset), top_k])

    command_gt_asts_list, pred_ast_list = [], []

    for data_id in xrange(len(grouped_dataset)):
        _, data_group = grouped_dataset[data_id]
        sc_str = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_str)
        if vocabs is not None:
            sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]
            if FLAGS.channel == 'char':
                sc_features = ''.join(sc_tokens)
                sc_features = sc_features.replace(constants._SPACE, ' ')
            else:
                sc_features = ' '.join(sc_tokens)
        command_gts = [dp.tg_txt.strip() for dp in data_group]
        command_gt_asts = [cmd_parser(cmd) for cmd in command_gts]
        command_gt_asts_list.append(command_gt_asts)
        template_gts = [
            data_tools.cmd2template(cmd, loose_constraints=True)
            for cmd in command_gts
        ]
        template_gt_asts = [cmd_parser(temp) for temp in template_gts]
        if verbose:
            print("Example {}".format(data_id))
            print("Original Source: {}".format(sc_str.encode('utf-8')))
            if vocabs is not None:
                print("Source: {}".format(
                    [x.encode('utf-8') for x in sc_features]))
            for j, command_gt in enumerate(command_gts):
                print("GT Target {}: {}".format(
                    j + 1,
                    command_gt.strip().encode('utf-8')))
        num_eval += 1
        predictions = prediction_list[data_id]
        for i in xrange(len(predictions)):
            pred_cmd = predictions[i]
            pred_ast = cmd_parser(pred_cmd)
            if i == 0:
                pred_ast_list.append(pred_ast)
            pred_temp = data_tools.cmd2template(pred_cmd,
                                                loose_constraints=True)
            # A) Exact match with ground truths & exisitng judgements
            command_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_cmd)
            structure_example_key = '{}<NL_PREDICTION>{}'.format(
                sc_key, pred_temp)
            # B) Match ignoring flag orders
            temp_match = tree_dist.one_match(template_gt_asts,
                                             pred_ast,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(command_gt_asts,
                                            pred_ast,
                                            ignore_arg_value=False)
            if command_eval_cache and command_example_key in command_eval_cache:
                str_match = normalize_judgement(
                    command_eval_cache[command_example_key]) == 'y'
            if structure_eval_cache and structure_example_key in structure_eval_cache:
                temp_match = normalize_judgement(
                    structure_eval_cache[structure_example_key]) == 'y'
            if temp_match:
                top_k_temp_correct[data_id, i] = 1
            if str_match:
                top_k_str_correct[data_id, i] = 1
            cms = token_based.command_match_score(command_gt_asts, pred_ast)
            # if pred_cmd.strip():
            #     bleu = token_based.sentence_bleu_score(command_gt_asts, pred_ast)
            # else:
            #     bleu = 0
            bleu = nltk.translate.bleu_score.sentence_bleu(
                command_gts, pred_cmd)
            top_k_cms[data_id, i] = cms
            top_k_bleu[data_id, i] = bleu
            if verbose:
                print("Prediction {}: {} ({}, {})".format(
                    i + 1, pred_cmd, cms, bleu))
        if verbose:
            print()

    bleu = token_based.corpus_bleu_score(command_gt_asts_list, pred_ast_list)

    top_temp_acc = [-1 for _ in [1, 3, 5, 10]]
    top_cmd_acc = [-1 for _ in [1, 3, 5, 10]]
    top_cms = [-1 for _ in [1, 3, 5, 10]]
    top_bleu = [-1 for _ in [1, 3, 5, 10]]
    top_temp_acc[0] = top_k_temp_correct[:, 0].mean()
    top_cmd_acc[0] = top_k_str_correct[:, 0].mean()
    top_cms[0] = top_k_cms[:, 0].mean()
    top_bleu[0] = top_k_bleu[:, 0].mean()
    print("{} examples evaluated".format(num_eval))
    print("Top 1 Template Acc = %.3f" % top_temp_acc[0])
    print("Top 1 Command Acc = %.3f" % top_cmd_acc[0])
    print("Average top 1 Template Match Score = %.3f" % top_cms[0])
    print("Average top 1 BLEU Score = %.3f" % top_bleu[0])
    if len(predictions) > 1:
        top_temp_acc[1] = np.max(top_k_temp_correct[:, :3], 1).mean()
        top_cmd_acc[1] = np.max(top_k_str_correct[:, :3], 1).mean()
        top_cms[1] = np.max(top_k_cms[:, :3], 1).mean()
        top_bleu[1] = np.max(top_k_bleu[:, :3], 1).mean()
        print("Top 3 Template Acc = %.3f" % top_temp_acc[1])
        print("Top 3 Command Acc = %.3f" % top_cmd_acc[1])
        print("Average top 3 Template Match Score = %.3f" % top_cms[1])
        print("Average top 3 BLEU Score = %.3f" % top_bleu[1])
    if len(predictions) > 3:
        top_temp_acc[2] = np.max(top_k_temp_correct[:, :5], 1).mean()
        top_cmd_acc[2] = np.max(top_k_str_correct[:, :5], 1).mean()
        top_cms[2] = np.max(top_k_cms[:, :5], 1).mean()
        top_bleu[2] = np.max(top_k_bleu[:, :5], 1).mean()
        print("Top 5 Template Acc = %.3f" % top_temp_acc[2])
        print("Top 5 Command Acc = %.3f" % top_cmd_acc[2])
        print("Average top 5 Template Match Score = %.3f" % top_cms[2])
        print("Average top 5 BLEU Score = %.3f" % top_bleu[2])
    if len(predictions) > 5:
        top_temp_acc[3] = np.max(top_k_temp_correct[:, :10], 1).mean()
        top_cmd_acc[3] = np.max(top_k_str_correct[:, :10], 1).mean()
        top_cms[3] = np.max(top_k_cms[:, :10], 1).mean()
        top_bleu[3] = np.max(top_k_bleu[:, :10], 1).mean()
        print("Top 10 Template Acc = %.3f" % top_temp_acc[3])
        print("Top 10 Command Acc = %.3f" % top_cmd_acc[3])
        print("Average top 10 Template Match Score = %.3f" % top_cms[3])
        print("Average top 10 BLEU Score = %.3f" % top_bleu[3])
    print('Corpus BLEU = %.3f' % bleu)
    print()

    metrics = {}
    metrics['acc_f'] = top_cmd_acc
    metrics['acc_t'] = top_temp_acc
    metrics['cms'] = top_cms
    metrics['bleu'] = top_bleu

    return metrics
Beispiel #3
0
def gen_manual_evaluation_csv_single_model(dataset, FLAGS):
    """
    Generate .csv spreadsheet for manual evaluation on dev/test set
    examples for a specific model.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    # Load model predictions
    model_subdir, decode_sig = graph_utils.get_decode_signature(FLAGS)
    model_dir = os.path.join(FLAGS.model_root_dir, model_subdir)
    prediction_list = load_predictions(model_dir, decode_sig, top_k=3)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError("ground truth list and prediction list length must "
                         "be equal: {} vs. {}".format(len(grouped_dataset),
                                                      len(prediction_list)))

    # Load additional ground truths
    template_translations, command_translations = load_cached_correct_translations(
        FLAGS.data_dir)

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = load_cached_evaluations(
        os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash else data_tools.paren_parser

    output_path = os.path.join(model_dir, 'manual.evaluations.single.model')
    with open(output_path, 'w') as o_f:
        # write spreadsheet header
        o_f.write('id,description,command,correct template,correct command\n')
        for example_id in range(len(grouped_dataset)):
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gts = set(command_gts + command_translations[sc_key])
            command_gt_asts = [
                data_tools.bash_parser(cmd) for cmd in command_gts
            ]
            template_gts = [
                data_tools.cmd2template(cmd, loose_constraints=True)
                for cmd in command_gts
            ]
            template_gts = set(template_gts + template_translations[sc_key])
            template_gt_asts = [
                data_tools.bash_parser(temp) for temp in template_gts
            ]
            predictions = prediction_list[example_id]
            for i in xrange(3):
                if i >= len(predictions):
                    o_f.write(',,,n,n\n')
                    continue
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(template_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                if i == 0:
                    o_f.write('{},"{}","{}",{},{}\n'.format(
                        example_id, sc_txt.replace('"', '""'),
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
                else:
                    o_f.write(',,"{}",{},{}\n'.format(
                        pred_cmd.replace('"', '""'), structure_eval,
                        command_eval))
    print('manual evaluation spreadsheet saved to {}'.format(output_path))
Beispiel #4
0
def decode_set(sess, model, dataset, top_k, FLAGS, verbose=False):
    """
    Compute top-k predictions on the dev/test dataset and write the predictions
    to disk.

    :param sess: A TensorFlow session.
    :param model: Prediction model object.
    :param top_k: Number of top predictions to compute.
    :param FLAGS: Training/testing hyperparameter settings.
    :param verbose: If set, also print decoding results to screen.
    """
    nl2bash = FLAGS.dataset.startswith('bash') and not FLAGS.explain

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_parallel_data(
        dataset, okenizer_selector=tokenizer_selector)
    vocabs = data_utils.load_vocabulary(FLAGS)
    rev_sc_vocab = vocabs.rev_sc_vocab

    ts = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H%M%S')
    pred_file_path = os.path.join(model.model_dir, 'predictions.{}.{}'.format(
        model.decode_sig, ts))
    pred_file = open(pred_file_path, 'w')
    eval_file_path = os.path.join(model.model_dir, 'predictions.{}.{}.csv'.format(
        model.decode_sig, ts))
    eval_file = open(eval_file_path, 'w')
    eval_file.write('example_id, description, ground_truth, prediction, ' +
                    'correct template, correct command\n')
    for example_id in xrange(len(grouped_dataset)):
        key, data_group = grouped_dataset[example_id]

        sc_txt = data_group[0].sc_txt.strip()
        sc_tokens = [rev_sc_vocab[i] for i in data_group[0].sc_ids]
        if FLAGS.channel == 'char':
            sc_temp = ''.join(sc_tokens)
            sc_temp = sc_temp.replace(constants._SPACE, ' ')
        else:
            sc_temp = ' '.join(sc_tokens)
        tg_txts = [dp.tg_txt for dp in data_group]
        tg_asts = [data_tools.bash_parser(tg_txt) for tg_txt in tg_txts]
        if verbose:
            print('\nExample {}:'.format(example_id))
            print('Original Source: {}'.format(sc_txt.encode('utf-8')))
            print('Source: {}'.format(sc_temp.encode('utf-8')))
            for j in xrange(len(data_group)):
                print('GT Target {}: {}'.format(j+1, data_group[j].tg_txt.encode('utf-8')))

        if FLAGS.fill_argument_slots:
            slot_filling_classifier = get_slot_filling_classifer(FLAGS)
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS, slot_filling_classifier=slot_filling_classifier)
        else:
            batch_outputs, sequence_logits = translate_fun(data_group, sess, model,
                vocabs, FLAGS)
        if FLAGS.tg_char:
            batch_outputs, batch_char_outputs = batch_outputs

        eval_row = '{},"{}",'.format(example_id, sc_txt.replace('"', '""'))
        if batch_outputs:
            if FLAGS.token_decoding_algorithm == 'greedy':
                tree, pred_cmd = batch_outputs[0]
                if nl2bash:
                    pred_cmd = data_tools.ast2command(
                        tree, loose_constraints=True)
                score = sequence_logits[0]
                if verbose:
                    print('Prediction: {} ({})'.format(pred_cmd, score))
                pred_file.write('{}\n'.format(pred_cmd))
            elif FLAGS.token_decoding_algorithm == 'beam_search':
                top_k_predictions = batch_outputs[0]
                if FLAGS.tg_char:
                    top_k_char_predictions = batch_char_outputs[0]
                top_k_scores = sequence_logits[0]
                num_preds = min(FLAGS.beam_size, top_k, len(top_k_predictions))
                for j in xrange(num_preds):
                    if j > 0:
                        eval_row = ',,'
                    if j < len(tg_txts):
                        eval_row += '"{}",'.format(tg_txts[j].strip().replace('"', '""'))
                    else:
                        eval_row += ','
                    top_k_pred_tree, top_k_pred_cmd = top_k_predictions[j]
                    if nl2bash:
                        pred_cmd = data_tools.ast2command(
                            top_k_pred_tree, loose_constraints=True)
                    else:
                        pred_cmd = top_k_pred_cmd
                    pred_file.write('{}|||'.format(pred_cmd.encode('utf-8')))
                    eval_row += '"{}",'.format(pred_cmd.replace('"', '""'))
                    temp_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=True)
                    str_match = tree_dist.one_match(
                        tg_asts, top_k_pred_tree, ignore_arg_value=False)
                    if temp_match:
                        eval_row += 'y,'
                    if str_match:
                        eval_row += 'y'
                    eval_file.write('{}\n'.format(eval_row.encode('utf-8')))
                    if verbose:
                        print('Prediction {}: {} ({})'.format(
                            j+1, pred_cmd.encode('utf-8'), top_k_scores[j]))
                        if FLAGS.tg_char:
                            print('Character-based prediction {}: {}'.format(
                                j+1, top_k_char_predictions[j].encode('utf-8')))
                pred_file.write('\n')
        else:
            print(APOLOGY_MSG)
            pred_file.write('\n')
            eval_file.write('{}\n'.format(eval_row))
            eval_file.write('\n')
            eval_file.write('\n')
    pred_file.close()
    eval_file.close()
    shutil.copyfile(pred_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest'.format(model.decode_sig)))
    shutil.copyfile(eval_file_path, os.path.join(FLAGS.model_dir,
        'predictions.{}.latest.csv'.format(model.decode_sig)))
Beispiel #5
0
def tabulate_example_predictions(dataset, FLAGS, num_examples=100):
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=1)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[:num_examples]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    model_name_pt = {
        'token-seq2seq': 'T-Seq2Seq',
        'tellina': 'Tellina',
        'token-copynet': 'T-CopyNet',
        'partial.token-seq2seq': 'ST-Seq2Seq',
        'partial.token-copynet': 'ST-CopyNet',
        'char-seq2seq': 'C-Seq2Seq',
        'char-copynet': 'C-CopyNet'
    }

    for example_id in sample_ids:
        print('Example {}'.format(example_id))
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        output_strs = {}
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_tree = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_tree,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_tree,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_tree,
                                                ignore_arg_value=False)

                output_str = '& \\<{}> & {}'.format(
                    pred_cmd.replace('__SP__', '').replace('_', '\\_').replace(
                        '$', '\\$').replace('%',
                                            '\\%').replace('{{}}', '\\ttcbs'),
                    model_name_pt[model_name])

                command_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and \
                        command_example_sig in command_eval_cache:
                    command_eval = command_eval_cache[command_example_sig]
                if structure_eval_cache and \
                        structure_example_sig in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_sig]
                output_str += ', {},{} \\\\'.format(structure_eval,
                                                    command_eval)
            output_strs[model_name] = output_str
        for model_name in [
                'char-seq2seq', 'char-copynet', 'token-seq2seq',
                'token-copynet', 'partial.token-seq2seq',
                'partial.token-copynet', 'tellina'
        ]:
            if model_name == 'char-seq2seq':
                print('\\multirow{{7}}{{*}}{{\\specialcell{{{}}}}} '.format(
                    sc_txt) + output_strs[model_name])
            else:
                print(output_strs[model_name])
        output_str = '& \<{}> & Human \\\\'.format(command_gts[0].replace(
            '__SP__', '').replace('_', '\\_').replace('$', '\\$').replace(
                '%', '\\%').replace('{{}}', '\\ttcbs'))
        print(output_str)
        print()
Beispiel #6
0
def print_error_analysis_csv(grouped_dataset,
                             prediction_list,
                             FLAGS,
                             cached_evaluation_results=None,
                             group_by_utility=False,
                             error_predictions_only=True):
    """
    Convert dev/test set examples to csv format so as to make it easier for
    human annotators to enter their judgements.

    :param grouped_dataset: dev/test set grouped by natural language.
    :param prediction_list: model predictions.
    :param FLAGS: experiment hyperparameters.
    :param cached_evaluation_results: cached evaluation results from previous
        rounds.
    :param group_by_utility: if set, group the error examples by the utilities
        used in the ground truth.
    """
    def mark_example(error_list, example, gt_utility=None):
        if gt_utility:
            error_list[gt_utility].append(example)
        else:
            error_list.append(example)

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser
    if group_by_utility:
        utility_index = {}
        for line in bash.utility_stats.split('\n'):
            ind, utility, _, _ = line.split(',')
            utility_index[utility] = ind

    grammar_errors = collections.defaultdict(list) if group_by_utility else []
    argument_errors = collections.defaultdict(list) if group_by_utility else []
    example_id = 0
    for nl_temp, data_group in grouped_dataset:
        sc_txt = data_group[0].sc_txt.strip()
        sc_temp = get_example_nl_key(sc_txt)
        tg_strs = [dp.tg_txt for dp in data_group]
        gt_trees = [cmd_parser(cm_str) for cm_str in tg_strs]
        if group_by_utility:
            gt_utilities = functools.reduce(
                lambda x, y: x | y,
                [data_tools.get_utilities(gt) for gt in gt_trees])
            gt_utility = sorted(list(gt_utilities),
                                key=lambda x: int(utility_index[x]))[-1]
        else:
            gt_utility = None
        predictions = prediction_list[example_id]
        example_id += 1
        example = []
        grammar_error, argument_error = False, False
        for i in xrange(min(3, len(predictions))):
            if i == 0:
                output_str = '{},"{}",'.format(example_id,
                                               sc_txt.replace('"', '""'))
            else:
                output_str = ',,'
            pred_cmd = predictions[i]
            tree = cmd_parser(pred_cmd)

            # evaluation ignoring flag orders
            temp_match = tree_dist.one_match(gt_trees,
                                             tree,
                                             ignore_arg_value=True)
            str_match = tree_dist.one_match(gt_trees,
                                            tree,
                                            ignore_arg_value=False)
            if i < len(tg_strs):
                output_str += '"{}",'.format(tg_strs[i].strip().replace(
                    '"', '""'))
            else:
                output_str += ','
            output_str += '"{}",'.format(pred_cmd.replace('"', '""'))
            if not str_match:
                if temp_match:
                    if i == 0:
                        argument_error = True
                        grammar_error = True
                else:
                    if i == 0:
                        grammar_error = True

            example_sig = '{}<NL_PREDICTION>{}'.format(sc_temp, pred_cmd)
            if cached_evaluation_results and \
                    example_sig in cached_evaluation_results:
                output_str += cached_evaluation_results[example_sig]
            else:
                if str_match:
                    output_str += 'y,y'
                elif temp_match:
                    output_str += 'y,'
            example.append(output_str)
        if error_predictions_only:
            if grammar_error:
                mark_example(grammar_errors, example, gt_utility)
            elif argument_error:
                mark_example(argument_errors, example, gt_utility)
        else:
            mark_example(grammar_errors, example, gt_utility)

    return grammar_errors, argument_errors
Beispiel #7
0
def gen_manual_evaluation_csv(dataset, FLAGS, num_examples=100):
    """
    Generate .csv spreadsheet for manual evaluation on a fixed set of test/dev
    examples, predictions of different models are listed side-by-side.
    """
    # Group dataset
    tokenizer_selector = "cm" if FLAGS.explain else "nl"
    grouped_dataset = data_utils.group_parallel_data(
        dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

    model_names, model_predictions = load_all_model_predictions(
        grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    sample_ids = example_ids[num_examples:num_examples + 100]

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    output_path = os.path.join(FLAGS.data_dir, 'manual.evaluations.csv')
    with open(output_path, 'w') as o_f:
        o_f.write('example_id, description, ground_truth, model, prediction, '
                  'correct template, correct command\n')
        for example_id in sample_ids:
            data_group = grouped_dataset[example_id][1]
            sc_txt = data_group[0].sc_txt.strip()
            sc_key = get_example_nl_key(sc_txt)
            command_gts = [dp.tg_txt for dp in data_group]
            command_gt_asts = [
                data_tools.bash_parser(gt) for gt in command_gts
            ]
            for model_id, model_name in enumerate(model_names):
                predictions = model_predictions[model_id][example_id]
                for i in xrange(min(3, len(predictions))):
                    if model_id == 0 and i == 0:
                        output_str = '{},"{}",'.format(
                            example_id, sc_txt.replace('"', '""'))
                    else:
                        output_str = ',,'
                    pred_cmd = predictions[i]
                    pred_tree = cmd_parser(pred_cmd)
                    pred_temp = data_tools.ast2template(pred_tree,
                                                        loose_constraints=True)
                    temp_match = tree_dist.one_match(command_gt_asts,
                                                     pred_tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(command_gt_asts,
                                                    pred_tree,
                                                    ignore_arg_value=False)
                    if (model_id * min(3, len(predictions)) +
                            i) < len(command_gts):
                        output_str += '"{}",'.format(
                            command_gts[model_id * min(3, len(predictions)) +
                                        i].strip().replace('"', '""'))
                    else:
                        output_str += ','
                    output_str += '{},"{}",'.format(
                        model_name, pred_cmd.replace('"', '""'))

                    command_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_cmd)
                    structure_example_sig = '{}<NL_PREDICTION>{}'.format(
                        sc_key, pred_temp)
                    command_eval, structure_eval = '', ''
                    if str_match:
                        command_eval = 'y'
                        structure_eval = 'y'
                    elif temp_match:
                        structure_eval = 'y'
                    if command_eval_cache and \
                            command_example_sig in command_eval_cache:
                        command_eval = command_eval_cache[command_example_sig]
                    if structure_eval_cache and \
                            structure_example_sig in structure_eval_cache:
                        structure_eval = structure_eval_cache[
                            structure_example_sig]
                    output_str += '{},{}'.format(structure_eval, command_eval)
                    o_f.write('{}\n'.format(output_str))

    print('Manual evaluation results saved to {}'.format(output_path))
Beispiel #8
0
def gen_evaluation_table(dataset, FLAGS, num_examples=-1, interactive=True):
    """
    Generate structure and full command accuracy results on a fixed set of dev/test
        set samples, with the results of multiple models tabulated in the same table.
    :param interactive:
        - If set, prompt the user to enter judgement if a prediction does not
            match any of the groundtruths and the correctness of the prediction
            has not been pre-determined;
          Otherwise, count all predictions that does not match any of the
          groundtruths as wrong.
    """
    def add_judgement(data_dir,
                      nl,
                      command,
                      correct_template='',
                      correct_command=''):
        """
        Append a new judgement
        """
        data_dir = os.path.join(data_dir, 'manual_judgements')
        manual_judgement_path = os.path.join(data_dir,
                                             'manual.evaluations.additional')
        if not os.path.exists(manual_judgement_path):
            with open(manual_judgement_path, 'w') as o_f:
                o_f.write(
                    'description,prediction,template,correct template,correct command\n'
                )
        with open(manual_judgement_path, 'a') as o_f:
            temp = data_tools.cmd2template(command, loose_constraints=True)
            if not correct_template:
                correct_template = 'n'
            if not correct_command:
                correct_command = 'n'
            o_f.write('"{}","{}","{}","{}","{}"\n'.format(
                nl.replace('"', '""'), command.replace('"', '""'),
                temp.replace('"', '""'), correct_template.replace('"', '""'),
                correct_command.replace('"', '""')))
        print('new judgement added to {}'.format(manual_judgement_path))

    # Group dataset
    grouped_dataset = data_utils.group_parallel_data(dataset, use_bucket=True)

    if FLAGS.test:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset,
            FLAGS,
            top_k=3,
            tellina=True,
            partial_token_copynet=True,
            token_seq2seq=False,
            token_copynet=False,
            char_seq2seq=False,
            char_copynet=False,
            partial_token_seq2seq=False)
    else:
        model_names, model_predictions = load_all_model_predictions(
            grouped_dataset, FLAGS, top_k=3)

    # Get FIXED dev set samples
    random.seed(100)
    example_ids = list(range(len(grouped_dataset)))
    random.shuffle(example_ids)
    if num_examples > 0:
        sample_ids = example_ids[:num_examples]
    else:
        sample_ids = example_ids

    # Load cached evaluation results
    structure_eval_cache, command_eval_cache = \
        load_cached_evaluations(
            os.path.join(FLAGS.data_dir, 'manual_judgements'))

    eval_bash = FLAGS.dataset.startswith("bash")
    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    # Interactive manual evaluation interface
    num_s_correct = collections.defaultdict(int)
    num_f_correct = collections.defaultdict(int)
    num_s_top_3_correct = collections.defaultdict(int)
    num_f_top_3_correct = collections.defaultdict(int)
    for exam_id, example_id in enumerate(sample_ids):
        data_group = grouped_dataset[example_id][1]
        sc_txt = data_group[0].sc_txt.strip()
        sc_key = get_example_nl_key(sc_txt)
        command_gts = [dp.tg_txt for dp in data_group]
        command_gt_asts = [data_tools.bash_parser(gt) for gt in command_gts]
        for model_id, model_name in enumerate(model_names):
            predictions = model_predictions[model_id][example_id]
            top_3_s_correct_marked = False
            top_3_f_correct_marked = False
            for i in xrange(min(3, len(predictions))):
                pred_cmd = predictions[i]
                pred_ast = cmd_parser(pred_cmd)
                pred_temp = data_tools.ast2template(pred_ast,
                                                    loose_constraints=True)
                temp_match = tree_dist.one_match(command_gt_asts,
                                                 pred_ast,
                                                 ignore_arg_value=True)
                str_match = tree_dist.one_match(command_gt_asts,
                                                pred_ast,
                                                ignore_arg_value=False)
                # Match ground truths & exisitng judgements
                command_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_cmd)
                structure_example_key = '{}<NL_PREDICTION>{}'.format(
                    sc_key, pred_temp)
                command_eval, structure_eval = '', ''
                if str_match:
                    command_eval = 'y'
                    structure_eval = 'y'
                elif temp_match:
                    structure_eval = 'y'
                if command_eval_cache and command_example_key in command_eval_cache:
                    command_eval = command_eval_cache[command_example_key]
                if structure_eval_cache and structure_example_key in structure_eval_cache:
                    structure_eval = structure_eval_cache[
                        structure_example_key]
                # Prompt for new judgements
                if command_eval != 'y':
                    if structure_eval == 'y':
                        if not command_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            command_eval = input(
                                'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    else:
                        if not structure_eval and interactive:
                            print('#{}. {}'.format(exam_id, sc_txt))
                            for j, gt in enumerate(command_gts):
                                print('- GT{}: {}'.format(j, gt))
                            print('> {}'.format(pred_cmd))
                            structure_eval = input(
                                'CORRECT STRUCTURE? [y/reason] ')
                            if structure_eval == 'y':
                                command_eval = input(
                                    'CORRECT COMMAND? [y/reason] ')
                            add_judgement(FLAGS.data_dir, sc_key, pred_cmd,
                                          structure_eval, command_eval)
                            print()
                    structure_eval_cache[
                        structure_example_key] = structure_eval
                    command_eval_cache[command_example_key] = command_eval
                if structure_eval == 'y':
                    if i == 0:
                        num_s_correct[model_name] += 1
                    if not top_3_s_correct_marked:
                        num_s_top_3_correct[model_name] += 1
                        top_3_s_correct_marked = True
                if command_eval == 'y':
                    if i == 0:
                        num_f_correct[model_name] += 1
                    if not top_3_f_correct_marked:
                        num_f_top_3_correct[model_name] += 1
                        top_3_f_correct_marked = True
    metrics_names = ['Acc_F_1', 'Acc_F_3', 'Acc_T_1', 'Acc_T_3']
    model_metrics = {}
    for model_name in model_names:
        metrics = [
            num_f_correct[model_name] / len(sample_ids),
            num_f_top_3_correct[model_name] / len(sample_ids),
            num_s_correct[model_name] / len(sample_ids),
            num_s_top_3_correct[model_name] / len(sample_ids)
        ]
        model_metrics[model_name] = metrics
    print_table(model_names, metrics_names, model_metrics)
Beispiel #9
0
def gen_eval_sheet(model, dataset, FLAGS, output_path):
    """
    :param model:
    :param dataset:
    :param FLAGS:
    :param output_path:

    :return:
    """
    with open(output_path, 'w') as o_f:
        # print evaluation form header
        o_f.write('example_id, description, ground_truth, prediction, ' +
                  'correct command, correct template\n')

        example_id = 0
        eval_bash = FLAGS.dataset.startswith("bash")
        cmd_parser = data_tools.bash_parser if eval_bash \
            else data_tools.paren_parser
        tokenizer_selector = "cm" if FLAGS.explain else "nl"
        grouped_dataset = data_utils.group_data(
            dataset, use_bucket=True, tokenizer_selector=tokenizer_selector)

        with DBConnection() as db:
            for nl_temp in grouped_dataset:
                data_group = grouped_dataset[nl_temp]
                nl_str = data_group[0].sc_txt

                tg_strs = [dp.tg_txt for dp in data_group]
                gt_trees = [cmd_parser(cm_str) for cm_str in tg_strs]
                if any(data_tools.is_low_frequency(t) for t in gt_trees):
                    continue
                gt_trees = gt_trees + [
                    cmd_parser(cmd) for cmd in db.get_correct_temps(nl_str)
                ]

                predictions = db.get_top_k_predictions(model, nl_str, k=10)

                example_id += 1

                for i in xrange(min(3, len(predictions))):
                    if i == 0:
                        output_str = '{},{},'.format(example_id,
                                                     nl_temp.strip())
                    else:
                        output_str = ',,'
                    pred_cmd, score = predictions[i]
                    tree = cmd_parser(pred_cmd)

                    # evaluation ignoring flag orders
                    temp_match = tree_dist.one_match(gt_trees,
                                                     tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(gt_trees,
                                                    tree,
                                                    ignore_arg_value=False)
                    if i < len(tg_strs):
                        output_str += '{},'.format(tg_strs[i].strip())
                    else:
                        output_str += ','
                    output_str += '{},'.format(pred_cmd)
                    if temp_match:
                        output_str += 'y,'
                    if str_match:
                        output_str += 'y'

                    o_f.write(output_str + '\n')
Beispiel #10
0
def eval_set(model_dir, decode_sig, dataset, top_k, FLAGS, verbose=True):
    eval_bash = FLAGS.dataset.startswith("bash") and not FLAGS.explain
    eval_regex = FLAGS.dataset.startswith("regex") and not FLAGS.explain

    cmd_parser = data_tools.bash_parser if eval_bash \
        else data_tools.paren_parser

    use_bucket = False if "knn" in model_dir else True

    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_data(
        dataset,
        use_bucket=use_bucket,
        use_temp=(eval_bash and FLAGS.normalized),
        tokenizer_selector=tokenizer_selector)
    top_k_temp_correct = np.zeros([len(grouped_dataset), top_k])
    top_k_str_correct = np.zeros([len(grouped_dataset), top_k])
    if eval_bash:
        top_k_cms = np.zeros([len(grouped_dataset), top_k])

    prediction_list = load_predictions(model_dir, decode_sig, top_k)
    if len(grouped_dataset) != len(prediction_list):
        raise ValueError(
            "ground truth and predictions length must be equal: {} vs. {}".
            format(len(grouped_dataset), len(prediction_list)))

    with DBConnection() as db:
        data_id = 0
        for num_eval in xrange(len(grouped_dataset)):
            key, data_group = grouped_dataset[num_eval]
            sc_str = data_group[0].sc_txt
            tg_strs = [dp.tg_txt.strip() for dp in data_group]
            num_gts = len(tg_strs)
            if eval_bash:
                gt_trees = [cmd_parser(cm_str) for cm_str in tg_strs]
                gts = gt_trees + \
                      [cmd_parser(cmd) for cmd in db.get_correct_temps(sc_str)]
            else:
                gts = tg_strs + db.get_correct_temps(sc_str)

            if verbose:
                print("Example {} ({})".format(num_eval, len(tg_strs)))
                print("Original Source: {}".format(sc_str))
                print("Source: {}".format(key))
                for j in xrange(len(tg_strs)):
                    print("GT Target {}: ".format(j + 1) + tg_strs[j].strip())
            num_eval += (1 if eval_bash else num_gts)

            predictions = prediction_list[data_id]
            for i in xrange(len(predictions)):
                pred_cmd = predictions[i]
                tree = cmd_parser(pred_cmd)
                unprocessed_pred_cmd = regexDFAEquals.unprocess_regex(pred_cmd)
                # evaluation ignoring flag orders
                if eval_bash:
                    temp_match = tree_dist.one_match(gts,
                                                     tree,
                                                     ignore_arg_value=True)
                    str_match = tree_dist.one_match(gts,
                                                    tree,
                                                    ignore_arg_value=False)
                else:
                    if eval_regex:
                        str_match = False
                        for cmd_str in gts:
                            unprocessed_cmd_str = regexDFAEquals.unprocess_regex(
                                cmd_str)
                            if regexDFAEquals.regex_equiv_from_raw(
                                    cmd_str, pred_cmd):
                                str_match = True
                                # Debugging
                                if verbose:
                                    if cmd_str != pred_cmd:
                                        print(
                                            "----------------------------------"
                                        )
                                        print("1) {} ({})".format(
                                            cmd_str, unprocessed_cmd_str))
                                        print("2) {} ({})".format(
                                            pred_cmd, unprocessed_pred_cmd))
                                        print(
                                            "----------------------------------"
                                        )
                                    else:
                                        print(
                                            "----------------------------------"
                                        )
                                        print("i) {} ({})".format(
                                            cmd_str, unprocessed_cmd_str))
                                        print("ii) {} ({})".format(
                                            pred_cmd, unprocessed_pred_cmd))
                                        print(
                                            "----------------------------------"
                                        )
                                break
                            else:
                                if verbose:
                                    print("----------------------------------")
                                    print("A) {} ({})".format(
                                        cmd_str, unprocessed_cmd_str))
                                    print("B) {} ({})".format(
                                        pred_cmd, unprocessed_pred_cmd))
                                    print("----------------------------------")
                    else:
                        str_match = pred_cmd in gts
                    temp_match = str_match

                cms = token_based.command_match_score(gts, tree) \
                    if eval_bash else -1

                if temp_match:
                    top_k_temp_correct[data_id,
                                       i] = 1 if eval_bash else num_gts
                if str_match:
                    top_k_str_correct[data_id, i] = 1 if eval_bash else num_gts
                if eval_bash:
                    top_k_cms[data_id, i] = cms
                    if verbose:
                        print("Prediction {}: {} ({})".format(
                            i + 1, pred_cmd, cms))
                else:
                    if verbose:
                        print("Prediction {}: {}".format(i + 1, pred_cmd))

            if verbose:
                print()

            data_id += 1
    M = {}
    M["top1_temp_ms"] = np.sum(top_k_temp_correct[:, 0]) / num_eval
    M["top3_temp_ms"] = -1
    M["top5_temp_ms"] = -1
    M["top10_temp_ms"] = -1
    M["top1_str_ms"] = np.sum(top_k_str_correct[:, 0]) / num_eval
    M["top3_str_ms"] = -1
    M["top5_str_ms"] = -1
    M["top10_str_ms"] = -1
    if eval_bash:
        M["top1_cms"] = np.sum(top_k_cms[:, 0] / num_eval)
        M["top3_cms"] = -1
        M["top5_cms"] = -1
        M["top10_cms"] = -1

    print("%d examples evaluated" % num_eval)
    print("Top 1 Match (template-only) = %.3f" % M["top1_temp_ms"])
    print("Top 1 Match (whole-string) = %.3f" % M["top1_str_ms"])
    if eval_bash:
        print("Average top 1 Template Match Score = %.3f" % M["top1_cms"])

    if len(predictions) > 1:
        M["top3_temp_ms"] = \
            np.sum(np.max(top_k_temp_correct[:, :3], 1)) / num_eval
        M["top3_str_ms"] = \
            np.sum(np.max(top_k_str_correct[:, :3], 1)) /num_eval
        print("Top 3 Match (template-only) = %.3f" % M["top3_temp_ms"])
        print("Top 3 Match (whole-string) = %.3f" % M["top3_str_ms"])
        if eval_bash:
            M["top3_cms"] = np.sum(np.max(top_k_cms[:, :3], 1)) / num_eval
            print("Average top 3 Template Match Score = %.3f" % M["top3_cms"])
    if len(predictions) > 3:
        M["top5_temp_ms"] = \
            np.sum(np.max(top_k_temp_correct[:, :5], 1)) / num_eval
        M["top5_str_ms"] = \
            np.sum(np.max(top_k_str_correct[:, :5], 1)) /num_eval
        print("Top 5 Match (template-only) = %.3f" % M["top5_temp_ms"])
        print("Top 5 Match (whole-string) = %.3f" % M["top5_str_ms"])
        if eval_bash:
            M["top5_cms"] = np.sum(np.max(top_k_cms[:, :5], 1)) / num_eval
            print("Average top 5 Template Match Score = %.3f" % M["top5_cms"])
    if len(predictions) > 5:
        M["top10_temp_ms"] = \
            np.sum(np.max(top_k_temp_correct[:, :10], 1)) / num_eval
        M["top10_str_ms"] = \
            np.sum(np.max(top_k_str_correct[:, :10], 1)) / num_eval
        print("Top 10 Match (template-only) = %.3f" % M["top10_temp_ms"])
        print("Top 10 Match (whole-string) = %.3f" % M["top10_str_ms"])
        if eval_bash:
            M["top10_cms"] = np.sum(np.max(top_k_cms[:, :10], 1)) / num_eval
            print("Average top 10 Template Match Score = %.3f" %
                  M["top10_cms"])
    print()

    return M
Beispiel #11
0
def manual_eval(model, dataset, FLAGS, output_dir, num_eval=None):
    num_top1_correct_temp = 0.0
    num_top3_correct_temp = 0.0
    num_top5_correct_temp = 0.0
    num_top10_correct_temp = 0.0
    num_top1_correct = 0.0
    num_top3_correct = 0.0
    num_top5_correct = 0.0
    num_top10_correct = 0.0
    num_evaled = 0

    eval_bash = FLAGS.dataset.startswith("bash")
    use_bucket = False if model == "knn" else True
    tokenizer_selector = 'cm' if FLAGS.explain else 'nl'
    grouped_dataset = data_utils.group_data(
        dataset,
        use_bucket=use_bucket,
        use_temp=eval_bash,
        tokenizer_selector=tokenizer_selector)

    if num_eval is None:
        num_eval = len(grouped_dataset)

    cmd_parser = data_tools.bash_parser if FLAGS.dataset.startswith("bash") \
        else data_tools.paren_parser

    o_f = open(os.path.join(output_dir, "manual.eval.results"), 'w')

    rejudge = False

    with DBConnection() as db:
        db.create_schema()
        while num_evaled < len(grouped_dataset):
            nl_strs, tg_strs, nls, search_historys = grouped_dataset[
                num_evaled]
            nl_str = nl_strs[0].decode('utf-8')

            if num_evaled == num_eval:
                break

            gt_trees = [cmd_parser(cmd) for cmd in tg_strs]

            predictions = db.get_top_k_predictions(model, nl_str, k=10)

            top1_correct_temp, top3_correct_temp, top5_correct_temp, \
                top10_correct_temp = False, False, False, False
            top1_correct, top3_correct, top5_correct, top10_correct = \
                False, False, False, False

            # evaluation ignoring ordering of flags
            print("Example %d (%d)" % (num_evaled + 1, len(tg_strs)))
            o_f.write("Example %d (%d)" % (num_evaled + 1, len(tg_strs)) +
                      "\n")
            print("English: " + nl_str.strip())
            o_f.write("English: " + nl_str.encode('utf-8'))
            for j in xrange(len(tg_strs)):
                print("GT Command %d: " % (j + 1) + tg_strs[j].strip())
                o_f.write("GT Command %d: " % (j + 1) + tg_strs[j].strip() +
                          "\n")

            pred_id = 0
            while pred_id < min(3, len(predictions)):
                pred_cmd, score = predictions[pred_id]
                tree = cmd_parser(pred_cmd)
                print("Prediction {}: {} ({})".format(pred_id + 1, pred_cmd,
                                                      score))
                o_f.write("Prediction {}: {} ({})\n".format(
                    pred_id + 1, pred_cmd, score))
                print()
                pred_temp = data_tools.ast2template(tree,
                                                    loose_constraints=True)
                str_judge = db.get_str_judgement((nl_str, pred_cmd))
                temp_judge = db.get_temp_judgement((nl_str, pred_temp))
                if temp_judge is not None and not rejudge:
                    judgement_str = "y" if temp_judge == 1 \
                        else "n ({})".format(error_types[temp_judge])
                    print("Is the command template correct [y/n]? %s" %
                          judgement_str)
                else:
                    temp_judge = tree_dist.one_match(gt_trees,
                                                     tree,
                                                     rewrite=False,
                                                     ignore_arg_value=True)
                    if not temp_judge:
                        inp = raw_input(
                            "Is the command template correct [y/n]? ")
                        if inp == "REVERSE":
                            rejudge = True
                        else:
                            if inp == "y":
                                temp_judge = True
                                db.add_temp_judgement((nl_str, pred_temp, 1))
                            else:
                                temp_judge = False
                                error_type = raw_input(
                                    "Error type: \n"
                                    "(2) extra utility \n"
                                    "(3) missing utility \n"
                                    "(4) confused utility \n"
                                    "(5) extra flag \n"
                                    "(6) missing flag \n"
                                    "(7) confused flag \n"
                                    "(8) logic error\n"
                                    "(9) count error\n")
                                db.add_temp_judgement(
                                    (nl_str, pred_temp, int(error_type)))
                            rejudge = False
                    else:
                        print("Is the command template correct [y/n]? y")
                if temp_judge == 1:
                    if pred_id < 1:
                        top1_correct_temp = True
                        top3_correct_temp = True
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 3:
                        top3_correct_temp = True
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 5:
                        top5_correct_temp = True
                        top10_correct_temp = True
                    elif pred_id < 10:
                        top10_correct_temp = True
                    o_f.write("C")
                    if str_judge is not None and not rejudge:
                        judgement_str = "y" if str_judge == 1 \
                            else "n ({})".format(error_types[str_judge])
                        print("Is the complete command correct [y/n]? %s" %
                              judgement_str)
                    else:
                        str_judge = tree_dist.one_match(gt_trees,
                                                        tree,
                                                        rewrite=False,
                                                        ignore_arg_value=False)
                        if not str_judge:
                            inp = raw_input(
                                "Is the complete command correct [y/n]? ")
                            if inp == "REVERSE":
                                rejudge = True
                                continue
                            elif inp == "y":
                                str_judge = True
                                o_f.write("C")
                                db.add_str_judgement((nl_str, pred_cmd, 1))
                            else:
                                str_judge = False
                                o_f.write("W")
                                db.add_str_judgement((nl_str, pred_cmd, 0))
                        else:
                            print("Is the complete command correct [y/n]? y")
                    if str_judge == 1:
                        if pred_id < 1:
                            top1_correct = True
                            top3_correct = True
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 3:
                            top3_correct = True
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 5:
                            top5_correct = True
                            top10_correct = True
                        elif pred_id < 10:
                            top10_correct = True
                        o_f.write("C")
                    else:
                        o_f.write("W")
                else:
                    o_f.write("WW")

                o_f.write("\n")
                o_f.write("\n")

                pred_id += 1

            if rejudge:
                num_evaled -= 1
            else:
                num_evaled += 1
                if top1_correct_temp:
                    num_top1_correct_temp += 1
                if top3_correct_temp:
                    num_top3_correct_temp += 1
                if top5_correct_temp:
                    num_top5_correct_temp += 1
                if top10_correct_temp:
                    num_top10_correct_temp += 1
                if top1_correct:
                    num_top1_correct += 1
                if top3_correct:
                    num_top3_correct += 1
                if top5_correct:
                    num_top5_correct += 1
                if top10_correct:
                    num_top10_correct += 1

            rejudge = False

            print()

    print("%d examples evaluated" % num_eval)
    print("Top 1 Template Match Score = %.2f" %
          (num_top1_correct_temp / num_eval))
    print("Top 1 String Match Score = %.2f" % (num_top1_correct / num_eval))
    if len(predictions) > 3:
        print("Top 5 Template Match Score = %.2f" %
              (num_top5_correct_temp / num_eval))
        print("Top 5 String Match Score = %.2f" %
              (num_top5_correct / num_eval))
        print("Top 10 Template Match Score = %.2f" %
              (num_top10_correct_temp / num_eval))
        print("Top 10 String Match Score = %.2f" %
              (num_top10_correct / num_eval))
    print()

    o_f.write("%d examples evaluated" % num_eval + "\n")
    o_f.write("Top 1 Template MatchScore = %.2f" %
              (num_top1_correct_temp / num_eval) + "\n")
    o_f.write("Top 1 String Match Score = %.2f" %
              (num_top1_correct / num_eval) + "\n")
    if len(predictions) > 1:
        o_f.write("Top 5 Template Match Score = %.2f" %
                  (num_top5_correct_temp / num_eval) + "\n")
        o_f.write("Top 5 String Match Score = %.2f" %
                  (num_top5_correct / num_eval) + "\n")
        o_f.write("Top 10 Template Match Score = %.2f" %
                  (num_top10_correct_temp / num_eval) + "\n")
        o_f.write("Top 10 String Match Score = %.2f" %
                  (num_top10_correct / num_eval) + "\n")
    o_f.write("\n")