コード例 #1
0
    def test_nested_query(self):
        query = (
            "SELECT count(*) FROM (SELECT * FROM endowment WHERE amount  >  "
            "8.5 GROUP BY school_id HAVING count(*)  >  1)"
        )
        sql_spans = abstract_sql.sql_to_sql_spans(query)
        replaced_spans = abstract_sql.replace_from_clause(sql_spans)

        expected_replaced_sql = (
            "select count ( * ) <from_clause_placeholder> ( select * "
            "<from_clause_placeholder> where endowment.amount > 8.5 group by "
            "endowment.school_id having count ( * ) > 1 )"
        )
        self.assertEqual(
            expected_replaced_sql, abstract_sql.sql_spans_to_string(replaced_spans)
        )

        restored_spans = abstract_sql.restore_from_clause(
            replaced_spans, fk_relations=[]
        )
        expected_restored_sql = (
            "select count ( * ) from ( select * from endowment where "
            "endowment.amount > 8.5 group by endowment.school_id having count ( * "
            ") > 1 )"
        )
        self.assertEqual(
            expected_restored_sql, abstract_sql.sql_spans_to_string(restored_spans)
        )
コード例 #2
0
def _get_abstract_sql(gold_sql, foreign_keys, table_schema, restore_from_clause):
    """Returns string using abstract SQL transformations."""
    print("Processing query:\n%s" % gold_sql)
    sql_spans = abstract_sql.sql_to_sql_spans(gold_sql, table_schema)
    if restore_from_clause:
        sql_spans = abstract_sql.replace_from_clause(sql_spans)
        print(
            "Replaced clause query:\n%s" % abstract_sql.sql_spans_to_string(sql_spans)
        )
        sql_spans = abstract_sql.restore_from_clause(sql_spans, foreign_keys)
    return abstract_sql.sql_spans_to_string(sql_spans)
コード例 #3
0
 def test_restore_from_string(self):
     sql_spans = abstract_sql.sql_to_sql_spans(TEST_QUERY)
     replaced_spans = abstract_sql.replace_from_clause(sql_spans)
     replaced_sql = abstract_sql.sql_spans_to_string(replaced_spans)
     parsed_spans = abstract_sql.sql_to_sql_spans(replaced_sql)
     fk_relations = [
         abstract_sql.ForeignKeyRelation("user_profiles", "follows", "uid", "f1")
     ]
     restored_spans = abstract_sql.restore_from_clause(
         parsed_spans, fk_relations=fk_relations
     )
     restored_sql = abstract_sql.sql_spans_to_string(restored_spans)
     expected_restored_sql = "select user_profiles.name from follows join user_profiles on follows.f1 = user_profiles.uid group by follows.f1 having count ( * ) > ( select count ( * ) from follows join user_profiles on follows.f1 = user_profiles.uid where user_profiles.name = 'tyler swift' )"
     self.assertEqual(expected_restored_sql, restored_sql)
コード例 #4
0
 def test_inner_join(self):
     original = "SELECT f.name FROM foo AS f INNER JOIN bar AS b"
     sql_spans = abstract_sql.sql_to_sql_spans(original)
     replaced_spans = abstract_sql.replace_from_clause(sql_spans)
     replaced = abstract_sql.sql_spans_to_string(replaced_spans, sep=" ")
     expected = ("select foo.name <from_clause_placeholder> bar")
     self.assertEqual(expected, replaced)
コード例 #5
0
def compute_michigan_coverage():
    """Prints out statistics for asql conversions."""
    # Read data files.
    schema_csv_path = os.path.join(FLAGS.michigan_data_dir,
                                   '%s-schema.csv' % FLAGS.dataset_name)
    examples_json_path = os.path.join(FLAGS.michigan_data_dir,
                                      '%s.json' % FLAGS.dataset_name)
    schema = michigan_preprocessing.read_schema(schema_csv_path)
    if FLAGS.use_oracle_foriegn_keys:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
            FLAGS.dataset_name)
    else:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
            schema)
    table_schema = abstract_sql_converters.michigan_db_to_table_tuples(schema)
    nl_sql_pairs = michigan_preprocessing.get_nl_sql_pairs(
        examples_json_path, FLAGS.splits)

    # Iterate through examples and generate counts.
    num_examples = 0
    num_conversion_failures = 0
    num_successes = 0
    num_parse_failures = 0
    num_reconstruction_failtures = 0
    exception_counts = collections.defaultdict(int)
    for _, gold_sql_query in nl_sql_pairs:
        num_examples += 1
        print('Parsing example number %s.' % num_examples)
        try:
            sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query,
                                                      table_schema)
            sql_spans = abstract_sql.replace_from_clause(sql_spans)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error converting:\n%s\n%s' % (gold_sql_query, e))
            num_conversion_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        except abstract_sql.ParseError as e:
            print('Error parsing:\n%s\n%s' % (gold_sql_query, e))
            num_parse_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        try:
            sql_spans = abstract_sql.restore_from_clause(
                sql_spans, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error recontructing:\n%s\n%s' % (gold_sql_query, e))
            exception_counts[str(e)[:100]] += 1
            num_reconstruction_failtures += 1
            continue
        print('Success:\n%s\n%s' %
              (gold_sql_query, abstract_sql.sql_spans_to_string(sql_spans)))
        num_successes += 1
    print('exception_counts: %s' % exception_counts)
    print('Examples: %s' % num_examples)
    print('Failed conversions: %s' % num_conversion_failures)
    print('Failed parses: %s' % num_parse_failures)
    print('Failed reconstructions: %s' % num_reconstruction_failtures)
    print('Successes: %s' % num_successes)
コード例 #6
0
 def test_parse_asql_tables(self):
     query = "SELECT table1.foo <from_clause_placeholder> table1 table2"
     sql_spans = abstract_sql.sql_to_sql_spans(query)
     sql_spans_string = abstract_sql.sql_spans_to_string(sql_spans, sep=",")
     # Ensure query is split correctly.
     expected_sql_spans_string = (
         "select,table1.foo,<from_clause_placeholder>,table1,table2")
     self.assertEqual(expected_sql_spans_string, sql_spans_string)
コード例 #7
0
 def test_restore_from_string_no_tables(self):
     sql_spans = abstract_sql.sql_to_sql_spans(
         "SELECT foo, count(*) FROM bar WHERE id = 5")
     replaced_spans = abstract_sql.replace_from_clause(sql_spans)
     replaced_sql = abstract_sql.sql_spans_to_string(replaced_spans)
     expected_replaced_sql = (
         "select bar.foo , count ( * ) <from_clause_placeholder> where bar.id = 5"
     )
     self.assertEqual(expected_replaced_sql, replaced_sql)
     parsed_spans = abstract_sql.sql_to_sql_spans(replaced_sql)
     fk_relations = []
     restored_spans = abstract_sql.restore_from_clause(
         parsed_spans, fk_relations=fk_relations)
     restored_sql = abstract_sql.sql_spans_to_string(restored_spans)
     expected_restored_sql = (
         "select bar.foo , count ( * ) from bar where bar.id = 5")
     self.assertEqual(expected_restored_sql, restored_sql)
コード例 #8
0
 def test_order_by(self):
     original = "SELECT title FROM course ORDER BY title ,  credits"
     sql_spans = abstract_sql.sql_to_sql_spans(original)
     replaced_spans = abstract_sql.replace_from_clause(sql_spans)
     replaced = abstract_sql.sql_spans_to_string(replaced_spans, sep=" ")
     expected = ("select course.title <from_clause_placeholder> order by "
                 "course.title , course.credits")
     self.assertEqual(expected, replaced)
コード例 #9
0
 def test_sql_to_sql_spans(self):
   sql_spans = abstract_sql.sql_to_sql_spans(TEST_QUERY)
   expected_sql = ("select user_profiles.name from user_profiles join follows "
                   "on user_profiles.uid = follows.f1 group by follows.f1 "
                   "having count ( * ) > ( select count ( * ) from "
                   "user_profiles join follows on user_profiles.uid = "
                   "follows.f1 where user_profiles.name = 'tyler swift' )")
   self.assertEqual(expected_sql, abstract_sql.sql_spans_to_string(sql_spans))
コード例 #10
0
 def test_parse_order_by(self):
   query = "SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC"
   sql_spans = abstract_sql.sql_to_sql_spans(query)
   sql_spans_string = abstract_sql.sql_spans_to_string(sql_spans, sep=",")
   # Ensure query is split correctly.
   expected_sql_spans_string = (
       "select,farm.total_horses,from,farm,order by,farm.total_horses,asc")
   self.assertEqual(expected_sql_spans_string, sql_spans_string)
コード例 #11
0
 def test_remove_from_clause(self):
   sql_spans = abstract_sql.sql_to_sql_spans(TEST_QUERY)
   replaced_spans = abstract_sql.replace_from_clause(sql_spans)
   expected_sql = ("select user_profiles.name <from_clause_placeholder> group "
                   "by follows.f1 having count ( * ) > ( select count ( * ) "
                   "<from_clause_placeholder> follows where user_profiles.name"
                   " = 'tyler swift' )")
   self.assertEqual(expected_sql,
                    abstract_sql.sql_spans_to_string(replaced_spans))
コード例 #12
0
 def test_find_table(self):
     tables = [
         "author",
         "domain",
         "domain_author",
         "organization",
         "publication",
         "writes",
     ]
     relations = [
         abstract_sql.ForeignKeyRelation(
             child_table="publication",
             parent_table="writes",
             child_column="pid",
             parent_column="pid",
         ),
         abstract_sql.ForeignKeyRelation(
             child_table="author",
             parent_table="writes",
             child_column="aid",
             parent_column="aid",
         ),
         abstract_sql.ForeignKeyRelation(
             child_table="author",
             parent_table="organization",
             child_column="oid",
             parent_column="oid",
         ),
         abstract_sql.ForeignKeyRelation(
             child_table="author",
             parent_table="domain_author",
             child_column="aid",
             parent_column="aid",
         ),
         abstract_sql.ForeignKeyRelation(
             child_table="domain",
             parent_table="domain_author",
             child_column="did",
             parent_column="did",
         ),
     ]
     from_clause_spans = abstract_sql._get_from_clause_for_tables(tables, relations)
     from_clause = abstract_sql.sql_spans_to_string(from_clause_spans, sep=" ")
     expected_from_clause = (
         "author "
         "join domain_author on author.aid = domain_author.aid "
         "join domain on domain_author.did = domain.did "
         "join organization on author.oid = organization.oid "
         "join writes on author.aid = writes.aid "
         "join publication on writes.pid = publication.pid"
     )
     self.assertEqual(expected_from_clause, from_clause)
コード例 #13
0
 def test_union_clause(self):
   query = ("SELECT student_id FROM student_course_registrations UNION SELECT "
            "student_id FROM student_course_attendance")
   sql_spans = abstract_sql.sql_to_sql_spans(query)
   replaced_spans = abstract_sql.replace_from_clause(sql_spans)
   restored_spans = abstract_sql.restore_from_clause(
       replaced_spans, fk_relations=[])
   expected_sql = (
       "select student_course_registrations.student_id from "
       "student_course_registrations union select "
       "student_course_attendance.student_id from student_course_attendance")
   self.assertEqual(expected_sql,
                    abstract_sql.sql_spans_to_string(restored_spans))
コード例 #14
0
 def test_nested_sql_with_unqualified_column(self):
   query = ("SELECT count(*) FROM enzyme WHERE id NOT IN ( SELECT enzyme_id "
            "FROM medicine_enzyme_interaction )")
   sql_spans = abstract_sql.sql_to_sql_spans(query)
   replaced_spans = abstract_sql.replace_from_clause(sql_spans)
   restored_spans = abstract_sql.restore_from_clause(
       replaced_spans, fk_relations=[])
   expected_sql = (
       "select count ( * ) from enzyme where enzyme.id not in ( select "
       "medicine_enzyme_interaction.enzyme_id from "
       "medicine_enzyme_interaction )")
   self.assertEqual(expected_sql,
                    abstract_sql.sql_spans_to_string(restored_spans))
コード例 #15
0
def restore_predicted_sql(sql_string, table_schemas, foreign_keys):
  """Restore FROM clause in predicted SQL.

  TODO(petershaw): Add call to this function from run_inference.py.

  Args:
    sql_string: SQL query as string.
    table_schemas: List of TableSchema tuples.
    foreign_keys: List of ForeignKeyRelation tuples.

  Returns:
    SQL query with restored FROM clause as a string.
  """
  sql_spans = abstract_sql.sql_to_sql_spans(
      sql_string, table_schemas, lowercase=False)
  sql_spans = abstract_sql.restore_from_clause(sql_spans, foreign_keys)
  return abstract_sql.sql_spans_to_string(sql_spans)
コード例 #16
0
 def test_remove_from_clause_multiple_joins(self):
     query = "select business.name , user.name from business join review on business.business_id = review.business_id join user on review.user_id = user.user_id where user.name = 'drake'"
     sql_spans = abstract_sql.sql_to_sql_spans(query)
     replaced_spans = abstract_sql.replace_from_clause(sql_spans)
     expected_sql = "select business.name , user.name <from_clause_placeholder> review where user.name = 'drake'"
     self.assertEqual(expected_sql, abstract_sql.sql_spans_to_string(replaced_spans))