def _read(self, file_path: str):
        """
        This dataset reader consumes the data from
        https://github.com/jkkummerfeld/text2sql-data/tree/master/data
        formatted using ``scripts/reformat_text2sql_data.py``.

        Parameters
        ----------
        file_path : ``str``, required.
            For this dataset reader, file_path can either be a path to a file `or` a
            path to a directory containing json files. The reason for this is because
            some of the text2sql datasets require cross validation, which means they are split
            up into many small files, for which you only want to exclude one.
        """
        files = [
            p for p in glob.glob(file_path) if
            self._cross_validation_split_to_exclude not in os.path.basename(p)
        ]
        schema = read_dataset_schema(self._schema_path)

        for path in files:
            with open(cached_path(path), "r") as data_file:
                data = json.load(data_file)

            for sql_data in text2sql_utils.process_sql_data(
                    data,
                    use_all_sql=self._use_all_sql,
                    remove_unneeded_aliases=self._remove_unneeded_aliases,
                    schema=schema):
                linked_entities = sql_data.sql_variables if self._use_prelinked_entities else None
                instance = self.text_to_instance(sql_data.text_with_variables,
                                                 linked_entities, sql_data.sql)
                if instance is not None:
                    yield instance
示例#2
0
    def _read(self, file_path: str):
        """
        This dataset reader consumes the data from
        https://github.com/jkkummerfeld/text2sql-data/tree/master/data
        formatted using ``scripts/reformat_text2sql_data.py``.

        Parameters
        ----------
        file_path : ``str``, required.
            For this dataset reader, file_path can either be a path to a file `or` a
            path to a directory containing json files. The reason for this is because
            some of the text2sql datasets require cross validation, which means they are split
            up into many small files, for which you only want to exclude one.
        """
        files = [
            p for p in glob.glob(file_path) if
            self._cross_validation_split_to_exclude not in os.path.basename(p)
        ]

        for path in files:
            with open(cached_path(path), "r") as data_file:
                data = json.load(data_file)

            for sql_data in text2sql_utils.process_sql_data(
                    data, self._use_all_sql):
                template = " ".join(sql_data.sql)
                yield self.text_to_instance(sql_data.text,
                                            sql_data.variable_tags, template)
示例#3
0
    def test_process_sql_data_blob(self):

        data = json.load(open(str(self.data)))
        dataset = text2sql_utils.process_sql_data([data[0]])
        dataset = list(dataset)
        sql_data = dataset[0]
        # Check that question de-duplication happens by default
        # (otherwise there would be more than 1 dataset element).
        assert len(dataset) == 1
        assert sql_data.text == ['how', 'many', 'buttercup', 'kitchen', 'are', 'there', 'in', 'san', 'francisco', '?']
        assert sql_data.text_with_variables == ['how', 'many', 'name0', 'are', 'there', 'in', 'city_name0', '?']
        assert sql_data.sql == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',',
                                'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'LOCATIONalias0', '.', 'CITY_NAME', '=',
                                '\'city_name0\'', 'AND', 'RESTAURANTalias0', '.', 'ID', '=', 'LOCATIONalias0', '.', 'RESTAURANT_ID',
                                'AND', 'RESTAURANTalias0', '.', 'NAME', '=', '\'name0\'', ';']

        assert sql_data.text_variables == {'city_name0': 'san francisco', 'name0': 'buttercup kitchen'}
        assert sql_data.sql_variables == {'city_name0': {'text': 'san francisco', 'type': 'city_name'},
                                          'name0': {'text': 'buttercup kitchen', 'type': 'name'}}


        dataset = text2sql_utils.process_sql_data([data[1]])
        correct_text = [
                [['how', 'many', 'chinese', 'restaurants', 'are', 'there', 'in', 'the', 'bay', 'area', '?'],
                 ['how', 'many', 'food_type0', 'restaurants', 'are', 'there', 'in', 'the', 'region0', '?']],
                [['how', 'many', 'places', 'for', 'chinese', 'food', 'are', 'there', 'in', 'the', 'bay', 'area', '?'],
                 ['how', 'many', 'places', 'for', 'food_type0', 'food', 'are', 'there', 'in', 'the', 'region0', '?']],
                [['how', 'many', 'chinese', 'places', 'are', 'there', 'in', 'the', 'bay', 'area', '?'],
                 ['how', 'many', 'food_type0', 'places', 'are', 'there', 'in', 'the', 'region0', '?']],
                [['how', 'many', 'places', 'for', 'chinese', 'are', 'there', 'in', 'the', 'bay', 'area', '?'],
                 ['how', 'many', 'places', 'for', 'food_type0', 'are', 'there', 'in', 'the', 'region0', '?']],
        ]

        for i, sql_data in enumerate(dataset):
            assert sql_data.sql == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'GEOGRAPHIC', 'AS', 'GEOGRAPHICalias0',
                                    ',', 'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'GEOGRAPHICalias0', '.', 'REGION',
                                    '=', '\'region0\'', 'AND', 'RESTAURANTalias0', '.', 'CITY_NAME', '=', 'GEOGRAPHICalias0',
                                    '.', 'CITY_NAME', 'AND', 'RESTAURANTalias0', '.', 'FOOD_TYPE', '=', '\'food_type0\'', ';']
            assert sql_data.text_variables == {'region0': 'bay area', 'food_type0': 'chinese'}
            assert sql_data.sql_variables == {'region0': {'text': 'bay area', 'type': 'region'},
                                              'food_type0': {'text': 'chinese', 'type': 'food_type'}}
            assert sql_data.text == correct_text[i][0]
            assert sql_data.text_with_variables == correct_text[i][1]
def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

    filter_by = filter_by or "13754332dvmklfdsaf-3543543"
    data = json.load(open(filename))
    num_queries = 0
    num_parsed = 0
    filtered_errors = 0

    for i, sql_data in enumerate(process_sql_data(data)):
        sql_visitor = SqlVisitor(SQL_GRAMMAR2)

        if any([x[:7] == "DERIVED"] for x in sql_data.sql):
            # NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
            # so we fix the data a bit here instead of completely re-working the grammar.
            sql_to_use = []
            for j, token in enumerate(sql_data.sql):
                if token[:7] == "DERIVED" and sql_data.sql[j-1] == ")":
                    sql_to_use.append("AS")
                sql_to_use.append(token)

            sql_string = " ".join(sql_to_use)
        else:
            sql_string = " ".join(sql_data.sql)
        num_queries += 1
        try:
            prod_rules = sql_visitor.parse(sql_string)
            num_parsed += 1
        except Exception as e:

            if filter_by in sql_string:
                filtered_errors += 1

            if verbose and filter_by not in sql_string:
                print()
                print(e)
                print(" ".join(sql_data.text))
                print(sql_data.sql)
                try:
                    import sqlparse
                    print(sqlparse.format(sql_string, reindent=True))
                except Exception:
                    print(sql_string)

        if (i + 1) % 500 == 0:
            print(f"\tProcessed {i + 1} queries.")

    return num_parsed, num_queries, filtered_errors
示例#5
0
    def _read(self, file_path: str):
        """
        This dataset reader consumes the data from
        https://github.com/jkkummerfeld/text2sql-data/tree/master/data
        formatted using ``scripts/reformat_text2sql_data.py``.

        Parameters
        ----------
        file_path : ``str``, required.
            For this dataset reader, file_path can either be a path to a file `or` a
            path to a directory containing json files. The reason for this is because
            some of the text2sql datasets require cross validation, which means they are split
            up into many small files, for which you only want to exclude one.
        """
        files = [p for p in glob.glob(file_path)
                 if self._cross_validation_split_to_exclude not in os.path.basename(p)]

        for path in files:
            with open(cached_path(path), "r") as data_file:
                data = json.load(data_file)

            for sql_data in text2sql_utils.process_sql_data(data, self._use_all_sql):
                template = " ".join(sql_data.sql)
                yield self.text_to_instance(sql_data.text, sql_data.variable_tags, template)
示例#6
0
def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

    grammar_string = format_grammar_string(GRAMMAR_DICTIONARY)
    grammar = Grammar(grammar_string)

    filter_by = filter_by or "13754332dvmklfdsaf-3543543"
    data = json.load(open(filename))
    num_queries = 0
    num_parsed = 0
    filtered_errors = 0

    non_basic_as_aliases = 0
    as_count = 0
    queries_with_weird_as = 0

    for i, sql_data in enumerate(process_sql_data(data)):
        sql_visitor = SqlVisitor(grammar)

        if any([x[:7] == "DERIVED"] for x in sql_data.sql):
            # NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
            # so we fix the data a bit here instead of completely re-working the grammar.
            sql_to_use = []
            for j, token in enumerate(sql_data.sql):
                if token[:7] == "DERIVED" and sql_data.sql[j-1] == ")":
                    sql_to_use.append("AS")
                sql_to_use.append(token)

            previous_token = None
            query_has_weird_as = False
            for j, token in enumerate(sql_to_use[:-1]):

                if token == "AS" and previous_token is not None:

                    table_name = sql_to_use[j + 1][:-6]
                    if table_name != previous_token:
                        non_basic_as_aliases += 1
                        query_has_weird_as = True
                    as_count += 1
                previous_token = token

            if query_has_weird_as:
                queries_with_weird_as += 1


            sql_string = " ".join(sql_to_use)
        else:
            sql_string = " ".join(sql_data.sql)
        num_queries += 1
        try:
            prod_rules = sql_visitor.parse(sql_string)
            num_parsed += 1
        except Exception as e:

            if filter_by in sql_string:
                filtered_errors += 1

            if verbose and filter_by not in sql_string:
                print()
                print(e)
                print(" ".join(sql_data.text))
                print(sql_data.sql)
                try:
                    import sqlparse
                    print(sqlparse.format(sql_string, reindent=True))
                except Exception:
                    print(sql_string)

        if (i + 1) % 500 == 0:
            print(f"\tProcessed {i + 1} queries.")

    return num_parsed, num_queries, filtered_errors, non_basic_as_aliases, as_count, queries_with_weird_as
示例#7
0
 def test_process_sql_data_can_yield_all_queries(self):
     data = json.load(open(str(self.data)))
     dataset = text2sql_utils.process_sql_data([data[0]], use_all_queries=True)
     dataset = list(dataset)
     assert len(dataset) == 3
def parse_dataset(filename: str, filter_by: str = None, verbose: bool = False):

    grammar_string = format_grammar_string(GRAMMAR_DICTIONARY)
    grammar = Grammar(grammar_string)

    filter_by = filter_by or "13754332dvmklfdsaf-3543543"
    data = json.load(open(filename))
    num_queries = 0
    num_parsed = 0
    filtered_errors = 0

    non_basic_as_aliases = 0
    as_count = 0
    queries_with_weird_as = 0

    for i, sql_data in enumerate(process_sql_data(data)):
        sql_visitor = SqlVisitor(grammar)

        if any([x[:7] == "DERIVED"] for x in sql_data.sql):
            # NOTE: DATA hack alert - the geography dataset doesn't alias derived tables consistently,
            # so we fix the data a bit here instead of completely re-working the grammar.
            sql_to_use = []
            for j, token in enumerate(sql_data.sql):
                if token[:7] == "DERIVED" and sql_data.sql[j - 1] == ")":
                    sql_to_use.append("AS")
                sql_to_use.append(token)

            previous_token = None
            query_has_weird_as = False
            for j, token in enumerate(sql_to_use[:-1]):

                if token == "AS" and previous_token is not None:

                    table_name = sql_to_use[j + 1][:-6]
                    if table_name != previous_token:
                        non_basic_as_aliases += 1
                        query_has_weird_as = True
                    as_count += 1
                previous_token = token

            if query_has_weird_as:
                queries_with_weird_as += 1

            sql_string = " ".join(sql_to_use)
        else:
            sql_string = " ".join(sql_data.sql)
        num_queries += 1
        try:
            prod_rules = sql_visitor.parse(sql_string)
            num_parsed += 1
        except Exception as e:

            if filter_by in sql_string:
                filtered_errors += 1

            if verbose and filter_by not in sql_string:
                print()
                print(e)
                print(" ".join(sql_data.text))
                print(sql_data.sql)
                try:
                    import sqlparse
                    print(sqlparse.format(sql_string, reindent=True))
                except Exception:
                    print(sql_string)

        if (i + 1) % 500 == 0:
            print(f"\tProcessed {i + 1} queries.")

    return num_parsed, num_queries, filtered_errors, non_basic_as_aliases, as_count, queries_with_weird_as
    def test_process_sql_data_blob(self):

        data = json.load(open(str(self.data)))
        dataset = text2sql_utils.process_sql_data([data[0]])
        dataset = list(dataset)
        sql_data = dataset[0]
        # Check that question de-duplication happens by default
        # (otherwise there would be more than 1 dataset element).
        assert len(dataset) == 1
        assert sql_data.text == [
            "how",
            "many",
            "buttercup",
            "kitchen",
            "are",
            "there",
            "in",
            "san",
            "francisco",
            "?",
        ]
        assert sql_data.text_with_variables == [
            "how",
            "many",
            "name0",
            "are",
            "there",
            "in",
            "city_name0",
            "?",
        ]
        assert sql_data.sql == [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            "AS",
            "LOCATIONalias0",
            ",",
            "RESTAURANT",
            "AS",
            "RESTAURANTalias0",
            "WHERE",
            "LOCATIONalias0",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANTalias0",
            ".",
            "ID",
            "=",
            "LOCATIONalias0",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANTalias0",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        assert sql_data.text_variables == {
            "city_name0": "san francisco",
            "name0": "buttercup kitchen",
        }
        assert sql_data.sql_variables == {
            "city_name0": {
                "text": "san francisco",
                "type": "city_name"
            },
            "name0": {
                "text": "buttercup kitchen",
                "type": "name"
            },
        }

        dataset = text2sql_utils.process_sql_data([data[1]])
        correct_text = [
            [
                [
                    "how",
                    "many",
                    "chinese",
                    "restaurants",
                    "are",
                    "there",
                    "in",
                    "the",
                    "bay",
                    "area",
                    "?",
                ],
                [
                    "how",
                    "many",
                    "food_type0",
                    "restaurants",
                    "are",
                    "there",
                    "in",
                    "the",
                    "region0",
                    "?",
                ],
            ],
            [
                [
                    "how",
                    "many",
                    "places",
                    "for",
                    "chinese",
                    "food",
                    "are",
                    "there",
                    "in",
                    "the",
                    "bay",
                    "area",
                    "?",
                ],
                [
                    "how",
                    "many",
                    "places",
                    "for",
                    "food_type0",
                    "food",
                    "are",
                    "there",
                    "in",
                    "the",
                    "region0",
                    "?",
                ],
            ],
            [
                [
                    "how",
                    "many",
                    "chinese",
                    "places",
                    "are",
                    "there",
                    "in",
                    "the",
                    "bay",
                    "area",
                    "?",
                ],
                [
                    "how",
                    "many",
                    "food_type0",
                    "places",
                    "are",
                    "there",
                    "in",
                    "the",
                    "region0",
                    "?",
                ],
            ],
            [
                [
                    "how",
                    "many",
                    "places",
                    "for",
                    "chinese",
                    "are",
                    "there",
                    "in",
                    "the",
                    "bay",
                    "area",
                    "?",
                ],
                [
                    "how",
                    "many",
                    "places",
                    "for",
                    "food_type0",
                    "are",
                    "there",
                    "in",
                    "the",
                    "region0",
                    "?",
                ],
            ],
        ]

        for i, sql_data in enumerate(dataset):
            assert sql_data.sql == [
                "SELECT",
                "COUNT",
                "(",
                "*",
                ")",
                "FROM",
                "GEOGRAPHIC",
                "AS",
                "GEOGRAPHICalias0",
                ",",
                "RESTAURANT",
                "AS",
                "RESTAURANTalias0",
                "WHERE",
                "GEOGRAPHICalias0",
                ".",
                "REGION",
                "=",
                "'region0'",
                "AND",
                "RESTAURANTalias0",
                ".",
                "CITY_NAME",
                "=",
                "GEOGRAPHICalias0",
                ".",
                "CITY_NAME",
                "AND",
                "RESTAURANTalias0",
                ".",
                "FOOD_TYPE",
                "=",
                "'food_type0'",
                ";",
            ]
            assert sql_data.text_variables == {
                "region0": "bay area",
                "food_type0": "chinese"
            }
            assert sql_data.sql_variables == {
                "region0": {
                    "text": "bay area",
                    "type": "region"
                },
                "food_type0": {
                    "text": "chinese",
                    "type": "food_type"
                },
            }
            assert sql_data.text == correct_text[i][0]
            assert sql_data.text_with_variables == correct_text[i][1]