示例#1
0
def match_and_save(predictions_path, output_path, dataset_name, splits,
                   data_filepath, schema_obj, sharded):
    """Loads an original dataset and matches with a predictions file."""
    # Load the predictions file
    prediction_dict = dict()
    if sharded:
        for data_file in tf.gfile.Glob(predictions_path + '*'):
            with tf.gfile.Open(data_file) as infile:
                for line in infile:
                    if line:
                        obj = json.loads(line)
                        prediction_dict[obj['utterance']] = obj

    else:
        with tf.gfile.Open(predictions_path) as infile:
            for line in infile:
                if line:
                    obj = json.loads(line)
                    prediction_dict[obj['utterance']] = obj

    # Load the data for this particular dataset (for look up)
    # `examples` is a list of dictionaries for each example, containing a TFRecord
    #  object, nl, sql, and a db_id (if running inference on Spider).
    matched_examples = list()
    if dataset_name.lower() == 'spider':
        assert len(splits) == 1
        split = splits[0]

        for example in load_spider_examples(
                os.path.join(data_filepath, split + '.json')):
            # Looks up the example's schema.
            schema = schema_obj[example['db_id']]

            # Returns a dictionary containing relevant prediction information.
            database_filepath = os.path.join('spider_databases',
                                             example['db_id'] + '.sqlite')

            key = ' '.join(example['question_toks'])
            prediction = prediction_dict[key]
            matched_examples.append({
                'utterance':
                key,
                'predictions':
                prediction['predictions'],
                'scores':
                prediction['scores'],
                'gold':
                example['query'],
                'database_path':
                os.path.join(FLAGS.database_filepath, database_filepath),
                'empty_database_path':
                os.path.join(FLAGS.empty_database_filepath, database_filepath),
                'schema':
                schema
            })

    elif dataset_name.lower() == 'wikisql':
        raise ValueError('Inference on WikiSQL not supported.')
    else:
        for nl, sql in get_nl_sql_pairs(
                os.path.join(data_filepath, dataset_name + '.json'),
                set(splits)):
            key = nl.encode('utf8')

            # Returns a dictionary containing relevant prediction information.
            database_filepath = dataset_name + '.db'

            prediction = prediction_dict[key]
            matched_examples.append({
                'utterance':
                key,
                'predictions':
                prediction['predictions'],
                'scores':
                prediction['scores'],
                'gold':
                preprocess_sql(sql),
                'database_path':
                os.path.join(FLAGS.database_filepath, database_filepath),
                'empty_database_path':
                os.path.join(FLAGS.empty_database_filepath, database_filepath),
                'schema':
                schema_obj
            })

    with tf.gfile.Open(output_path, 'w') as ofile:
        ofile.write(json.dumps(matched_examples))
示例#2
0
def match_with_dataset(
        config: Config, predictions: Sequence[Prediction],
        schema_obj: Dict[Any, Any]) -> List[ExecutionInstructions]:
    """
    Loads an original dataset and matches with a predictions file.
    """

    prediction_dict: Dict[str, Any] = {}
    for prediction in predictions:
        prediction_dict[prediction["utterance"]] = prediction

    # Load the data for this particular dataset (for look up)
    # `examples` is a list of dictionaries for each example, containing a TFRecord
    #  object, nl, sql, and a db_id (if running inference on Spider).
    matched_examples: List[ExecutionInstructions] = []
    if config.dataset_name.lower() == "spider":
        assert len(config.splits) == 1
        split = config.splits[0]

        for example in load_spider_examples(
                os.path.join(config.original_data_directory, split + ".json")):
            # Looks up the example's schema.
            schema = schema_obj[example["db_id"]]

            # Returns a dictionary containing relevant prediction information.
            database_filepath = os.path.join("spider_databases",
                                             example["db_id"] + ".sqlite")
            key = " ".join(example["question_toks"])

            try:
                prediction = prediction_dict[key]
            except KeyError:
                continue

            matched_examples.append({
                "prediction": {
                    "utterance": key,
                    "predictions": prediction["predictions"],
                    "scores": prediction["scores"],
                },
                "gold":
                example["query"],
                "database_path":
                os.path.join(config.database_directory, database_filepath),
                "empty_database_path":
                os.path.join(config.empty_database_directory,
                             database_filepath),
                "schema":
                schema,
            })

    elif config.dataset_name.lower() == "wikisql":
        raise ValueError("Inference on WikiSQL not supported.")
    else:  # michigan datasets
        dataset_path: str = os.path.join(config.original_data_directory,
                                         config.dataset_name + ".json")
        for nl, sql in get_nl_sql_pairs(dataset_path,
                                        frozenset(config.splits)):
            # TODO(samuelstevens): What is the point of encoding then decoding? Simplify.
            key = nl.encode("utf8").decode("utf-8")

            # Returns a dictionary containing relevant prediction information.
            database_filepath = config.dataset_name + ".db"

            assert len(prediction_dict) > 0
            prediction = prediction_dict[key]

            matched_examples.append({
                "prediction": {
                    "utterance": key,
                    "predictions": prediction["predictions"],
                    "scores": prediction["scores"],
                },
                "gold":
                preprocess_sql(sql),
                "database_path":
                os.path.join(config.database_directory, database_filepath),
                "empty_database_path":
                os.path.join(config.empty_database_directory,
                             database_filepath),
                "schema":
                schema_obj,
            })

    assert len(matched_examples) == len(
        predictions
    ), f"Only matched {len(matched_examples)} of {len(predictions)} examples."

    return matched_examples