From 5d63ecc1642818f4de006faf0c39c23124d1c0e5 Mon Sep 17 00:00:00 2001
From: Michael Paquier <michael@otacoo.com>
Date: Sat, 17 Jan 2015 22:52:16 +0900
Subject: [PATCH] Fix compatibility checks of connectby in tablefunc

Coverity has pointed out a block of dead code that was used to check the
compatibility of the input and return types of connectby. The check done
simply compared the input and output OIDs, returning an error if the type
OIDs do not match, however it seems more solid to check for a cast between
those two types, to avoid at the same time errors related to data input.
---
 contrib/tablefunc/expected/tablefunc.out |  8 ++++
 contrib/tablefunc/sql/tablefunc.sql      |  6 +++
 contrib/tablefunc/tablefunc.c            | 64 ++++++++++++++++++--------------
 3 files changed, 50 insertions(+), 28 deletions(-)

diff --git a/contrib/tablefunc/expected/tablefunc.out b/contrib/tablefunc/expected/tablefunc.out
index 0437ecf..75eff3e 100644
--- a/contrib/tablefunc/expected/tablefunc.out
+++ b/contrib/tablefunc/expected/tablefunc.out
@@ -376,6 +376,14 @@ SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 4, '~') A
     11 |           10 |     4 | 2~5~9~10~11
 (8 rows)
 
+-- should fail as first two columns must have the same type
+SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 0, '~') AS t(keyid text, parent_keyid int, level int, branch text);
+ERROR:  invalid return type
+DETAIL:  First two columns must be the same type.
+-- should fail as key field datatype should match return datatype
+SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 0, '~') AS t(keyid json, parent_keyid json, level int, branch text);
+ERROR:  invalid return type
+DETAIL:  Failed to cast integer to json
 -- test for falsely detected recursion
 DROP TABLE connectby_int;
 CREATE TABLE connectby_int(keyid int, parent_keyid int);
diff --git a/contrib/tablefunc/sql/tablefunc.sql b/contrib/tablefunc/sql/tablefunc.sql
index bf874f2..aff0699 100644
--- a/contrib/tablefunc/sql/tablefunc.sql
+++ b/contrib/tablefunc/sql/tablefunc.sql
@@ -179,6 +179,12 @@ SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 0, '~') A
 -- infinite recursion failure avoided by depth limit
 SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 4, '~') AS t(keyid int, parent_keyid int, level int, branch text);
 
+-- should fail as first two columns must have the same type
+SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 0, '~') AS t(keyid text, parent_keyid int, level int, branch text);
+
+-- should fail as key field datatype should match return datatype
+SELECT * FROM connectby('connectby_int', 'keyid', 'parent_keyid', '2', 0, '~') AS t(keyid json, parent_keyid json, level int, branch text);
+
 -- test for falsely detected recursion
 DROP TABLE connectby_int;
 CREATE TABLE connectby_int(keyid int, parent_keyid int);
diff --git a/contrib/tablefunc/tablefunc.c b/contrib/tablefunc/tablefunc.c
index 3388fab..3677d96 100644
--- a/contrib/tablefunc/tablefunc.c
+++ b/contrib/tablefunc/tablefunc.c
@@ -40,7 +40,10 @@
 #include "funcapi.h"
 #include "lib/stringinfo.h"
 #include "miscadmin.h"
+#include "nodes/makefuncs.h"
+#include "parser/parse_coerce.h"
 #include "utils/builtins.h"
+#include "utils/lsyscache.h"
 
 #include "tablefunc.h"
 
@@ -54,7 +57,7 @@ static Tuplestorestate *get_crosstab_tuplestore(char *sql,
 						bool randomAccess);
 static void validateConnectbyTupleDesc(TupleDesc tupdesc, bool show_branch, bool show_serial);
 static bool compatCrosstabTupleDescs(TupleDesc tupdesc1, TupleDesc tupdesc2);
-static bool compatConnectbyTupleDescs(TupleDesc tupdesc1, TupleDesc tupdesc2);
+static void compatConnectbyTupleDescs(TupleDesc tupdesc1, TupleDesc tupdesc2);
 static void get_normal_pair(float8 *x1, float8 *x2);
 static Tuplestorestate *connectby(char *relname,
 		  char *key_fld,
@@ -1317,20 +1320,14 @@ build_tuplestore_recursively(char *key_fld,
 		StringInfoData chk_current_key;
 
 		/* First time through, do a little more setup */
-		if (level == 0)
+		if (level == 1)
 		{
 			/*
 			 * Check that return tupdesc is compatible with the one we got
-			 * from the query, but only at level 0 -- no need to check more
-			 * than once
+			 * from the query, but only at the first level -- no need to check
+			 * more than once
 			 */
-
-			if (!compatConnectbyTupleDescs(tupdesc, spi_tupdesc))
-				ereport(ERROR,
-						(errcode(ERRCODE_SYNTAX_ERROR),
-						 errmsg("invalid return type"),
-						 errdetail("Return and SQL tuple descriptions are " \
-								   "incompatible.")));
+			compatConnectbyTupleDescs(tupdesc, spi_tupdesc);
 		}
 
 		initStringInfo(&branchstr);
@@ -1486,36 +1483,47 @@ validateConnectbyTupleDesc(TupleDesc tupdesc, bool show_branch, bool show_serial
 }
 
 /*
- * Check if spi sql tupdesc and return tupdesc are compatible
+ * Check if spi sql tupdesc and return tupdesc are compatible and
+ * if implicit casts can be used.
  */
-static bool
+static void
 compatConnectbyTupleDescs(TupleDesc ret_tupdesc, TupleDesc sql_tupdesc)
 {
 	Oid			ret_atttypid;
 	Oid			sql_atttypid;
+	int32		ret_typmod;
+	Oid			ret_basetype;
+	Oid			ret_basecoll;
+	Expr		*defval;
 
-	/* check the key_fld types match */
+	/*
+	 * Note that the attribute type of the two first columns must match
+	 * (see validateConnectbyTupleDesc).
+	 */
+	Assert(sql_tupdesc->attrs[0]->atttypid == sql_tupdesc->attrs[1]->atttypid);
+
+	/* check compatibility of the the key fields */
 	ret_atttypid = ret_tupdesc->attrs[0]->atttypid;
 	sql_atttypid = sql_tupdesc->attrs[0]->atttypid;
-	if (ret_atttypid != sql_atttypid)
+	ret_basetype = getBaseTypeAndTypmod(ret_atttypid, &ret_typmod);
+	ret_basecoll = get_typcollation(ret_basetype);
+	defval = (Expr *) makeNullConst(ret_basetype, ret_typmod, ret_basecoll);
+	defval = (Expr *) coerce_to_target_type(NULL, (Node *) defval,
+								   sql_atttypid,
+								   ret_basetype, ret_typmod,
+								   COERCION_EXPLICIT,
+								   COERCE_EXPLICIT_CAST,
+								   -1);
+
+	if (defval == NULL)
 		ereport(ERROR,
 				(errcode(ERRCODE_SYNTAX_ERROR),
 				 errmsg("invalid return type"),
-				 errdetail("SQL key field datatype does " \
-						   "not match return key field datatype.")));
-
-	/* check the parent_key_fld types match */
-	ret_atttypid = ret_tupdesc->attrs[1]->atttypid;
-	sql_atttypid = sql_tupdesc->attrs[1]->atttypid;
-	if (ret_atttypid != sql_atttypid)
-		ereport(ERROR,
-				(errcode(ERRCODE_SYNTAX_ERROR),
-				 errmsg("invalid return type"),
-				 errdetail("SQL parent key field datatype does " \
-						   "not match return parent key field datatype.")));
+				 errdetail("Failed to cast %s to %s",
+						   format_type_be(sql_atttypid),
+						   format_type_be(ret_atttypid))));
 
 	/* OK, the two tupdescs are compatible for our purposes */
-	return true;
 }
 
 /*
-- 
2.2.2

