Ejemplo n.º 1
0
def error_analysis(sp):
    dataset = data_loader.load_processed_data(args)
    dev_examples = dataset['dev']
    sp.schema_graphs = dataset['schema']
    print('{} dev examples loaded'.format(len(dev_examples)))

    if len(ensemble_model_dirs) <= 2:
        print('Needs at least 3 models to perform majority vote')
        sys.exit()

    predictions = []
    for model_dir in ensemble_model_dirs:
        pred_file = os.path.join(model_dir, 'predictions.16.txt')
        with open(pred_file) as f:
            predictions.append([x.strip() for x in f.readlines()])
    for i in range(len(predictions)):
        assert(len(dev_examples) == len(predictions[i]))

    import collections
    disagree = collections.defaultdict(lambda: collections.defaultdict(list))
    out_txt = 'majority_vote.txt'
    o_f = open(out_txt, 'w')
    for e_id in range(len(dev_examples)):
        example = dev_examples[e_id]
        gt_program_list = example.program_list
        votes = collections.defaultdict(list)
        for i in range(len(predictions)):
            pred_sql = predictions[i][e_id]
            votes[pred_sql].append(i)
        # break ties
        voting_results = sorted(
            votes.items(), key=lambda x: len(x[1]), reverse=True)
        voted_sql = voting_results[0][0]
        # TODO: the implementation below cheated
        # if len(voting_results) == 1:
        #     voted_sql = voting_results[0][0]
        # else:
        #     if len(voting_results[0][1]) > len(voting_results[1][1]):
        #         voted_sql = voting_results[0][0]
        #     else:
        #         j = 1
        #         while(j < len(voting_results) and len(voting_results[j][1]) == len(voting_results[0][1])):
        #             j += 1
        #         voting_results = sorted(voting_results[:j], key=lambda x:sum(x[1]))
        #         voted_sql = voting_results[0][0]
        o_f.write(voted_sql + '\n')
        evals = []
        for i in range(len(predictions)):
            eval_results, _, _ = eval_tools.eval_prediction(
                pred=predictions[i][e_id],
                gt_list=gt_program_list,
                dataset_id=example.dataset_id,
                db_name=example.db_name,
                in_execution_order=False
            )
            evals.append(eval_results)
        models_agree = (len(set(evals)) == 1)
        if not models_agree:
            for i in range(len(evals)-1):
                for j in range(1, len(evals)):
                    if evals[i] != evals[j]:
                        disagree[i][j].append(e_id)
            schema = sp.schema_graphs[example.db_name]
            print('Example {}'.format(e_id+1))
            example.pretty_print(schema)
            for i in range(len(predictions)):
                print('Prediction {} [{}]: {}'.format(
                    i+1, evals[i], predictions[i][e_id]))
            print()
    o_f.close()

    for i in range(len(predictions)-1):
        for j in range(i+1, len(predictions)):
            print('Disagree {}, {}: {}'.format(i+1, j+1, len(disagree[i][j])))
    import functools
    disagree_all = functools.reduce(lambda x, y: x & y, [set(l) for l in [
                                    disagree[i][j] for i in range(len(disagree)) for j in disagree[i]]])
    print('Disagree all: {}'.format(len(disagree_all)))
    print('Majority voting results saved to {}'.format(out_txt))
Ejemplo n.º 2
0
    def inference(self,
                  examples,
                  decode_str_output=True,
                  restore_clause_order=False,
                  pred_restored_cache=None,
                  check_schema_consistency_=True,
                  engine=None,
                  inline_eval=False,
                  model_ensemble=None,
                  verbose=False):
        # sanity check
        if self.args.leaderboard_submission or self.args.demo:
            assert (not verbose and not inline_eval
                    and not self.args.use_oracle_tables)

        pred_list, pred_score_list, pred_decoded_list, pred_decoded_score_list = [], [], [], []
        if restore_clause_order:
            if pred_restored_cache is None:
                pred_restored_cache = dict()
        if self.save_vis:
            text_ptr_weights_vis, pointer_vis = [], []

        num_error_cases = 0
        for batch_start_id in tqdm(range(0, len(examples),
                                         self.dev_batch_size)):
            mini_batch = examples[batch_start_id:batch_start_id +
                                  self.dev_batch_size]
            formatted_batch = self.format_batch(mini_batch)
            outputs = self.forward(formatted_batch, model_ensemble)
            if self.model_id in [SEQ2SEQ_PG, BRIDGE]:
                preds, pred_scores, text_p_pointers, text_ptr_weights, seq_len = outputs
                text_p_pointers.unsqueeze_(2)
                p_pointers = torch.cat([1 - text_p_pointers, text_p_pointers],
                                       dim=2)
            elif self.model_id == SEQ2SEQ:
                preds, pred_scores, text_ptr_weights, seq_len = outputs
                p_pointers = None
            else:
                raise NotImplementedError

            pred_list.append(preds)
            pred_score_list.append(pred_scores)
            if decode_str_output or verbose:
                for i in range(len(mini_batch)):
                    example = mini_batch[i]
                    db_name = example.db_name
                    schema = self.schema_graphs[db_name]
                    table_po, field_po = None, None
                    if self.args.use_oracle_tables:
                        # TODO: The implementation below is incorrect.
                        if self.args.num_random_tables_added > 0:
                            table_po, field_po = formatted_batch[-1][i]

                    exp_output_strs, exp_output_scores, exp_seq_lens, exp_correct = [], [], [], []

                    if inline_eval:
                        if example.dataset_id == SPIDER:
                            gt_program_list = example.program_list
                            gt_program_ast = example.program_ast_list_[0] \
                                if example.program_ast_list_ else example.program
                            hardness = spider_eval_tools.Evaluator(
                            ).eval_hardness(gt_program_ast,
                                            db_dir=self.args.db_dir,
                                            db_name=example.db_name)
                        elif example.dataset_id == WIKISQL:
                            gt_program_list = example.program_ast_list_
                        else:
                            raise NotImplementedError
                        if example.dataset_id == WIKISQL:
                            hardness = 'easy'

                    if self.decoding_algorithm == 'beam-search':
                        for j in range(self.beam_size):
                            beam_id = i * self.beam_size + j
                            post_processed_output = self.post_process_nn_output(
                                beam_id,
                                example.dataset_id,
                                example,
                                preds,
                                schema,
                                text_ptr_weights,
                                p_pointers,
                                table_po=table_po,
                                field_po=field_po,
                                verbose=verbose)
                            if post_processed_output:
                                pred_sql = post_processed_output[0]
                                # print('{}\t{}'.format(pred_sql, float(pred_scores[beam_id])))
                                if restore_clause_order:
                                    if pred_restored_cache and db_name in pred_restored_cache and \
                                            pred_sql in pred_restored_cache[db_name]:
                                        restored_pred, grammatical, schema_consistent = pred_restored_cache[
                                            db_name][pred_sql]
                                    else:
                                        restored_pred, grammatical, schema_consistent = moz_sp.restore_clause_order(
                                            pred_sql,
                                            schema,
                                            check_schema_consistency_=
                                            check_schema_consistency_,
                                            verbose=verbose)
                                        if pred_restored_cache and check_schema_consistency_:
                                            # TODO: we don't cache the results when check_schema_consistency_ is off to
                                            # avoid logging false negatives
                                            if db_name not in pred_restored_cache:
                                                pred_restored_cache[
                                                    db_name] = dict()
                                            pred_restored_cache[db_name][pred_sql] = restored_pred, grammatical, \
                                                                                     schema_consistent
                                    if check_schema_consistency_ and not schema_consistent:
                                        restored_pred = None
                                    pred_sql = restored_pred
                                else:
                                    if check_schema_consistency_:
                                        if not moz_sp.check_schema_consistency(
                                                pred_sql,
                                                schema,
                                                in_execution_order=self.args.
                                                process_sql_in_execution_order
                                        ):
                                            pred_sql = None
                                if pred_sql and self.args.execution_guided_decoding:
                                    assert (engine is not None)
                                    try:
                                        pred_query = Query.from_dict(
                                            pred_sql, ordered=False)
                                        pred_ex = engine.execute_query(
                                            example.db_name,
                                            pred_query,
                                            lower=True)
                                        if not pred_ex:
                                            pred_sql = None
                                    except Exception:
                                        pred_sql = None
                            else:
                                pred_sql = None
                            # if not pred_sql:
                            #     pred_sql = self.get_dummy_prediction(schema)
                            if pred_sql:
                                exp_output_strs.append(pred_sql)
                                exp_output_scores.append(
                                    float(pred_scores[beam_id]))
                                exp_seq_lens.append(int(seq_len[beam_id]))
                                if self.save_vis:
                                    self.save_vis_parameters(
                                        post_processed_output,
                                        text_ptr_weights_vis, pointer_vis)
                                if inline_eval:
                                    results = eval_tools.eval_prediction(
                                        pred=pred_sql,
                                        gt_list=gt_program_list,
                                        dataset_id=example.dataset_id,
                                        db_name=example.db_name,
                                        in_execution_order=(
                                            self.args.
                                            process_sql_in_execution_order
                                            and not restore_clause_order))
                                    correct, _, _ = results
                                    exp_correct.append(correct)
                                    correct_ = correct[1] if isinstance(
                                        correct, tuple) else correct
                                    if correct_:
                                        break
                    else:
                        raise NotImplementedError
                    num_preds = len(exp_output_strs)
                    pred_decoded_list.append(exp_output_strs)
                    pred_decoded_score_list.append(
                        exp_output_scores[:num_preds])
                    if verbose:
                        predictions = zip(exp_output_strs, exp_output_scores,
                                          exp_seq_lens, exp_correct)
                        is_error_case = self.print_predictions(
                            batch_start_id + i, example, hardness, predictions,
                            schema)
                        if is_error_case:
                            num_error_cases += 1
                            print('Error Case {}'.format(num_error_cases))
                            print()
                            # if num_error_cases == 50:
                            #     import sys
                            #     sys.exit()
                    if not pred_decoded_list[-1] and not self.args.demo:
                        pred_decoded_list[-1].append(
                            self.get_dummy_prediction(schema))
                        pred_decoded_score_list[-1].append(-ops.HUGE_INT)

        out_dict = dict()
        out_dict['preds'] = ops.pad_and_cat(pred_list, self.out_vocab.pad_id)
        out_dict['pred_scores'] = torch.cat(pred_score_list)
        if decode_str_output:
            out_dict['pred_decoded'] = pred_decoded_list
            out_dict['pred_decoded_scores'] = pred_decoded_score_list
        if restore_clause_order:
            out_dict['pred_restored_cache'] = pred_restored_cache

        if self.save_vis:
            vis_dict = dict()
            vis_dict['text_attention_vis'] = text_ptr_weights_vis
            vis_dict['text_pointer_vis'] = pointer_vis
            for key in vis_dict:
                if key.endswith('_vis'):
                    if key.endswith('_attention_vis'):
                        attn_target_label = key.split('_')[0]
                        self.vis_writer.save_cross_attention(
                            vis_dict[key], attn_target_label)
                    if key.endswith('_pointer_vis'):
                        self.vis_writer.save_pointer(vis_dict[key], 'all')

        return out_dict