Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions sqlparse/engine/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlparse import sql
from sqlparse import tokens as T
from sqlparse.utils import recurse, imt
from sqlparse.utils import recurse

T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
T_STRING = (T.String, T.String.Single, T.String.Symbol)
Expand All @@ -17,36 +17,28 @@
def _group_matching(tlist, cls):
"""Groups Tokens that have beginning and end."""
opens = []
tidx_offset = 0
for idx, token in enumerate(list(tlist)):
tidx = idx - tidx_offset
n = len(tlist.tokens)
for idx in range(n):
token = tlist.tokens[idx]
if token is None:
continue

if token.is_whitespace:
# ~50% of tokens will be whitespace. Will checking early
# for them avoid 3 comparisons, but then add 1 more comparison
# for the other ~50% of tokens...
continue

if token.is_group and not isinstance(token, cls):
# Check inside previously grouped (i.e. parenthesis) if group
# of different type is inside (i.e., case). though ideally should
# should check for all open/close tokens at once to avoid recursion
_group_matching(token, cls)
continue

if token.match(*cls.M_OPEN):
opens.append(tidx)
opens.append(idx)

elif token.match(*cls.M_CLOSE):
try:
open_idx = opens.pop()
except IndexError:
# this indicates invalid sql and unbalanced tokens.
# instead of break, continue in case other "valid" groups exist
continue
close_idx = tidx
tlist.group_tokens(cls, open_idx, close_idx)
tidx_offset += close_idx - open_idx
tlist.group_tokens(cls, open_idx, idx)


def group_brackets(tlist):
Expand Down Expand Up @@ -114,7 +106,7 @@ def group_typed_literal(tlist):
# https://www.postgresql.org/docs/9.1/datatype-datetime.html
# https://www.postgresql.org/docs/9.1/functions-datetime.html
def match(token):
return imt(token, m=sql.TypedLiteral.M_OPEN)
return any(token.match(*pattern) for pattern in sql.TypedLiteral.M_OPEN)

def match_to_extend(token):
return isinstance(token, sql.TypedLiteral)
Expand Down Expand Up @@ -147,20 +139,20 @@ def match(token):
return False

def valid_prev(token):
sqlcls = sql.SquareBrackets, sql.Identifier
ttypes = T.Name, T.String.Symbol
return imt(token, i=sqlcls, t=ttypes)
return (isinstance(token, (sql.SquareBrackets, sql.Identifier))
or token.ttype in (T.Name, T.String.Symbol))

def valid_next(token):
# issue261, allow invalid next token
return True

def post(tlist, pidx, tidx, nidx):
# next_ validation is being performed here. issue261
sqlcls = sql.SquareBrackets, sql.Function
ttypes = T.Name, T.String.Symbol, T.Wildcard, T.String.Single
next_ = tlist[nidx] if nidx is not None else None
valid_next = imt(next_, i=sqlcls, t=ttypes)
valid_next = (next_ is not None
and (isinstance(next_, (sql.SquareBrackets, sql.Function))
or next_.ttype in (T.Name, T.String.Symbol,
T.Wildcard, T.String.Single)))

return (pidx, nidx) if valid_next else (pidx, tidx)

Expand All @@ -175,8 +167,7 @@ def valid_prev(token):
return token.normalized == 'NULL' or not token.is_keyword

def valid_next(token):
ttypes = T.DML, T.DDL, T.CTE
return not imt(token, t=ttypes) and token is not None
return token is not None and token.ttype not in (T.DML, T.DDL, T.CTE)

def post(tlist, pidx, tidx, nidx):
return pidx, nidx
Expand Down Expand Up @@ -210,12 +201,13 @@ def match(token):
return token.ttype == T.Operator.Comparison

def valid(token):
if imt(token, t=ttypes, i=sqlcls):
if token is None:
return False
if isinstance(token, sqlcls) or token.ttype in ttypes:
return True
elif token and token.is_keyword and token.normalized == 'NULL':
if token.is_keyword and token.normalized == 'NULL':
return True
else:
return False
return False

def post(tlist, pidx, tidx, nidx):
return pidx, nidx
Expand All @@ -240,7 +232,9 @@ def group_over(tlist):
tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN)
while token:
nidx, next_ = tlist.token_next(tidx)
if imt(next_, i=sql.Parenthesis, t=T.Name):
if (next_ is not None
and (isinstance(next_, sql.Parenthesis)
or next_.ttype in T.Name)):
tlist.group_tokens(sql.Over, tidx, nidx)
tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN, idx=tidx)

Expand All @@ -253,7 +247,7 @@ def match(token):
return isinstance(token, sql.SquareBrackets)

def valid_prev(token):
return imt(token, i=sqlcls, t=ttypes)
return isinstance(token, sqlcls) or token.ttype in ttypes

def valid_next(token):
return True
Expand All @@ -271,13 +265,16 @@ def group_operator(tlist):
sql.Identifier, sql.Operation, sql.TypedLiteral)

def match(token):
return imt(token, t=(T.Operator, T.Wildcard))
return token.ttype in (T.Operator, T.Wildcard)

def valid(token):
return imt(token, i=sqlcls, t=ttypes) \
or (token and token.match(
T.Keyword,
('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
if token is None:
return False
return (isinstance(token, sqlcls)
or token.ttype in ttypes
or token.match(
T.Keyword,
('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))

def post(tlist, pidx, tidx, nidx):
tlist[tidx].ttype = T.Operator
Expand All @@ -299,7 +296,11 @@ def match(token):
return token.match(T.Punctuation, ',')

def valid(token):
return imt(token, i=sqlcls, m=m_role, t=ttypes)
if token is None:
return False
return (isinstance(token, sqlcls)
or token.match(*m_role)
or token.ttype in ttypes)

def post(tlist, pidx, tidx, nidx):
return pidx, nidx
Expand All @@ -314,7 +315,7 @@ def group_comments(tlist):
tidx, token = tlist.token_next_by(t=T.Comment)
while token:
eidx, end = tlist.token_not_matching(
lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx)
lambda tk: tk.ttype in T.Comment or tk.is_newline, idx=tidx)
if end is not None:
eidx, end = tlist.token_prev(eidx, skip_ws=False)
tlist.group_tokens(sql.Comment, tidx, eidx)
Expand Down Expand Up @@ -358,11 +359,13 @@ def group_functions(tlist):
has_table = False
has_as = False
for tmp_token in tlist.tokens:
if tmp_token.value.upper() == 'CREATE':
val = str(tmp_token)
val_upper = val.upper()
if val_upper == 'CREATE':
has_create = True
if tmp_token.value.upper() == 'TABLE':
if val_upper == 'TABLE':
has_table = True
if tmp_token.value == 'AS':
if val == 'AS':
has_as = True
if has_create and has_table and not has_as:
return
Expand All @@ -386,7 +389,9 @@ def group_order(tlist):
tidx, token = tlist.token_next_by(t=T.Keyword.Order)
while token:
pidx, prev_ = tlist.token_prev(tidx)
if imt(prev_, i=sql.Identifier, t=T.Number):
if (prev_ is not None
and (isinstance(prev_, sql.Identifier)
or prev_.ttype in T.Number)):
tlist.group_tokens(sql.Identifier, pidx, tidx)
tidx = pidx
tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
Expand Down Expand Up @@ -415,6 +420,14 @@ def group_values(tlist):
tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)


def _compact_all(tlist):
"""Recursively compact None sentinels from all token lists."""
tlist._compact_tokens()
for token in tlist.tokens:
if token.is_group:
_compact_all(token)


def group(stmt):
for func in [
group_comments,
Expand Down Expand Up @@ -448,6 +461,7 @@ def group(stmt):
group_values,
]:
func(stmt)
_compact_all(stmt)
return stmt


Expand All @@ -460,11 +474,9 @@ def _group(tlist, cls, match,
):
"""Groups together tokens that are joined by a middle token. i.e. x < y"""

tidx_offset = 0
pidx, prev_ = None, None
for idx, token in enumerate(list(tlist)):
tidx = idx - tidx_offset
if tidx < 0: # tidx shouldn't get negative
if token is None:
continue

if token.is_whitespace:
Expand All @@ -473,14 +485,17 @@ def _group(tlist, cls, match,
if recurse and token.is_group and not isinstance(token, cls):
_group(token, cls, match, valid_prev, valid_next, post, extend)

# Token may have been consumed by a prior group_tokens in this pass
if tlist.tokens[idx] is None:
continue

if match(token):
nidx, next_ = tlist.token_next(tidx)
nidx, next_ = tlist.token_next(idx)
if prev_ and valid_prev(prev_) and valid_next(next_):
from_idx, to_idx = post(tlist, pidx, tidx, nidx)
from_idx, to_idx = post(tlist, pidx, idx, nidx)
grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)

tidx_offset += to_idx - from_idx
pidx, prev_ = from_idx, grp
continue

pidx, prev_ = tidx, token
pidx, prev_ = idx, token
3 changes: 2 additions & 1 deletion sqlparse/filters/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def _get_insert_token(token):
tidx, token = get_next_comment(idx=tidx)

def process(self, stmt):
[self.process(sgroup) for sgroup in stmt.get_sublists()]
[self.process(sgroup) for sgroup in stmt.get_sublists()
if not isinstance(sgroup, sql.Comment)]
StripCommentsFilter._process(stmt)
return stmt

Expand Down
52 changes: 39 additions & 13 deletions sqlparse/filters/reindent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,35 +94,61 @@ def _next_token(self, tlist, idx=-1):
return tidx, token

def _split_kwds(self, tlist):
# Pass 1: scan unmodified list for all keyword positions
inserts = {} # idx -> token to insert before
deletes = set()

tidx, token = self._next_token(tlist)
while token:
pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
uprev = str(prev_)

if prev_ and prev_.is_whitespace:
del tlist.tokens[pidx]
tidx -= 1
deletes.add(pidx)

if not (uprev.endswith('\n') or uprev.endswith('\r')):
tlist.insert_before(tidx, self.nl())
tidx += 1
inserts[tidx] = self.nl()

tidx, token = self._next_token(tlist, tidx)

# Pass 2: rebuild token list in O(n)
if inserts or deletes:
new_tokens = []
for i, tok in enumerate(tlist.tokens):
if i in inserts:
nl_tok = inserts[i]
nl_tok.parent = tlist
new_tokens.append(nl_tok)
if i not in deletes:
new_tokens.append(tok)
tlist.tokens = new_tokens

def _split_statements(self, tlist):
ttypes = T.Keyword.DML, T.Keyword.DDL
inserts = {}
deletes = set()

tidx, token = tlist.token_next_by(t=ttypes)
while token:
pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
if prev_ and prev_.is_whitespace:
del tlist.tokens[pidx]
tidx -= 1
deletes.add(pidx)
# only break if it's not the first token
if prev_:
tlist.insert_before(tidx, self.nl())
tidx += 1
if prev_ is not None:
inserts[tidx] = self.nl()
tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)

if inserts or deletes:
new_tokens = []
for i, tok in enumerate(tlist.tokens):
if i in inserts:
nl_tok = inserts[i]
nl_tok.parent = tlist
new_tokens.append(nl_tok)
if i not in deletes:
new_tokens.append(tok)
tlist.tokens = new_tokens

def _process(self, tlist):
func_name = f'_process_{type(tlist).__name__}'
func = getattr(self, func_name.lower(), self._process_default)
Expand Down Expand Up @@ -172,7 +198,7 @@ def _process_identifierlist(self, tlist):
shift = 0
for token in identifiers:
# Add 1 for the "," separator
position += len(token.value) + 1
position += len(str(token)) + 1
if position > (self.wrap_after - self.offset):
adjust = 0
tidx = token_to_idx[id(token)] + shift
Expand Down Expand Up @@ -215,12 +241,12 @@ def _process_identifierlist(self, tlist):
adj_i, sql.Token(T.Whitespace, ' '))
ws_shift += 1

end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
end_at = self.offset + sum(len(str(i)) + 1 for i in identifiers)
adjusted_offset = 0
if (self.wrap_after > 0
and end_at > (self.wrap_after - self.offset)
and self._last_func):
adjusted_offset = -len(self._last_func.value) - 1
adjusted_offset = -len(str(self._last_func)) - 1

# Rebuild index mapping after whitespace insertions
token_to_idx = {id(t): i
Expand All @@ -234,7 +260,7 @@ def _process_identifierlist(self, tlist):
position = 0
for token in identifiers:
# Add 1 for the "," separator
position += len(token.value) + 1
position += len(str(token)) + 1
if (self.wrap_after > 0
and position > (self.wrap_after - self.offset)):
tidx = token_to_idx[id(token)] + shift
Expand Down
Loading
Loading