From 566b43c29526c2c6a64b695a049531a747501540 Mon Sep 17 00:00:00 2001
From: Julien Rouhaud <julien.rouhaud@free.fr>
Date: Thu, 22 Apr 2021 01:33:42 +0800
Subject: [PATCH v1 3/4] Add a new MODE_SINGLE_QUERY to the core parser and use
 it in pg_parse_query.

If a third-party module provides a parser_hook, pg_parse_query() switches to
single-query parsing so multi-query commands using different grammar can work
properly.  If the third-party module supports the full set of SQL we support,
or want to prevent fallback on the core parser, it can ignore the
MODE_SINGLE_QUERY mode and parse the full query string.  In that case they must
return a List with more than one RawStmt or a single RawStmt with a 0 length to
stop the parsing phase, or raise an ERROR.

Otherwise, plugins should parse a single query only and always return a List
containing a single RawStmt with a properly set length (possibly 0 if it was a
single query without end of query delimiter).  If the command is valid but
doesn't contain any statements (e.g. a single semi-colon), a single RawStmt
with a NULL stmt field should be returned, containing the consumed query string
length so we can move to the next command in a single pass rather than 1 byte
at a time.

Also, third-party modules can choose to ignore some or all of parsing error if
they want to implement only subset of postgres suppoted syntax, or even a
totally different syntax, and fall-back on core grammar for unhandled case.  In
thase case, they should set the error flag to true.  The returned List will be
ignored and the same offset of the input string will be parsed using the core
parser.

Finally, note that third-party plugins that wants to fallback on other grammar
should first try to call a previous parser hook if any before setting the error
switch and returning.
---
 .../pg_stat_statements/pg_stat_statements.c   |   3 +-
 src/backend/commands/tablecmds.c              |   2 +-
 src/backend/executor/spi.c                    |   4 +-
 src/backend/parser/gram.y                     |  27 ++++
 src/backend/parser/parse_type.c               |   2 +-
 src/backend/parser/parser.c                   |   7 +-
 src/backend/parser/scan.l                     |  13 +-
 src/backend/tcop/postgres.c                   | 131 ++++++++++++++++--
 src/include/parser/parser.h                   |   5 +-
 src/include/parser/scanner.h                  |   3 +-
 src/include/tcop/tcopprot.h                   |   3 +-
 src/pl/plpgsql/src/pl_gram.y                  |   2 +-
 src/pl/plpgsql/src/pl_scanner.c               |   2 +-
 13 files changed, 179 insertions(+), 25 deletions(-)

diff --git a/contrib/pg_stat_statements/pg_stat_statements.c b/contrib/pg_stat_statements/pg_stat_statements.c
index f42f07622e..7c911ef58d 100644
--- a/contrib/pg_stat_statements/pg_stat_statements.c
+++ b/contrib/pg_stat_statements/pg_stat_statements.c
@@ -2711,7 +2711,8 @@ fill_in_constant_lengths(JumbleState *jstate, const char *query,
 	yyscanner = scanner_init(query,
 							 &yyextra,
 							 &ScanKeywords,
-							 ScanKeywordTokens);
+							 ScanKeywordTokens,
+							 0);
 
 	/* we don't want to re-emit any escape string warnings */
 	yyextra.escape_string_warning = false;
diff --git a/src/backend/commands/tablecmds.c b/src/backend/commands/tablecmds.c
index d9ba87a2a3..cc9c86778c 100644
--- a/src/backend/commands/tablecmds.c
+++ b/src/backend/commands/tablecmds.c
@@ -12602,7 +12602,7 @@ ATPostAlterTypeParse(Oid oldId, Oid oldRelId, Oid refRelId, char *cmd,
 	 * parse_analyze() or the rewriter, but instead we need to pass them
 	 * through parse_utilcmd.c to make them ready for execution.
 	 */
-	raw_parsetree_list = raw_parser(cmd, RAW_PARSE_DEFAULT);
+	raw_parsetree_list = raw_parser(cmd, RAW_PARSE_DEFAULT, 0);
 	querytree_list = NIL;
 	foreach(list_item, raw_parsetree_list)
 	{
diff --git a/src/backend/executor/spi.c b/src/backend/executor/spi.c
index 00aa78ea53..e456172fef 100644
--- a/src/backend/executor/spi.c
+++ b/src/backend/executor/spi.c
@@ -2121,7 +2121,7 @@ _SPI_prepare_plan(const char *src, SPIPlanPtr plan)
 	/*
 	 * Parse the request string into a list of raw parse trees.
 	 */
-	raw_parsetree_list = raw_parser(src, plan->parse_mode);
+	raw_parsetree_list = raw_parser(src, plan->parse_mode, 0);
 
 	/*
 	 * Do parse analysis and rule rewrite for each raw parsetree, storing the
@@ -2229,7 +2229,7 @@ _SPI_prepare_oneshot_plan(const char *src, SPIPlanPtr plan)
 	/*
 	 * Parse the request string into a list of raw parse trees.
 	 */
-	raw_parsetree_list = raw_parser(src, plan->parse_mode);
+	raw_parsetree_list = raw_parser(src, plan->parse_mode, 0);
 
 	/*
 	 * Construct plancache entries, but don't do parse analysis yet.
diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y
index b4ab4014c8..9733b30529 100644
--- a/src/backend/parser/gram.y
+++ b/src/backend/parser/gram.y
@@ -753,6 +753,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
 %token		MODE_PLPGSQL_ASSIGN1
 %token		MODE_PLPGSQL_ASSIGN2
 %token		MODE_PLPGSQL_ASSIGN3
+%token		MODE_SINGLE_QUERY
 
 
 /* Precedence: lowest to highest */
@@ -858,6 +859,32 @@ parse_toplevel:
 				pg_yyget_extra(yyscanner)->parsetree =
 					list_make1(makeRawStmt((Node *) n, 0));
 			}
+			| MODE_SINGLE_QUERY toplevel_stmt ';'
+			{
+				RawStmt *raw = makeRawStmt($2, 0);
+				updateRawStmtEnd(raw, @3 + 1);
+				/* NOTE: we can return a raw statement containing a NULL stmt.
+				 * This is done to allow pg_parse_query to ignore that part of
+				 * the input string and move to the next command.
+				 */
+				pg_yyget_extra(yyscanner)->parsetree = list_make1(raw);
+				YYACCEPT;
+			}
+			/*
+			 * We need to explicitly look for EOF to parse non-semicolon
+			 * terminated statements in single query mode, as we could
+			 * otherwise successfully parse the beginning of an otherwise
+			 * invalid query.
+			 */
+			| MODE_SINGLE_QUERY toplevel_stmt YYEOF
+			{
+				/* NOTE: we can return a raw statement containing a NULL stmt.
+				 * This is done to allow pg_parse_query to ignore that part of
+				 * the input string.
+				 */
+				pg_yyget_extra(yyscanner)->parsetree = list_make1(makeRawStmt($2, 0));
+				YYACCEPT;
+			}
 		;
 
 /*
diff --git a/src/backend/parser/parse_type.c b/src/backend/parser/parse_type.c
index abe131ebeb..e9a7b5d62a 100644
--- a/src/backend/parser/parse_type.c
+++ b/src/backend/parser/parse_type.c
@@ -746,7 +746,7 @@ typeStringToTypeName(const char *str)
 	ptserrcontext.previous = error_context_stack;
 	error_context_stack = &ptserrcontext;
 
-	raw_parsetree_list = raw_parser(str, RAW_PARSE_TYPE_NAME);
+	raw_parsetree_list = raw_parser(str, RAW_PARSE_TYPE_NAME, 0);
 
 	error_context_stack = ptserrcontext.previous;
 
diff --git a/src/backend/parser/parser.c b/src/backend/parser/parser.c
index 875de7ba28..7297733168 100644
--- a/src/backend/parser/parser.c
+++ b/src/backend/parser/parser.c
@@ -39,7 +39,7 @@ static char *str_udeescape(const char *str, char escape,
  * list have the form required by the specified RawParseMode.
  */
 List *
-raw_parser(const char *str, RawParseMode mode)
+raw_parser(const char *str, RawParseMode mode, int offset)
 {
 	core_yyscan_t yyscanner;
 	base_yy_extra_type yyextra;
@@ -47,7 +47,7 @@ raw_parser(const char *str, RawParseMode mode)
 
 	/* initialize the flex scanner */
 	yyscanner = scanner_init(str, &yyextra.core_yy_extra,
-							 &ScanKeywords, ScanKeywordTokens);
+							 &ScanKeywords, ScanKeywordTokens, offset);
 
 	/* base_yylex() only needs us to initialize the lookahead token, if any */
 	if (mode == RAW_PARSE_DEFAULT)
@@ -61,7 +61,8 @@ raw_parser(const char *str, RawParseMode mode)
 			MODE_PLPGSQL_EXPR,	/* RAW_PARSE_PLPGSQL_EXPR */
 			MODE_PLPGSQL_ASSIGN1,	/* RAW_PARSE_PLPGSQL_ASSIGN1 */
 			MODE_PLPGSQL_ASSIGN2,	/* RAW_PARSE_PLPGSQL_ASSIGN2 */
-			MODE_PLPGSQL_ASSIGN3	/* RAW_PARSE_PLPGSQL_ASSIGN3 */
+			MODE_PLPGSQL_ASSIGN3,	/* RAW_PARSE_PLPGSQL_ASSIGN3 */
+			MODE_SINGLE_QUERY		/* RAW_PARSE_SINGLE_QUERY */
 		};
 
 		yyextra.have_lookahead = true;
diff --git a/src/backend/parser/scan.l b/src/backend/parser/scan.l
index 9f9d8a1706..2191360a72 100644
--- a/src/backend/parser/scan.l
+++ b/src/backend/parser/scan.l
@@ -1189,8 +1189,10 @@ core_yyscan_t
 scanner_init(const char *str,
 			 core_yy_extra_type *yyext,
 			 const ScanKeywordList *keywordlist,
-			 const uint16 *keyword_tokens)
+			 const uint16 *keyword_tokens,
+			 int offset)
 {
+	YY_BUFFER_STATE state;
 	Size		slen = strlen(str);
 	yyscan_t	scanner;
 
@@ -1213,13 +1215,20 @@ scanner_init(const char *str,
 	yyext->scanbuflen = slen;
 	memcpy(yyext->scanbuf, str, slen);
 	yyext->scanbuf[slen] = yyext->scanbuf[slen + 1] = YY_END_OF_BUFFER_CHAR;
-	yy_scan_buffer(yyext->scanbuf, slen + 2, scanner);
+	state = yy_scan_buffer(yyext->scanbuf, slen + 2, scanner);
 
 	/* initialize literal buffer to a reasonable but expansible size */
 	yyext->literalalloc = 1024;
 	yyext->literalbuf = (char *) palloc(yyext->literalalloc);
 	yyext->literallen = 0;
 
+	/*
+	 * Adjust the offset in the input string.  This is required in single-query
+	 * mode, as we need to register the same token locations as we would have
+	 * in normal mode with multi-statement query string.
+	 */
+	state->yy_buf_pos += offset;
+
 	return scanner;
 }
 
diff --git a/src/backend/tcop/postgres.c b/src/backend/tcop/postgres.c
index e91db69830..a45dd602c0 100644
--- a/src/backend/tcop/postgres.c
+++ b/src/backend/tcop/postgres.c
@@ -602,17 +602,130 @@ ProcessClientWriteInterrupt(bool blocked)
 List *
 pg_parse_query(const char *query_string)
 {
-	List	   *raw_parsetree_list = NIL;
+	List		   *result = NIL;
+	int				stmt_len, offset;
 
 	TRACE_POSTGRESQL_QUERY_PARSE_START(query_string);
 
 	if (log_parser_stats)
 		ResetUsage();
 
-	if (parser_hook)
-		raw_parsetree_list = (*parser_hook) (query_string, RAW_PARSE_DEFAULT);
-	else
-		raw_parsetree_list = raw_parser(query_string, RAW_PARSE_DEFAULT);
+	stmt_len = 0; /* lazily computed when needed */
+	offset = 0;
+
+	while(true)
+	{
+		List *raw_parsetree_list;
+		RawStmt *raw;
+		bool	error = false;
+
+		/*----------------
+		 * Start parsing the input string.  If a third-party module provided a
+		 * parser_hook, we switch to single-query parsing so multi-query
+		 * commands using different grammar can work properly.
+		 * If the third-party modules support the full set of SQL we support,
+		 * or want to prevent fallback on the core parser, it can ignore the
+		 * RAW_PARSE_SINGLE_QUERY flag and parse the full query string.
+		 * In that case they must return a List with more than one RawStmt or a
+		 * single RawStmt with a 0 length to stop the parsing phase, or raise
+		 * an ERROR.
+		 *
+		 * Otherwise, plugins should parse a single query only and always
+		 * return a List containing a single RawStmt with a properly set length
+		 * (possibly 0 if it was a single query without end of query
+		 * delimiter).  If the command is valid but doesn't contain any
+		 * statements (e.g. a single semi-colon), a single RawStmt with a NULL
+		 * stmt field should be returned, containing the consumed query string
+		 * length so we can move to the next command in a single pass rather
+		 * than 1 byte at a time.
+		 *
+		 * Also, third-party modules can choose to ignore some or all of
+		 * parsing error if they want to implement only subset of postgres
+		 * suppoted syntax, or even a totally different syntax, and fall-back
+		 * on core grammar for unhandled case.  In thase case, they should set
+		 * the error flag to true.  The returned List will be ignored and the
+		 * same offset of the input string will be parsed using the core
+		 * parser.
+		 *
+		 * Finally, note that third-party modules that wants to fallback on
+		 * other grammar should first try to call a previous parser hook if any
+		 * before setting the error switch and returning .
+		 */
+		if (parser_hook)
+			raw_parsetree_list = (*parser_hook) (query_string,
+												 RAW_PARSE_SINGLE_QUERY,
+												 offset,
+												 &error);
+
+		/*
+		 * If a third-party module couldn't parse a single query or if no
+		 * third-party module is configured, fallback on core parser.
+		 */
+		if (error || !parser_hook)
+			raw_parsetree_list = raw_parser(query_string,
+					error ? RAW_PARSE_SINGLE_QUERY : RAW_PARSE_DEFAULT, offset);
+
+		/*
+		 * If there are no third-party plugin, or none of the parsers found a
+		 * valid query, or if a third party module consumed the whole
+		 * query string we're done.
+		 */
+		if (!parser_hook || raw_parsetree_list == NIL ||
+			list_length(raw_parsetree_list) > 1)
+		{
+			/*
+			 * Warn third-party plugins if they mix "single query" and "whole
+			 * input string" strategy rather than silently accepting it and
+			 * maybe allow fallback on core grammar even if they want to avoid
+			 * that.  This way plugin authors can be warned early of the issue.
+			 */
+			if (result != NIL)
+			{
+				Assert(parser_hook != NULL);
+				elog(ERROR, "parser_hook should parse a single statement at "
+						"a time or consume the whole input string at once");
+			}
+			result = raw_parsetree_list;
+			break;
+		}
+
+		if (stmt_len == 0)
+			stmt_len = strlen(query_string);
+
+		raw = linitial_node(RawStmt, raw_parsetree_list);
+
+		/*
+		 * In single-query mode, the parser will return statement location info
+		 * relative to the beginning of complete original string, not the part
+		 * we just parsed, so adjust the location info.
+		 */
+		if (offset > 0 && raw->stmt_len > 0)
+		{
+			Assert(raw->stmt_len > offset);
+			raw->stmt_location = offset;
+			raw->stmt_len -= offset;
+		}
+
+		/* Ignore the statement if it didn't contain any command. */
+		if (raw->stmt)
+			result = lappend(result, raw);
+
+		if (raw->stmt_len == 0)
+		{
+			/* The statement was the whole string, we're done. */
+			break;
+		}
+		else if (raw->stmt_len + offset >= stmt_len)
+		{
+			/* We consumed all of the input string, we're done. */
+			break;
+		}
+		else
+		{
+			/* Advance the offset to the next command. */
+			offset += raw->stmt_len;
+		}
+	}
 
 	if (log_parser_stats)
 		ShowUsage("PARSER STATISTICS");
@@ -620,13 +733,13 @@ pg_parse_query(const char *query_string)
 #ifdef COPY_PARSE_PLAN_TREES
 	/* Optional debugging check: pass raw parsetrees through copyObject() */
 	{
-		List	   *new_list = copyObject(raw_parsetree_list);
+		List	   *new_list = copyObject(result);
 
 		/* This checks both copyObject() and the equal() routines... */
-		if (!equal(new_list, raw_parsetree_list))
+		if (!equal(new_list, result))
 			elog(WARNING, "copyObject() failed to produce an equal raw parse tree");
 		else
-			raw_parsetree_list = new_list;
+			result = new_list;
 	}
 #endif
 
@@ -638,7 +751,7 @@ pg_parse_query(const char *query_string)
 
 	TRACE_POSTGRESQL_QUERY_PARSE_DONE(query_string);
 
-	return raw_parsetree_list;
+	return result;
 }
 
 /*
diff --git a/src/include/parser/parser.h b/src/include/parser/parser.h
index 853b0f1606..5694ae791a 100644
--- a/src/include/parser/parser.h
+++ b/src/include/parser/parser.h
@@ -41,7 +41,8 @@ typedef enum
 	RAW_PARSE_PLPGSQL_EXPR,
 	RAW_PARSE_PLPGSQL_ASSIGN1,
 	RAW_PARSE_PLPGSQL_ASSIGN2,
-	RAW_PARSE_PLPGSQL_ASSIGN3
+	RAW_PARSE_PLPGSQL_ASSIGN3,
+	RAW_PARSE_SINGLE_QUERY
 } RawParseMode;
 
 /* Values for the backslash_quote GUC */
@@ -59,7 +60,7 @@ extern PGDLLIMPORT bool standard_conforming_strings;
 
 
 /* Primary entry point for the raw parsing functions */
-extern List *raw_parser(const char *str, RawParseMode mode);
+extern List *raw_parser(const char *str, RawParseMode mode, int offset);
 
 /* Utility functions exported by gram.y (perhaps these should be elsewhere) */
 extern List *SystemFuncName(char *name);
diff --git a/src/include/parser/scanner.h b/src/include/parser/scanner.h
index 0d8182faa0..2747e8b1a0 100644
--- a/src/include/parser/scanner.h
+++ b/src/include/parser/scanner.h
@@ -136,7 +136,8 @@ extern PGDLLIMPORT const uint16 ScanKeywordTokens[];
 extern core_yyscan_t scanner_init(const char *str,
 								  core_yy_extra_type *yyext,
 								  const ScanKeywordList *keywordlist,
-								  const uint16 *keyword_tokens);
+								  const uint16 *keyword_tokens,
+								  int offset);
 extern void scanner_finish(core_yyscan_t yyscanner);
 extern int	core_yylex(core_YYSTYPE *lvalp, YYLTYPE *llocp,
 					   core_yyscan_t yyscanner);
diff --git a/src/include/tcop/tcopprot.h b/src/include/tcop/tcopprot.h
index 131dc2b22e..27201dde1d 100644
--- a/src/include/tcop/tcopprot.h
+++ b/src/include/tcop/tcopprot.h
@@ -45,7 +45,8 @@ typedef enum
 extern PGDLLIMPORT int log_statement;
 
 /* Hook for plugins to get control in pg_parse_query() */
-typedef List *(*parser_hook_type) (const char *str, RawParseMode mode);
+typedef List *(*parser_hook_type) (const char *str, RawParseMode mode,
+								   int offset, bool *error);
 extern PGDLLIMPORT parser_hook_type parser_hook;
 
 extern List *pg_parse_query(const char *query_string);
diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y
index 34e0520719..6e09f01370 100644
--- a/src/pl/plpgsql/src/pl_gram.y
+++ b/src/pl/plpgsql/src/pl_gram.y
@@ -3690,7 +3690,7 @@ check_sql_expr(const char *stmt, RawParseMode parseMode, int location)
 	error_context_stack = &syntax_errcontext;
 
 	oldCxt = MemoryContextSwitchTo(plpgsql_compile_tmp_cxt);
-	(void) raw_parser(stmt, parseMode);
+	(void) raw_parser(stmt, parseMode, 0);
 	MemoryContextSwitchTo(oldCxt);
 
 	/* Restore former ereport callback */
diff --git a/src/pl/plpgsql/src/pl_scanner.c b/src/pl/plpgsql/src/pl_scanner.c
index e4c7a91ab5..a2886c42ec 100644
--- a/src/pl/plpgsql/src/pl_scanner.c
+++ b/src/pl/plpgsql/src/pl_scanner.c
@@ -587,7 +587,7 @@ plpgsql_scanner_init(const char *str)
 {
 	/* Start up the core scanner */
 	yyscanner = scanner_init(str, &core_yy,
-							 &ReservedPLKeywords, ReservedPLKeywordTokens);
+							 &ReservedPLKeywords, ReservedPLKeywordTokens, 0);
 
 	/*
 	 * scanorig points to the original string, which unlike the scanner's
-- 
2.30.1

