Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 95 additions & 36 deletions sqlparse/filters/reindent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,53 @@ def __init__(self, width=2, char=' ', wrap_after=0, n='\n',
self._last_stmt = None
self._last_func = None

def _flatten_up_to_token(self, token):
"""Yields all tokens up to token but excluding current."""
if token.is_group:
token = next(token.flatten())

for t in self._curr_stmt.flatten():
if t == token:
break
yield t
def _reverse_leaves_before(self, target_leaf, known_parent_and_idx=None):
"""Yield leaf token values in reverse order before target_leaf."""
current = target_leaf
while current is not self._curr_stmt and current.parent is not None:
parent = current.parent
if known_parent_and_idx is not None \
and parent is known_parent_and_idx[0]:
idx = known_parent_and_idx[1]
else:
try:
idx = parent.tokens.index(current)
except ValueError:
break
for i in range(idx - 1, -1, -1):
sibling = parent.tokens[i]
if sibling.is_group:
yield from self._reverse_flatten(sibling)
else:
yield sibling.value
current = parent

def _reverse_flatten(self, token_list):
"""Yield all leaf token values in a TokenList in reverse order."""
for i in range(len(token_list.tokens) - 1, -1, -1):
child = token_list.tokens[i]
if child.is_group:
yield from self._reverse_flatten(child)
else:
yield child.value

@property
def leading_ws(self):
return self.offset + self.indent * self.width

def _get_offset(self, token):
raw = ''.join(map(str, self._flatten_up_to_token(token)))
line = (raw or '\n').splitlines()[-1]
# Now take current offset into account and return relative offset.
return len(line) - len(self.char * self.leading_ws)
def _get_offset(self, token, known_parent_and_idx=None):
if token.is_group:
token = next(token.flatten())

column = 0
for value in self._reverse_leaves_before(token, known_parent_and_idx):
newline_pos = value.rfind('\n')
if newline_pos != -1:
column += len(value) - newline_pos - 1
break
column += len(value)

return column - len(self.char * self.leading_ws)

def nl(self, offset=0):
return sql.Token(
Expand Down Expand Up @@ -136,37 +164,57 @@ def _process_identifierlist(self, tlist):
num_offset = 1 if self.char == '\t' else self._get_offset(first)

if not tlist.within(sql.Function) and not tlist.within(sql.Values):
# Build index mapping for O(1) lookups instead of O(n)
# token_index calls
token_to_idx = {id(t): i
for i, t in enumerate(tlist.tokens)}
with offset(self, num_offset):
position = 0
shift = 0
for token in identifiers:
# Add 1 for the "," separator
position += len(token.value) + 1
if position > (self.wrap_after - self.offset):
adjust = 0
tidx = token_to_idx[id(token)] + shift
if self.comma_first:
adjust = -2
_, comma = tlist.token_prev(
tlist.token_index(token))
pidx, comma = tlist.token_prev(tidx)
if comma is None:
continue
token = comma
tlist.insert_before(token, self.nl(offset=adjust))
if self.comma_first:
tlist.insert_before(
pidx, self.nl(offset=adjust))
shift += 1
# comma is now at pidx + 1
_, ws = tlist.token_next(
tlist.token_index(token), skip_ws=False)
pidx + 1, skip_ws=False)
if (ws is not None
and ws.ttype is not T.Text.Whitespace):
and ws.ttype is not
T.Text.Whitespace):
tlist.insert_after(
token, sql.Token(T.Whitespace, ' '))
pidx + 1,
sql.Token(T.Whitespace, ' '))
shift += 1
else:
tlist.insert_before(
tidx, self.nl(offset=adjust))
shift += 1
position = 0
else:
# ensure whitespace
for token in tlist:
_, next_ws = tlist.token_next(
tlist.token_index(token), skip_ws=False)
if token.value == ',' and not next_ws.is_whitespace:
tlist.insert_after(
token, sql.Token(T.Whitespace, ' '))
token_to_idx = {id(t): i
for i, t in enumerate(tlist.tokens)}
ws_shift = 0
for token in list(tlist.tokens):
if token.value == ',':
adj_i = token_to_idx[id(token)] + ws_shift
_, next_ws = tlist.token_next(
adj_i, skip_ws=False)
if (next_ws is not None
and not next_ws.is_whitespace):
tlist.insert_after(
adj_i, sql.Token(T.Whitespace, ' '))
ws_shift += 1

end_at = self.offset + sum(len(i.value) + 1 for i in identifiers)
adjusted_offset = 0
Expand All @@ -175,17 +223,25 @@ def _process_identifierlist(self, tlist):
and self._last_func):
adjusted_offset = -len(self._last_func.value) - 1

# Rebuild index mapping after whitespace insertions
token_to_idx = {id(t): i
for i, t in enumerate(tlist.tokens)}
with offset(self, adjusted_offset), indent(self):
shift = 0
if adjusted_offset < 0:
tlist.insert_before(identifiers[0], self.nl())
idx0 = token_to_idx[id(identifiers[0])] + shift
tlist.insert_before(idx0, self.nl())
shift += 1
position = 0
for token in identifiers:
# Add 1 for the "," separator
position += len(token.value) + 1
if (self.wrap_after > 0
and position > (self.wrap_after - self.offset)):
adjust = 0
tlist.insert_before(token, self.nl(offset=adjust))
tidx = token_to_idx[id(token)] + shift
tlist.insert_before(
tidx, self.nl(offset=0))
shift += 1
position = 0
self._process_default(tlist)

Expand Down Expand Up @@ -216,17 +272,20 @@ def _process_values(self, tlist):
tlist.insert_before(0, self.nl())
tidx, token = tlist.token_next_by(i=sql.Parenthesis)
first_token = token

if self.comma_first and first_token:
cf_offset = self._get_offset(first_token) - 2

while token:
ptidx, ptoken = tlist.token_next_by(m=(T.Punctuation, ','),
idx=tidx)
if ptoken:
if self.comma_first:
adjust = -2
offset = self._get_offset(first_token) + adjust
tlist.insert_before(ptoken, self.nl(offset))
tlist.insert_before(ptidx, self.nl(cf_offset))
else:
tlist.insert_after(ptoken,
self.nl(self._get_offset(token)))
nl_offset = self._get_offset(
token, known_parent_and_idx=(tlist, tidx))
tlist.insert_after(ptidx, self.nl(nl_offset))
tidx, token = tlist.token_next_by(i=sql.Parenthesis, idx=tidx)

def _process_default(self, tlist, stmts=True):
Expand Down
145 changes: 145 additions & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Benchmark tests for sqlparse.format(reindent=True) on large SQL queries.

Run with:
uv run --with pytest --with pytest-benchmark pytest tests/test_benchmarks.py -v
"""

import pytest
import sqlparse

pytest.importorskip("pytest_benchmark")


# ---------------------------------------------------------------------------
# Query generators — deterministic, no randomness
# ---------------------------------------------------------------------------

def _wide_select_sql(n_cols=5000):
cols = ', '.join(f'col_{i}' for i in range(n_cols))
return f'SELECT {cols} FROM t'


def _large_in_list_sql(n=100000):
values = ', '.join(str(i) for i in range(n))
return f'SELECT * FROM t WHERE id IN ({values})'


def _large_insert_sql(n_rows=25000):
rows = ', '.join(f'({i}, {i+1})' for i in range(n_rows))
return f'INSERT INTO t VALUES {rows}'


def _deep_subqueries_sql(depth=12):
q = 'SELECT * FROM t'
for i in range(depth):
q = f'SELECT * FROM ({q}) s{i}'
return q


def _many_joins_sql(n=500):
joins = ' '.join(
f'JOIN t{i} ON t{i}.id = t{i-1}.id' for i in range(1, n + 1))
return f'SELECT * FROM t0 {joins}'


def _complex_where_sql(depth=8, breadth=3):
def _build(d, idx=0):
if d == 0:
return f'col_{idx} = {idx}'
conn = 'AND' if d % 2 == 0 else 'OR'
parts = []
for i in range(breadth):
parts.append(_build(d - 1, idx * breadth + i))
return '(' + f' {conn} '.join(parts) + ')'
return f'SELECT * FROM t WHERE {_build(depth)}'


def _mixed_batch_sql(n=50):
stmts = []
for i in range(n):
if i % 4 == 0:
cols = ', '.join(f'col_{j} INT' for j in range(20))
stmts.append(f'CREATE TABLE t_{i} ({cols})')
elif i % 4 == 1:
rows = ', '.join(f'({j}, {j+1})' for j in range(100))
stmts.append(f'INSERT INTO t_{i} VALUES {rows}')
elif i % 4 == 2:
cols = ', '.join(f't.col_{j}' for j in range(20))
stmts.append(f'SELECT {cols} FROM t_{i} t WHERE t.col_0 > 0')
else:
stmts.append(
f'UPDATE t_{i} SET col_0 = col_0 + 1 WHERE col_1 > 0')
return '; '.join(stmts)


def _heavy_formatting_sql():
cases = ', '.join(
f'CASE WHEN col_{i} > 0 THEN col_{i} ELSE 0 END AS c_{i}'
for i in range(200))
return (
f'WITH cte AS (SELECT {cases} FROM t) '
f'SELECT * FROM cte WHERE c_0 > 0 ORDER BY c_1'
)


# ---------------------------------------------------------------------------
# Reindent benchmarks (one per PR table row)
# ---------------------------------------------------------------------------

@pytest.mark.benchmark(group="reindent")
def test_wide_select(benchmark):
sql = _wide_select_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_large_in_list(benchmark):
sql = _large_in_list_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_large_insert(benchmark):
sql = _large_insert_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_deep_subqueries(benchmark):
sql = _deep_subqueries_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_many_joins(benchmark):
sql = _many_joins_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_complex_where(benchmark):
sql = _complex_where_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_mixed_batch(benchmark):
sql = _mixed_batch_sql()
benchmark(sqlparse.format, sql, reindent=True)


@pytest.mark.benchmark(group="reindent")
def test_heavy_formatting(benchmark):
sql = _heavy_formatting_sql()
benchmark(sqlparse.format, sql, reindent=True)


# ---------------------------------------------------------------------------
# INSERT scaling benchmarks (_process_values)
# ---------------------------------------------------------------------------

@pytest.mark.benchmark(group="insert-scaling")
@pytest.mark.parametrize("n_rows", [5000, 10000, 25000], ids=["5k", "10k", "25k"])
def test_insert_scaling(benchmark, n_rows):
sql = _large_insert_sql(n_rows)
benchmark(sqlparse.format, sql, reindent=True)
Loading