def match_and_save(predictions_path, output_path, dataset_name, splits, data_filepath, schema_obj, sharded): """Loads an original dataset and matches with a predictions file.""" # Load the predictions file prediction_dict = dict() if sharded: for data_file in tf.gfile.Glob(predictions_path + '*'): with tf.gfile.Open(data_file) as infile: for line in infile: if line: obj = json.loads(line) prediction_dict[obj['utterance']] = obj else: with tf.gfile.Open(predictions_path) as infile: for line in infile: if line: obj = json.loads(line) prediction_dict[obj['utterance']] = obj # Load the data for this particular dataset (for look up) # `examples` is a list of dictionaries for each example, containing a TFRecord # object, nl, sql, and a db_id (if running inference on Spider). matched_examples = list() if dataset_name.lower() == 'spider': assert len(splits) == 1 split = splits[0] for example in load_spider_examples( os.path.join(data_filepath, split + '.json')): # Looks up the example's schema. schema = schema_obj[example['db_id']] # Returns a dictionary containing relevant prediction information. database_filepath = os.path.join('spider_databases', example['db_id'] + '.sqlite') key = ' '.join(example['question_toks']) prediction = prediction_dict[key] matched_examples.append({ 'utterance': key, 'predictions': prediction['predictions'], 'scores': prediction['scores'], 'gold': example['query'], 'database_path': os.path.join(FLAGS.database_filepath, database_filepath), 'empty_database_path': os.path.join(FLAGS.empty_database_filepath, database_filepath), 'schema': schema }) elif dataset_name.lower() == 'wikisql': raise ValueError('Inference on WikiSQL not supported.') else: for nl, sql in get_nl_sql_pairs( os.path.join(data_filepath, dataset_name + '.json'), set(splits)): key = nl.encode('utf8') # Returns a dictionary containing relevant prediction information. database_filepath = dataset_name + '.db' prediction = prediction_dict[key] matched_examples.append({ 'utterance': key, 'predictions': prediction['predictions'], 'scores': prediction['scores'], 'gold': preprocess_sql(sql), 'database_path': os.path.join(FLAGS.database_filepath, database_filepath), 'empty_database_path': os.path.join(FLAGS.empty_database_filepath, database_filepath), 'schema': schema_obj }) with tf.gfile.Open(output_path, 'w') as ofile: ofile.write(json.dumps(matched_examples))
def match_with_dataset( config: Config, predictions: Sequence[Prediction], schema_obj: Dict[Any, Any]) -> List[ExecutionInstructions]: """ Loads an original dataset and matches with a predictions file. """ prediction_dict: Dict[str, Any] = {} for prediction in predictions: prediction_dict[prediction["utterance"]] = prediction # Load the data for this particular dataset (for look up) # `examples` is a list of dictionaries for each example, containing a TFRecord # object, nl, sql, and a db_id (if running inference on Spider). matched_examples: List[ExecutionInstructions] = [] if config.dataset_name.lower() == "spider": assert len(config.splits) == 1 split = config.splits[0] for example in load_spider_examples( os.path.join(config.original_data_directory, split + ".json")): # Looks up the example's schema. schema = schema_obj[example["db_id"]] # Returns a dictionary containing relevant prediction information. database_filepath = os.path.join("spider_databases", example["db_id"] + ".sqlite") key = " ".join(example["question_toks"]) try: prediction = prediction_dict[key] except KeyError: continue matched_examples.append({ "prediction": { "utterance": key, "predictions": prediction["predictions"], "scores": prediction["scores"], }, "gold": example["query"], "database_path": os.path.join(config.database_directory, database_filepath), "empty_database_path": os.path.join(config.empty_database_directory, database_filepath), "schema": schema, }) elif config.dataset_name.lower() == "wikisql": raise ValueError("Inference on WikiSQL not supported.") else: # michigan datasets dataset_path: str = os.path.join(config.original_data_directory, config.dataset_name + ".json") for nl, sql in get_nl_sql_pairs(dataset_path, frozenset(config.splits)): # TODO(samuelstevens): What is the point of encoding then decoding? Simplify. key = nl.encode("utf8").decode("utf-8") # Returns a dictionary containing relevant prediction information. database_filepath = config.dataset_name + ".db" assert len(prediction_dict) > 0 prediction = prediction_dict[key] matched_examples.append({ "prediction": { "utterance": key, "predictions": prediction["predictions"], "scores": prediction["scores"], }, "gold": preprocess_sql(sql), "database_path": os.path.join(config.database_directory, database_filepath), "empty_database_path": os.path.join(config.empty_database_directory, database_filepath), "schema": schema_obj, }) assert len(matched_examples) == len( predictions ), f"Only matched {len(matched_examples)} of {len(predictions)} examples." return matched_examples