diff --git a/src/pl/plpgsql/src/pl_gram.y b/src/pl/plpgsql/src/pl_gram.y index 6a09bfdd67..3ff12b7af9 100644 --- a/src/pl/plpgsql/src/pl_gram.y +++ b/src/pl/plpgsql/src/pl_gram.y @@ -23,6 +23,7 @@ #include "parser/scanner.h" #include "parser/scansup.h" #include "utils/builtins.h" +#include "utils/syscache.h" #include "plpgsql.h" @@ -76,6 +77,7 @@ static PLpgSQL_expr *read_sql_expression2(int until, int until2, int *endtoken); static PLpgSQL_expr *read_sql_stmt(void); static PLpgSQL_type *read_datatype(int tok); +static PLpgSQL_type *read_datatype_array(PLpgSQL_type *elem_typ); static PLpgSQL_stmt *make_execsql_stmt(int firsttoken, int location); static PLpgSQL_stmt_fetch *read_fetch_direction(void); static void complete_direction(PLpgSQL_stmt_fetch *fetch, @@ -2783,6 +2785,55 @@ read_sql_construct(int until, return expr; } +static PLpgSQL_type * +read_datatype_array(PLpgSQL_type *elem_typ) +{ + int tok; + HeapTuple type_tup = NULL; + Form_pg_type type_frm; + Oid arrtyp_oid; + + Assert(elem_typ); + if (!OidIsValid(elem_typ->typoid)) + return elem_typ; + + tok = yylex(); + /* Next token is not square bracket. */ + if (tok != '[') + { + plpgsql_push_back_token(tok); + + return elem_typ; + } + + tok = yylex(); + /* For now, deal only with []. */ + if (tok != ']') + { + plpgsql_push_back_token('['); + plpgsql_push_back_token(tok); + + return elem_typ; + } + + type_tup = SearchSysCache1(TYPEOID, + ObjectIdGetDatum(elem_typ->typoid)); + if (!HeapTupleIsValid(type_tup)) + return elem_typ; + + type_frm = (Form_pg_type) GETSTRUCT(type_tup); + arrtyp_oid = type_frm->typarray; + ReleaseSysCache(type_tup); + + if (OidIsValid(arrtyp_oid)) + return plpgsql_build_datatype(arrtyp_oid, + elem_typ->atttypmod, + elem_typ->collation, + NULL); + else + return elem_typ; +} + static PLpgSQL_type * read_datatype(int tok) { @@ -2818,7 +2869,9 @@ read_datatype(int tok) { result = plpgsql_parse_wordtype(dtname); if (result) - return result; + { + return read_datatype_array(result); + } } else if (tok_is_keyword(tok, &yylval, K_ROWTYPE, "rowtype")) @@ -2842,7 +2895,9 @@ read_datatype(int tok) { result = plpgsql_parse_wordtype(dtname); if (result) - return result; + { + return read_datatype_array(result); + } } else if (tok_is_keyword(tok, &yylval, K_ROWTYPE, "rowtype")) @@ -2866,7 +2921,9 @@ read_datatype(int tok) { result = plpgsql_parse_cwordtype(dtnames); if (result) - return result; + { + return read_datatype_array(result); + } } else if (tok_is_keyword(tok, &yylval, K_ROWTYPE, "rowtype")) diff --git a/src/test/regress/expected/plpgsql.out b/src/test/regress/expected/plpgsql.out index 272f5d2111..8db28c1122 100644 --- a/src/test/regress/expected/plpgsql.out +++ b/src/test/regress/expected/plpgsql.out @@ -5814,6 +5814,31 @@ SELECT * FROM list_partitioned_table() AS t; 2 (2 rows) +CREATE OR REPLACE FUNCTION array_partitioned_table() +RETURNS SETOF partitioned_table.a%TYPE AS $$ +DECLARE + i int; + row partitioned_table%ROWTYPE; + a_val partitioned_table.a%TYPE[]; + b_val partitioned_table.a%TYPE; + c_val b_val%TYPE[]; +BEGIN + i := 1; + FOR row IN SELECT * FROM partitioned_table ORDER BY a LOOP + a_val[i] := row.a; + c_val[i] := a_val[i]; + i = i + 1; + END LOOP; + RETURN QUERY SELECT unnest(c_val); +END; $$ LANGUAGE plpgsql; +NOTICE: type reference partitioned_table.a%TYPE converted to integer +SELECT * FROM array_partitioned_table() AS t; + t +--- + 1 + 2 +(2 rows) + -- -- Check argument name is used instead of $n in error message -- diff --git a/src/test/regress/sql/plpgsql.sql b/src/test/regress/sql/plpgsql.sql index 924d524094..7b4df77d85 100644 --- a/src/test/regress/sql/plpgsql.sql +++ b/src/test/regress/sql/plpgsql.sql @@ -4748,6 +4748,26 @@ END; $$ LANGUAGE plpgsql; SELECT * FROM list_partitioned_table() AS t; +CREATE OR REPLACE FUNCTION array_partitioned_table() +RETURNS SETOF partitioned_table.a%TYPE AS $$ +DECLARE + i int; + row partitioned_table%ROWTYPE; + a_val partitioned_table.a%TYPE[]; + b_val partitioned_table.a%TYPE; + c_val b_val%TYPE[]; +BEGIN + i := 1; + FOR row IN SELECT * FROM partitioned_table ORDER BY a LOOP + a_val[i] := row.a; + c_val[i] := a_val[i]; + i = i + 1; + END LOOP; + RETURN QUERY SELECT unnest(c_val); +END; $$ LANGUAGE plpgsql; + +SELECT * FROM array_partitioned_table() AS t; + -- -- Check argument name is used instead of $n in error message --