def test_nested_query(self): query = ( "SELECT count(*) FROM (SELECT * FROM endowment WHERE amount > " "8.5 GROUP BY school_id HAVING count(*) > 1)" ) sql_spans = abstract_sql.sql_to_sql_spans(query) replaced_spans = abstract_sql.replace_from_clause(sql_spans) expected_replaced_sql = ( "select count ( * ) <from_clause_placeholder> ( select * " "<from_clause_placeholder> where endowment.amount > 8.5 group by " "endowment.school_id having count ( * ) > 1 )" ) self.assertEqual( expected_replaced_sql, abstract_sql.sql_spans_to_string(replaced_spans) ) restored_spans = abstract_sql.restore_from_clause( replaced_spans, fk_relations=[] ) expected_restored_sql = ( "select count ( * ) from ( select * from endowment where " "endowment.amount > 8.5 group by endowment.school_id having count ( * " ") > 1 )" ) self.assertEqual( expected_restored_sql, abstract_sql.sql_spans_to_string(restored_spans) )
def compute_michigan_coverage(): """Prints out statistics for asql conversions.""" # Read data files. schema_csv_path = os.path.join(FLAGS.michigan_data_dir, '%s-schema.csv' % FLAGS.dataset_name) examples_json_path = os.path.join(FLAGS.michigan_data_dir, '%s.json' % FLAGS.dataset_name) schema = michigan_preprocessing.read_schema(schema_csv_path) if FLAGS.use_oracle_foriegn_keys: foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale( FLAGS.dataset_name) else: foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples( schema) table_schema = abstract_sql_converters.michigan_db_to_table_tuples(schema) nl_sql_pairs = michigan_preprocessing.get_nl_sql_pairs( examples_json_path, FLAGS.splits) # Iterate through examples and generate counts. num_examples = 0 num_conversion_failures = 0 num_successes = 0 num_parse_failures = 0 num_reconstruction_failtures = 0 exception_counts = collections.defaultdict(int) for _, gold_sql_query in nl_sql_pairs: num_examples += 1 print('Parsing example number %s.' % num_examples) try: sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query, table_schema) sql_spans = abstract_sql.replace_from_clause(sql_spans) except abstract_sql.UnsupportedSqlError as e: print('Error converting:\n%s\n%s' % (gold_sql_query, e)) num_conversion_failures += 1 exception_counts[str(e)[:100]] += 1 continue except abstract_sql.ParseError as e: print('Error parsing:\n%s\n%s' % (gold_sql_query, e)) num_parse_failures += 1 exception_counts[str(e)[:100]] += 1 continue try: sql_spans = abstract_sql.restore_from_clause( sql_spans, foreign_keys) except abstract_sql.UnsupportedSqlError as e: print('Error recontructing:\n%s\n%s' % (gold_sql_query, e)) exception_counts[str(e)[:100]] += 1 num_reconstruction_failtures += 1 continue print('Success:\n%s\n%s' % (gold_sql_query, abstract_sql.sql_spans_to_string(sql_spans))) num_successes += 1 print('exception_counts: %s' % exception_counts) print('Examples: %s' % num_examples) print('Failed conversions: %s' % num_conversion_failures) print('Failed parses: %s' % num_parse_failures) print('Failed reconstructions: %s' % num_reconstruction_failtures) print('Successes: %s' % num_successes)
def _get_abstract_sql(gold_sql, foreign_keys, table_schema, restore_from_clause): """Returns string using abstract SQL transformations.""" print("Processing query:\n%s" % gold_sql) sql_spans = abstract_sql.sql_to_sql_spans(gold_sql, table_schema) if restore_from_clause: sql_spans = abstract_sql.replace_from_clause(sql_spans) print( "Replaced clause query:\n%s" % abstract_sql.sql_spans_to_string(sql_spans) ) sql_spans = abstract_sql.restore_from_clause(sql_spans, foreign_keys) return abstract_sql.sql_spans_to_string(sql_spans)
def test_union_clause(self): query = ("SELECT student_id FROM student_course_registrations UNION SELECT " "student_id FROM student_course_attendance") sql_spans = abstract_sql.sql_to_sql_spans(query) replaced_spans = abstract_sql.replace_from_clause(sql_spans) restored_spans = abstract_sql.restore_from_clause( replaced_spans, fk_relations=[]) expected_sql = ( "select student_course_registrations.student_id from " "student_course_registrations union select " "student_course_attendance.student_id from student_course_attendance") self.assertEqual(expected_sql, abstract_sql.sql_spans_to_string(restored_spans))
def test_nested_sql_with_unqualified_column(self): query = ("SELECT count(*) FROM enzyme WHERE id NOT IN ( SELECT enzyme_id " "FROM medicine_enzyme_interaction )") sql_spans = abstract_sql.sql_to_sql_spans(query) replaced_spans = abstract_sql.replace_from_clause(sql_spans) restored_spans = abstract_sql.restore_from_clause( replaced_spans, fk_relations=[]) expected_sql = ( "select count ( * ) from enzyme where enzyme.id not in ( select " "medicine_enzyme_interaction.enzyme_id from " "medicine_enzyme_interaction )") self.assertEqual(expected_sql, abstract_sql.sql_spans_to_string(restored_spans))
def test_restore_from_string(self): sql_spans = abstract_sql.sql_to_sql_spans(TEST_QUERY) replaced_spans = abstract_sql.replace_from_clause(sql_spans) replaced_sql = abstract_sql.sql_spans_to_string(replaced_spans) parsed_spans = abstract_sql.sql_to_sql_spans(replaced_sql) fk_relations = [ abstract_sql.ForeignKeyRelation("user_profiles", "follows", "uid", "f1") ] restored_spans = abstract_sql.restore_from_clause( parsed_spans, fk_relations=fk_relations ) restored_sql = abstract_sql.sql_spans_to_string(restored_spans) expected_restored_sql = "select user_profiles.name from follows join user_profiles on follows.f1 = user_profiles.uid group by follows.f1 having count ( * ) > ( select count ( * ) from follows join user_profiles on follows.f1 = user_profiles.uid where user_profiles.name = 'tyler swift' )" self.assertEqual(expected_restored_sql, restored_sql)
def test_restore_from_string_no_tables(self): sql_spans = abstract_sql.sql_to_sql_spans( "SELECT foo, count(*) FROM bar WHERE id = 5") replaced_spans = abstract_sql.replace_from_clause(sql_spans) replaced_sql = abstract_sql.sql_spans_to_string(replaced_spans) expected_replaced_sql = ( "select bar.foo , count ( * ) <from_clause_placeholder> where bar.id = 5" ) self.assertEqual(expected_replaced_sql, replaced_sql) parsed_spans = abstract_sql.sql_to_sql_spans(replaced_sql) fk_relations = [] restored_spans = abstract_sql.restore_from_clause( parsed_spans, fk_relations=fk_relations) restored_sql = abstract_sql.sql_spans_to_string(restored_spans) expected_restored_sql = ( "select bar.foo , count ( * ) from bar where bar.id = 5") self.assertEqual(expected_restored_sql, restored_sql)
def restore_predicted_sql(sql_string, table_schemas, foreign_keys): """Restore FROM clause in predicted SQL. TODO(petershaw): Add call to this function from run_inference.py. Args: sql_string: SQL query as string. table_schemas: List of TableSchema tuples. foreign_keys: List of ForeignKeyRelation tuples. Returns: SQL query with restored FROM clause as a string. """ sql_spans = abstract_sql.sql_to_sql_spans( sql_string, table_schemas, lowercase=False) sql_spans = abstract_sql.restore_from_clause(sql_spans, foreign_keys) return abstract_sql.sql_spans_to_string(sql_spans)
def compute_spider_coverage(spider_examples_json, spider_tables_json): """Prints out statistics for asql conversions.""" table_json = _load_json(spider_tables_json) # Map of database id to a list of ForeignKeyRelation tuples. foreign_key_map = abstract_sql_converters.spider_foreign_keys_map( table_json) table_schema_map = abstract_sql_converters.spider_table_schemas_map( table_json) examples = _load_json(spider_examples_json) num_examples = 0 num_conversion_failures = 0 num_reconstruction_failtures = 0 for example in examples: num_examples += 1 print("Parsing example number %s: %s" % (num_examples, example["query"])) gold_sql_query = example["query"] foreign_keys = foreign_key_map[example["db_id"]] table_schema = table_schema_map[example["db_id"]] try: sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query, table_schema) sql_spans = abstract_sql.replace_from_clause(sql_spans) except abstract_sql.UnsupportedSqlError as e: print("Error converting:\n%s\n%s" % (gold_sql_query, e)) num_conversion_failures += 1 else: try: sql_spans = abstract_sql.restore_from_clause( sql_spans, foreign_keys) except abstract_sql.UnsupportedSqlError as e: print("Error recontructing:\n%s\n%s" % (gold_sql_query, e)) num_reconstruction_failtures += 1 print("Examples: %s" % num_examples) print("Failed conversions: %s" % num_conversion_failures) print("Failed reconstructions: %s" % num_reconstruction_failtures)