示例#1
0
def process_michigan_datasets(output_file, debugging_file, tokenizer):
  """Loads, converts, and writes Michigan examples to the standard format."""
  # TODO(alanesuhr,petershaw): Support asql for this dataset.
  if FLAGS.generate_sql and FLAGS.abstract_sql:
    raise NotImplementedError(
        'Abstract SQL currently only supported for SPIDER.')

  schema_csv = os.path.join(FLAGS.input_filepath,
                            FLAGS.dataset_name + '_schema.csv')
  data_filepath = os.path.join(FLAGS.input_filepath,
                               FLAGS.dataset_name + '.json')

  # Don't actually provide table entities.
  num_examples_created = 0
  num_examples_failed = 0

  print('Loading from ' + data_filepath)
  paired_data = get_nl_sql_pairs(data_filepath, set(FLAGS.splits))
  print('Loaded %d examples.' % len(paired_data))

  schema = read_schema(schema_csv)

  for nl, sql in paired_data:
    example = convert_michigan(nl, sql, schema, tokenizer, FLAGS.generate_sql)
    if example is not None:
      output_file.write(json.dumps(example.to_json()) + '\n')
      num_examples_created += 1

      debugging_file.write(example.model_input.original_utterance + '\n')
      if FLAGS.generate_sql:
        debugging_file.write(example.gold_query_string() + '\n\n')
    else:
      num_examples_failed += 1

  return num_examples_created, num_examples_failed
def compute_michigan_coverage():
    """Prints out statistics for asql conversions."""
    # Read data files.
    schema_csv_path = os.path.join(FLAGS.michigan_data_dir,
                                   '%s-schema.csv' % FLAGS.dataset_name)
    examples_json_path = os.path.join(FLAGS.michigan_data_dir,
                                      '%s.json' % FLAGS.dataset_name)
    schema = michigan_preprocessing.read_schema(schema_csv_path)
    if FLAGS.use_oracle_foriegn_keys:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
            FLAGS.dataset_name)
    else:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
            schema)
    table_schema = abstract_sql_converters.michigan_db_to_table_tuples(schema)
    nl_sql_pairs = michigan_preprocessing.get_nl_sql_pairs(
        examples_json_path, FLAGS.splits)

    # Iterate through examples and generate counts.
    num_examples = 0
    num_conversion_failures = 0
    num_successes = 0
    num_parse_failures = 0
    num_reconstruction_failtures = 0
    exception_counts = collections.defaultdict(int)
    for _, gold_sql_query in nl_sql_pairs:
        num_examples += 1
        print('Parsing example number %s.' % num_examples)
        try:
            sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query,
                                                      table_schema)
            sql_spans = abstract_sql.replace_from_clause(sql_spans)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error converting:\n%s\n%s' % (gold_sql_query, e))
            num_conversion_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        except abstract_sql.ParseError as e:
            print('Error parsing:\n%s\n%s' % (gold_sql_query, e))
            num_parse_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        try:
            sql_spans = abstract_sql.restore_from_clause(
                sql_spans, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error recontructing:\n%s\n%s' % (gold_sql_query, e))
            exception_counts[str(e)[:100]] += 1
            num_reconstruction_failtures += 1
            continue
        print('Success:\n%s\n%s' %
              (gold_sql_query, abstract_sql.sql_spans_to_string(sql_spans)))
        num_successes += 1
    print('exception_counts: %s' % exception_counts)
    print('Examples: %s' % num_examples)
    print('Failed conversions: %s' % num_conversion_failures)
    print('Failed parses: %s' % num_parse_failures)
    print('Failed reconstructions: %s' % num_reconstruction_failtures)
    print('Successes: %s' % num_successes)
def process_michigan_datasets(output_file, debugging_file,
                              tokenizer) -> Tuple[int, int]:
    """Loads, converts, and writes Michigan examples to the standard format."""
    # TODO(alanesuhr,petershaw): Support asql for this dataset.
    # if FLAGS.generate_sql and FLAGS.abstract_sql:
    #     raise NotImplementedError("Abstract SQL currently only supported for SPIDER.")

    schema_csv: str = os.path.join(FLAGS.input_filepath,
                                   FLAGS.dataset_name + "_schema.csv")
    data_filepath: str = os.path.join(FLAGS.input_filepath,
                                      FLAGS.dataset_name + ".json")

    # Don't actually provide table entities.
    num_examples_created = 0
    num_examples_failed = 0

    print("Loading from " + data_filepath)
    paired_data = get_nl_sql_pairs(data_filepath, frozenset(FLAGS.splits))
    print("Loaded %d examples." % len(paired_data))

    schema = read_schema(schema_csv)

    for nl, sql in paired_data:
        example = convert_michigan(
            nl,
            sql,
            schema,
            tokenizer,
            FLAGS.generate_sql,
            FLAGS.anonymize_values,
            FLAGS.abstract_sql,
            abstract_sql_converters.michigan_db_to_table_tuples(schema),
            FLAGS.allow_value_generation,
        )
        if example is not None:
            output_file.write(json.dumps(example.to_json()) + "\n")
            num_examples_created += 1

            debugging_file.write(example.model_input.original_utterance + "\n")
            if FLAGS.generate_sql:
                debugging_file.write(example.gold_query_string() + "\n\n")
        else:
            num_examples_failed += 1

    return num_examples_created, num_examples_failed
def process_wikisql(output_file, debugging_file, tokenizer) -> Tuple[int, int]:
    """Loads, converts, and writes WikiSQL examples to the standard format."""

    if len(FLAGS.splits) > 1:
        raise ValueError("Not expecting more than one split for WikiSQL.")
    split = FLAGS.splits[0]

    num_examples_created = 0
    num_examples_failed = 0

    data_filepath = os.path.join(FLAGS.input_filepath,
                                 FLAGS.dataset_name + ".json")

    paired_data = get_nl_sql_pairs(data_filepath,
                                   frozenset(FLAGS.splits),
                                   with_dbs=True)

    table_definitions = load_wikisql_tables(
        os.path.join(FLAGS.input_filepath, split + ".tables.jsonl"))

    wikisql_table_schemas_map = abstract_sql_converters.wikisql_table_schemas_map(
        table_definitions)

    for input_example in paired_data:
        example = convert_wikisql(
            input_example,
            table_definitions[input_example[2]],
            tokenizer,
            FLAGS.generate_sql,
            FLAGS.anonymize_values,
            FLAGS.abstract_sql,
            tables_schema=wikisql_table_schemas_map[input_example[2]],
            allow_value_generation=FLAGS.allow_value_generation,
        )
        if example:
            output_file.write(json.dumps(example.to_json()) + "\n")
            num_examples_created += 1

            debugging_file.write(example.model_input.original_utterance + "\n")
            if FLAGS.generate_sql:
                debugging_file.write(example.gold_query_string() + "\n\n")
        else:
            num_examples_failed += 1
    return num_examples_created, num_examples_failed
示例#5
0
def process_wikisql(output_file, debugging_file, tokenizer):
    """Loads, converts, and writes WikiSQL examples to the standard format."""
    # TODO(alanesuhr,petershaw): Support asql for this dataset.
    if FLAGS.generate_sql and FLAGS.abstract_sql:
        raise NotImplementedError(
            'Abstract SQL currently only supported for SPIDER.')

    if len(FLAGS.splits) > 1:
        raise ValueError('Not expecting more than one split for WikiSQL.')
    split = FLAGS.splits[0]

    num_examples_created = 0
    num_examples_failed = 0

    data_filepath = os.path.join(FLAGS.input_filepath,
                                 FLAGS.dataset_name + '.json')

    paired_data = get_nl_sql_pairs(data_filepath,
                                   set(FLAGS.splits),
                                   with_dbs=True)

    table_definitions = \
      load_wikisql_tables(os.path.join(FLAGS.input_filepath,
                                       split + '.tables.jsonl'))

    for input_example in paired_data:
        example = \
          convert_wikisql(input_example, table_definitions[input_example[2]],
                          tokenizer,
                          FLAGS.generate_sql,
                          FLAGS.anonymize_values)
        if example:
            output_file.write(json.dumps(example.to_json()) + '\n')
            num_examples_created += 1

            debugging_file.write(example.model_input.original_utterance + '\n')
            if FLAGS.generate_sql:
                debugging_file.write(example.gold_query_string() + '\n\n')
        else:
            num_examples_failed += 1
    return num_examples_created, num_examples_failed
示例#6
0
def match_and_save(predictions_path, output_path, dataset_name, splits,
                   data_filepath, schema_obj, sharded):
    """Loads an original dataset and matches with a predictions file."""
    # Load the predictions file
    prediction_dict = dict()
    if sharded:
        for data_file in tf.gfile.Glob(predictions_path + '*'):
            with tf.gfile.Open(data_file) as infile:
                for line in infile:
                    if line:
                        obj = json.loads(line)
                        prediction_dict[obj['utterance']] = obj

    else:
        with tf.gfile.Open(predictions_path) as infile:
            for line in infile:
                if line:
                    obj = json.loads(line)
                    prediction_dict[obj['utterance']] = obj

    # Load the data for this particular dataset (for look up)
    # `examples` is a list of dictionaries for each example, containing a TFRecord
    #  object, nl, sql, and a db_id (if running inference on Spider).
    matched_examples = list()
    if dataset_name.lower() == 'spider':
        assert len(splits) == 1
        split = splits[0]

        for example in load_spider_examples(
                os.path.join(data_filepath, split + '.json')):
            # Looks up the example's schema.
            schema = schema_obj[example['db_id']]

            # Returns a dictionary containing relevant prediction information.
            database_filepath = os.path.join('spider_databases',
                                             example['db_id'] + '.sqlite')

            key = ' '.join(example['question_toks'])
            prediction = prediction_dict[key]
            matched_examples.append({
                'utterance':
                key,
                'predictions':
                prediction['predictions'],
                'scores':
                prediction['scores'],
                'gold':
                example['query'],
                'database_path':
                os.path.join(FLAGS.database_filepath, database_filepath),
                'empty_database_path':
                os.path.join(FLAGS.empty_database_filepath, database_filepath),
                'schema':
                schema
            })

    elif dataset_name.lower() == 'wikisql':
        raise ValueError('Inference on WikiSQL not supported.')
    else:
        for nl, sql in get_nl_sql_pairs(
                os.path.join(data_filepath, dataset_name + '.json'),
                set(splits)):
            key = nl.encode('utf8')

            # Returns a dictionary containing relevant prediction information.
            database_filepath = dataset_name + '.db'

            prediction = prediction_dict[key]
            matched_examples.append({
                'utterance':
                key,
                'predictions':
                prediction['predictions'],
                'scores':
                prediction['scores'],
                'gold':
                preprocess_sql(sql),
                'database_path':
                os.path.join(FLAGS.database_filepath, database_filepath),
                'empty_database_path':
                os.path.join(FLAGS.empty_database_filepath, database_filepath),
                'schema':
                schema_obj
            })

    with tf.gfile.Open(output_path, 'w') as ofile:
        ofile.write(json.dumps(matched_examples))
示例#7
0
def match_with_dataset(
        config: Config, predictions: Sequence[Prediction],
        schema_obj: Dict[Any, Any]) -> List[ExecutionInstructions]:
    """
    Loads an original dataset and matches with a predictions file.
    """

    prediction_dict: Dict[str, Any] = {}
    for prediction in predictions:
        prediction_dict[prediction["utterance"]] = prediction

    # Load the data for this particular dataset (for look up)
    # `examples` is a list of dictionaries for each example, containing a TFRecord
    #  object, nl, sql, and a db_id (if running inference on Spider).
    matched_examples: List[ExecutionInstructions] = []
    if config.dataset_name.lower() == "spider":
        assert len(config.splits) == 1
        split = config.splits[0]

        for example in load_spider_examples(
                os.path.join(config.original_data_directory, split + ".json")):
            # Looks up the example's schema.
            schema = schema_obj[example["db_id"]]

            # Returns a dictionary containing relevant prediction information.
            database_filepath = os.path.join("spider_databases",
                                             example["db_id"] + ".sqlite")
            key = " ".join(example["question_toks"])

            try:
                prediction = prediction_dict[key]
            except KeyError:
                continue

            matched_examples.append({
                "prediction": {
                    "utterance": key,
                    "predictions": prediction["predictions"],
                    "scores": prediction["scores"],
                },
                "gold":
                example["query"],
                "database_path":
                os.path.join(config.database_directory, database_filepath),
                "empty_database_path":
                os.path.join(config.empty_database_directory,
                             database_filepath),
                "schema":
                schema,
            })

    elif config.dataset_name.lower() == "wikisql":
        raise ValueError("Inference on WikiSQL not supported.")
    else:  # michigan datasets
        dataset_path: str = os.path.join(config.original_data_directory,
                                         config.dataset_name + ".json")
        for nl, sql in get_nl_sql_pairs(dataset_path,
                                        frozenset(config.splits)):
            # TODO(samuelstevens): What is the point of encoding then decoding? Simplify.
            key = nl.encode("utf8").decode("utf-8")

            # Returns a dictionary containing relevant prediction information.
            database_filepath = config.dataset_name + ".db"

            assert len(prediction_dict) > 0
            prediction = prediction_dict[key]

            matched_examples.append({
                "prediction": {
                    "utterance": key,
                    "predictions": prediction["predictions"],
                    "scores": prediction["scores"],
                },
                "gold":
                preprocess_sql(sql),
                "database_path":
                os.path.join(config.database_directory, database_filepath),
                "empty_database_path":
                os.path.join(config.empty_database_directory,
                             database_filepath),
                "schema":
                schema_obj,
            })

    assert len(matched_examples) == len(
        predictions
    ), f"Only matched {len(matched_examples)} of {len(predictions)} examples."

    return matched_examples