Exemplo n.º 1
0
 def test_repr(self):
     token = sql.Token(T.Keyword, 'foo')
     tst = "<Keyword 'foo' at 0x"
     self.assertEqual(repr(token)[:len(tst)], tst)
     token = sql.Token(T.Keyword, '1234567890')
     tst = "<Keyword '123456...' at 0x"
     self.assertEqual(repr(token)[:len(tst)], tst)
Exemplo n.º 2
0
def test_token_repr():
    token = sql.Token(T.Keyword, 'foo')
    tst = "<Keyword 'foo' at 0x"
    assert repr(token)[:len(tst)] == tst
    token = sql.Token(T.Keyword, '1234567890')
    tst = "<Keyword '123456...' at 0x"
    assert repr(token)[:len(tst)] == tst
Exemplo n.º 3
0
 def _replace_table_entity_name(self,
                                parent,
                                token,
                                table_name,
                                entity_name=None):
     if not entity_name:
         entity_name = table_name
     next_token = self._token_next(parent, token)
     if not table_name in self._skip_tables + self._translate_tables:
         token_to_replace = parent.tokens[self._token_idx(parent, token)]
         if isinstance(token_to_replace, Types.Function):
             t = self._token_first(token_to_replace)
             if isinstance(t, Types.Identifier):
                 token_to_replace.tokens[self._token_idx(
                     token_to_replace, t)] = Types.Token(
                         Tokens.Keyword,
                         self._prefixed_table_entity_name(entity_name))
         elif isinstance(token_to_replace, Types.Identifier) or isinstance(
                 token_to_replace, Types.Token):
             parent.tokens[self._token_idx(
                 parent, token_to_replace)] = Types.Token(
                     Tokens.Keyword,
                     self._prefixed_table_entity_name(entity_name))
         else:
             raise Exception(
                 "Internal error, invalid table entity token type")
     return next_token
Exemplo n.º 4
0
    def _process_identifierlist(self, tlist):
        identifiers = list(tlist.get_identifiers())
        if self.indent_columns:
            first = next(identifiers[0].flatten())
            num_offset = 1 if self.char == '\t' else self.width
        else:
            first = next(identifiers.pop(0).flatten())
            num_offset = 1 if self.char == '\t' else self._get_offset(first)

        if not tlist.within(sql.Function):
            with offset(self, num_offset):
                position = 0
                for token in identifiers:
                    # Add 1 for the "," separator
                    position += len(token.value) + 1
                    if position > (self.wrap_after - self.offset):
                        adjust = 0
                        if self.comma_first:
                            adjust = -2
                            _, comma = tlist.token_prev(
                                tlist.token_index(token))
                            if comma is None:
                                continue
                            token = comma
                        tlist.insert_before(token, self.nl(offset=adjust))
                        if self.comma_first:
                            _, ws = tlist.token_next(tlist.token_index(token),
                                                     skip_ws=False)
                            if (ws is not None
                                    and ws.ttype is not T.Text.Whitespace):
                                tlist.insert_after(
                                    token, sql.Token(T.Whitespace, ' '))
                        position = 0
        else:
            # ensure whitespace
            for token in tlist:
                _, next_ws = tlist.token_next(tlist.token_index(token),
                                              skip_ws=False)
                if token.value == ',' and not next_ws.is_whitespace:
                    tlist.insert_after(token, sql.Token(T.Whitespace, ' '))

            end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
            adjusted_offset = 0
            if (self.wrap_after > 0 and end_at >
                (self.wrap_after - self.offset) and self._last_func):
                adjusted_offset = -len(self._last_func.value) - 1

            with offset(self, adjusted_offset), indent(self):
                if adjusted_offset < 0:
                    tlist.insert_before(identifiers[0], self.nl())
                position = 0
                for token in identifiers:
                    # Add 1 for the "," separator
                    position += len(token.value) + 1
                    if (self.wrap_after > 0 and position >
                        (self.wrap_after - self.offset)):
                        adjust = 0
                        tlist.insert_before(token, self.nl(offset=adjust))
                        position = 0
        self._process_default(tlist)
Exemplo n.º 5
0
 def _update_delete_where_limit(self, table_name, parent, start_token):
     if not start_token:
         return
     where_token = start_token if isinstance(start_token, Types.Where) \
                               else self._token_next_by_instance(parent, start_token, Types.Where)
     if where_token:
         self._where(parent, where_token)
     if not table_name in self._translate_tables:
         return
     if where_token:
         keywords = [self._product_column, '=', "'", self._product_prefix, "'", ' ', 'AND', ' ']
         keywords.reverse()
         token = self._token_first(where_token)
         if not token.match(Tokens.Keyword, 'WHERE'):
             token = self._token_next_match(where_token, token, Tokens.Keyword, 'WHERE')
         if not token:
             raise Exception("Invalid UPDATE statement, failed to parse WHERE")
         for keyword in keywords:
             self._token_insert_after(where_token, token, Types.Token(Tokens.Keyword, keyword))
     else:
         keywords = ['WHERE', ' ', self._product_column, '=', "'", self._product_prefix, "'"]
         limit_token = self._token_next_match(parent, start_token, Tokens.Keyword, 'LIMIT')
         if limit_token:
             for keyword in keywords:
                 self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, keyword))
             self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, ' '))
         else:
             last_token = token = start_token
             while token:
                 last_token = token
                 token = self._token_next(parent, token)
             keywords.reverse()
             for keyword in keywords:
                 self._token_insert_after(parent, last_token, Types.Token(Tokens.Keyword, keyword))
     return
Exemplo n.º 6
0
def test_tokenlist_token_matching():
    t1 = sql.Token(T.Keyword, 'foo')
    t2 = sql.Token(T.Punctuation, ',')
    x = sql.TokenList([t1, t2])
    assert x.token_matching([lambda t: t.ttype is T.Keyword], 0) == t1
    assert x.token_matching([lambda t: t.ttype is T.Punctuation], 0) == t2
    assert x.token_matching([lambda t: t.ttype is T.Keyword], 1) is None
Exemplo n.º 7
0
    def _process(tlist):
        def get_next_comment():
            # TODO(andi) Comment types should be unified, see related issue38
            return tlist.token_next_by(i=sql.Comment, t=T.Comment)

        tidx, token = get_next_comment()
        while token:
            pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
            nidx, next_ = tlist.token_next(tidx, skip_ws=False)
            # Replace by whitespace if prev and next exist and if they're not
            # whitespaces. This doesn't apply if prev or next is a parenthesis.
            if (prev_ is None or next_ is None or prev_.is_whitespace
                    or prev_.match(T.Punctuation, '(') or next_.is_whitespace
                    or next_.match(T.Punctuation, ')')):
                # Insert a whitespace to ensure the following SQL produces
                # a valid SQL (see #425). For example:
                #
                # Before: select a--comment\nfrom foo
                # After: select a from foo
                if prev_ is not None and next_ is None:
                    tlist.tokens.insert(tidx, sql.Token(T.Whitespace, ' '))
                tlist.tokens.remove(token)
            else:
                tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ')

            tidx, token = get_next_comment()
Exemplo n.º 8
0
 def _get_insert_token(token):
     """Returns either a whitespace or the line breaks from token."""
     # See issue484 why line breaks should be preserved.
     m = re.search(r'((\r\n|\r|\n)+) *$', token.value)
     if m is not None:
         return sql.Token(T.Whitespace.Newline, m.groups()[0])
     else:
         return sql.Token(T.Whitespace, ' ')
Exemplo n.º 9
0
def test_issue212_py2unicode():
    if sys.version_info < (3, ):
        t1 = sql.Token(T.String, u"schöner ")
    else:
        t1 = sql.Token(T.String, "schöner ")
    t2 = sql.Token(T.String, u"bug")
    l = sql.TokenList([t1, t2])
    assert str(l) == 'schöner bug'
 def test_token_matching(self):
     t1 = sql.Token(Keyword, 'foo')
     t2 = sql.Token(Punctuation, ',')
     x = sql.TokenList([t1, t2])
     self.assertEqual(x.token_matching(0, [lambda t: t.ttype is Keyword]),
                      t1)
     self.assertEqual(
         x.token_matching(0, [lambda t: t.ttype is Punctuation]), t2)
     self.assertEqual(x.token_matching(1, [lambda t: t.ttype is Keyword]),
                      None)
Exemplo n.º 11
0
 def _generate_attribute_token(self,
                               attr_name_token,
                               attr_value_token=None):
     attribute_token_list = [
         sql.Token(value=attr_name_token.value, ttype=T.Keyword)
     ]
     if attr_value_token is not None:
         attribute_token_list.append(
             sql.Token(value=attr_value_token.value.strip('`"\''),
                       ttype=attr_value_token.ttype))
     return sql.Attribute(tokens=attribute_token_list)
Exemplo n.º 12
0
 def _get_insert_token(token):
     """Returns either a whitespace or the line breaks from token."""
     # See issue484 why line breaks should be preserved.
     # Note: The actual value for a line break is replaced by \n
     # in SerializerUnicode which will be executed in the
     # postprocessing state.
     m = re.search(r'((\r|\n)+) *$', token.value)
     if m is not None:
         return sql.Token(T.Whitespace.Newline, m.groups()[0])
     else:
         return sql.Token(T.Whitespace, ' ')
Exemplo n.º 13
0
    def _process(tlist):
        def next_token(idx=0):
            return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx)

        token = next_token()
        while token:
            prev_ = tlist.token_prev(token, skip_ws=False)
            if prev_ and prev_.ttype != T.Whitespace:
                tlist.insert_before(token, sql.Token(T.Whitespace, ' '))

            next_ = tlist.token_next(token, skip_ws=False)
            if next_ and next_.ttype != T.Whitespace:
                tlist.insert_after(token, sql.Token(T.Whitespace, ' '))

            token = next_token(idx=token)
Exemplo n.º 14
0
def build_comparison(identifier, operator):
    '''
    Build an SQL token representing a comparison ``identifier operator %s``
    
    :param identifier: An identifier previously found by calling ``find_identifier``
    :param operator: An operator like =, <, >=, ...
    '''

    return sql.Comparison([
        bare_identifier(identifier),
        sql.Token(tokens.Whitespace, " "),
        sql.Token(tokens.Comparison, operator),
        sql.Token(tokens.Whitespace, " "),
        sql.Token(tokens.Wildcard, "%s")
    ])
Exemplo n.º 15
0
    def nl(self, offset=1):
        # offset = 1 represent a single space after SELECT
        offset = -len(offset) if not isinstance(offset, int) else offset
        # add two for the space and parens

        return sql.Token(T.Whitespace,
                         self.n + self.char * (self.leading_ws + offset))
Exemplo n.º 16
0
    def _process(tlist):

        ttypes = (T.Operator, T.Comparison)
        tidx, token = tlist.token_next_by(t=ttypes)
        while token:
            nidx, next_ = tlist.token_next(tidx, skip_ws=False)
            if next_ and next_.ttype != T.Whitespace:
                tlist.insert_after(tidx, sql.Token(T.Whitespace, ' '))

            pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
            if prev_ and prev_.ttype != T.Whitespace:
                tlist.insert_before(tidx, sql.Token(T.Whitespace, ' '))
                tidx += 1  # has to shift since token inserted before it

            # assert tlist.token_index(token) == tidx
            tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
Exemplo n.º 17
0
 def insert_extra_column_value(tablename, ptoken, before_token):
     if tablename in self._translate_tables:
         for keyword in [',', "'", self._product_prefix, "'"]:
             self._token_insert_before(
                 ptoken, before_token,
                 Types.Token(Tokens.Keyword, keyword))
     return
Exemplo n.º 18
0
    def process(self, stream):
        """Process the stream"""
        EOS_TTYPE = T.Whitespace, T.Comment.Single

        # Run over all stream tokens
        for ttype, value in stream:
            # Yield token if we finished a statement and there's no whitespaces
            # It will count newline token as a non whitespace. In this context
            # whitespace ignores newlines.
            # why don't multi line comments also count?
            if self.consume_ws and ttype not in EOS_TTYPE:
                yield sql.Statement(self.tokens)

                # Reset filter and prepare to process next statement
                self._reset()

            # Change current split level (increase, decrease or remain equal)
            self.level += self._change_splitlevel(ttype, value)

            # Append the token to the current statement
            self.tokens.append(sql.Token(ttype, value))

            # Check if we get the end of a statement
            if self.level <= 0 and ttype is T.Punctuation and value == ';':
                self.consume_ws = True

        # Yield pending statement (if any)
        if self.tokens and not all(t.is_whitespace for t in self.tokens):
            yield sql.Statement(self.tokens)
Exemplo n.º 19
0
def test_extract_value(_label, token_value, expected):
    token = token_groups.Token(token_types.Literal, token_value)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        value = common.extract_value(token)
    assert value == expected
Exemplo n.º 20
0
 def insert_extra_column(tablename, columns_token):
     if tablename in self._translate_tables and \
        isinstance(columns_token, Types.Parenthesis):
         ptoken = self._token_first(columns_token)
         if not ptoken.match(Tokens.Punctuation, '('):
             raise Exception(
                 "Invalid INSERT statement, expected parenthesis around columns"
             )
         ptoken = self._token_next(columns_token, ptoken)
         last_token = ptoken
         while ptoken:
             if isinstance(ptoken, Types.IdentifierList):
                 if any(i.get_name() == 'product'
                        for i in ptoken.get_identifiers()
                        if isinstance(i, Types.Identifier)):
                     return True
             last_token = ptoken
             ptoken = self._token_next(columns_token, ptoken)
         if not last_token or \
            not last_token.match(Tokens.Punctuation, ')'):
             raise Exception(
                 "Invalid INSERT statement, unable to find column parenthesis end"
             )
         for keyword in [',', ' ', self._product_column]:
             self._token_insert_before(
                 columns_token, last_token,
                 Types.Token(Tokens.Keyword, keyword))
     return False
Exemplo n.º 21
0
    def _process_case(self, tlist):
        offset_ = len('case ') + len('when ')
        cases = tlist.get_cases(skip_ws=True)
        # align the end as well
        end_token = tlist.token_next_by(m=(T.Keyword, 'END'))[1]
        cases.append((None, [end_token]))

        condition_width = [
            len(' '.join(map(text_type, cond))) if cond else 0
            for cond, _ in cases
        ]
        max_cond_width = max(condition_width)

        for i, (cond, value) in enumerate(cases):
            # cond is None when 'else or end'
            stmt = cond[0] if cond else value[0]

            if i > 0:
                tlist.insert_before(stmt,
                                    self.nl(offset_ - len(text_type(stmt))))
            if cond:
                ws = sql.Token(
                    T.Whitespace,
                    self.char * (max_cond_width - condition_width[i]))
                tlist.insert_after(cond[-1], ws)
Exemplo n.º 22
0
 def _process_identifierlist(self, tlist):
     identifiers = list(tlist.get_identifiers())
     first = next(identifiers.pop(0).flatten())
     num_offset = 1 if self.char == '\t' else self._get_offset(first)
     if not tlist.within(sql.Function):
         with offset(self, num_offset):
             position = 0
             for token in identifiers:
                 # Add 1 for the "," separator
                 position += len(token.value) + 1
                 if position > (self.wrap_after - self.offset):
                     adjust = 0
                     if self.comma_first:
                         adjust = -2
                         _, comma = tlist.token_prev(
                             tlist.token_index(token))
                         if comma is None:
                             continue
                         token = comma
                     tlist.insert_before(token, self.nl(offset=adjust))
                     if self.comma_first:
                         _, ws = tlist.token_next(tlist.token_index(token),
                                                  skip_ws=False)
                         if (ws is not None
                                 and ws.ttype is not T.Text.Whitespace):
                             tlist.insert_after(
                                 token, sql.Token(T.Whitespace, ' '))
                     position = 0
     self._process_default(tlist)
Exemplo n.º 23
0
def _function(node, token):
    """Resolve a SQL function.

    :node: The current tree node on which the function operates.
    :token: The token representing the function.
    """
    funs = ["DP", "K_ANONIMITY", "L_DIVERSITY", "T_CLOSENESS", "TOKENIZE"]

    def _resolve_function_parameters(node, token):
        subtokens = token.tokens[1:-1]
        first = _token_first(subtokens)
        # extract columns from function parameters when present
        if first:
            if first.match(T.Keyword.DML, "SELECT"):
                child = select(subtokens)
                token.tokens = [token.tokens[0], child.state, token.tokens[-1]]
                node.childs.append(child)
            else:
                _comma_separated_list(node,
                                      subtokens,
                                      skip=_skip_modifier,
                                      item_resolver=_parameter_resolver)

    identifier = token.tokens[0]
    parenthesis = token.tokens[1]

    # sqlparse mistakes operators for functions when a space doesn't separates
    # the operator and the paretheses
    if identifier.normalized.upper() in OPERATORS:
        # convert identifier into keyword
        identifier = S.Token(T.Keyword, identifier.normalized)
        _expression(node, [identifier, parenthesis])
    # anonimization function
    elif identifier.normalized in funs:
        # keep anonimization function as we consider it part of the schema
        prev_count = len(node.state.cols)
        _resolve_function_parameters(node, parenthesis)
        next_count = len(node.state.cols)
        cols = node.state.cols[prev_count:next_count]
        node.state.cols = node.state.cols[:prev_count]  # restore state
        for col in cols:
            node.state.cols.append(
                S.Token(T.Literal,
                        col.normalized + '/' + identifier.normalized.lower()))
    # other function
    else:
        _resolve_function_parameters(node, parenthesis)
Exemplo n.º 24
0
    def nl(self, offset=1):
        # offset = 1 represent a single space after SELECT
        offset = -len(offset) if not isinstance(offset, int) else offset
        # add two for the space and parens
        indent = self.indent * (2 + self._max_kwd_len)

        return sql.Token(T.Whitespace, self.n + self.char * (
            self._max_kwd_len + offset + indent + self.offset))
Exemplo n.º 25
0
 def handle_insert_table(table_name):
     if insert_table and insert_table in self._translate_tables:
         if not field_lists or not field_lists[-1]:
             raise Exception("Invalid SELECT field list")
         last_token = list(field_lists[-1][-1].flatten())[-1]
         for keyword in ["'", self._product_prefix, "'", ' ', ',']:
             self._token_insert_after(last_token.parent, last_token, Types.Token(Tokens.Keyword, keyword))
     return
Exemplo n.º 26
0
    def process_functions(self, group=None):
        """Recursively parse and process FQL functions in the given group token.

    TODO: switch to sqlite3.Connection.create_function().

    Currently handles: me(), now()
    """
        if group is None:
            group = self.statement

        for tok in group.tokens:
            if isinstance(tok, sql.Function):
                assert isinstance(tok.tokens[0], sql.Identifier)
                name = tok.tokens[0].tokens[0]
                if name.value not in Fql.FUNCTIONS:
                    raise InvalidFunctionError(name.value)

                # check number of params
                #
                # i wish i could use tok.get_parameters() here, but it doesn't work
                # with string parameters for some reason. :/
                assert isinstance(tok.tokens[1], sql.Parenthesis)
                params = [
                    t for t in tok.tokens[1].flatten()
                    if t.ttype not in (tokens.Punctuation, tokens.Whitespace)
                ]
                actual_num = len(params)
                expected_num = Fql.FUNCTIONS[name.value]
                if actual_num != expected_num:
                    raise ParamMismatchError(name.value, expected_num,
                                             actual_num)

                # handle each function
                replacement = None
                if name.value == 'me':
                    replacement = str(self.me)
                elif name.value == 'now':
                    replacement = str(int(time.time()))
                elif name.value == 'strlen':
                    # pass through to sqlite's length() function
                    name.value = 'length'
                elif name.value == 'substr':
                    # the index param is 0-based in FQL but 1-based in sqlite
                    params[1].value = str(int(params[1].value) + 1)
                elif name.value == 'strpos':
                    # strip quote chars
                    string = params[0].value[1:-1]
                    sub = params[1].value[1:-1]
                    replacement = str(string.find(sub))
                else:
                    # shouldn't happen
                    assert False, 'unknown function: %s' % name.value

                if replacement is not None:
                    tok.tokens = [sql.Token(tokens.Number, replacement)]

            elif tok.is_group():
                self.process_functions(tok)
Exemplo n.º 27
0
def bare_identifier(identifier):
    '''
    Copy an identifier without any aliases.
    :param identifier: An identifier to copy.
    '''
    itokens = []

    if isinstance(identifier, str):
        for part in identifier.split(","):
            itokens.append(sql.Token(tokens.Literal, part))
    else:
        for token in identifier.tokens:
            if token.ttype in (tokens.Name, tokens.Punctuation):
                itokens.append(sql.Token(token.ttype, token.value))
            else:
                break

    return sql.Identifier(itokens)
Exemplo n.º 28
0
    def process(self, stmt):
        self._curr_stmt = stmt
        self._process(stmt)

        if self._last_stmt is not None:
            nl = "\n" if str(self._last_stmt).endswith("\n") else "\n\n"
            stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))

        self._last_stmt = stmt
        return stmt
Exemplo n.º 29
0
 def nl(self):
     # TODO: newline character should be configurable
     space = (self.char * ((self.indent * self.width) + self.offset))
     # Detect runaway indenting due to parsing errors
     if len(space) > 200:
         # something seems to be wrong, flip back
         self.indent = self.offset = 0
         space = (self.char * ((self.indent * self.width) + self.offset))
     ws = '\n' + space
     return sql.Token(T.Whitespace, ws)
Exemplo n.º 30
0
    def process(self, stmt):
        self._curr_stmt = stmt
        self._process(stmt)

        if self._last_stmt is not None:
            nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n'
            stmt.tokens.insert(0, sql.Token(T.Whitespace, nl))

        self._last_stmt = stmt
        return stmt