82 lines
2.4 KiB
Diff
82 lines
2.4 KiB
Diff
From b4a39d9850969b4e1d6940d32094ee0b42a2cf03 Mon Sep 17 00:00:00 2001
|
|
From: Andi Albrecht <albrecht.andi@gmail.com>
|
|
Date: Sat, 13 Apr 2024 13:59:00 +0200
|
|
Subject: [PATCH] Raise SQLParseError instead of RecursionError.
|
|
|
|
Origin: https://github.com/andialbrecht/sqlparse/commit/b4a39d9850969b4e1d6940d32094ee0b42a2cf03
|
|
|
|
---
|
|
sqlparse/sql.py | 15 +++++++++------
|
|
tests/test_regressions.py | 14 ++++++++++++++
|
|
2 files changed, 23 insertions(+), 6 deletions(-)
|
|
|
|
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
|
|
index a942bcd..84ed1c2 100644
|
|
--- a/sqlparse/sql.py
|
|
+++ b/sqlparse/sql.py
|
|
@@ -12,6 +12,7 @@ from __future__ import print_function
|
|
import re
|
|
|
|
from sqlparse import tokens as T
|
|
+from sqlparse.exceptions import SQLParseError
|
|
from sqlparse.compat import string_types, text_type, unicode_compatible
|
|
from sqlparse.utils import imt, remove_quotes
|
|
|
|
@@ -214,12 +215,14 @@ class TokenList(Token):
|
|
|
|
This method is recursively called for all child tokens.
|
|
"""
|
|
- for token in self.tokens:
|
|
- if token.is_group:
|
|
- for item in token.flatten():
|
|
- yield item
|
|
- else:
|
|
- yield token
|
|
+ try:
|
|
+ for token in self.tokens:
|
|
+ if token.is_group:
|
|
+ yield from token.flatten()
|
|
+ else:
|
|
+ yield token
|
|
+ except RecursionError as err:
|
|
+ raise SQLParseError('Maximum recursion depth exceeded') from err
|
|
|
|
def get_sublists(self):
|
|
for token in self.tokens:
|
|
diff --git a/tests/test_regressions.py b/tests/test_regressions.py
|
|
index 2ed0ff3..0f843b6 100644
|
|
--- a/tests/test_regressions.py
|
|
+++ b/tests/test_regressions.py
|
|
@@ -1,10 +1,12 @@
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import pytest
|
|
+import sys
|
|
|
|
import sqlparse
|
|
from sqlparse import sql, tokens as T
|
|
from sqlparse.compat import PY2
|
|
+from sqlparse.exceptions import SQLParseError
|
|
|
|
|
|
def test_issue9():
|
|
@@ -406,3 +408,15 @@ def test_issue489_tzcasts():
|
|
p = sqlparse.parse('select bar at time zone \'UTC\' as foo')[0]
|
|
assert p.tokens[-1].has_alias() is True
|
|
assert p.tokens[-1].get_alias() == 'foo'
|
|
+
|
|
+@pytest.fixture
|
|
+def limit_recursion():
|
|
+ curr_limit = sys.getrecursionlimit()
|
|
+ sys.setrecursionlimit(70)
|
|
+ yield
|
|
+ sys.setrecursionlimit(curr_limit)
|
|
+
|
|
+
|
|
+def test_max_recursion(limit_recursion):
|
|
+ with pytest.raises(SQLParseError):
|
|
+ sqlparse.parse('[' * 100 + ']' * 100)
|
|
--
|
|
2.33.0
|
|
|