Exemplo n.º 1
0
    def test_clean_unneeded_aliases(self):
        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS',
            'LOCATIONalias0', ',', 'RESTAURANT', 'AS', 'RESTAURANTalias0',
            'WHERE', 'LOCATIONalias0', '.', 'CITY_NAME', '=', '\'city_name0\'',
            'AND', 'RESTAURANTalias0', '.', 'ID', '=', 'LOCATIONalias0', '.',
            'RESTAURANT_ID', 'AND', 'RESTAURANTalias0', '.', 'NAME', '=',
            '\'name0\'', ';'
        ]

        cleaned = text2sql_utils.clean_unneeded_aliases(sql)
        assert cleaned == [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', 'WHERE', 'LOCATION', '.', 'CITY_NAME', '=',
            "'city_name0'", 'AND', 'RESTAURANT', '.', 'ID', '=', 'LOCATION',
            '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=',
            "'name0'", ';'
        ]

        # Check we don't mangle decimal numbers:
        assert text2sql_utils.clean_unneeded_aliases(["2.5"]) == ["2.5"]

        # Check we don't remove non-trivial aliases:
        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION',
            '.', 'ID', ')', 'AS', 'LOCATIONalias0', ";"
        ]
        assert text2sql_utils.clean_unneeded_aliases(sql) == sql
Exemplo n.º 2
0
    def test_clean_unneeded_aliases(self):
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', 'AS', 'LOCATIONalias0', ',',
               'RESTAURANT', 'AS', 'RESTAURANTalias0', 'WHERE', 'LOCATIONalias0', '.', 'CITY_NAME', '=',
               '\'city_name0\'', 'AND', 'RESTAURANTalias0', '.', 'ID', '=', 'LOCATIONalias0', '.', 'RESTAURANT_ID',
               'AND', 'RESTAURANTalias0', '.', 'NAME', '=', '\'name0\'', ';']

        cleaned = text2sql_utils.clean_unneeded_aliases(sql)
        assert cleaned == ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', 'WHERE',
                           'LOCATION', '.', 'CITY_NAME', '=', "'city_name0'", 'AND', 'RESTAURANT', '.', 'ID',
                           '=', 'LOCATION', '.', 'RESTAURANT_ID', 'AND', 'RESTAURANT', '.', 'NAME', '=', "'name0'", ';']

        # Check we don't mangle decimal numbers:
        assert text2sql_utils.clean_unneeded_aliases(["2.5"]) == ["2.5"]

        # Check we don't remove non-trivial aliases:
        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'MAX', '(', 'LOCATION', '.', 'ID', ')', 'AS', 'LOCATIONalias0', ";"]
        assert text2sql_utils.clean_unneeded_aliases(sql) == sql
    def text_to_instance(
            self,
            source_string: str,
            target_string: str = None,
            spans: List[Tuple[int, int]] = None) -> Instance:  # type: ignore
        # pylint: disable=arguments-differ
        tokenized_source = self._source_tokenizer.tokenize(source_string)
        if self._source_add_start_token:
            tokenized_source.insert(0, Token(START_SYMBOL))
        tokenized_source.append(Token(END_SYMBOL))
        source_field = TextField(tokenized_source, self._source_token_indexers)

        spans_field: List[Field] = []
        spans = self._fix_spans_coverage(spans, len(tokenized_source))
        for start, end in spans:
            spans_field.append(SpanField(start, end, source_field))
        span_list_field: ListField = ListField(spans_field)

        if target_string is not None:
            if self._schema_free_supervision:
                _, _, target_string = sql_schema_sanitize(
                    target_string,
                    text2sql_utils.read_schema_dict(self._schema_path))
            tokenized_target = self._target_tokenizer.tokenize(target_string)
            if self._remove_unneeded_aliases:
                new_target = tu.clean_unneeded_aliases(
                    [token.text for token in tokenized_target])
                tokenized_target = [Token(t) for t in new_target]
            tokenized_target.insert(0, Token(START_SYMBOL))
            tokenized_target.append(Token(END_SYMBOL))
            target_field = TextField(tokenized_target,
                                     self._target_token_indexers)
            return Instance({
                "source_tokens": source_field,
                "spans": span_list_field,
                "target_tokens": target_field
            })
        else:
            return Instance({
                'source_tokens': source_field,
                "spans": span_list_field
            })
    def test_clean_unneeded_aliases(self):
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            "AS",
            "LOCATIONalias0",
            ",",
            "RESTAURANT",
            "AS",
            "RESTAURANTalias0",
            "WHERE",
            "LOCATIONalias0",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANTalias0",
            ".",
            "ID",
            "=",
            "LOCATIONalias0",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANTalias0",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        cleaned = text2sql_utils.clean_unneeded_aliases(sql)
        assert cleaned == [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "LOCATION",
            ",",
            "RESTAURANT",
            "WHERE",
            "LOCATION",
            ".",
            "CITY_NAME",
            "=",
            "'city_name0'",
            "AND",
            "RESTAURANT",
            ".",
            "ID",
            "=",
            "LOCATION",
            ".",
            "RESTAURANT_ID",
            "AND",
            "RESTAURANT",
            ".",
            "NAME",
            "=",
            "'name0'",
            ";",
        ]

        # Check we don't mangle decimal numbers:
        assert text2sql_utils.clean_unneeded_aliases(["2.5"]) == ["2.5"]

        # Check we don't remove non-trivial aliases:
        sql = [
            "SELECT",
            "COUNT",
            "(",
            "*",
            ")",
            "FROM",
            "MAX",
            "(",
            "LOCATION",
            ".",
            "ID",
            ")",
            "AS",
            "LOCATIONalias0",
            ";",
        ]
        assert text2sql_utils.clean_unneeded_aliases(sql) == sql