*** a/src/pl/plpgsql/src/pl_gram.y --- b/src/pl/plpgsql/src/pl_gram.y *************** *** 22,27 **** --- 22,28 ---- #include "parser/scanner.h" #include "parser/scansup.h" #include "utils/builtins.h" + #include "nodes/nodefuncs.h" /* Location tracking support --- simpler than bison's default */ *************** *** 97,103 **** static PLpgSQL_row *make_scalar_list1(char *initial_name, PLpgSQL_datum *initial_datum, int lineno, int location); static void check_sql_expr(const char *stmt, int location, ! int leaderlen); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location); static void check_labels(const char *start_label, --- 98,104 ---- PLpgSQL_datum *initial_datum, int lineno, int location); static void check_sql_expr(const char *stmt, int location, ! int leaderlen, PLpgSQL_row *check_row); static void plpgsql_sql_error_callback(void *arg); static PLpgSQL_type *parse_datatype(const char *string, int location); static void check_labels(const char *start_label, *************** *** 106,111 **** static void check_labels(const char *start_label, --- 107,115 ---- static PLpgSQL_expr *read_cursor_args(PLpgSQL_var *cursor, int until, const char *expected); static List *read_raise_options(void); + static bool find_a_star_walker(Node *node, void *context); + static int tlist_result_column_count(Node *stmt); + %} *************** *** 1408,1414 **** for_control : for_variable K_IN PLpgSQL_stmt_fori *new; /* Check first expression is well-formed */ ! check_sql_expr(expr1->query, expr1loc, 7); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, --- 1412,1418 ---- PLpgSQL_stmt_fori *new; /* Check first expression is well-formed */ ! check_sql_expr(expr1->query, expr1loc, 7, NULL); /* Read and check the second one */ expr2 = read_sql_expression2(K_LOOP, K_BY, *************** *** 1470,1476 **** for_control : for_variable K_IN pfree(expr1->query); expr1->query = tmp_query; ! check_sql_expr(expr1->query, expr1loc, 0); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; --- 1474,1480 ---- pfree(expr1->query); expr1->query = tmp_query; ! check_sql_expr(expr1->query, expr1loc, 0, NULL); new = palloc0(sizeof(PLpgSQL_stmt_fors)); new->cmd_type = PLPGSQL_STMT_FORS; *************** *** 2562,2568 **** read_sql_construct(int until, pfree(ds.data); if (valid_sql) ! check_sql_expr(expr->query, startlocation, strlen(sqlstart)); return expr; } --- 2566,2572 ---- pfree(ds.data); if (valid_sql) ! check_sql_expr(expr->query, startlocation, strlen(sqlstart), NULL); return expr; } *************** *** 2785,2791 **** make_execsql_stmt(int firsttoken, int location) expr->ns = plpgsql_ns_top(); pfree(ds.data); ! check_sql_expr(expr->query, location, 0); execsql = palloc(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; --- 2789,2795 ---- expr->ns = plpgsql_ns_top(); pfree(ds.data); ! check_sql_expr(expr->query, location, 0, have_into ? row : NULL); execsql = palloc(sizeof(PLpgSQL_stmt_execsql)); execsql->cmd_type = PLPGSQL_STMT_EXECSQL; *************** *** 3379,3391 **** make_scalar_list1(char *initial_name, * (typically "SELECT ") prefixed to the source text. We use this assumption * to transpose any error cursor position back to the function source text. * If no error cursor is provided, we'll just point at "location". */ static void ! check_sql_expr(const char *stmt, int location, int leaderlen) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; if (!plpgsql_check_syntax) return; --- 3383,3400 ---- * (typically "SELECT ") prefixed to the source text. We use this assumption * to transpose any error cursor position back to the function source text. * If no error cursor is provided, we'll just point at "location". + * + * If check_row is specified, the statement's result column count is compared + * against it when the exact number of columns is known. If the columns counts + * don't match, a syntax error is raised. */ static void ! check_sql_expr(const char *stmt, int location, int leaderlen, PLpgSQL_row *check_row) { sql_error_callback_arg cbarg; ErrorContextCallback syntax_errcontext; MemoryContext oldCxt; + List *raw_parsetree_list; if (!plpgsql_check_syntax) return; *************** *** 3399,3411 **** check_sql_expr(const char *stmt, int location, int leaderlen) error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(compile_tmp_cxt); ! (void) raw_parser(stmt); MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; } static void plpgsql_sql_error_callback(void *arg) { --- 3408,3497 ---- error_context_stack = &syntax_errcontext; oldCxt = MemoryContextSwitchTo(compile_tmp_cxt); ! raw_parsetree_list = raw_parser(stmt); ! if (check_row != NULL) ! { ! Node *raw_parse_tree; ! int ncols; ! int fnum; ! int expected_ncols = 0; ! ! for (fnum = 0; fnum < check_row->nfields; fnum++) ! { ! if (check_row->varnos[fnum] < 0) ! continue; ! expected_ncols++; ! } ! ! raw_parse_tree = linitial(raw_parsetree_list); ! ncols = tlist_result_column_count(raw_parse_tree); ! if (ncols >= 0 && ncols > expected_ncols) ! ereport(ERROR, ! (errcode(ERRCODE_SYNTAX_ERROR), ! errmsg("query has more expressions than expected by the INTO target"))); ! if (ncols >= 0 && ncols < expected_ncols) ! ereport(ERROR, ! (errcode(ERRCODE_SYNTAX_ERROR), ! errmsg("INTO target expects more expressions than the query has"))); ! } MemoryContextSwitchTo(oldCxt); /* Restore former ereport callback */ error_context_stack = syntax_errcontext.previous; } + /* + * Expression tree walker for tlist_result_column_count. Returns true if the + * expression tree contains an A_Star node, false otherwise. + */ + static bool + find_a_star_walker(Node *node, void *context) + { + if (node == NULL) + return false; + if (IsA(node, A_Star)) + return true; + if (IsA(node, ColumnRef)) + { + ColumnRef *ref = (ColumnRef *) node; + /* A_Star can only be the last element */ + if (IsA(llast(ref->fields), A_Star)) + return true; + } + return raw_expression_tree_walker((Node *) node, + find_a_star_walker, + context); + } + + /* + * Find the number of columns in a raw statement's targetList (if SELECT) or + * returningList (if INSERT, UPDATE or DELETE). Returns -1 if the number of + * columns could not be determined because of an A_Star. + */ + static int + tlist_result_column_count(Node *stmt) + { + List *tlist; + + if (IsA(stmt, SelectStmt)) + tlist = ((SelectStmt *) stmt)->targetList; + else if (IsA(stmt, InsertStmt)) + tlist = ((InsertStmt *) stmt)->returningList; + else if (IsA(stmt, UpdateStmt)) + tlist = ((UpdateStmt *) stmt)->returningList; + else if (IsA(stmt, DeleteStmt)) + tlist = ((DeleteStmt *) stmt)->returningList; + else + elog(ERROR, "unknown nodeTag %d", nodeTag(stmt)); + + if (tlist == NIL) + return 0; + + if (raw_expression_tree_walker((Node *) tlist, find_a_star_walker, NULL)) + return -1; + return list_length(tlist); + } + static void plpgsql_sql_error_callback(void *arg) { *** a/src/test/regress/expected/plpgsql.out --- b/src/test/regress/expected/plpgsql.out *************** *** 5218,5220 **** NOTICE: outer_func() done --- 5218,5280 ---- drop function outer_outer_func(int); drop function outer_func(int); drop function inner_func(int); + create temporary table somecols(a int, b int); + -- INTO column counts + create or replace function into_counts() + returns void as $$ + declare + myresult int; + myrec record; + begin + -- not an error + select 1 into myresult; + -- we don't know the column count without parse analysis, not an error + select * into myresult from somecols; + -- records are OK + select 1, 2 into myrec; + -- error + select a, b into myresult from somecols; + end; + $$ language plpgsql; + ERROR: query has more expressions than expected by the INTO target + LINE 14: select a, b into myresult from somecols; + ^ + create or replace function into_counts() + returns void as $$ + declare + myresult int; + begin + -- error + select into myresult from somecols; + end; + $$ language plpgsql; + ERROR: INTO target expects more expressions than the query has + LINE 7: select into myresult from somecols; + ^ + create or replace function into_counts() + returns void as $$ + declare + myresult somecols; + begin + -- not error + select 1, 2 into myresult; + -- error + select 1 into myresult; + end; + $$ language plpgsql; + ERROR: INTO target expects more expressions than the query has + LINE 9: select 1 into myresult; + ^ + create or replace function into_counts() + returns void as $$ + declare + myresult somecols; + begin + -- error + select 1, 2, 3 into myresult; + end; + $$ language plpgsql; + ERROR: query has more expressions than expected by the INTO target + LINE 7: select 1, 2, 3 into myresult; + ^ + drop table somecols; *** a/src/test/regress/sql/plpgsql.sql --- b/src/test/regress/sql/plpgsql.sql *************** *** 4104,4106 **** select outer_outer_func(20); --- 4104,4160 ---- drop function outer_outer_func(int); drop function outer_func(int); drop function inner_func(int); + + create temporary table somecols(a int, b int); + + -- INTO column counts + create or replace function into_counts() + returns void as $$ + declare + myresult int; + myrec record; + begin + -- not an error + select 1 into myresult; + -- we don't know the column count without parse analysis, not an error + select * into myresult from somecols; + -- records are OK + select 1, 2 into myrec; + -- error + select a, b into myresult from somecols; + end; + $$ language plpgsql; + + create or replace function into_counts() + returns void as $$ + declare + myresult int; + begin + -- error + select into myresult from somecols; + end; + $$ language plpgsql; + + create or replace function into_counts() + returns void as $$ + declare + myresult somecols; + begin + -- not error + select 1, 2 into myresult; + -- error + select 1 into myresult; + end; + $$ language plpgsql; + + create or replace function into_counts() + returns void as $$ + declare + myresult somecols; + begin + -- error + select 1, 2, 3 into myresult; + end; + $$ language plpgsql; + + drop table somecols;