From 0b40ad205cf8fc8a83b8c65f5f18502a51c47370 Mon Sep 17 00:00:00 2001
From: Melanie Plageman <melanieplageman@gmail.com>
Date: Wed, 24 Jun 2020 17:07:14 -0700
Subject: [PATCH v1] Move extracting columns for hashagg to planner

Find the columns needed the executor will need to know about for hashagg
and bifurcate them into aggregated and unaggregated column bitmapsets.
Save them in the plan tree for use during execution.

This is done after set_plan_refs() so all the references are correct.
---
 src/backend/executor/nodeAgg.c       | 117 +++++++++++++--------------
 src/backend/nodes/copyfuncs.c        |   2 +
 src/backend/nodes/outfuncs.c         |   2 +
 src/backend/nodes/readfuncs.c        |   2 +
 src/backend/optimizer/plan/setrefs.c |  69 ++++++++++++++++
 src/include/nodes/execnodes.h        |   8 +-
 src/include/nodes/plannodes.h        |   2 +
 7 files changed, 140 insertions(+), 62 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index a20554ae65..2f340e4031 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -359,6 +359,14 @@ typedef struct HashAggBatch
 	int64		input_tuples;	/* number of tuples in this batch */
 } HashAggBatch;
 
+/* used to find referenced colnos */
+typedef struct FindColsContext
+{
+	bool	   is_aggref;		/* is under an aggref */
+	Bitmapset *aggregated;		/* column references under an aggref */
+	Bitmapset *unaggregated;	/* other column references */
+} FindColsContext;
+
 static void select_current_set(AggState *aggstate, int setno, bool is_hash);
 static void initialize_phase(AggState *aggstate, int newphase);
 static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
@@ -391,8 +399,6 @@ static void finalize_aggregates(AggState *aggstate,
 								AggStatePerAgg peragg,
 								AggStatePerGroup pergroup);
 static TupleTableSlot *project_aggregates(AggState *aggstate);
-static Bitmapset *find_unaggregated_cols(AggState *aggstate);
-static bool find_unaggregated_cols_walker(Node *node, Bitmapset **colnos);
 static void build_hash_tables(AggState *aggstate);
 static void build_hash_table(AggState *aggstate, int setno, long nbuckets);
 static void hashagg_recompile_expressions(AggState *aggstate, bool minslot,
@@ -425,8 +431,8 @@ static MinimalTuple hashagg_batch_read(HashAggBatch *batch, uint32 *hashp);
 static void hashagg_spill_init(HashAggSpill *spill, HashTapeInfo *tapeinfo,
 							   int used_bits, uint64 input_tuples,
 							   double hashentrysize);
-static Size hashagg_spill_tuple(HashAggSpill *spill, TupleTableSlot *slot,
-								uint32 hash);
+static Size hashagg_spill_tuple(AggState *aggstate, HashAggSpill *spill,
+								TupleTableSlot *slot, uint32 hash);
 static void hashagg_spill_finish(AggState *aggstate, HashAggSpill *spill,
 								 int setno);
 static void hashagg_tapeinfo_init(AggState *aggstate);
@@ -1374,49 +1380,6 @@ project_aggregates(AggState *aggstate)
 	return NULL;
 }
 
-/*
- * find_unaggregated_cols
- *	  Construct a bitmapset of the column numbers of un-aggregated Vars
- *	  appearing in our targetlist and qual (HAVING clause)
- */
-static Bitmapset *
-find_unaggregated_cols(AggState *aggstate)
-{
-	Agg		   *node = (Agg *) aggstate->ss.ps.plan;
-	Bitmapset  *colnos;
-
-	colnos = NULL;
-	(void) find_unaggregated_cols_walker((Node *) node->plan.targetlist,
-										 &colnos);
-	(void) find_unaggregated_cols_walker((Node *) node->plan.qual,
-										 &colnos);
-	return colnos;
-}
-
-static bool
-find_unaggregated_cols_walker(Node *node, Bitmapset **colnos)
-{
-	if (node == NULL)
-		return false;
-	if (IsA(node, Var))
-	{
-		Var		   *var = (Var *) node;
-
-		/* setrefs.c should have set the varno to OUTER_VAR */
-		Assert(var->varno == OUTER_VAR);
-		Assert(var->varlevelsup == 0);
-		*colnos = bms_add_member(*colnos, var->varattno);
-		return false;
-	}
-	if (IsA(node, Aggref) || IsA(node, GroupingFunc))
-	{
-		/* do not descend into aggregate exprs */
-		return false;
-	}
-	return expression_tree_walker(node, find_unaggregated_cols_walker,
-								  (void *) colnos);
-}
-
 /*
  * (Re-)initialize the hash table(s) to empty.
  *
@@ -1531,19 +1494,30 @@ build_hash_table(AggState *aggstate, int setno, long nbuckets)
 static void
 find_hash_columns(AggState *aggstate)
 {
-	Bitmapset  *base_colnos;
+	TupleDesc	scanDesc = aggstate->ss.ss_ScanTupleSlot->tts_tupleDescriptor;
 	List	   *outerTlist = outerPlanState(aggstate)->plan->targetlist;
 	int			numHashes = aggstate->num_hashes;
 	EState	   *estate = aggstate->ss.ps.state;
 	int			j;
 
-	/* Find Vars that will be needed in tlist and qual */
-	base_colnos = find_unaggregated_cols(aggstate);
+	Agg *agg = (Agg *) aggstate->ss.ps.plan;
+	aggstate->colnos_needed = bms_union(agg->unaggregated_colnos, agg->aggregated_colnos);
+	aggstate->max_colno_needed = 0;
+	aggstate->all_cols_needed = true;
+
+	for (int i = 0; i < scanDesc->natts; i++)
+	{
+		int colno = i + 1;
+		if (bms_is_member(colno, aggstate->colnos_needed))
+			aggstate->max_colno_needed = colno;
+		else
+			aggstate->all_cols_needed = false;
+	}
 
 	for (j = 0; j < numHashes; ++j)
 	{
 		AggStatePerHash perhash = &aggstate->perhash[j];
-		Bitmapset  *colnos = bms_copy(base_colnos);
+		Bitmapset  *colnos = bms_copy(agg->unaggregated_colnos);
 		AttrNumber *grpColIdx = perhash->aggnode->grpColIdx;
 		List	   *hashTlist = NIL;
 		TupleDesc	hashDesc;
@@ -1637,7 +1611,6 @@ find_hash_columns(AggState *aggstate)
 		bms_free(colnos);
 	}
 
-	bms_free(base_colnos);
 }
 
 /*
@@ -2097,7 +2070,7 @@ lookup_hash_entries(AggState *aggstate)
 								   perhash->aggnode->numGroups,
 								   aggstate->hashentrysize);
 
-			hashagg_spill_tuple(spill, slot, hash);
+			hashagg_spill_tuple(aggstate, spill, slot, hash);
 		}
 	}
 }
@@ -2619,7 +2592,7 @@ agg_refill_hash_table(AggState *aggstate)
 							 HASHAGG_READ_BUFFER_SIZE);
 	for (;;)
 	{
-		TupleTableSlot *slot = aggstate->hash_spill_slot;
+		TupleTableSlot *slot = aggstate->hash_spill_rslot;
 		MinimalTuple tuple;
 		uint32		hash;
 		bool		in_hash_table;
@@ -2655,7 +2628,7 @@ agg_refill_hash_table(AggState *aggstate)
 								   ngroups_estimate, aggstate->hashentrysize);
 			}
 			/* no memory for a new group, spill */
-			hashagg_spill_tuple(&spill, slot, hash);
+			hashagg_spill_tuple(aggstate, &spill, slot, hash);
 		}
 
 		/*
@@ -2934,9 +2907,11 @@ hashagg_spill_init(HashAggSpill *spill, HashTapeInfo *tapeinfo, int used_bits,
  * partition.
  */
 static Size
-hashagg_spill_tuple(HashAggSpill *spill, TupleTableSlot *slot, uint32 hash)
+hashagg_spill_tuple(AggState *aggstate, HashAggSpill *spill,
+					TupleTableSlot *inputslot, uint32 hash)
 {
 	LogicalTapeSet *tapeset = spill->tapeset;
+	TupleTableSlot *spillslot;
 	int			partition;
 	MinimalTuple tuple;
 	int			tapenum;
@@ -2945,8 +2920,28 @@ hashagg_spill_tuple(HashAggSpill *spill, TupleTableSlot *slot, uint32 hash)
 
 	Assert(spill->partitions != NULL);
 
-	/* XXX: may contain unnecessary attributes, should project */
-	tuple = ExecFetchSlotMinimalTuple(slot, &shouldFree);
+	/* spill only attributes that we actually need */
+	if (!aggstate->all_cols_needed)
+	{
+		spillslot = aggstate->hash_spill_wslot;
+		slot_getsomeattrs(inputslot, aggstate->max_colno_needed);
+		ExecClearTuple(spillslot);
+		for (int i = 0; i < spillslot->tts_tupleDescriptor->natts; i++)
+		{
+			if (bms_is_member(i + 1, aggstate->colnos_needed))
+			{
+				spillslot->tts_values[i] = inputslot->tts_values[i];
+				spillslot->tts_isnull[i] = inputslot->tts_isnull[i];
+			}
+			else
+				spillslot->tts_isnull[i] = true;
+		}
+		ExecStoreVirtualTuple(spillslot);
+	}
+	else
+		spillslot = inputslot;
+
+	tuple = ExecFetchSlotMinimalTuple(spillslot, &shouldFree);
 
 	partition = (hash & spill->mask) >> spill->shift;
 	spill->ntuples[partition]++;
@@ -3563,8 +3558,10 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		aggstate->hash_metacxt = AllocSetContextCreate(aggstate->ss.ps.state->es_query_cxt,
 													   "HashAgg meta context",
 													   ALLOCSET_DEFAULT_SIZES);
-		aggstate->hash_spill_slot = ExecInitExtraTupleSlot(estate, scanDesc,
-														   &TTSOpsMinimalTuple);
+		aggstate->hash_spill_rslot = ExecInitExtraTupleSlot(estate, scanDesc,
+															&TTSOpsMinimalTuple);
+		aggstate->hash_spill_wslot = ExecInitExtraTupleSlot(estate, scanDesc,
+															&TTSOpsVirtual);
 
 		/* this is an array of pointers, not structures */
 		aggstate->hash_pergroup = pergroups;
diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c
index d8cf87e6d0..5f58a2fc27 100644
--- a/src/backend/nodes/copyfuncs.c
+++ b/src/backend/nodes/copyfuncs.c
@@ -1015,6 +1015,8 @@ _copyAgg(const Agg *from)
 	COPY_SCALAR_FIELD(aggstrategy);
 	COPY_SCALAR_FIELD(aggsplit);
 	COPY_SCALAR_FIELD(numCols);
+	COPY_BITMAPSET_FIELD(aggregated_colnos);
+	COPY_BITMAPSET_FIELD(unaggregated_colnos);
 	if (from->numCols > 0)
 	{
 		COPY_POINTER_FIELD(grpColIdx, from->numCols * sizeof(AttrNumber));
diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c
index e2f177515d..1d544a7f02 100644
--- a/src/backend/nodes/outfuncs.c
+++ b/src/backend/nodes/outfuncs.c
@@ -779,6 +779,8 @@ _outAgg(StringInfo str, const Agg *node)
 	WRITE_ENUM_FIELD(aggstrategy, AggStrategy);
 	WRITE_ENUM_FIELD(aggsplit, AggSplit);
 	WRITE_INT_FIELD(numCols);
+	WRITE_BITMAPSET_FIELD(aggregated_colnos);
+	WRITE_BITMAPSET_FIELD(unaggregated_colnos);
 	WRITE_ATTRNUMBER_ARRAY(grpColIdx, node->numCols);
 	WRITE_OID_ARRAY(grpOperators, node->numCols);
 	WRITE_OID_ARRAY(grpCollations, node->numCols);
diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c
index 42050ab719..7874899972 100644
--- a/src/backend/nodes/readfuncs.c
+++ b/src/backend/nodes/readfuncs.c
@@ -2227,6 +2227,8 @@ _readAgg(void)
 	READ_ENUM_FIELD(aggstrategy, AggStrategy);
 	READ_ENUM_FIELD(aggsplit, AggSplit);
 	READ_INT_FIELD(numCols);
+	READ_BITMAPSET_FIELD(aggregated_colnos);
+	READ_BITMAPSET_FIELD(unaggregated_colnos);
 	READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols);
 	READ_OID_ARRAY(grpOperators, local_node->numCols);
 	READ_OID_ARRAY(grpCollations, local_node->numCols);
diff --git a/src/backend/optimizer/plan/setrefs.c b/src/backend/optimizer/plan/setrefs.c
index baefe0e946..340ec3ce66 100644
--- a/src/backend/optimizer/plan/setrefs.c
+++ b/src/backend/optimizer/plan/setrefs.c
@@ -68,6 +68,14 @@ typedef struct
 	int			rtoffset;
 } fix_upper_expr_context;
 
+/* used to find referenced colnos */
+typedef struct BifurcateColsContext
+{
+	bool	   is_aggref;		/* is under an aggref */
+	Bitmapset *aggregated;		/* column references under an aggref */
+	Bitmapset *unaggregated;	/* other column references */
+} BifurcateColsContext;
+
 /*
  * Check if a Const node is a regclass value.  We accept plain OID too,
  * since a regclass Const will get folded to that type if it's an argument
@@ -113,6 +121,8 @@ static Node *fix_scan_expr(PlannerInfo *root, Node *node, int rtoffset);
 static Node *fix_scan_expr_mutator(Node *node, fix_scan_expr_context *context);
 static bool fix_scan_expr_walker(Node *node, fix_scan_expr_context *context);
 static void set_join_references(PlannerInfo *root, Join *join, int rtoffset);
+static void bifurcate_agg_cols(List *quals, List *tlist, Bitmapset **aggregated, Bitmapset **unaggregated);
+static bool bifurcate_agg_cols_walker(Node *node, BifurcateColsContext *context);
 static void set_upper_references(PlannerInfo *root, Plan *plan, int rtoffset);
 static void set_param_references(PlannerInfo *root, Plan *plan);
 static Node *convert_combining_aggrefs(Node *node, void *context);
@@ -1879,6 +1889,58 @@ set_join_references(PlannerInfo *root, Join *join, int rtoffset)
 	pfree(inner_itlist);
 }
 
+
+/*
+ * Walk tlist and qual to find referenced colnos, dividing them into
+ * aggregated and unaggregated sets.
+ */
+static void
+bifurcate_agg_cols(List *quals, List *tlist, Bitmapset **aggregated, Bitmapset **unaggregated)
+{
+	BifurcateColsContext context;
+
+	context.is_aggref = false;
+	context.aggregated = NULL;
+	context.unaggregated = NULL;
+
+	(void) bifurcate_agg_cols_walker((Node *) tlist, &context);
+	(void) bifurcate_agg_cols_walker((Node *) quals, &context);
+
+	*aggregated = context.aggregated;
+	*unaggregated = context.unaggregated;
+}
+
+
+static bool
+bifurcate_agg_cols_walker(Node *node, BifurcateColsContext *context)
+{
+	if (node == NULL)
+		return false;
+	if (IsA(node, Var))
+	{
+		Var		   *var = (Var *) node;
+
+		if (var->varno != OUTER_VAR && var->varlevelsup != 0)
+			return false;
+
+		if (context->is_aggref)
+			context->aggregated = bms_add_member(context->aggregated, var->varattno);
+		else
+			context->unaggregated = bms_add_member(context->unaggregated, var->varattno);
+		return false;
+	}
+	if (IsA(node, Aggref))
+	{
+		Assert(!context->is_aggref);
+		context->is_aggref = true;
+		expression_tree_walker(node, bifurcate_agg_cols_walker, (void *) context);
+		context->is_aggref = false;
+		return false;
+	}
+	return expression_tree_walker(node, bifurcate_agg_cols_walker, (void *) context);
+}
+
+
 /*
  * set_upper_references
  *	  Update the targetlist and quals of an upper-level plan node
@@ -1947,6 +2009,13 @@ set_upper_references(PlannerInfo *root, Plan *plan, int rtoffset)
 					   OUTER_VAR,
 					   rtoffset);
 
+	if (IsA(plan, Agg))
+	{
+		Agg *agg = (Agg *)plan;
+		if (agg->aggstrategy == AGG_HASHED || agg->aggstrategy == AGG_MIXED)
+			bifurcate_agg_cols(plan->qual, plan->targetlist, &agg->aggregated_colnos, &agg->unaggregated_colnos);
+	}
+
 	pfree(subplan_itlist);
 }
 
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index f5dfa32d55..e2aa07cb45 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -2169,6 +2169,9 @@ typedef struct AggState
 	int			current_set;	/* The current grouping set being evaluated */
 	Bitmapset  *grouped_cols;	/* grouped cols in current projection */
 	List	   *all_grouped_cols;	/* list of all grouped cols in DESC order */
+	Bitmapset  *colnos_needed;	/* all columns needed from the outer plan */
+	int			max_colno_needed;	/* highest colno needed from outer plan */
+	bool		all_cols_needed;	/* are all cols from outer plan needed? */
 	/* These fields are for grouping set phase data */
 	int			maxsets;		/* The max number of sets in any phase */
 	AggStatePerPhase phases;	/* array of all phases */
@@ -2186,7 +2189,8 @@ typedef struct AggState
 	struct HashTapeInfo *hash_tapeinfo; /* metadata for spill tapes */
 	struct HashAggSpill *hash_spills;	/* HashAggSpill for each grouping set,
 										 * exists only during first pass */
-	TupleTableSlot *hash_spill_slot;	/* slot for reading from spill files */
+	TupleTableSlot *hash_spill_rslot;	/* for reading spill files */
+	TupleTableSlot *hash_spill_wslot;	/* for writing spill files */
 	List	   *hash_batches;	/* hash batches remaining to be processed */
 	bool		hash_ever_spilled;	/* ever spilled during this execution? */
 	bool		hash_spill_mode;	/* we hit a limit during the current batch
@@ -2207,7 +2211,7 @@ typedef struct AggState
 										 * per-group pointers */
 
 	/* support for evaluation of agg input expressions: */
-#define FIELDNO_AGGSTATE_ALL_PERGROUPS 49
+#define FIELDNO_AGGSTATE_ALL_PERGROUPS 53
 	AggStatePerGroup *all_pergroups;	/* array of first ->pergroups, than
 										 * ->hash_pergroup */
 	ProjectionInfo *combinedproj;	/* projection machinery */
diff --git a/src/include/nodes/plannodes.h b/src/include/nodes/plannodes.h
index 83e01074ed..b0c50d2751 100644
--- a/src/include/nodes/plannodes.h
+++ b/src/include/nodes/plannodes.h
@@ -819,6 +819,8 @@ typedef struct Agg
 	AggStrategy aggstrategy;	/* basic strategy, see nodes.h */
 	AggSplit	aggsplit;		/* agg-splitting mode, see nodes.h */
 	int			numCols;		/* number of grouping columns */
+	Bitmapset *aggregated_colnos;
+	Bitmapset *unaggregated_colnos;
 	AttrNumber *grpColIdx;		/* their indexes in the target list */
 	Oid		   *grpOperators;	/* equality operators to compare with */
 	Oid		   *grpCollations;
-- 
2.20.1

