diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c index 015dfdc..bcfaf06 100644 --- a/src/backend/nodes/nodeFuncs.c +++ b/src/backend/nodes/nodeFuncs.c @@ -2354,6 +2354,50 @@ bool return true; } break; + case T_InsertStmt: + { + InsertStmt *stmt = (InsertStmt *) node; + + if (walker(stmt->relation, context)) + return true; + if (walker(stmt->cols, context)) + return true; + if (walker(stmt->selectStmt, context)) + return true; + if (walker(stmt->returningList, context)) + return true; + } + break; + case T_UpdateStmt: + { + UpdateStmt *stmt = (UpdateStmt *) node; + + if (walker(stmt->relation, context)) + return true; + if (walker(stmt->targetList, context)) + return true; + if (walker(stmt->whereClause, context)) + return true; + if (walker(stmt->fromClause, context)) + return true; + if (walker(stmt->returningList, context)) + return true; + } + break; + case T_DeleteStmt: + { + DeleteStmt *stmt = (DeleteStmt *) node; + + if (walker(stmt->relation, context)) + return true; + if (walker(stmt->usingClause, context)) + return true; + if (walker(stmt->whereClause, context)) + return true; + if (walker(stmt->returningList, context)) + return true; + } + break; case T_A_Expr: { A_Expr *expr = (A_Expr *) node; diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y index 9a45355..9e66536 100644 --- a/src/backend/parser/gram.y +++ b/src/backend/parser/gram.y @@ -7028,7 +7028,8 @@ cte_list: | cte_list ',' common_table_expr { $$ = lappend($1, $3); } ; -common_table_expr: name opt_name_list AS select_with_parens +common_table_expr: + name opt_name_list AS select_with_parens { CommonTableExpr *n = makeNode(CommonTableExpr); n->ctename = $1; @@ -7037,6 +7038,33 @@ common_table_expr: name opt_name_list AS select_with_parens n->location = @1; $$ = (Node *) n; } + | name opt_name_list AS '(' InsertStmt ')' + { + CommonTableExpr *n = makeNode(CommonTableExpr); + n->ctename = $1; + n->aliascolnames = $2; + n->ctequery = $5; + n->location = @1; + $$ = (Node *) n; + } + | name opt_name_list AS '(' UpdateStmt ')' + { + CommonTableExpr *n = makeNode(CommonTableExpr); + n->ctename = $1; + n->aliascolnames = $2; + n->ctequery = $5; + n->location = @1; + $$ = (Node *) n; + } + | name opt_name_list AS '(' DeleteStmt ')' + { + CommonTableExpr *n = makeNode(CommonTableExpr); + n->ctename = $1; + n->aliascolnames = $2; + n->ctequery = $5; + n->location = @1; + $$ = (Node *) n; + } ; into_clause: diff --git a/src/backend/parser/parse_cte.c b/src/backend/parser/parse_cte.c index 988e8eb..2347b28 100644 --- a/src/backend/parser/parse_cte.c +++ b/src/backend/parser/parse_cte.c @@ -246,23 +246,40 @@ transformWithClause(ParseState *pstate, WithClause *withClause) static void analyzeCTE(ParseState *pstate, CommonTableExpr *cte) { - Query *query; + Query *query; + List *ctelist; /* Analysis not done already */ - Assert(IsA(cte->ctequery, SelectStmt)); + /* This needs to be one of SelectStmt, InsertStmt, UpdateStmt, DeleteStmt instead of: + * Assert(IsA(cte->ctequery, SelectStmt)); */ query = parse_sub_analyze(cte->ctequery, pstate); cte->ctequery = (Node *) query; + if (query->commandType == CMD_SELECT) + ctelist = query->targetList; + else + { + ctelist = query->returningList; + } + /* * Check that we got something reasonable. Many of these conditions are * impossible given restrictions of the grammar, but check 'em anyway. - * (These are the same checks as in transformRangeSubselect.) + * (In addition to the same checks as in transformRangeSubselect, + * this adds checks for (INSERT|UPDATE|DELETE)...RETURNING.) */ if (!IsA(query, Query) || query->commandType != CMD_SELECT || - query->utilityStmt != NULL) - elog(ERROR, "unexpected non-SELECT command in subquery in WITH"); + query->utilityStmt != NULL || + ((query->commandType == CMD_INSERT || + query->commandType == CMD_UPDATE || + query->commandType == CMD_DELETE) && + query->returningList == NULL)) + ereport(ERROR, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("unexpected non-row-returning command in subquery in WITH"), + parser_errposition(pstate, 0))); if (query->intoClause) ereport(ERROR, (errcode(ERRCODE_SYNTAX_ERROR), @@ -273,7 +290,7 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte) if (!cte->cterecursive) { /* Compute the output column names/types if not done yet */ - analyzeCTETargetList(pstate, cte, query->targetList); + analyzeCTETargetList(pstate, cte, ctelist); } else { @@ -291,7 +308,7 @@ analyzeCTE(ParseState *pstate, CommonTableExpr *cte) lctyp = list_head(cte->ctecoltypes); lctypmod = list_head(cte->ctecoltypmods); varattno = 0; - foreach(lctlist, query->targetList) + foreach(lctlist, ctelist) { TargetEntry *te = (TargetEntry *) lfirst(lctlist); Node *texpr; diff --git a/src/backend/parser/parse_target.c b/src/backend/parser/parse_target.c index 08b8edb..9af7d91 100644 --- a/src/backend/parser/parse_target.c +++ b/src/backend/parser/parse_target.c @@ -310,10 +310,12 @@ markTargetListOrigin(ParseState *pstate, TargetEntry *tle, { CommonTableExpr *cte = GetCTEForRTE(pstate, rte, netlevelsup); TargetEntry *ste; + Query *query; /* should be analyzed by now */ Assert(IsA(cte->ctequery, Query)); - ste = get_tle_by_resno(((Query *) cte->ctequery)->targetList, + query = (Query *) cte->ctequery; + ste = get_tle_by_resno((query->commandType == CMD_SELECT) ? query->targetList : query->returningList, attnum); if (ste == NULL || ste->resjunk) elog(ERROR, "subquery %s does not have attribute %d", @@ -1233,11 +1235,19 @@ expandRecordVariable(ParseState *pstate, Var *var, int levelsup) { CommonTableExpr *cte = GetCTEForRTE(pstate, rte, netlevelsup); TargetEntry *ste; + Query *query; + List *ctelist; /* should be analyzed by now */ Assert(IsA(cte->ctequery, Query)); - ste = get_tle_by_resno(((Query *) cte->ctequery)->targetList, - attnum); + query = (Query *) cte->ctequery; + if (query->commandType == CMD_SELECT) + ctelist = query->targetList; + else + { + ctelist = query->returningList; + } + ste = get_tle_by_resno(ctelist, attnum); if (ste == NULL || ste->resjunk) elog(ERROR, "subquery %s does not have attribute %d", rte->eref->aliasname, attnum); diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c index d302fb8..68c98d4 100644 --- a/src/backend/utils/adt/ruleutils.c +++ b/src/backend/utils/adt/ruleutils.c @@ -3800,9 +3800,17 @@ get_name_for_var_field(Var *var, int fieldno, } if (lc != NULL) { - Query *ctequery = (Query *) cte->ctequery; - TargetEntry *ste = get_tle_by_resno(ctequery->targetList, - attnum); + Query *ctequery = (Query *) cte->ctequery; + List *ctelist; + + if (ctequery->commandType == CMD_SELECT) + ctelist = ctequery->targetList; + else + { + ctelist = ctequery->returningList; + } + + TargetEntry *ste = get_tle_by_resno(ctelist, attnum); if (ste == NULL || ste->resjunk) elog(ERROR, "subquery %s does not have attribute %d", diff --git a/src/test/regress/expected/with.out b/src/test/regress/expected/with.out index 4a2f18c..cb603ca 100644 --- a/src/test/regress/expected/with.out +++ b/src/test/regress/expected/with.out @@ -912,3 +912,23 @@ ERROR: recursive query "foo" column 1 has type numeric(3,0) in non-recursive te LINE 2: (SELECT i::numeric(3,0) FROM (VALUES(1),(2)) t(i) ^ HINT: Cast the output of the non-recursive term to the correct type. + +-- DELETE inside the CTE +CREATE TEMPORARY TABLE t(i INTEGER); +INSERT INTO t(i) SELECT * FROM generate_series(1,10); + +WITH RECURSIVE foo(i) AS ( + DELETE FROM t RETURNING i +) +SELECT i FROM foo ORDER BY i; + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 +(10 rows) diff --git a/src/test/regress/sql/with.sql b/src/test/regress/sql/with.sql index c736441..eb83aab 100644 --- a/src/test/regress/sql/with.sql +++ b/src/test/regress/sql/with.sql @@ -469,3 +469,12 @@ WITH RECURSIVE foo(i) AS UNION ALL SELECT (i+1)::numeric(10,0) FROM foo WHERE i < 10) SELECT * FROM foo; + +-- DELETE inside the CTE +CREATE TEMPORARY TABLE t(i INTEGER); +INSERT INTO t(i) SELECT * FROM generate_series(1,10); + +WITH RECURSIVE foo(i) AS ( + DELETE FROM t RETURNING i +) +SELECT i FROM foo ORDER BY i;