Esempio n. 1
0
def process_spider(output_file, debugging_file, tokenizer):
  """Loads, converts, and writes Spider examples to the standard format."""
  if len(FLAGS.splits) > 1:
    raise ValueError('Not expecting more than one split for Spider.')
  split = FLAGS.splits[0]

  table_definitions = load_spider_tables(
      os.path.join(FLAGS.input_filepath, 'tables.json'))
  print('Loaded %d table definitions.' % len(table_definitions))

  spider_examples = \
    load_spider_examples(os.path.join(FLAGS.input_filepath,
                                      split + '.json'))

  num_examples_created = 0
  num_examples_failed = 0

  # TODO(petershaw): Reduce duplication with other code path for schema
  # pre-processing.
  tables_json = _load_json_from_file(
      os.path.join(FLAGS.input_filepath, 'tables.json'))
  spider_table_schemas_map = abstract_sql_converters.spider_table_schemas_map(
      tables_json)

  for spider_example in spider_examples:
    # Make sure the DB specified exists.
    example_db = spider_example['db_id']
    try:
      example = convert_spider(
          spider_example,
          table_definitions[example_db],
          tokenizer,
          FLAGS.generate_sql,
          FLAGS.anonymize_values,
          abstract_sql=FLAGS.abstract_sql,
          table_schemas=spider_table_schemas_map[example_db],
          allow_value_generation=FLAGS.allow_value_generation)
    except abstract_sql.UnsupportedSqlError as e:
      print(e)
      example = None
    if example:
      output_file.write(json.dumps(example.to_json()) + '\n')
      num_examples_created += 1

      debugging_file.write(example.model_input.original_utterance + '\n')
      if FLAGS.generate_sql:
        debugging_file.write(example.gold_query_string() + '\n\n')
    else:
      num_examples_failed += 1
  return num_examples_created, num_examples_failed
Esempio n. 2
0
def main(unused_argv):
    table_json = _load_json(FLAGS.tables)
    # Map of database id to a list of ForeignKeyRelation tuples.
    foreign_key_map = abstract_sql_converters.spider_foreign_keys_map(table_json)
    table_schema_map = abstract_sql_converters.spider_table_schemas_map(table_json)

    examples = _load_json(FLAGS.input)
    num_failures = 0
    num_examples = 0
    with open(FLAGS.gold_sql_output, "w") as gold_sql_file:
        with open(FLAGS.abstract_sql_output, "w") as abstract_sql_file:
            for example in examples:
                num_examples += 1
                print("Parsing example number %s." % num_examples)
                gold_sql_query = example["query"]
                foreign_keys = foreign_key_map[example["db_id"]]
                table_schema = table_schema_map[example["db_id"]]
                try:
                    abstract_sql_query = _get_abstract_sql(
                        gold_sql_query,
                        foreign_keys,
                        table_schema,
                        FLAGS.restore_from_clause,
                    )
                except abstract_sql.UnsupportedSqlError as e:
                    print("Error for query:\n%s" % gold_sql_query)
                    num_failures += 1
                    if not FLAGS.keep_going:
                        raise e
                    else:
                        continue
                else:
                    # Write SQL to output files.
                    gold_sql_query = gold_sql_query.replace("\t", " ")
                    gold_sql_file.write("%s\t%s\n" % (gold_sql_query, example["db_id"]))
                    abstract_sql_file.write(
                        "%s\t%s\n" % (abstract_sql_query, example["db_id"])
                    )
    print("Examples: %s" % num_examples)
    print("Failed parses: %s" % num_failures)
def compute_spider_coverage(spider_examples_json, spider_tables_json):
    """Prints out statistics for asql conversions."""
    table_json = _load_json(spider_tables_json)
    # Map of database id to a list of ForeignKeyRelation tuples.
    foreign_key_map = abstract_sql_converters.spider_foreign_keys_map(
        table_json)
    table_schema_map = abstract_sql_converters.spider_table_schemas_map(
        table_json)
    examples = _load_json(spider_examples_json)
    num_examples = 0
    num_conversion_failures = 0
    num_reconstruction_failtures = 0
    for example in examples:
        num_examples += 1
        print("Parsing example number %s: %s" %
              (num_examples, example["query"]))
        gold_sql_query = example["query"]
        foreign_keys = foreign_key_map[example["db_id"]]
        table_schema = table_schema_map[example["db_id"]]
        try:
            sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query,
                                                      table_schema)
            sql_spans = abstract_sql.replace_from_clause(sql_spans)
        except abstract_sql.UnsupportedSqlError as e:
            print("Error converting:\n%s\n%s" % (gold_sql_query, e))
            num_conversion_failures += 1
        else:
            try:
                sql_spans = abstract_sql.restore_from_clause(
                    sql_spans, foreign_keys)
            except abstract_sql.UnsupportedSqlError as e:
                print("Error recontructing:\n%s\n%s" % (gold_sql_query, e))
                num_reconstruction_failtures += 1
    print("Examples: %s" % num_examples)
    print("Failed conversions: %s" % num_conversion_failures)
    print("Failed reconstructions: %s" % num_reconstruction_failtures)