Ejemplo n.º 1
0
def convert_spider(
    spider_example: SpiderExample,
    schema: Schema,
    wordpiece_tokenizer,
    generate_sql,
    anonymize_values,
    abstract_sql=False,
    table_schemas: Optional[List[TableSchema]] = None,
    allow_value_generation=False,
) -> Optional[NLToSQLExample]:
    """
    Converts a Spider example to the standard format.

    Args:
        spider_example: JSON object for SPIDER example in original format.
        schema: JSON object for SPIDER schema in converted format.
        wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance.
        generate_sql: If True, will populate SQL.
        anonymize_values: If True, anonymizes values in SQL.
        abstract_sql: If True, use under-specified FROM clause.
        table_schemas: required if abstract_sql, list of TableSchema tuples.
        allow_value_generation: Allow value generation.

    Returns:
        NLToSQLExample instance.
    """
    if spider_example["question"] in WRONG_TRAINING_EXAMPLES:
        return None

    sql_query: sqlparse.sql.Statement = sqlparse.parse(
        preprocess_sql(spider_example["query"].rstrip("; ").lower())
    )[0]

    nl_query = " ".join(spider_example["question_toks"])

    example = NLToSQLExample.empty(nl_query)

    # Set the input
    populate_utterance(example, schema, wordpiece_tokenizer)

    # Set the output
    successful_copy = True
    if generate_sql:
        if abstract_sql:
            assert table_schemas
            successful_copy = abstract_sql_converters.populate_abstract_sql(
                example, spider_example["query"], table_schemas, anonymize_values
            )
        else:
            try:
                successful_copy = populate_sql(sql_query, example, anonymize_values)
            except ParseError:
                return None

    # If the example contained an unsuccessful copy action, and values should not
    # be generated, then return an empty example.
    if not successful_copy and not allow_value_generation:
        return None

    return example
Ejemplo n.º 2
0
def convert_wikisql(input_example, schema, tokenizer, generate_sql,
                    anonymize_values):
    """Converts a WikiSQL example into a NLToSQLExample."""
    example = NLToSQLExample()

    try:
        try:
            example = populate_utterance(example, input_example[0], schema,
                                         tokenizer)
        except ValueError as e:
            print(e)
            return None

        # WikiSQL databases have a single table.
        assert len(schema) == 1

        # Some preprocessing of the WikiSQL SQL queries.
        sql = input_example[1].rstrip('; ')
        sql = sql.replace('TABLE', schema.keys()[0])
        sql = sql.replace('_FIELD', '')
        string_split_sql = sql.split(' ')
        if string_split_sql[1].lower() in {
                'count', 'min', 'max', 'avg', 'sum'
        }:
            # Add parentheses around the column that's an argument of any of these
            # aggregate functions (because gold annotations don't have it).
            sql = ' '.join(string_split_sql[0:2] +
                           ['(', string_split_sql[2], ')'] +
                           string_split_sql[3:])

        sql = normalize_sql(sql, replace_period=False)

        try:
            sql = preprocess_sql(sql)
        except UnicodeDecodeError as e:
            print('Unicode error: ' + str(e))
            return None

        sql = sql.lower()
        parsed_sql = sqlparse.parse(sql)[0]

        if generate_sql:
            try:
                populate_sql(parsed_sql, example, anonymize_values)
            except (ParseError, ValueError, AssertionError, KeyError,
                    IndexError) as e:
                print(e)
                return None

        if example.gold_sql_query.actions[-1].symbol == '=':
            print('The last token should not be an equals sign!')
            return None

    except UnicodeEncodeError as e:
        print(e)
        return None

    return example
Ejemplo n.º 3
0
def convert_spider(spider_example,
                   schema,
                   wordpiece_tokenizer,
                   generate_sql,
                   anonymize_values,
                   abstract_sql=False,
                   table_schemas=None,
                   allow_value_generation=False):
    """Converts a Spider example to the standard format.

  Args:
    spider_example: JSON object for SPIDER example in original format.
    schema: JSON object for SPIDER schema in converted format.
    wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance.
    generate_sql: If True, will populate SQL.
    anonymize_values: If True, anonymizes values in SQL.
    abstract_sql: If True, use under-specified FROM clause.
    table_schemas: required if abstract_sql, list of TableSchema tuples.
    allow_value_generation: Allow value generation.

  Returns:
    NLToSQLExample instance.
  """
    if spider_example['question'] in WRONG_TRAINING_EXAMPLES:
        return None

    sql_query = spider_example['query'].rstrip('; ')
    sql_query = sqlparse.parse(preprocess_sql(sql_query.lower()))[0]

    example = NLToSQLExample()

    # Set the input
    populate_utterance(example, ' '.join(spider_example['question_toks']),
                       schema, wordpiece_tokenizer)

    # Set the output
    successful_copy = True
    if generate_sql:
        if abstract_sql:
            successful_copy = abstract_sql_converters.populate_abstract_sql(
                example, spider_example['query'], table_schemas,
                anonymize_values)
        else:
            successful_copy = populate_sql(sql_query, example,
                                           anonymize_values)

    # If the example contained an unsuccessful copy action, and values should not
    # be generated, then return an empty example.
    if not successful_copy and not allow_value_generation:
        return None

    return example
Ejemplo n.º 4
0
def convert_wikisql(
    input_example: Union[Tuple[str, str], Tuple[str, str, Any]],
    schema,
    tokenizer,
    generate_sql: bool,
    anonymize_values: bool,
    use_abstract_sql: bool,
    tables_schema=None,
    allow_value_generation=False,
) -> Optional[NLToSQLExample]:
    """
    Converts a WikiSQL example into a NLToSQLExample.
    """

    if len(input_example) == 2:
        # https://github.com/python/mypy/issues/1178
        nl_str, sql_str = cast(Tuple[str, str], input_example)
    else:
        nl_str, sql_str, _ = cast(Tuple[str, str, Any], input_example)

    example = NLToSQLExample.empty(nl_str)

    try:
        try:
            populate_utterance(example, schema, tokenizer)
        except ValueError as e:
            print("Couldn't populate utterance in wikisql example:")
            print(e)
            return None

        # WikiSQL databases have a single table.
        assert len(schema) == 1

        # Some preprocessing of the WikiSQL SQL queries.
        sql = input_example[1].rstrip("; ")
        sql = sql.replace("TABLE", list(schema.keys())[0])
        sql = sql.replace("_FIELD", "")
        string_split_sql = sql.split(" ")
        if string_split_sql[1].lower() in {
                "count", "min", "max", "avg", "sum"
        }:
            # Add parentheses around the column that's an argument of any of these
            # aggregate functions (because gold annotations don't have it).
            sql = " ".join(string_split_sql[0:2] +
                           ["(", string_split_sql[2], ")"] +
                           string_split_sql[3:])

        sql = normalize_sql(sql, replace_period=False)

        try:
            sql = preprocess_sql(sql)
        except UnicodeDecodeError:
            print("Couldn't preprocess sql in wikisql example:")
            return None

        sql = sql.lower()
        parsed_sql = sqlparse.parse(sql)[0]

        successful_copy = True
        if generate_sql:
            try:
                if use_abstract_sql:
                    successful_copy = abstract_sql_converters.populate_abstract_sql(
                        example, sql, tables_schema, anonymize_values)
                else:
                    successful_copy = populate_sql(parsed_sql, example,
                                                   anonymize_values)
            except (
                    ParseError,
                    ValueError,
                    AssertionError,
                    KeyError,
                    IndexError,
                    abstract_sql.ParseError,
                    abstract_sql.UnsupportedSqlError,
            ):
                return None

        if not successful_copy and not allow_value_generation:
            return None

        if not example.gold_sql_query.actions:
            return None
        elif example.gold_sql_query.actions[-1].symbol == "=":
            return None

    except UnicodeEncodeError as e:
        print(e)
        return None

    return example
Ejemplo n.º 5
0
def convert_wikisql(input_example,
                    schema,
                    tokenizer,
                    generate_sql,
                    anonymize_values,
                    use_abstract_sql,
                    tables_schema=None,
                    allow_value_generation=False):
    """Converts a WikiSQL example into a NLToSQLExample."""
    example = NLToSQLExample()

    try:
        try:
            example = populate_utterance(example, input_example[0], schema,
                                         tokenizer)
        except ValueError as e:
            print(e)
            return None

        # WikiSQL databases have a single table.
        assert len(schema) == 1

        # Some preprocessing of the WikiSQL SQL queries.
        sql = input_example[1].rstrip('; ')
        sql = sql.replace('TABLE', list(schema.keys())[0])
        sql = sql.replace('_FIELD', '')
        string_split_sql = sql.split(' ')
        if string_split_sql[1].lower() in {
                'count', 'min', 'max', 'avg', 'sum'
        }:
            # Add parentheses around the column that's an argument of any of these
            # aggregate functions (because gold annotations don't have it).
            sql = ' '.join(string_split_sql[0:2] +
                           ['(', string_split_sql[2], ')'] +
                           string_split_sql[3:])

        sql = normalize_sql(sql, replace_period=False)

        try:
            sql = preprocess_sql(sql)
        except UnicodeDecodeError as e:
            return None

        sql = sql.lower()
        parsed_sql = sqlparse.parse(sql)[0]

        successful_copy = True
        if generate_sql:
            try:
                if use_abstract_sql:
                    successful_copy = abstract_sql_converters.populate_abstract_sql(
                        example, sql, tables_schema, anonymize_values)
                else:
                    successful_copy = populate_sql(parsed_sql, example,
                                                   anonymize_values)
            except (ParseError, ValueError, AssertionError, KeyError,
                    IndexError, abstract_sql.ParseError,
                    abstract_sql.UnsupportedSqlError) as e:
                return None

        if not successful_copy and not allow_value_generation:
            return None

        if not example.gold_sql_query.actions:
            return None
        elif example.gold_sql_query.actions[-1].symbol == '=':
            return None

    except UnicodeEncodeError as e:
        print(e)
        return None

    return example
Ejemplo n.º 6
0
def convert_michigan(
    nl_query: str,
    sql_str: str,
    schema: Schema,
    tokenizer: FullTokenizer,
    generate_sql: bool,
    anonymize_values: bool,
    abstract_sql: bool,
    table_schemas: Optional[List[TableSchema]],
    allow_value_generation: bool,
) -> Optional[NLToSQLExample]:
    """
    Converts a Michigan example to the standard format.

    Args:
        nl_query: natural language query
        sql: SQL query
        schema: JSON object for SPIDER schema in converted format.
        wordpiece_tokenizer: language.bert.tokenization.FullTokenizer instance.
        generate_sql: If True, will populate SQL.
        anonymize_values: If True, anonymizes values in SQL.
        abstract_sql: If True, use under-specified FROM clause.
        table_schemas: required if abstract_sql, list of TableSchema tuples.
        allow_value_generation: Allow value generation.

    Returns:
        NLToSQLExample instance.
    """
    example = NLToSQLExample.empty(nl_query)
    populate_utterance(example, schema, tokenizer)

    # gold_sql_query =

    # Set the output
    successful_copy = True
    if generate_sql:
        if abstract_sql:
            assert table_schemas
            successful_copy = abstract_sql_converters.populate_abstract_sql(
                example, sql_str, table_schemas, anonymize_values)
        else:
            sql_query: sqlparse.sql.Statement = sqlparse.parse(
                preprocess_sql(sql_str.rstrip("; ").lower()))[0]
            try:
                successful_copy = sql_parsing.populate_sql(
                    sql_query, example, anonymize_values)
            except sql_parsing.ParseError as e:
                print(e)
                successful_copy = False

    # If the example contained an unsuccessful copy action, and values should not
    # be generated, then return an empty example.
    if not successful_copy and not allow_value_generation:
        return None

    return example

    if generate_sql:
        raise ValueError(
            "Generating annotated SQL is not yet supported for Michigan datasets. "
            "Tried to annotate: " + sql_query)
Ejemplo n.º 7
0
 def test_count_function(self) -> None:
     sql = "select count( * ) from paper where paper.year = '1999' ;"
     self.assertEqual(sql, sql_utils.preprocess_sql(sql))
Ejemplo n.º 8
0
 def test_anonymize_alias_in_function(self) -> None:
     sql = "select count(WRITESalias0.AUTHORID) from paper as PAPERalias0 , writes as WRITESalias0 , where PAPERalias0.paperid = WRITESalias0.paperid ;".lower(
     )
     expected_sql = "select count( T1.authorid ) from paper as T2 , writes as T1 , where T2.paperid = T1.paperid ;"
     self.assertEqual(expected_sql, sql_utils.preprocess_sql(sql))
Ejemplo n.º 9
0
 def test_anonymize_alias(self) -> None:
     sql = 'SELECT DISTINCT WRITESalias0.AUTHORID FROM PAPER AS PAPERalias0 , VENUE AS VENUEalias0 , WRITES AS WRITESalias0 WHERE VENUEalias0.VENUEID = PAPERalias0.VENUEID AND VENUEalias0.VENUENAME = "venuename0" AND WRITESalias0.PAPERID = PAPERalias0.PAPERID ;'.lower(
     )
     expected_sql = "select distinct T1.authorid from paper as T2 , venue as T3 , writes as T1 where T3.venueid = T2.venueid and T3.venuename = 'venuename0' and T1.paperid = T2.paperid ;"
     self.assertEqual(expected_sql, sql_utils.preprocess_sql(sql))