From d73e0a9463332d7f25dc42b5bca7e1878d0ee7fe Mon Sep 17 00:00:00 2001
From: Alexander Pyhalov <a.pyhalov@postgrespro.ru>
Date: Fri, 14 Feb 2025 10:08:52 +0300
Subject: [PATCH 4/4] Handle SQL functions which are modified between rewrite
 and plan stages.

Query can be modified between rewrite and plan stages by
check_sql_fn_retval(). If later revalidation is considered by
RevalidateCachedQuery(), modifications, done by check_sql_fn_retval(),
could be lost.

To fix this issue

1) don't cache plans, which were modified by check_sql_fn_retval();
2) introduce callback in RevalidateCachedQuery(), which calls
check_sql_fn_retval() for non-saved plans, which were invalidated;
3) build plans with fixed_result = true, so that if target list
was mistakenly changed on revalidation, we throw an error instead
of crashing.
---
 src/backend/catalog/pg_proc.c        |   2 +-
 src/backend/executor/functions.c     | 134 +++++++++++++++++++++++++--
 src/backend/optimizer/util/clauses.c |   4 +-
 src/backend/utils/cache/plancache.c  |  24 +++++
 src/include/executor/functions.h     |   3 +-
 src/include/utils/plancache.h        |   9 ++
 6 files changed, 165 insertions(+), 11 deletions(-)

diff --git a/src/backend/catalog/pg_proc.c b/src/backend/catalog/pg_proc.c
index fe0490259e9..196a01ca803 100644
--- a/src/backend/catalog/pg_proc.c
+++ b/src/backend/catalog/pg_proc.c
@@ -960,7 +960,7 @@ fmgr_sql_validator(PG_FUNCTION_ARGS)
 			(void) check_sql_fn_retval(querytree_list,
 									   rettype, rettupdesc,
 									   proc->prokind,
-									   false, NULL);
+									   false, NULL, NULL);
 		}
 
 		error_context_stack = sqlerrcontext.previous;
diff --git a/src/backend/executor/functions.c b/src/backend/executor/functions.c
index 7fe49dbe17c..f2e3e601a65 100644
--- a/src/backend/executor/functions.c
+++ b/src/backend/executor/functions.c
@@ -174,9 +174,18 @@ typedef struct SQLFunctionPlanEntry
 	 */
 	SQLFunctionParseInfoPtr pinfo;	/* cached information about arguments */
 
-	MemoryContext entry_ctx;	/* memory context for allocated fields of this entry */
+	MemoryContext entry_ctx;	/* memory context for allocated fields of this
+								 * entry */
 }			SQLFunctionPlanEntry;
 
+/* Data necessary to plansource after-rewrite callback to modify query targetlist */
+typedef struct SQLFunctionPlanSourceCallbackData
+{
+	Oid			rettype;		/* function return type */
+	TupleDesc	rettupdesc;		/* function return record type */
+	char		prokind;		/* function kind */
+}			SQLFunctionPlanSourceCallbackData;
+
 static HTAB *sql_plan_cache_htab = NULL;
 
 /* non-export function prototypes */
@@ -212,14 +221,18 @@ static void sqlfunction_shutdown(DestReceiver *self);
 static void sqlfunction_destroy(DestReceiver *self);
 
 /* SQL-functions plan cache-related routines */
-static void compute_plan_entry_key(SQLFunctionPlanKey *hashkey, FunctionCallInfo fcinfo, Form_pg_proc procedureStruct);
-static SQLFunctionPlanEntry *get_cached_plan_entry(SQLFunctionPlanKey *hashkey);
+static void compute_plan_entry_key(SQLFunctionPlanKey * hashkey, FunctionCallInfo fcinfo, Form_pg_proc procedureStruct);
+static SQLFunctionPlanEntry * get_cached_plan_entry(SQLFunctionPlanKey * hashkey);
 static void save_cached_plan_entry(SQLFunctionPlanKey * hashkey, HeapTuple procedureTuple, List *plansource_list, List *result_tlist, bool returnsTuple, SQLFunctionParseInfoPtr pinfo, MemoryContext alianable_context);
 static void delete_cached_plan_entry(SQLFunctionPlanEntry * entry);
 
 static bool check_sql_fn_retval_matches(List *tlist, Oid rettype, TupleDesc rettupdesc, char prokind);
 static bool target_entry_has_compatible_type(TargetEntry *tle, Oid res_type, int32 res_typmod);
 
+static void plancache_rewrite_cb(struct CachedPlanSource *plansource, List *tlist, void *arg);
+
+static void register_plancache_cb(List *queryTree_list, List *plansource_list, Oid rettype, TupleDesc rettupdesc, char prokind);
+
 /*
  * Fill array of arguments with actual function argument types oids
  */
@@ -837,6 +850,66 @@ target_entry_has_compatible_type(TargetEntry *tle, Oid res_type, int32 res_typmo
 	return result;
 }
 
+/* Rewrite queries when plan is revalidated */
+static void
+plancache_rewrite_cb(struct CachedPlanSource *plansource, List *queries, void *arg)
+{
+	SQLFunctionPlanSourceCallbackData *cbdata = (SQLFunctionPlanSourceCallbackData *) arg;
+
+	/* Shouldn't happen */
+	if (cbdata == NULL)
+		elog(ERROR, "plancache callback data is missing");
+
+	check_sql_fn_retval(list_make1(queries) /* expects list of lists */ , cbdata->rettype, cbdata->rettupdesc, cbdata->prokind, false, NULL, NULL);
+}
+
+
+/*
+ * Register plancache callback which fires after query rewrite if plansource is invalidated.
+ */
+static void
+register_plancache_cb(List *queryTree_list, List *plansource_list, Oid rettype, TupleDesc rettupdesc, char prokind)
+{
+	ListCell   *qlc;
+	ListCell   *slc;
+	ListCell   *plc;
+	CachedPlanSource *modified_plansource = NULL;
+	SQLFunctionPlanSourceCallbackData *cbdata = NULL;
+
+	/* find plansource, which was modified by check_sql_fn_retval() */
+	forboth(qlc, queryTree_list, plc, plansource_list)
+	{
+		List	   *sublist = lfirst_node(List, qlc);
+
+		foreach(slc, sublist)
+		{
+			Query	   *q = lfirst_node(Query, slc);
+
+			if (q->canSetTag)
+				modified_plansource = (CachedPlanSource *) lfirst(plc);
+		}
+	}
+
+	/*
+	 * We've modified some queries, so now should find corresponding
+	 * plansource
+	 */
+	if (modified_plansource == NULL)
+		elog(ERROR, "couldn't find plansource, corresponding to query with modified targetlist");
+
+	/*
+	 * We don't care much about persistent storage for callback data or
+	 * finding out actual rettype and rettupdesc as such callback is only
+	 * registered for plans, which are not saved.
+	 */
+	cbdata = palloc(sizeof(SQLFunctionPlanSourceCallbackData));
+	cbdata->rettype = rettype;
+	cbdata->rettupdesc = rettupdesc;
+	cbdata->prokind = prokind;
+
+	CachedPlanRegisterPostRewriteCallback(modified_plansource, plancache_rewrite_cb, (void *) cbdata);
+}
+
 /*
  * Check if result tlist would be changed by check_sql_fn_retval()
  */
@@ -1063,6 +1136,7 @@ init_sql_fcache(FunctionCallInfo fcinfo, Oid collation, bool *lazyEvalOK)
 	else
 	{
 		MemoryContext alianable_context = fcontext;
+		bool		tlist_was_modified = false;
 
 		/* We need to preserve parse info */
 		if (use_plan_cache)
@@ -1179,7 +1253,8 @@ init_sql_fcache(FunctionCallInfo fcinfo, Oid collation, bool *lazyEvalOK)
 												   rettupdesc,
 												   procedureStruct->prokind,
 												   false,
-												   &resulttlist);
+												   &resulttlist,
+												   &tlist_was_modified);
 
 		/*
 		 * Queries could be rewritten by check_sql_fn_retval(). Now when they
@@ -1196,7 +1271,13 @@ init_sql_fcache(FunctionCallInfo fcinfo, Oid collation, bool *lazyEvalOK)
 				CachedPlanSource *plansource = lfirst(plc);
 
 
-				/* Finish filling in the CachedPlanSource */
+				/*
+				 * Finish filling in the CachedPlanSource. We force fixed
+				 * result type to be sure that query rewriting by
+				 * sql_fn_retval_cb() leads to the same result. And it should,
+				 * as callback is used only when plansource is not saved
+				 * (during one function call).
+				 */
 				CompleteCachedPlan(plansource,
 								   queryTree_sublist,
 								   NULL,
@@ -1205,10 +1286,24 @@ init_sql_fcache(FunctionCallInfo fcinfo, Oid collation, bool *lazyEvalOK)
 								   (ParserSetupHook) sql_fn_parser_setup,
 								   fcache->pinfo,
 								   CURSOR_OPT_PARALLEL_OK | CURSOR_OPT_NO_SCROLL,
-								   false);
+								   true);
+
 			}
 		}
 
+		if (tlist_was_modified)
+		{
+			/* Avoid caching plan if check_sql_fn_retval() has modified query */
+			use_plan_cache = false;
+
+			/*
+			 * Now we know that target query was modified. Find corresponding
+			 * plansource and add callback, which would reapply these
+			 * modifications if plan is invalidated.
+			 */
+			register_plancache_cb(queryTree_list, plansource_list, rettype, rettupdesc, procedureStruct->prokind);
+		}
+
 		/* If we can possibly use cached plan entry, save it. */
 		if (use_plan_cache)
 			save_cached_plan_entry(&plan_cache_entry_key, procedureTuple, plansource_list, resulttlist, fcache->returnsTuple, fcache->pinfo, alianable_context);
@@ -2139,7 +2234,8 @@ check_sql_fn_retval(List *queryTreeLists,
 					Oid rettype, TupleDesc rettupdesc,
 					char prokind,
 					bool insertDroppedCols,
-					List **resultTargetList)
+					List **resultTargetList,
+					bool *targetListModified)
 {
 	bool		is_tuple_result = false;
 	Query	   *parse;
@@ -2151,6 +2247,10 @@ check_sql_fn_retval(List *queryTreeLists,
 	List	   *upper_tlist = NIL;
 	bool		upper_tlist_nontrivial = false;
 	ListCell   *lc;
+	List	   *tlist_copy = NIL;
+
+	if (targetListModified)
+		*targetListModified = false;
 
 	if (resultTargetList)
 		*resultTargetList = NIL;	/* initialize in case of VOID result */
@@ -2240,6 +2340,14 @@ check_sql_fn_retval(List *queryTreeLists,
 	 * just does a projection.
 	 */
 
+	/*
+	 * If caller wants to check  check if tlist was modified, we have no much
+	 * choice except copying original tlist and compare (as tlist could be
+	 * modified in place).
+	 */
+	if (targetListModified)
+		tlist_copy = copyObject(tlist);
+
 	/*
 	 * Count the non-junk entries in the result targetlist.
 	 */
@@ -2343,6 +2451,12 @@ check_sql_fn_retval(List *queryTreeLists,
 			/* Return tlist if requested */
 			if (resultTargetList)
 				*resultTargetList = tlist;
+			if (targetListModified)
+			{
+				*targetListModified = !equal(tlist, tlist_copy);
+
+				list_free_deep(tlist_copy);
+			}
 			return true;
 		}
 
@@ -2514,6 +2628,12 @@ tlist_coercion_finished:
 	/* Return tlist (possibly modified) if requested */
 	if (resultTargetList)
 		*resultTargetList = upper_tlist;
+	if (targetListModified)
+	{
+		*targetListModified = upper_tlist_nontrivial || !equal(tlist, tlist_copy);
+
+		list_free_deep(tlist_copy);
+	}
 
 	return is_tuple_result;
 }
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index 43dfecfb47f..0b7c9657c8c 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -4742,7 +4742,7 @@ inline_function(Oid funcid, Oid result_type, Oid result_collid,
 	if (check_sql_fn_retval(list_make1(querytree_list),
 							result_type, rettupdesc,
 							funcform->prokind,
-							false, NULL))
+							false, NULL, NULL))
 		goto fail;				/* reject whole-tuple-result cases */
 
 	/*
@@ -5288,7 +5288,7 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte)
 	if (!check_sql_fn_retval(list_make1(querytree_list),
 							 fexpr->funcresulttype, rettupdesc,
 							 funcform->prokind,
-							 true, NULL) &&
+							 true, NULL, NULL) &&
 		(functypclass == TYPEFUNC_COMPOSITE ||
 		 functypclass == TYPEFUNC_COMPOSITE_DOMAIN ||
 		 functypclass == TYPEFUNC_RECORD))
diff --git a/src/backend/utils/cache/plancache.c b/src/backend/utils/cache/plancache.c
index da65a6010bd..51344caf411 100644
--- a/src/backend/utils/cache/plancache.c
+++ b/src/backend/utils/cache/plancache.c
@@ -279,6 +279,8 @@ CreateCachedPlan(RawStmt *raw_parse_tree,
 	plansource->total_custom_cost = 0;
 	plansource->num_generic_plans = 0;
 	plansource->num_custom_plans = 0;
+	plansource->post_rewrite_cb = NULL;
+	plansource->post_rewrite_cb_arg = NULL;
 
 	MemoryContextSwitchTo(oldcxt);
 
@@ -306,6 +308,20 @@ CreateCachedPlanForQuery(Query *analyzed_parse_tree,
 	return plansource;
 }
 
+/*
+ * CachedPlanRegisterPostRewriteCallback() registers plansource
+ * callback which fires after rewriting query during plan
+ * revalidation.
+ */
+void
+CachedPlanRegisterPostRewriteCallback(CachedPlanSource *plansource,
+						post_rewrite_cb_type post_rewrite_cb,
+						void* post_rewrite_cb_arg)
+{
+	plansource->post_rewrite_cb = post_rewrite_cb;
+	plansource->post_rewrite_cb_arg = post_rewrite_cb_arg;
+}
+
 /*
  * CreateOneShotCachedPlan: initially create a one-shot plan cache entry.
  *
@@ -787,6 +803,11 @@ RevalidateCachedQuery(CachedPlanSource *plansource,
 												   plansource->num_params,
 												   queryEnv);
 
+
+	/* Call post-rewrite callback */
+	if (plansource->post_rewrite_cb)
+		plansource->post_rewrite_cb(plansource, tlist, plansource->post_rewrite_cb_arg);
+
 	/* Release snapshot if we got one */
 	if (snapshot_set)
 		PopActiveSnapshot();
@@ -1663,6 +1684,9 @@ CopyCachedPlan(CachedPlanSource *plansource)
 	newsource->rewriteRowSecurity = plansource->rewriteRowSecurity;
 	newsource->dependsOnRLS = plansource->dependsOnRLS;
 
+	newsource->post_rewrite_cb = plansource->post_rewrite_cb;
+	newsource->post_rewrite_cb_arg = plansource->post_rewrite_cb_arg;
+
 	newsource->gplan = NULL;
 
 	newsource->is_oneshot = false;
diff --git a/src/include/executor/functions.h b/src/include/executor/functions.h
index a6ae2e72d79..144a9b91fba 100644
--- a/src/include/executor/functions.h
+++ b/src/include/executor/functions.h
@@ -49,7 +49,8 @@ extern bool check_sql_fn_retval(List *queryTreeLists,
 								Oid rettype, TupleDesc rettupdesc,
 								char prokind,
 								bool insertDroppedCols,
-								List **resultTargetList);
+								List **resultTargetList,
+								bool *targetListModified);
 
 extern DestReceiver *CreateSQLFunctionDestReceiver(void);
 
diff --git a/src/include/utils/plancache.h b/src/include/utils/plancache.h
index 1493f726649..8f0853062c8 100644
--- a/src/include/utils/plancache.h
+++ b/src/include/utils/plancache.h
@@ -35,6 +35,9 @@ typedef enum
 	PLAN_CACHE_MODE_FORCE_CUSTOM_PLAN,
 }			PlanCacheMode;
 
+struct CachedPlanSource;
+typedef void (*post_rewrite_cb_type) (struct CachedPlanSource *plansource, List *tlist, void *arg);
+
 /* GUC parameter */
 extern PGDLLIMPORT int plan_cache_mode;
 
@@ -134,6 +137,10 @@ typedef struct CachedPlanSource
 	double		total_custom_cost;	/* total cost of custom plans so far */
 	int64		num_custom_plans;	/* # of custom plans included in total */
 	int64		num_generic_plans;	/* # of generic plans */
+
+	/* Post-rewrite callback */
+	post_rewrite_cb_type 	post_rewrite_cb;
+	void 		*post_rewrite_cb_arg;	/* post-rewrite callback argument */
 } CachedPlanSource;
 
 /*
@@ -211,6 +218,8 @@ extern void CompleteCachedPlan(CachedPlanSource *plansource,
 							   int cursor_options,
 							   bool fixed_result);
 
+extern void CachedPlanRegisterPostRewriteCallback(CachedPlanSource *plansource,	post_rewrite_cb_type post_rewrite_cb, void* post_rewrite_cb_arg);
+
 extern void SaveCachedPlan(CachedPlanSource *plansource);
 extern void DropCachedPlan(CachedPlanSource *plansource);
 
-- 
2.43.0

