コード例 #1
0
def _convert_single_wtq(interaction_file, prediction_file, output_file):
    """Convert predictions to WikiTablequestions format."""

    interactions = dict(
        (prediction_utils.parse_interaction_id(i.id), i)
        for i in prediction_utils.iterate_interactions(interaction_file))
    missing_interaction_ids = set(interactions.keys())

    with tf.io.gfile.GFile(output_file, 'w') as output_file:
        for prediction in prediction_utils.iterate_predictions(
                prediction_file):
            interaction_id = prediction['id']
            if interaction_id in missing_interaction_ids:
                missing_interaction_ids.remove(interaction_id)
            else:
                continue

            coordinates = prediction_utils.parse_coordinates(
                prediction['answer_coordinates'])

            denot_pred, _ = calc_metrics_utils.execute(
                int(prediction.get('pred_aggr', '0')), coordinates,
                prediction_utils.table_to_panda_frame(
                    interactions[interaction_id].table))

            answers = '\t'.join(sorted(map(str, denot_pred)))
            output_file.write('{}\t{}\n'.format(interaction_id, answers))

        for interaction_id in missing_interaction_ids:
            output_file.write('{}\n'.format(interaction_id))
コード例 #2
0
def _read_data_examples(data_path):
    """Reads examples from a dataset csv file."""
    data_examples = {}
    with tf.io.gfile.GFile(data_path, mode='r') as f:
        reader = csv.DictReader(f, delimiter='\t')
        for row in reader:
            ex_id = '{}-{}_{}'.format(row['id'], row['annotator'],
                                      row['position'])
            question = row['question'].strip()
            table_id = row['table_file']
            gold_cell_coo = prediction_utils.parse_coordinates(
                row['answer_coordinates'])
            gold_agg_function = int(row['aggregation'])
            float_answer_raw = row['float_answer']
            float_answer = float(
                float_answer_raw) if float_answer_raw else None
            ex = calc_metrics_utils.Example(ex_id,
                                            question,
                                            table_id,
                                            None,
                                            gold_cell_coo,
                                            gold_agg_function,
                                            float_answer,
                                            has_gold_answer=True)
            data_examples[ex_id] = ex
    return data_examples
コード例 #3
0
def read_predictions(predictions_path, examples):
  """Reads predictions from a csv file."""
  for row in prediction_utils.iterate_predictions(predictions_path):
    pred_id = '{}-{}_{}'.format(row['id'], row['annotator'], row['position'])
    example = examples[pred_id]
    example.pred_cell_coo = prediction_utils.parse_coordinates(
        row['answer_coordinates'])
    example.pred_agg_function = int(row.get('pred_aggr', '0'))
コード例 #4
0
def get_predictions(prediction_file):
    """Yields an iterable of Prediction objects from a tsv prediction file."""
    fn_map = {
        'logits_cls': float,
        'position': int,
        'answer_coordinates':
        lambda x: list(prediction_utils.parse_coordinates(x)),
        'answers': token_answers_from_text,
        'token_probabilities': json.loads,
    }
    for prediction_dict in prediction_utils.iterate_predictions(
            prediction_file):
        for key in tuple(prediction_dict.keys()):
            fn = fn_map.get(key, lambda x: x)
            prediction_dict[key] = fn(prediction_dict[key])
        yield Prediction(**prediction_dict)
コード例 #5
0
def read_predictions(predictions_path, examples):
    """Reads predictions from a csv file."""
    for row in prediction_utils.iterate_predictions(predictions_path):
        pred_id = '{}-{}_{}'.format(row['id'], row['annotator'],
                                    row['position'])
        example = examples[pred_id]
        example.pred_cell_coo = prediction_utils.parse_coordinates(
            row['answer_coordinates'])
        example.pred_agg_function = int(row.get('pred_aggr', '0'))
        if 'column_scores' in row:
            column_scores = list(
                filter(None, row['column_scores'][1:-1].split(' ')))
            removed_column_scores = [
                float(score) for score in column_scores if float(score) < 0.0
            ]
            if column_scores:
                example.weight = len(removed_column_scores) / len(
                    column_scores)
コード例 #6
0
def predict(table_data, queries):
    print("Prediction started!")
    table = [
        list(map(lambda s: s.strip(), row.split("|")))
        for row in table_data.split("\n") if row.strip()
    ]
    examples = convert_interactions_to_examples([(table, queries)])
    write_tf_example("results/sqa/tf_examples/test.tfrecord", examples)
    write_tf_example("results/sqa/tf_examples/random-split-1-dev.tfrecord", [])

    print("Processed table data!")

    os.system(''' python tapas/tapas/run_task_main.py \
    --task="SQA" \
    --output_dir="results" \
    --noloop_predict \
    --test_batch_size=3 \
    --tapas_verbosity="ERROR" \
    --compression_type= \
    --init_checkpoint="tapas_sqa_base/model.ckpt" \
    --bert_config_file="tapas_sqa_base/bert_config.json" \
    --mode="predict" 2> error''')

    print("Prediction completed!")

    results_path = "results/sqa/model/test_sequence.tsv"
    all_coordinates = []
    answers_lst = []
    df = pd.DataFrame(table[1:], columns=table[0])
    #display(IPython.display.HTML(df.to_html(index=False)))
    print("Result printing!")
    with open(results_path) as csvfile:
        reader = csv.DictReader(csvfile, delimiter='\t')
        for row in reader:
            coordinates = prediction_utils.parse_coordinates(
                row["answer_coordinates"])
            all_coordinates.append(coordinates)
            answers = ', '.join(
                [table[row + 1][col] for row, col in coordinates])
            position = int(row['position'])
            print(">", queries[position])
            print(answers)
            answers_lst.append(answers)
    return answers_lst
コード例 #7
0
def predict(table_data, queries):
    table = table_data.values.tolist()
    examples = convert_interactions_to_examples([(table, queries)])
    write_tf_example("results/sqa/tf_examples/test.tfrecord", examples)
    write_tf_example("results/sqa/tf_examples/random-split-1-dev.tfrecord", [])

    cmd = '/mnt/d/Data_Science_Work/tapas/predict.sh'
    subprocess.call(cmd)
    
    results_path = "results/sqa/model/test_sequence.tsv"
    all_coordinates = []
    with open(results_path) as csvfile:
        reader = csv.DictReader(csvfile, delimiter='\t')
        for row in reader:
            coordinates = prediction_utils.parse_coordinates(row["answer_coordinates"])
            all_coordinates.append(coordinates)
            answers = ', '.join([table[row + 1][col] for row, col in coordinates])
            position = int(row['position'])
            print(">", queries[position])
            print(answers)
    return answers
コード例 #8
0
def eval_cell_selection(
    questions,
    predictions_file,
):
    """Evaluates cell selection results in HybridQA experiment.

  Args:
    questions: A map of Question protos by their respective ids.
    predictions_file: Path to a tsv file with predictions for a checkpoint.

  Yields:
    An AnswerType and its corresponding CellSelectionMetrics instance
  """
    total = collections.Counter()
    total_correct = collections.Counter()
    total_correct_at_k = {k: collections.Counter() for k in _RECALL_KS}
    total_seen = collections.Counter()
    total_non_empty = collections.Counter()
    total_coordinates = collections.Counter()
    sum_precision = collections.defaultdict(float)

    for question in questions.values():
        for answer_type in [AnswerType.ALL, _get_answer_type(question)]:
            total[answer_type] += 1

    for row in prediction_utils.iterate_predictions(predictions_file):
        question = questions.get(row['question_id'])
        if question is None:
            # The dataset lost some examples after an update.
            continue
        gold_coordinates = {(x.row_index, x.column_index)
                            for x in question.answer.answer_coordinates}
        coordinates = prediction_utils.parse_coordinates(
            row['answer_coordinates'])
        # We only care about finding one correct cell for the downstream model.
        correct_coordinates = len(coordinates & gold_coordinates)
        has_probabilities = 'token_probabilities' in row
        if has_probabilities:
            best_cells = get_best_cells(json.loads(row['token_probabilities']))
            correct_at_k = {
                k: bool(set(best_cells[:k]) & gold_coordinates)
                for k in _RECALL_KS
            }
        else:
            correct_at_k = {}
        for answer_type in [AnswerType.ALL, _get_answer_type(question)]:
            total_coordinates[answer_type] += len(coordinates)
            total_correct[answer_type] += bool(correct_coordinates)
            total_seen[answer_type] += 1
            for k, correct in correct_at_k.items():
                total_correct_at_k[k][answer_type] += correct
            if coordinates:
                sum_precision[answer_type] += correct_coordinates / len(
                    coordinates)
                total_non_empty[answer_type] += 1

    for answer_type in AnswerType:
        if total[answer_type]:
            recall_at_k = {
                f'recall_at_{k}':
                (total_correct_at_k[k][answer_type] /
                 total[answer_type]) if has_probabilities else None
                for k in _RECALL_KS
            }
            yield answer_type, CellSelectionMetrics(
                recall=total_correct[answer_type] / total[answer_type],
                non_empty=total_non_empty[answer_type] / total[answer_type],
                coverage=total_seen[answer_type] / total[answer_type],
                answer_len=total_coordinates[answer_type] / total[answer_type],
                precision=((sum_precision[answer_type] /
                            total_non_empty[answer_type])
                           if total_non_empty[answer_type] else None),
                **recall_at_k,
            )