def convert_spider( spider_example: SpiderExample, schema: Schema, wordpiece_tokenizer, generate_sql, anonymize_values, abstract_sql=False, table_schemas: Optional[List[TableSchema]] = None, allow_value_generation=False, ) -> Optional[NLToSQLExample]: """ Converts a Spider example to the standard format. Args: spider_example: JSON object for SPIDER example in original format. schema: JSON object for SPIDER schema in converted format. wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance. generate_sql: If True, will populate SQL. anonymize_values: If True, anonymizes values in SQL. abstract_sql: If True, use under-specified FROM clause. table_schemas: required if abstract_sql, list of TableSchema tuples. allow_value_generation: Allow value generation. Returns: NLToSQLExample instance. """ if spider_example["question"] in WRONG_TRAINING_EXAMPLES: return None sql_query: sqlparse.sql.Statement = sqlparse.parse( preprocess_sql(spider_example["query"].rstrip("; ").lower()) )[0] nl_query = " ".join(spider_example["question_toks"]) example = NLToSQLExample.empty(nl_query) # Set the input populate_utterance(example, schema, wordpiece_tokenizer) # Set the output successful_copy = True if generate_sql: if abstract_sql: assert table_schemas successful_copy = abstract_sql_converters.populate_abstract_sql( example, spider_example["query"], table_schemas, anonymize_values ) else: try: successful_copy = populate_sql(sql_query, example, anonymize_values) except ParseError: return None # If the example contained an unsuccessful copy action, and values should not # be generated, then return an empty example. if not successful_copy and not allow_value_generation: return None return example
def convert_wikisql(input_example, schema, tokenizer, generate_sql, anonymize_values): """Converts a WikiSQL example into a NLToSQLExample.""" example = NLToSQLExample() try: try: example = populate_utterance(example, input_example[0], schema, tokenizer) except ValueError as e: print(e) return None # WikiSQL databases have a single table. assert len(schema) == 1 # Some preprocessing of the WikiSQL SQL queries. sql = input_example[1].rstrip('; ') sql = sql.replace('TABLE', schema.keys()[0]) sql = sql.replace('_FIELD', '') string_split_sql = sql.split(' ') if string_split_sql[1].lower() in { 'count', 'min', 'max', 'avg', 'sum' }: # Add parentheses around the column that's an argument of any of these # aggregate functions (because gold annotations don't have it). sql = ' '.join(string_split_sql[0:2] + ['(', string_split_sql[2], ')'] + string_split_sql[3:]) sql = normalize_sql(sql, replace_period=False) try: sql = preprocess_sql(sql) except UnicodeDecodeError as e: print('Unicode error: ' + str(e)) return None sql = sql.lower() parsed_sql = sqlparse.parse(sql)[0] if generate_sql: try: populate_sql(parsed_sql, example, anonymize_values) except (ParseError, ValueError, AssertionError, KeyError, IndexError) as e: print(e) return None if example.gold_sql_query.actions[-1].symbol == '=': print('The last token should not be an equals sign!') return None except UnicodeEncodeError as e: print(e) return None return example
def convert_spider(spider_example, schema, wordpiece_tokenizer, generate_sql, anonymize_values, abstract_sql=False, table_schemas=None, allow_value_generation=False): """Converts a Spider example to the standard format. Args: spider_example: JSON object for SPIDER example in original format. schema: JSON object for SPIDER schema in converted format. wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance. generate_sql: If True, will populate SQL. anonymize_values: If True, anonymizes values in SQL. abstract_sql: If True, use under-specified FROM clause. table_schemas: required if abstract_sql, list of TableSchema tuples. allow_value_generation: Allow value generation. Returns: NLToSQLExample instance. """ if spider_example['question'] in WRONG_TRAINING_EXAMPLES: return None sql_query = spider_example['query'].rstrip('; ') sql_query = sqlparse.parse(preprocess_sql(sql_query.lower()))[0] example = NLToSQLExample() # Set the input populate_utterance(example, ' '.join(spider_example['question_toks']), schema, wordpiece_tokenizer) # Set the output successful_copy = True if generate_sql: if abstract_sql: successful_copy = abstract_sql_converters.populate_abstract_sql( example, spider_example['query'], table_schemas, anonymize_values) else: successful_copy = populate_sql(sql_query, example, anonymize_values) # If the example contained an unsuccessful copy action, and values should not # be generated, then return an empty example. if not successful_copy and not allow_value_generation: return None return example
def convert_wikisql( input_example: Union[Tuple[str, str], Tuple[str, str, Any]], schema, tokenizer, generate_sql: bool, anonymize_values: bool, use_abstract_sql: bool, tables_schema=None, allow_value_generation=False, ) -> Optional[NLToSQLExample]: """ Converts a WikiSQL example into a NLToSQLExample. """ if len(input_example) == 2: # https://github.com/python/mypy/issues/1178 nl_str, sql_str = cast(Tuple[str, str], input_example) else: nl_str, sql_str, _ = cast(Tuple[str, str, Any], input_example) example = NLToSQLExample.empty(nl_str) try: try: populate_utterance(example, schema, tokenizer) except ValueError as e: print("Couldn't populate utterance in wikisql example:") print(e) return None # WikiSQL databases have a single table. assert len(schema) == 1 # Some preprocessing of the WikiSQL SQL queries. sql = input_example[1].rstrip("; ") sql = sql.replace("TABLE", list(schema.keys())[0]) sql = sql.replace("_FIELD", "") string_split_sql = sql.split(" ") if string_split_sql[1].lower() in { "count", "min", "max", "avg", "sum" }: # Add parentheses around the column that's an argument of any of these # aggregate functions (because gold annotations don't have it). sql = " ".join(string_split_sql[0:2] + ["(", string_split_sql[2], ")"] + string_split_sql[3:]) sql = normalize_sql(sql, replace_period=False) try: sql = preprocess_sql(sql) except UnicodeDecodeError: print("Couldn't preprocess sql in wikisql example:") return None sql = sql.lower() parsed_sql = sqlparse.parse(sql)[0] successful_copy = True if generate_sql: try: if use_abstract_sql: successful_copy = abstract_sql_converters.populate_abstract_sql( example, sql, tables_schema, anonymize_values) else: successful_copy = populate_sql(parsed_sql, example, anonymize_values) except ( ParseError, ValueError, AssertionError, KeyError, IndexError, abstract_sql.ParseError, abstract_sql.UnsupportedSqlError, ): return None if not successful_copy and not allow_value_generation: return None if not example.gold_sql_query.actions: return None elif example.gold_sql_query.actions[-1].symbol == "=": return None except UnicodeEncodeError as e: print(e) return None return example
def convert_wikisql(input_example, schema, tokenizer, generate_sql, anonymize_values, use_abstract_sql, tables_schema=None, allow_value_generation=False): """Converts a WikiSQL example into a NLToSQLExample.""" example = NLToSQLExample() try: try: example = populate_utterance(example, input_example[0], schema, tokenizer) except ValueError as e: print(e) return None # WikiSQL databases have a single table. assert len(schema) == 1 # Some preprocessing of the WikiSQL SQL queries. sql = input_example[1].rstrip('; ') sql = sql.replace('TABLE', list(schema.keys())[0]) sql = sql.replace('_FIELD', '') string_split_sql = sql.split(' ') if string_split_sql[1].lower() in { 'count', 'min', 'max', 'avg', 'sum' }: # Add parentheses around the column that's an argument of any of these # aggregate functions (because gold annotations don't have it). sql = ' '.join(string_split_sql[0:2] + ['(', string_split_sql[2], ')'] + string_split_sql[3:]) sql = normalize_sql(sql, replace_period=False) try: sql = preprocess_sql(sql) except UnicodeDecodeError as e: return None sql = sql.lower() parsed_sql = sqlparse.parse(sql)[0] successful_copy = True if generate_sql: try: if use_abstract_sql: successful_copy = abstract_sql_converters.populate_abstract_sql( example, sql, tables_schema, anonymize_values) else: successful_copy = populate_sql(parsed_sql, example, anonymize_values) except (ParseError, ValueError, AssertionError, KeyError, IndexError, abstract_sql.ParseError, abstract_sql.UnsupportedSqlError) as e: return None if not successful_copy and not allow_value_generation: return None if not example.gold_sql_query.actions: return None elif example.gold_sql_query.actions[-1].symbol == '=': return None except UnicodeEncodeError as e: print(e) return None return example
def convert_michigan( nl_query: str, sql_str: str, schema: Schema, tokenizer: FullTokenizer, generate_sql: bool, anonymize_values: bool, abstract_sql: bool, table_schemas: Optional[List[TableSchema]], allow_value_generation: bool, ) -> Optional[NLToSQLExample]: """ Converts a Michigan example to the standard format. Args: nl_query: natural language query sql: SQL query schema: JSON object for SPIDER schema in converted format. wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance. generate_sql: If True, will populate SQL. anonymize_values: If True, anonymizes values in SQL. abstract_sql: If True, use under-specified FROM clause. table_schemas: required if abstract_sql, list of TableSchema tuples. allow_value_generation: Allow value generation. Returns: NLToSQLExample instance. """ example = NLToSQLExample.empty(nl_query) populate_utterance(example, schema, tokenizer) # gold_sql_query = # Set the output successful_copy = True if generate_sql: if abstract_sql: assert table_schemas successful_copy = abstract_sql_converters.populate_abstract_sql( example, sql_str, table_schemas, anonymize_values) else: sql_query: sqlparse.sql.Statement = sqlparse.parse( preprocess_sql(sql_str.rstrip("; ").lower()))[0] try: successful_copy = sql_parsing.populate_sql( sql_query, example, anonymize_values) except sql_parsing.ParseError as e: print(e) successful_copy = False # If the example contained an unsuccessful copy action, and values should not # be generated, then return an empty example. if not successful_copy and not allow_value_generation: return None return example if generate_sql: raise ValueError( "Generating annotated SQL is not yet supported for Michigan datasets. " "Tried to annotate: " + sql_query)
def test_count_function(self) -> None: sql = "select count( * ) from paper where paper.year = '1999' ;" self.assertEqual(sql, sql_utils.preprocess_sql(sql))
def test_anonymize_alias_in_function(self) -> None: sql = "select count(WRITESalias0.AUTHORID) from paper as PAPERalias0 , writes as WRITESalias0 , where PAPERalias0.paperid = WRITESalias0.paperid ;".lower( ) expected_sql = "select count( T1.authorid ) from paper as T2 , writes as T1 , where T2.paperid = T1.paperid ;" self.assertEqual(expected_sql, sql_utils.preprocess_sql(sql))
def test_anonymize_alias(self) -> None: sql = 'SELECT DISTINCT WRITESalias0.AUTHORID FROM PAPER AS PAPERalias0 , VENUE AS VENUEalias0 , WRITES AS WRITESalias0 WHERE VENUEalias0.VENUEID = PAPERalias0.VENUEID AND VENUEalias0.VENUENAME = "venuename0" AND WRITESalias0.PAPERID = PAPERalias0.PAPERID ;'.lower( ) expected_sql = "select distinct T1.authorid from paper as T2 , venue as T3 , writes as T1 where T3.venueid = T2.venueid and T3.venuename = 'venuename0' and T1.paperid = T2.paperid ;" self.assertEqual(expected_sql, sql_utils.preprocess_sql(sql))