def test_read_database_schema(self):
     schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT /
                                                 "data" / "text2sql" /
                                                 "restaurants-schema.csv")
     # Make it easier to compare:
     schema = {
         k: [(x.name, x.column_type, x.is_primary_key) for x in v]
         for k, v in schema.items()
     }
     assert schema == {
         "RESTAURANT": [
             ("RESTAURANT_ID", "int(11)", True),
             ("NAME", "varchar(255)", False),
             ("FOOD_TYPE", "varchar(255)", False),
             ("CITY_NAME", "varchar(255)", False),
             ("RATING", '"decimal(1', False),
         ],
         "LOCATION": [
             ("RESTAURANT_ID", "int(11)", True),
             ("HOUSE_NUMBER", "int(11)", False),
             ("STREET_NAME", "varchar(255)", False),
             ("CITY_NAME", "varchar(255)", False),
         ],
         "GEOGRAPHIC": [
             ("CITY_NAME", "varchar(255)", True),
             ("COUNTY", "varchar(255)", False),
             ("REGION", "varchar(255)", False),
         ],
     }
    def _read(self, file_path: str):
        """
        This dataset reader consumes the data from
        https://github.com/jkkummerfeld/text2sql-data/tree/master/data
        formatted using ``scripts/reformat_text2sql_data.py``.

        Parameters
        ----------
        file_path : ``str``, required.
            For this dataset reader, file_path can either be a path to a file `or` a
            path to a directory containing json files. The reason for this is because
            some of the text2sql datasets require cross validation, which means they are split
            up into many small files, for which you only want to exclude one.
        """
        files = [
            p for p in glob.glob(file_path) if
            self._cross_validation_split_to_exclude not in os.path.basename(p)
        ]
        schema = read_dataset_schema(self._schema_path)

        for path in files:
            with open(cached_path(path), "r") as data_file:
                data = json.load(data_file)

            for sql_data in text2sql_utils.process_sql_data(
                    data,
                    use_all_sql=self._use_all_sql,
                    remove_unneeded_aliases=self._remove_unneeded_aliases,
                    schema=schema):
                linked_entities = sql_data.sql_variables if self._use_prelinked_entities else None
                instance = self.text_to_instance(sql_data.text_with_variables,
                                                 linked_entities, sql_data.sql)
                if instance is not None:
                    yield instance
 def __init__(self,
              schema_path: str = None) -> None:
     self.grammar_dictionary = deepcopy(GRAMMAR_DICTIONARY)
     schema = read_dataset_schema(schema_path)
     self.all_tables = {k: [x[0] for x in v] for k, v in schema.items()}
     self.grammar_str: str = self.initialize_grammar_str()
     self.grammar: Grammar = Grammar(self.grammar_str)
     self.valid_actions: Dict[str, List[str]] = initialize_valid_actions(self.grammar)
    def test_resolve_primary_keys_in_schema(self):
        schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT / 'data' / 'text2sql' / 'restaurants-schema.csv')
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION', '.', 'ID', ')', 'AS', 'LOCATIONalias0', ";"]


        resolved = text2sql_utils.resolve_primary_keys_in_schema(sql, schema)
        print(resolved)
        assert resolved == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION', '.', 'RESTAURANT_ID', ')', 'AS', 'LOCATIONalias0', ";"]
示例#5
0
    def __init__(self,
                 schema_path: str,
                 cursor: Cursor = None,
                 use_prelinked_entities: bool = True) -> None:
        self.cursor = cursor
        self.schema = read_dataset_schema(schema_path)
        self.use_prelinked_entities = use_prelinked_entities

        # NOTE: This base dictionary should not be modified.
        self.base_grammar_dictionary = self._initialize_grammar_dictionary(
            deepcopy(GRAMMAR_DICTIONARY))
示例#6
0
    def __init__(self,
                 schema_path: str,
                 cursor: Cursor = None,
                 use_prelinked_entities: bool = True) -> None:
        self.cursor = cursor
        self.schema = read_dataset_schema(schema_path)
        self.columns = {column.name: column for table in self.schema.values() for column in table}
        self.dataset_name = os.path.basename(schema_path).split("-")[0]
        self.use_prelinked_entities = use_prelinked_entities

        # NOTE: This base dictionary should not be modified.
        self.base_grammar_dictionary = self._initialize_grammar_dictionary(deepcopy(GRAMMAR_DICTIONARY))
示例#7
0
    def __init__(self,
                 schema_path: str,
                 cursor: Cursor = None,
                 use_prelinked_entities: bool = True,
                 variable_free: bool = True) -> None:
        self.cursor = cursor
        self.schema = read_dataset_schema(schema_path)
        self.dataset_name = os.path.basename(schema_path).split("-")[0]
        self.use_prelinked_entities = use_prelinked_entities
        self.variable_free = variable_free

        # NOTE: This base dictionary should not be modified.
        self.base_grammar_dictionary = self._initialize_grammar_dictionary(deepcopy(GRAMMAR_DICTIONARY))
示例#8
0
    def __init__(self,
                 schema_path: str,
                 cursor: Cursor = None,
                 use_prelinked_entities: bool = True,
                 variable_free: bool = True,
                 use_untyped_entities: bool = False) -> None:
        self.cursor = cursor
        self.schema = read_dataset_schema(schema_path)
        self.columns = {column.name: column for table in self.schema.values() for column in table}
        self.dataset_name = os.path.basename(schema_path).split("-")[0]
        self.use_prelinked_entities = use_prelinked_entities
        self.variable_free = variable_free
        self.use_untyped_entities = use_untyped_entities

        # NOTE: This base dictionary should not be modified.
        self.base_grammar_dictionary = self._initialize_grammar_dictionary(deepcopy(GRAMMAR_DICTIONARY))
示例#9
0
 def test_read_database_schema(self):
     schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT /
                                                 'data' / 'text2sql' /
                                                 'restaurants-schema.csv')
     assert schema == {
         'RESTAURANT': [('RESTAURANT_ID', 'int(11)'),
                        ('NAME', 'varchar(255)'),
                        ('FOOD_TYPE', 'varchar(255)'),
                        ('CITY_NAME', 'varchar(255)'),
                        ('RATING', '"decimal(1')],
         'LOCATION': [('RESTAURANT_ID', 'int(11)'),
                      ('HOUSE_NUMBER', 'int(11)'),
                      ('STREET_NAME', 'varchar(255)'),
                      ('CITY_NAME', 'varchar(255)')],
         'GEOGRAPHIC': [('CITY_NAME', 'varchar(255)'),
                        ('COUNTY', 'varchar(255)'),
                        ('REGION', 'varchar(255)')]
     }
    def test_resolve_primary_keys_in_schema(self):
        schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT /
                                                    "data" / "text2sql" /
                                                    "restaurants-schema.csv")
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "MAX",
            "(",
            "LOCATION",
            ".",
            "ID",
            ")",
            "AS",
            "LOCATIONalias0",
            ";",
        ]

        resolved = text2sql_utils.resolve_primary_keys_in_schema(sql, schema)
        assert resolved == [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "MAX",
            "(",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            ")",
            "AS",
            "LOCATIONalias0",
            ";",
        ]
示例#11
0
 def test_read_database_schema(self):
     schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT / 'data' / 'text2sql' / 'restaurants-schema.csv')
     # Make it easier to compare:
     schema = {k: [(x.name, x.column_type, x.is_primary_key) for x in v]
               for k, v in schema.items()}
     assert schema == {
             'RESTAURANT': [
                     ('RESTAURANT_ID', 'int(11)', True),
                     ('NAME', 'varchar(255)', False),
                     ('FOOD_TYPE', 'varchar(255)', False),
                     ('CITY_NAME', 'varchar(255)', False),
                     ('RATING', '"decimal(1', False)
             ],
             'LOCATION': [
                     ('RESTAURANT_ID', 'int(11)', True),
                     ('HOUSE_NUMBER', 'int(11)', False),
                     ('STREET_NAME', 'varchar(255)', False),
                     ('CITY_NAME', 'varchar(255)', False)
             ],
             'GEOGRAPHIC': [
                     ('CITY_NAME', 'varchar(255)', True),
                     ('COUNTY', 'varchar(255)', False),
                     ('REGION', 'varchar(255)', False)]
             }
示例#12
0
    def test_resolve_primary_keys_in_schema(self):
        schema = text2sql_utils.read_dataset_schema(self.FIXTURES_ROOT / 'data' / 'text2sql' / 'restaurants-schema.csv')
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION', '.', 'ID', ')', 'AS', 'LOCATIONalias0', ";"]

        resolved = text2sql_utils.resolve_primary_keys_in_schema(sql, schema)
        assert resolved == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION', '.', 'RESTAURANT_ID', ')', 'AS', 'LOCATIONalias0', ";"]
示例#13
0
    def __init__(self,
                 schema_path: str,
                 database_file: str = None,
                 use_all_sql: bool = False,
                 use_all_queries: bool = True,
                 remove_unneeded_aliases: bool = False,
                 use_prelinked_entities: bool = True,
                 use_untyped_entities: bool = True,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 cross_validation_split_to_exclude: int = None,
                 lazy: bool = False,
                 load_cache: bool = True,
                 save_cache: bool = True,
                 loading_limit: int = -1) -> None:
        """
        :param schema_path: str path to csv file that describes the database schema
        :param database_file: str optional, used to load values from the database as terminal rules. unused in case of anonymization
        :param use_all_sql: bool if false, for each english query only one example is read, using the first SQL program.
        :param use_all_queries: bool if false, read only one example out of all the examples with the same SQL
        :param remove_unneeded_aliases: bool not supported, default False
        :param use_prelinked_entities: bool not supported, default True
        :param use_untyped_entities: bool not supported, default True
        :param token_indexers: Dict[str, TokenIndexer], optional (default=``{"tokens": SingleIdTokenIndexer()}``)
        We use this to define the input representation for the text.  See :class:`TokenIndexer`.
        Note that the `output` tags will always correspond to single token IDs based on how they
        are pre-tokenised in the data file.
        :param cross_validation_split_to_exclude: int, optional (default = None)
        Some of the text2sql datasets are very small, so you may need to do cross validation.
        Here, you can specify a integer corresponding to a split_{int}.json file not to include
        in the training set.
        :param load_cache: bool if true, loads dataset from cahce
        :param save_cache: bool if true, saves dataset to cache
        :param loading_limit: int if larger than -1, read only loading_limit examples (for debug)
        """
        super().__init__(lazy)

        self._load_cache = load_cache
        self._save_cache = save_cache
        self._loading_limit = loading_limit

        self._token_indexers = token_indexers or {
            'tokens': SingleIdTokenIndexer()
        }
        self._use_all_sql = use_all_sql
        self._remove_unneeded_aliases = remove_unneeded_aliases
        self._use_prelinked_entities = use_prelinked_entities
        self._use_all_queries = use_all_queries

        if not self._use_prelinked_entities:
            raise ConfigurationError(
                "The grammar based text2sql dataset reader "
                "currently requires the use of entity pre-linking.")

        self._cross_validation_split_to_exclude = str(
            cross_validation_split_to_exclude)

        if database_file:
            try:
                database_file = cached_path(database_file)
                connection = sqlite3.connect(database_file)
                self._cursor = connection.cursor()
            except FileNotFoundError as e:
                self._cursor = None
        else:
            self._cursor = None

        self._schema_path = schema_path
        self._schema = read_dataset_schema(self._schema_path)
        self._world = AttnSupGrammarBasedWorld(
            schema_path,
            self._cursor,
            use_prelinked_entities=use_prelinked_entities,
            use_untyped_entities=use_untyped_entities)