def test_repr(self): token = sql.Token(T.Keyword, 'foo') tst = "<Keyword 'foo' at 0x" self.assertEqual(repr(token)[:len(tst)], tst) token = sql.Token(T.Keyword, '1234567890') tst = "<Keyword '123456...' at 0x" self.assertEqual(repr(token)[:len(tst)], tst)
def test_token_repr(): token = sql.Token(T.Keyword, 'foo') tst = "<Keyword 'foo' at 0x" assert repr(token)[:len(tst)] == tst token = sql.Token(T.Keyword, '1234567890') tst = "<Keyword '123456...' at 0x" assert repr(token)[:len(tst)] == tst
def _replace_table_entity_name(self, parent, token, table_name, entity_name=None): if not entity_name: entity_name = table_name next_token = self._token_next(parent, token) if not table_name in self._skip_tables + self._translate_tables: token_to_replace = parent.tokens[self._token_idx(parent, token)] if isinstance(token_to_replace, Types.Function): t = self._token_first(token_to_replace) if isinstance(t, Types.Identifier): token_to_replace.tokens[self._token_idx( token_to_replace, t)] = Types.Token( Tokens.Keyword, self._prefixed_table_entity_name(entity_name)) elif isinstance(token_to_replace, Types.Identifier) or isinstance( token_to_replace, Types.Token): parent.tokens[self._token_idx( parent, token_to_replace)] = Types.Token( Tokens.Keyword, self._prefixed_table_entity_name(entity_name)) else: raise Exception( "Internal error, invalid table entity token type") return next_token
def _process_identifierlist(self, tlist): identifiers = list(tlist.get_identifiers()) if self.indent_columns: first = next(identifiers[0].flatten()) num_offset = 1 if self.char == '\t' else self.width else: first = next(identifiers.pop(0).flatten()) num_offset = 1 if self.char == '\t' else self._get_offset(first) if not tlist.within(sql.Function): with offset(self, num_offset): position = 0 for token in identifiers: # Add 1 for the "," separator position += len(token.value) + 1 if position > (self.wrap_after - self.offset): adjust = 0 if self.comma_first: adjust = -2 _, comma = tlist.token_prev( tlist.token_index(token)) if comma is None: continue token = comma tlist.insert_before(token, self.nl(offset=adjust)) if self.comma_first: _, ws = tlist.token_next(tlist.token_index(token), skip_ws=False) if (ws is not None and ws.ttype is not T.Text.Whitespace): tlist.insert_after( token, sql.Token(T.Whitespace, ' ')) position = 0 else: # ensure whitespace for token in tlist: _, next_ws = tlist.token_next(tlist.token_index(token), skip_ws=False) if token.value == ',' and not next_ws.is_whitespace: tlist.insert_after(token, sql.Token(T.Whitespace, ' ')) end_at = self.offset + sum(len(i.value) + 1 for i in identifiers) adjusted_offset = 0 if (self.wrap_after > 0 and end_at > (self.wrap_after - self.offset) and self._last_func): adjusted_offset = -len(self._last_func.value) - 1 with offset(self, adjusted_offset), indent(self): if adjusted_offset < 0: tlist.insert_before(identifiers[0], self.nl()) position = 0 for token in identifiers: # Add 1 for the "," separator position += len(token.value) + 1 if (self.wrap_after > 0 and position > (self.wrap_after - self.offset)): adjust = 0 tlist.insert_before(token, self.nl(offset=adjust)) position = 0 self._process_default(tlist)
def _update_delete_where_limit(self, table_name, parent, start_token): if not start_token: return where_token = start_token if isinstance(start_token, Types.Where) \ else self._token_next_by_instance(parent, start_token, Types.Where) if where_token: self._where(parent, where_token) if not table_name in self._translate_tables: return if where_token: keywords = [self._product_column, '=', "'", self._product_prefix, "'", ' ', 'AND', ' '] keywords.reverse() token = self._token_first(where_token) if not token.match(Tokens.Keyword, 'WHERE'): token = self._token_next_match(where_token, token, Tokens.Keyword, 'WHERE') if not token: raise Exception("Invalid UPDATE statement, failed to parse WHERE") for keyword in keywords: self._token_insert_after(where_token, token, Types.Token(Tokens.Keyword, keyword)) else: keywords = ['WHERE', ' ', self._product_column, '=', "'", self._product_prefix, "'"] limit_token = self._token_next_match(parent, start_token, Tokens.Keyword, 'LIMIT') if limit_token: for keyword in keywords: self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, keyword)) self._token_insert_before(parent, limit_token, Types.Token(Tokens.Keyword, ' ')) else: last_token = token = start_token while token: last_token = token token = self._token_next(parent, token) keywords.reverse() for keyword in keywords: self._token_insert_after(parent, last_token, Types.Token(Tokens.Keyword, keyword)) return
def test_tokenlist_token_matching(): t1 = sql.Token(T.Keyword, 'foo') t2 = sql.Token(T.Punctuation, ',') x = sql.TokenList([t1, t2]) assert x.token_matching([lambda t: t.ttype is T.Keyword], 0) == t1 assert x.token_matching([lambda t: t.ttype is T.Punctuation], 0) == t2 assert x.token_matching([lambda t: t.ttype is T.Keyword], 1) is None
def _process(tlist): def get_next_comment(): # TODO(andi) Comment types should be unified, see related issue38 return tlist.token_next_by(i=sql.Comment, t=T.Comment) tidx, token = get_next_comment() while token: pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) nidx, next_ = tlist.token_next(tidx, skip_ws=False) # Replace by whitespace if prev and next exist and if they're not # whitespaces. This doesn't apply if prev or next is a parenthesis. if (prev_ is None or next_ is None or prev_.is_whitespace or prev_.match(T.Punctuation, '(') or next_.is_whitespace or next_.match(T.Punctuation, ')')): # Insert a whitespace to ensure the following SQL produces # a valid SQL (see #425). For example: # # Before: select a--comment\nfrom foo # After: select a from foo if prev_ is not None and next_ is None: tlist.tokens.insert(tidx, sql.Token(T.Whitespace, ' ')) tlist.tokens.remove(token) else: tlist.tokens[tidx] = sql.Token(T.Whitespace, ' ') tidx, token = get_next_comment()
def _get_insert_token(token): """Returns either a whitespace or the line breaks from token.""" # See issue484 why line breaks should be preserved. m = re.search(r'((\r\n|\r|\n)+) *$', token.value) if m is not None: return sql.Token(T.Whitespace.Newline, m.groups()[0]) else: return sql.Token(T.Whitespace, ' ')
def test_issue212_py2unicode(): if sys.version_info < (3, ): t1 = sql.Token(T.String, u"schöner ") else: t1 = sql.Token(T.String, "schöner ") t2 = sql.Token(T.String, u"bug") l = sql.TokenList([t1, t2]) assert str(l) == 'schöner bug'
def test_token_matching(self): t1 = sql.Token(Keyword, 'foo') t2 = sql.Token(Punctuation, ',') x = sql.TokenList([t1, t2]) self.assertEqual(x.token_matching(0, [lambda t: t.ttype is Keyword]), t1) self.assertEqual( x.token_matching(0, [lambda t: t.ttype is Punctuation]), t2) self.assertEqual(x.token_matching(1, [lambda t: t.ttype is Keyword]), None)
def _generate_attribute_token(self, attr_name_token, attr_value_token=None): attribute_token_list = [ sql.Token(value=attr_name_token.value, ttype=T.Keyword) ] if attr_value_token is not None: attribute_token_list.append( sql.Token(value=attr_value_token.value.strip('`"\''), ttype=attr_value_token.ttype)) return sql.Attribute(tokens=attribute_token_list)
def _get_insert_token(token): """Returns either a whitespace or the line breaks from token.""" # See issue484 why line breaks should be preserved. # Note: The actual value for a line break is replaced by \n # in SerializerUnicode which will be executed in the # postprocessing state. m = re.search(r'((\r|\n)+) *$', token.value) if m is not None: return sql.Token(T.Whitespace.Newline, m.groups()[0]) else: return sql.Token(T.Whitespace, ' ')
def _process(tlist): def next_token(idx=0): return tlist.token_next_by(t=(T.Operator, T.Comparison), idx=idx) token = next_token() while token: prev_ = tlist.token_prev(token, skip_ws=False) if prev_ and prev_.ttype != T.Whitespace: tlist.insert_before(token, sql.Token(T.Whitespace, ' ')) next_ = tlist.token_next(token, skip_ws=False) if next_ and next_.ttype != T.Whitespace: tlist.insert_after(token, sql.Token(T.Whitespace, ' ')) token = next_token(idx=token)
def build_comparison(identifier, operator): ''' Build an SQL token representing a comparison ``identifier operator %s`` :param identifier: An identifier previously found by calling ``find_identifier`` :param operator: An operator like =, <, >=, ... ''' return sql.Comparison([ bare_identifier(identifier), sql.Token(tokens.Whitespace, " "), sql.Token(tokens.Comparison, operator), sql.Token(tokens.Whitespace, " "), sql.Token(tokens.Wildcard, "%s") ])
def nl(self, offset=1): # offset = 1 represent a single space after SELECT offset = -len(offset) if not isinstance(offset, int) else offset # add two for the space and parens return sql.Token(T.Whitespace, self.n + self.char * (self.leading_ws + offset))
def _process(tlist): ttypes = (T.Operator, T.Comparison) tidx, token = tlist.token_next_by(t=ttypes) while token: nidx, next_ = tlist.token_next(tidx, skip_ws=False) if next_ and next_.ttype != T.Whitespace: tlist.insert_after(tidx, sql.Token(T.Whitespace, ' ')) pidx, prev_ = tlist.token_prev(tidx, skip_ws=False) if prev_ and prev_.ttype != T.Whitespace: tlist.insert_before(tidx, sql.Token(T.Whitespace, ' ')) tidx += 1 # has to shift since token inserted before it # assert tlist.token_index(token) == tidx tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
def insert_extra_column_value(tablename, ptoken, before_token): if tablename in self._translate_tables: for keyword in [',', "'", self._product_prefix, "'"]: self._token_insert_before( ptoken, before_token, Types.Token(Tokens.Keyword, keyword)) return
def process(self, stream): """Process the stream""" EOS_TTYPE = T.Whitespace, T.Comment.Single # Run over all stream tokens for ttype, value in stream: # Yield token if we finished a statement and there's no whitespaces # It will count newline token as a non whitespace. In this context # whitespace ignores newlines. # why don't multi line comments also count? if self.consume_ws and ttype not in EOS_TTYPE: yield sql.Statement(self.tokens) # Reset filter and prepare to process next statement self._reset() # Change current split level (increase, decrease or remain equal) self.level += self._change_splitlevel(ttype, value) # Append the token to the current statement self.tokens.append(sql.Token(ttype, value)) # Check if we get the end of a statement if self.level <= 0 and ttype is T.Punctuation and value == ';': self.consume_ws = True # Yield pending statement (if any) if self.tokens and not all(t.is_whitespace for t in self.tokens): yield sql.Statement(self.tokens)
def test_extract_value(_label, token_value, expected): token = token_groups.Token(token_types.Literal, token_value) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) value = common.extract_value(token) assert value == expected
def insert_extra_column(tablename, columns_token): if tablename in self._translate_tables and \ isinstance(columns_token, Types.Parenthesis): ptoken = self._token_first(columns_token) if not ptoken.match(Tokens.Punctuation, '('): raise Exception( "Invalid INSERT statement, expected parenthesis around columns" ) ptoken = self._token_next(columns_token, ptoken) last_token = ptoken while ptoken: if isinstance(ptoken, Types.IdentifierList): if any(i.get_name() == 'product' for i in ptoken.get_identifiers() if isinstance(i, Types.Identifier)): return True last_token = ptoken ptoken = self._token_next(columns_token, ptoken) if not last_token or \ not last_token.match(Tokens.Punctuation, ')'): raise Exception( "Invalid INSERT statement, unable to find column parenthesis end" ) for keyword in [',', ' ', self._product_column]: self._token_insert_before( columns_token, last_token, Types.Token(Tokens.Keyword, keyword)) return False
def _process_case(self, tlist): offset_ = len('case ') + len('when ') cases = tlist.get_cases(skip_ws=True) # align the end as well end_token = tlist.token_next_by(m=(T.Keyword, 'END'))[1] cases.append((None, [end_token])) condition_width = [ len(' '.join(map(text_type, cond))) if cond else 0 for cond, _ in cases ] max_cond_width = max(condition_width) for i, (cond, value) in enumerate(cases): # cond is None when 'else or end' stmt = cond[0] if cond else value[0] if i > 0: tlist.insert_before(stmt, self.nl(offset_ - len(text_type(stmt)))) if cond: ws = sql.Token( T.Whitespace, self.char * (max_cond_width - condition_width[i])) tlist.insert_after(cond[-1], ws)
def _process_identifierlist(self, tlist): identifiers = list(tlist.get_identifiers()) first = next(identifiers.pop(0).flatten()) num_offset = 1 if self.char == '\t' else self._get_offset(first) if not tlist.within(sql.Function): with offset(self, num_offset): position = 0 for token in identifiers: # Add 1 for the "," separator position += len(token.value) + 1 if position > (self.wrap_after - self.offset): adjust = 0 if self.comma_first: adjust = -2 _, comma = tlist.token_prev( tlist.token_index(token)) if comma is None: continue token = comma tlist.insert_before(token, self.nl(offset=adjust)) if self.comma_first: _, ws = tlist.token_next(tlist.token_index(token), skip_ws=False) if (ws is not None and ws.ttype is not T.Text.Whitespace): tlist.insert_after( token, sql.Token(T.Whitespace, ' ')) position = 0 self._process_default(tlist)
def _function(node, token): """Resolve a SQL function. :node: The current tree node on which the function operates. :token: The token representing the function. """ funs = ["DP", "K_ANONIMITY", "L_DIVERSITY", "T_CLOSENESS", "TOKENIZE"] def _resolve_function_parameters(node, token): subtokens = token.tokens[1:-1] first = _token_first(subtokens) # extract columns from function parameters when present if first: if first.match(T.Keyword.DML, "SELECT"): child = select(subtokens) token.tokens = [token.tokens[0], child.state, token.tokens[-1]] node.childs.append(child) else: _comma_separated_list(node, subtokens, skip=_skip_modifier, item_resolver=_parameter_resolver) identifier = token.tokens[0] parenthesis = token.tokens[1] # sqlparse mistakes operators for functions when a space doesn't separates # the operator and the paretheses if identifier.normalized.upper() in OPERATORS: # convert identifier into keyword identifier = S.Token(T.Keyword, identifier.normalized) _expression(node, [identifier, parenthesis]) # anonimization function elif identifier.normalized in funs: # keep anonimization function as we consider it part of the schema prev_count = len(node.state.cols) _resolve_function_parameters(node, parenthesis) next_count = len(node.state.cols) cols = node.state.cols[prev_count:next_count] node.state.cols = node.state.cols[:prev_count] # restore state for col in cols: node.state.cols.append( S.Token(T.Literal, col.normalized + '/' + identifier.normalized.lower())) # other function else: _resolve_function_parameters(node, parenthesis)
def nl(self, offset=1): # offset = 1 represent a single space after SELECT offset = -len(offset) if not isinstance(offset, int) else offset # add two for the space and parens indent = self.indent * (2 + self._max_kwd_len) return sql.Token(T.Whitespace, self.n + self.char * ( self._max_kwd_len + offset + indent + self.offset))
def handle_insert_table(table_name): if insert_table and insert_table in self._translate_tables: if not field_lists or not field_lists[-1]: raise Exception("Invalid SELECT field list") last_token = list(field_lists[-1][-1].flatten())[-1] for keyword in ["'", self._product_prefix, "'", ' ', ',']: self._token_insert_after(last_token.parent, last_token, Types.Token(Tokens.Keyword, keyword)) return
def process_functions(self, group=None): """Recursively parse and process FQL functions in the given group token. TODO: switch to sqlite3.Connection.create_function(). Currently handles: me(), now() """ if group is None: group = self.statement for tok in group.tokens: if isinstance(tok, sql.Function): assert isinstance(tok.tokens[0], sql.Identifier) name = tok.tokens[0].tokens[0] if name.value not in Fql.FUNCTIONS: raise InvalidFunctionError(name.value) # check number of params # # i wish i could use tok.get_parameters() here, but it doesn't work # with string parameters for some reason. :/ assert isinstance(tok.tokens[1], sql.Parenthesis) params = [ t for t in tok.tokens[1].flatten() if t.ttype not in (tokens.Punctuation, tokens.Whitespace) ] actual_num = len(params) expected_num = Fql.FUNCTIONS[name.value] if actual_num != expected_num: raise ParamMismatchError(name.value, expected_num, actual_num) # handle each function replacement = None if name.value == 'me': replacement = str(self.me) elif name.value == 'now': replacement = str(int(time.time())) elif name.value == 'strlen': # pass through to sqlite's length() function name.value = 'length' elif name.value == 'substr': # the index param is 0-based in FQL but 1-based in sqlite params[1].value = str(int(params[1].value) + 1) elif name.value == 'strpos': # strip quote chars string = params[0].value[1:-1] sub = params[1].value[1:-1] replacement = str(string.find(sub)) else: # shouldn't happen assert False, 'unknown function: %s' % name.value if replacement is not None: tok.tokens = [sql.Token(tokens.Number, replacement)] elif tok.is_group(): self.process_functions(tok)
def bare_identifier(identifier): ''' Copy an identifier without any aliases. :param identifier: An identifier to copy. ''' itokens = [] if isinstance(identifier, str): for part in identifier.split(","): itokens.append(sql.Token(tokens.Literal, part)) else: for token in identifier.tokens: if token.ttype in (tokens.Name, tokens.Punctuation): itokens.append(sql.Token(token.ttype, token.value)) else: break return sql.Identifier(itokens)
def process(self, stmt): self._curr_stmt = stmt self._process(stmt) if self._last_stmt is not None: nl = "\n" if str(self._last_stmt).endswith("\n") else "\n\n" stmt.tokens.insert(0, sql.Token(T.Whitespace, nl)) self._last_stmt = stmt return stmt
def nl(self): # TODO: newline character should be configurable space = (self.char * ((self.indent * self.width) + self.offset)) # Detect runaway indenting due to parsing errors if len(space) > 200: # something seems to be wrong, flip back self.indent = self.offset = 0 space = (self.char * ((self.indent * self.width) + self.offset)) ws = '\n' + space return sql.Token(T.Whitespace, ws)
def process(self, stmt): self._curr_stmt = stmt self._process(stmt) if self._last_stmt is not None: nl = '\n' if text_type(self._last_stmt).endswith('\n') else '\n\n' stmt.tokens.insert(0, sql.Token(T.Whitespace, nl)) self._last_stmt = stmt return stmt