From 55142837eaf439652fd91e165f3d4c3607ea9987 Mon Sep 17 00:00:00 2001
From: Tomas Vondra <tomas@vondra.me>
Date: Tue, 31 Dec 2024 17:34:59 +0100
Subject: [PATCH v20241231-single-spill] Hash join with a single spill file

We can't use arbitrary number of batches in the hash join, because that
can use substantial memory use, ignored by the memory limit. Instead,
decide how many batches we can keep in memory, and open files only for
those batches, and one "overflow" file for all future batches.

Then use this spill file after "switching" to the first batch that is
not in memory.

NB: This allows enforcing the memory limit very strictly, but it's very
inefficient and causes a lot of temporary file traffic.
---
 src/backend/commands/explain.c        |   6 +-
 src/backend/executor/nodeHash.c       | 206 +++++++++++++++++++++++---
 src/backend/executor/nodeHashjoin.c   |  76 ++++++----
 src/backend/optimizer/path/costsize.c |   2 +
 src/include/executor/hashjoin.h       |   4 +
 src/include/executor/nodeHash.h       |   7 +
 src/include/nodes/execnodes.h         |   1 +
 7 files changed, 250 insertions(+), 52 deletions(-)

diff --git a/src/backend/commands/explain.c b/src/backend/commands/explain.c
index a201ed30824..aa3fbeda3dd 100644
--- a/src/backend/commands/explain.c
+++ b/src/backend/commands/explain.c
@@ -3466,19 +3466,23 @@ show_hash_info(HashState *hashstate, ExplainState *es)
 								   hinstrument.nbatch, es);
 			ExplainPropertyInteger("Original Hash Batches", NULL,
 								   hinstrument.nbatch_original, es);
+			ExplainPropertyInteger("In-Memory Hash Batches", NULL,
+								   hinstrument.nbatch_original, es);
 			ExplainPropertyUInteger("Peak Memory Usage", "kB",
 									spacePeakKb, es);
 		}
 		else if (hinstrument.nbatch_original != hinstrument.nbatch ||
+				 hinstrument.nbatch_inmemory != hinstrument.nbatch ||
 				 hinstrument.nbuckets_original != hinstrument.nbuckets)
 		{
 			ExplainIndentText(es);
 			appendStringInfo(es->str,
-							 "Buckets: %d (originally %d)  Batches: %d (originally %d)  Memory Usage: " UINT64_FORMAT "kB\n",
+							 "Buckets: %d (originally %d)  Batches: %d (originally %d, in-memory %d)  Memory Usage: " UINT64_FORMAT "kB\n",
 							 hinstrument.nbuckets,
 							 hinstrument.nbuckets_original,
 							 hinstrument.nbatch,
 							 hinstrument.nbatch_original,
+							 hinstrument.nbatch_inmemory,
 							 spacePeakKb);
 		}
 		else
diff --git a/src/backend/executor/nodeHash.c b/src/backend/executor/nodeHash.c
index 3e22d50e3a4..58062a8b256 100644
--- a/src/backend/executor/nodeHash.c
+++ b/src/backend/executor/nodeHash.c
@@ -80,6 +80,7 @@ static bool ExecParallelHashTuplePrealloc(HashJoinTable hashtable,
 static void ExecParallelHashMergeCounters(HashJoinTable hashtable);
 static void ExecParallelHashCloseBatchAccessors(HashJoinTable hashtable);
 
+static void ExecHashUpdateSpacePeak(HashJoinTable hashtable);
 
 /* ----------------------------------------------------------------
  *		ExecHash
@@ -199,10 +200,8 @@ MultiExecPrivateHash(HashState *node)
 	if (hashtable->nbuckets != hashtable->nbuckets_optimal)
 		ExecHashIncreaseNumBuckets(hashtable);
 
-	/* Account for the buckets in spaceUsed (reported in EXPLAIN ANALYZE) */
-	hashtable->spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+	/* refresh info about peak used memory */
+	ExecHashUpdateSpacePeak(hashtable);
 
 	hashtable->partialTuples = hashtable->totalTuples;
 }
@@ -451,6 +450,7 @@ ExecHashTableCreate(HashState *state)
 	size_t		space_allowed;
 	int			nbuckets;
 	int			nbatch;
+	int			nbatch_inmemory;
 	double		rows;
 	int			num_skew_mcvs;
 	int			log2_nbuckets;
@@ -477,7 +477,8 @@ ExecHashTableCreate(HashState *state)
 							state->parallel_state != NULL ?
 							state->parallel_state->nparticipants - 1 : 0,
 							&space_allowed,
-							&nbuckets, &nbatch, &num_skew_mcvs);
+							&nbuckets, &nbatch, &nbatch_inmemory,
+							&num_skew_mcvs);
 
 	/* nbuckets must be a power of 2 */
 	log2_nbuckets = my_log2(nbuckets);
@@ -503,6 +504,7 @@ ExecHashTableCreate(HashState *state)
 	hashtable->nSkewBuckets = 0;
 	hashtable->skewBucketNums = NULL;
 	hashtable->nbatch = nbatch;
+	hashtable->nbatch_inmemory = nbatch_inmemory;
 	hashtable->curbatch = 0;
 	hashtable->nbatch_original = nbatch;
 	hashtable->nbatch_outstart = nbatch;
@@ -512,6 +514,8 @@ ExecHashTableCreate(HashState *state)
 	hashtable->skewTuples = 0;
 	hashtable->innerBatchFile = NULL;
 	hashtable->outerBatchFile = NULL;
+	hashtable->innerOverflowFile = NULL;
+	hashtable->outerOverflowFile = NULL;
 	hashtable->spaceUsed = 0;
 	hashtable->spacePeak = 0;
 	hashtable->spaceAllowed = space_allowed;
@@ -553,14 +557,16 @@ ExecHashTableCreate(HashState *state)
 	{
 		MemoryContext oldctx;
 
+		int	cnt = Min(nbatch, nbatch_inmemory);
+
 		/*
 		 * allocate and initialize the file arrays in hashCxt (not needed for
 		 * parallel case which uses shared tuplestores instead of raw files)
 		 */
 		oldctx = MemoryContextSwitchTo(hashtable->spillCxt);
 
-		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch);
-		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch);
+		hashtable->innerBatchFile = palloc0_array(BufFile *, cnt);
+		hashtable->outerBatchFile = palloc0_array(BufFile *, cnt);
 
 		MemoryContextSwitchTo(oldctx);
 
@@ -661,6 +667,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 						size_t *space_allowed,
 						int *numbuckets,
 						int *numbatches,
+						int *numbatches_inmemory,
 						int *num_skew_mcvs)
 {
 	int			tupsize;
@@ -669,6 +676,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 	size_t		bucket_bytes;
 	size_t		max_pointers;
 	int			nbatch = 1;
+	int			nbatch_inmemory = 1;
 	int			nbuckets;
 	double		dbuckets;
 
@@ -811,6 +819,7 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 									space_allowed,
 									numbuckets,
 									numbatches,
+									numbatches_inmemory,
 									num_skew_mcvs);
 			return;
 		}
@@ -848,11 +857,22 @@ ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 		nbatch = pg_nextpower2_32(Max(2, minbatch));
 	}
 
+	/*
+	 * See how many batches we can fit into memory (driven mostly by size
+	 * of BufFile, with PGAlignedBlock being the largest part of that).
+	 * We need one BufFile for inner and outer side, so we count it twice
+	 * for each batch, and we stop once we exceed (work_mem/2).
+	 */
+	while ((nbatch_inmemory * 2) * sizeof(PGAlignedBlock) * 2
+			<= (work_mem * 1024L / 2))
+		nbatch_inmemory *= 2;
+
 	Assert(nbuckets > 0);
 	Assert(nbatch > 0);
 
 	*numbuckets = nbuckets;
 	*numbatches = nbatch;
+	*numbatches_inmemory = nbatch_inmemory;
 }
 
 
@@ -874,13 +894,21 @@ ExecHashTableDestroy(HashJoinTable hashtable)
 	 */
 	if (hashtable->innerBatchFile != NULL)
 	{
-		for (i = 1; i < hashtable->nbatch; i++)
+		int nbatch = Min(hashtable->nbatch, hashtable->nbatch_inmemory);
+
+		for (i = 1; i < nbatch; i++)
 		{
 			if (hashtable->innerBatchFile[i])
 				BufFileClose(hashtable->innerBatchFile[i]);
 			if (hashtable->outerBatchFile[i])
 				BufFileClose(hashtable->outerBatchFile[i]);
 		}
+
+		if (hashtable->innerOverflowFile)
+			BufFileClose(hashtable->innerOverflowFile);
+
+		if (hashtable->outerOverflowFile)
+			BufFileClose(hashtable->outerOverflowFile);
 	}
 
 	/* Release working memory (batchCxt is a child, so it goes away too) */
@@ -923,11 +951,13 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 
 	if (hashtable->innerBatchFile == NULL)
 	{
+		int nbatch_tmp = Min(nbatch, hashtable->nbatch_inmemory);
+
 		MemoryContext oldcxt = MemoryContextSwitchTo(hashtable->spillCxt);
 
 		/* we had no file arrays before */
-		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch);
-		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch);
+		hashtable->innerBatchFile = palloc0_array(BufFile *, nbatch_tmp);
+		hashtable->outerBatchFile = palloc0_array(BufFile *, nbatch_tmp);
 
 		MemoryContextSwitchTo(oldcxt);
 
@@ -936,9 +966,12 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 	}
 	else
 	{
+		int nbatch_tmp = Min(nbatch, hashtable->nbatch_inmemory);
+		int oldnbatch_tmp = Min(oldnbatch, hashtable->nbatch_inmemory);
+
 		/* enlarge arrays and zero out added entries */
-		hashtable->innerBatchFile = repalloc0_array(hashtable->innerBatchFile, BufFile *, oldnbatch, nbatch);
-		hashtable->outerBatchFile = repalloc0_array(hashtable->outerBatchFile, BufFile *, oldnbatch, nbatch);
+		hashtable->innerBatchFile = repalloc0_array(hashtable->innerBatchFile, BufFile *, oldnbatch_tmp, nbatch_tmp);
+		hashtable->outerBatchFile = repalloc0_array(hashtable->outerBatchFile, BufFile *, oldnbatch_tmp, nbatch_tmp);
 	}
 
 	hashtable->nbatch = nbatch;
@@ -1008,11 +1041,18 @@ ExecHashIncreaseNumBatches(HashJoinTable hashtable)
 			}
 			else
 			{
+				BufFile **batchFile;
+
 				/* dump it out */
 				Assert(batchno > curbatch);
+
+				batchFile = ExecHashGetBatchFile(hashtable, batchno,
+												 hashtable->innerBatchFile,
+												 &hashtable->innerOverflowFile);
+
 				ExecHashJoinSaveTuple(HJTUPLE_MINTUPLE(hashTuple),
 									  hashTuple->hashvalue,
-									  &hashtable->innerBatchFile[batchno],
+									  batchFile,
 									  hashtable);
 
 				hashtable->spaceUsed -= hashTupleSize;
@@ -1673,22 +1713,33 @@ ExecHashTableInsert(HashJoinTable hashtable,
 
 		/* Account for space used, and back off if we've used too much */
 		hashtable->spaceUsed += hashTupleSize;
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
+
+		/* refresh info about peak used memory */
+		ExecHashUpdateSpacePeak(hashtable);
+
+		/* Consider increasing number of batches if we filled work_mem. */
 		if (hashtable->spaceUsed +
-			hashtable->nbuckets_optimal * sizeof(HashJoinTuple)
+			hashtable->nbuckets_optimal * sizeof(HashJoinTuple) +
+			Min(hashtable->nbatch, hashtable->nbatch_inmemory) * sizeof(PGAlignedBlock) * 2	/* inner + outer */
 			> hashtable->spaceAllowed)
 			ExecHashIncreaseNumBatches(hashtable);
 	}
 	else
 	{
+		BufFile **batchFile;
+
 		/*
 		 * put the tuple into a temp file for later batches
 		 */
 		Assert(batchno > hashtable->curbatch);
+
+		batchFile = ExecHashGetBatchFile(hashtable, batchno,
+										 hashtable->innerBatchFile,
+										 &hashtable->innerOverflowFile);
+
 		ExecHashJoinSaveTuple(tuple,
 							  hashvalue,
-							  &hashtable->innerBatchFile[batchno],
+							  batchFile,
 							  hashtable);
 	}
 
@@ -1843,6 +1894,100 @@ ExecHashGetBucketAndBatch(HashJoinTable hashtable,
 	}
 }
 
+int
+ExecHashGetBatchIndex(HashJoinTable hashtable, int batchno)
+{
+	int	slice,
+		curslice;
+
+	if (hashtable->nbatch <= hashtable->nbatch_inmemory)
+		return batchno;
+
+	slice = batchno / hashtable->nbatch_inmemory;
+	curslice = hashtable->curbatch / hashtable->nbatch_inmemory;
+
+	/* slices can't go backwards */
+	Assert(slice >= curslice);
+
+	/* overflow slice */
+	if (slice > curslice)
+		return -1;
+
+	/* current slice, compute index in the current array */
+	return (batchno % hashtable->nbatch_inmemory);
+}
+
+BufFile **
+ExecHashGetBatchFile(HashJoinTable hashtable, int batchno,
+					 BufFile **batchFiles, BufFile **overflowFile)
+{
+	int		idx = ExecHashGetBatchIndex(hashtable, batchno);
+
+	if (idx == -1)
+		return overflowFile;
+
+	return &batchFiles[idx];
+}
+
+void
+ExecHashSwitchToNextBatchSlice(HashJoinTable hashtable)
+{
+	memset(hashtable->innerBatchFile, 0,
+		   hashtable->nbatch_inmemory * sizeof(BufFile *));
+
+	hashtable->innerBatchFile[0] = hashtable->innerOverflowFile;
+	hashtable->innerOverflowFile = NULL;
+
+	memset(hashtable->outerBatchFile, 0,
+		   hashtable->nbatch_inmemory * sizeof(BufFile *));
+
+	hashtable->outerBatchFile[0] = hashtable->outerOverflowFile;
+	hashtable->outerOverflowFile = NULL;
+}
+
+int
+ExecHashSwitchToNextBatch(HashJoinTable hashtable)
+{
+	int		batchidx;
+
+	hashtable->curbatch++;
+
+	/* see if we skipped to the next batch slice */
+	batchidx = ExecHashGetBatchIndex(hashtable, hashtable->curbatch);
+
+	/* Can't be -1, current batch is in the current slice by definition. */
+	Assert(batchidx >= 0 && batchidx < hashtable->nbatch_inmemory);
+
+	/*
+	 * If we skipped to the next slice of batches, reset the array of files
+	 * and use the overflow file as the first batch.
+	 */
+	if (batchidx == 0)
+		ExecHashSwitchToNextBatchSlice(hashtable);
+
+	return hashtable->curbatch;
+}
+
+static void
+ExecHashUpdateSpacePeak(HashJoinTable hashtable)
+{
+	Size	spaceUsed = hashtable->spaceUsed;
+
+	/* Account for the buckets in spaceUsed (reported in EXPLAIN ANALYZE) */
+	spaceUsed += hashtable->nbuckets * sizeof(HashJoinTuple);
+
+	/* Account for memory used for batch files (inner + outer) */
+	spaceUsed += Min(hashtable->nbatch, hashtable->nbatch_inmemory) *
+				 sizeof(PGAlignedBlock) * 2;
+
+	/* Account for slice files (inner + outer) */
+	spaceUsed += (hashtable->nbatch / hashtable->nbatch_inmemory) *
+				 sizeof(PGAlignedBlock) * 2;
+
+	if (spaceUsed > hashtable->spacePeak)
+		hashtable->spacePeak = spaceUsed;
+}
+
 /*
  * ExecScanHashBucket
  *		scan a hash bucket for matches to the current outer tuple
@@ -2349,8 +2494,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			+ mcvsToUse * sizeof(int);
 		hashtable->spaceUsedSkew += nbuckets * sizeof(HashSkewBucket *)
 			+ mcvsToUse * sizeof(int);
-		if (hashtable->spaceUsed > hashtable->spacePeak)
-			hashtable->spacePeak = hashtable->spaceUsed;
+
+		/* refresh info about peak used memory */
+		ExecHashUpdateSpacePeak(hashtable);
 
 		/*
 		 * Create a skew bucket for each MCV hash value.
@@ -2399,8 +2545,9 @@ ExecHashBuildSkewHash(HashState *hashstate, HashJoinTable hashtable,
 			hashtable->nSkewBuckets++;
 			hashtable->spaceUsed += SKEW_BUCKET_OVERHEAD;
 			hashtable->spaceUsedSkew += SKEW_BUCKET_OVERHEAD;
-			if (hashtable->spaceUsed > hashtable->spacePeak)
-				hashtable->spacePeak = hashtable->spaceUsed;
+
+			/* refresh info about peak used memory */
+			ExecHashUpdateSpacePeak(hashtable);
 		}
 
 		free_attstatsslot(&sslot);
@@ -2489,8 +2636,10 @@ ExecHashSkewTableInsert(HashJoinTable hashtable,
 	/* Account for space used, and back off if we've used too much */
 	hashtable->spaceUsed += hashTupleSize;
 	hashtable->spaceUsedSkew += hashTupleSize;
-	if (hashtable->spaceUsed > hashtable->spacePeak)
-		hashtable->spacePeak = hashtable->spaceUsed;
+
+	/* refresh info about peak used memory */
+	ExecHashUpdateSpacePeak(hashtable);
+
 	while (hashtable->spaceUsedSkew > hashtable->spaceAllowedSkew)
 		ExecHashRemoveNextSkewBucket(hashtable);
 
@@ -2569,10 +2718,17 @@ ExecHashRemoveNextSkewBucket(HashJoinTable hashtable)
 		}
 		else
 		{
+			BufFile **batchFile;
+
 			/* Put the tuple into a temp file for later batches */
 			Assert(batchno > hashtable->curbatch);
+
+			batchFile = ExecHashGetBatchFile(hashtable, batchno,
+											 hashtable->innerBatchFile,
+											 &hashtable->innerOverflowFile);
+
 			ExecHashJoinSaveTuple(tuple, hashvalue,
-								  &hashtable->innerBatchFile[batchno],
+								  batchFile,
 								  hashtable);
 			pfree(hashTuple);
 			hashtable->spaceUsed -= tupleSize;
@@ -2750,6 +2906,8 @@ ExecHashAccumInstrumentation(HashInstrumentation *instrument,
 							 hashtable->nbatch);
 	instrument->nbatch_original = Max(instrument->nbatch_original,
 									  hashtable->nbatch_original);
+	instrument->nbatch_inmemory = Min(hashtable->nbatch,
+									  hashtable->nbatch_inmemory);
 	instrument->space_peak = Max(instrument->space_peak,
 								 hashtable->spacePeak);
 }
diff --git a/src/backend/executor/nodeHashjoin.c b/src/backend/executor/nodeHashjoin.c
index ea0045bc0f3..2f5d94653e8 100644
--- a/src/backend/executor/nodeHashjoin.c
+++ b/src/backend/executor/nodeHashjoin.c
@@ -481,6 +481,7 @@ ExecHashJoinImpl(PlanState *pstate, bool parallel)
 				if (batchno != hashtable->curbatch &&
 					node->hj_CurSkewBucketNo == INVALID_SKEW_BUCKET_NO)
 				{
+					BufFile	  **batchFile;
 					bool		shouldFree;
 					MinimalTuple mintuple = ExecFetchSlotMinimalTuple(outerTupleSlot,
 																	  &shouldFree);
@@ -491,8 +492,13 @@ ExecHashJoinImpl(PlanState *pstate, bool parallel)
 					 */
 					Assert(parallel_state == NULL);
 					Assert(batchno > hashtable->curbatch);
+
+					batchFile = ExecHashGetBatchFile(hashtable, batchno,
+													 hashtable->outerBatchFile,
+													 &hashtable->outerOverflowFile);
+
 					ExecHashJoinSaveTuple(mintuple, hashvalue,
-										  &hashtable->outerBatchFile[batchno],
+										  batchFile,
 										  hashtable);
 
 					if (shouldFree)
@@ -1030,17 +1036,19 @@ ExecHashJoinOuterGetTuple(PlanState *outerNode,
 	}
 	else if (curbatch < hashtable->nbatch)
 	{
-		BufFile    *file = hashtable->outerBatchFile[curbatch];
+		BufFile    **file = ExecHashGetBatchFile(hashtable, curbatch,
+												 hashtable->outerBatchFile,
+												 &hashtable->outerOverflowFile);
 
 		/*
 		 * In outer-join cases, we could get here even though the batch file
 		 * is empty.
 		 */
-		if (file == NULL)
+		if (*file == NULL)
 			return NULL;
 
 		slot = ExecHashJoinGetSavedTuple(hjstate,
-										 file,
+										 *file,
 										 hashvalue,
 										 hjstate->hj_OuterTupleSlot);
 		if (!TupIsNull(slot))
@@ -1135,9 +1143,18 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 	BufFile    *innerFile;
 	TupleTableSlot *slot;
 	uint32		hashvalue;
+	int			batchidx;
+	int			curbatch_old;
 
 	nbatch = hashtable->nbatch;
 	curbatch = hashtable->curbatch;
+	curbatch_old = curbatch;
+
+	/* index of the old batch */
+	batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
+
+	/* has to be in the current slice of batches */
+	Assert(batchidx >= 0 && batchidx < hashtable->nbatch_inmemory);
 
 	if (curbatch > 0)
 	{
@@ -1145,9 +1162,9 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 		 * We no longer need the previous outer batch file; close it right
 		 * away to free disk space.
 		 */
-		if (hashtable->outerBatchFile[curbatch])
-			BufFileClose(hashtable->outerBatchFile[curbatch]);
-		hashtable->outerBatchFile[curbatch] = NULL;
+		if (hashtable->outerBatchFile[batchidx])
+			BufFileClose(hashtable->outerBatchFile[batchidx]);
+		hashtable->outerBatchFile[batchidx] = NULL;
 	}
 	else						/* we just finished the first batch */
 	{
@@ -1182,45 +1199,50 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 	 * scan, we have to rescan outer batches in case they contain tuples that
 	 * need to be reassigned.
 	 */
-	curbatch++;
+	curbatch = ExecHashSwitchToNextBatch(hashtable);
+	batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
+
 	while (curbatch < nbatch &&
-		   (hashtable->outerBatchFile[curbatch] == NULL ||
-			hashtable->innerBatchFile[curbatch] == NULL))
+		   (hashtable->outerBatchFile[batchidx] == NULL ||
+			hashtable->innerBatchFile[batchidx] == NULL))
 	{
-		if (hashtable->outerBatchFile[curbatch] &&
+		if (hashtable->outerBatchFile[batchidx] &&
 			HJ_FILL_OUTER(hjstate))
 			break;				/* must process due to rule 1 */
-		if (hashtable->innerBatchFile[curbatch] &&
+		if (hashtable->innerBatchFile[batchidx] &&
 			HJ_FILL_INNER(hjstate))
 			break;				/* must process due to rule 1 */
-		if (hashtable->innerBatchFile[curbatch] &&
+		if (hashtable->innerBatchFile[batchidx] &&
 			nbatch != hashtable->nbatch_original)
 			break;				/* must process due to rule 2 */
-		if (hashtable->outerBatchFile[curbatch] &&
+		if (hashtable->outerBatchFile[batchidx] &&
 			nbatch != hashtable->nbatch_outstart)
 			break;				/* must process due to rule 3 */
 		/* We can ignore this batch. */
 		/* Release associated temp files right away. */
-		if (hashtable->innerBatchFile[curbatch])
-			BufFileClose(hashtable->innerBatchFile[curbatch]);
-		hashtable->innerBatchFile[curbatch] = NULL;
-		if (hashtable->outerBatchFile[curbatch])
-			BufFileClose(hashtable->outerBatchFile[curbatch]);
-		hashtable->outerBatchFile[curbatch] = NULL;
-		curbatch++;
+		if (hashtable->innerBatchFile[batchidx])
+			BufFileClose(hashtable->innerBatchFile[batchidx]);
+		hashtable->innerBatchFile[batchidx] = NULL;
+		if (hashtable->outerBatchFile[batchidx])
+			BufFileClose(hashtable->outerBatchFile[batchidx]);
+		hashtable->outerBatchFile[batchidx] = NULL;
+
+		curbatch = ExecHashSwitchToNextBatch(hashtable);
+		batchidx = ExecHashGetBatchIndex(hashtable, curbatch);
 	}
 
 	if (curbatch >= nbatch)
+	{
+		hashtable->curbatch = curbatch_old;
 		return false;			/* no more batches */
-
-	hashtable->curbatch = curbatch;
+	}
 
 	/*
 	 * Reload the hash table with the new inner batch (which could be empty)
 	 */
 	ExecHashTableReset(hashtable);
 
-	innerFile = hashtable->innerBatchFile[curbatch];
+	innerFile = hashtable->innerBatchFile[batchidx];
 
 	if (innerFile != NULL)
 	{
@@ -1246,15 +1268,15 @@ ExecHashJoinNewBatch(HashJoinState *hjstate)
 		 * needed
 		 */
 		BufFileClose(innerFile);
-		hashtable->innerBatchFile[curbatch] = NULL;
+		hashtable->innerBatchFile[batchidx] = NULL;
 	}
 
 	/*
 	 * Rewind outer batch file (if present), so that we can start reading it.
 	 */
-	if (hashtable->outerBatchFile[curbatch] != NULL)
+	if (hashtable->outerBatchFile[batchidx] != NULL)
 	{
-		if (BufFileSeek(hashtable->outerBatchFile[curbatch], 0, 0, SEEK_SET))
+		if (BufFileSeek(hashtable->outerBatchFile[batchidx], 0, 0, SEEK_SET))
 			ereport(ERROR,
 					(errcode_for_file_access(),
 					 errmsg("could not rewind hash-join temporary file")));
diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c
index c36687aa4df..6f40600d10b 100644
--- a/src/backend/optimizer/path/costsize.c
+++ b/src/backend/optimizer/path/costsize.c
@@ -4173,6 +4173,7 @@ initial_cost_hashjoin(PlannerInfo *root, JoinCostWorkspace *workspace,
 	int			num_hashclauses = list_length(hashclauses);
 	int			numbuckets;
 	int			numbatches;
+	int			numbatches_inmemory;
 	int			num_skew_mcvs;
 	size_t		space_allowed;	/* unused */
 
@@ -4227,6 +4228,7 @@ initial_cost_hashjoin(PlannerInfo *root, JoinCostWorkspace *workspace,
 							&space_allowed,
 							&numbuckets,
 							&numbatches,
+							&numbatches_inmemory,
 							&num_skew_mcvs);
 
 	/*
diff --git a/src/include/executor/hashjoin.h b/src/include/executor/hashjoin.h
index 2d8ed8688cd..c612ae98117 100644
--- a/src/include/executor/hashjoin.h
+++ b/src/include/executor/hashjoin.h
@@ -320,6 +320,7 @@ typedef struct HashJoinTableData
 	int		   *skewBucketNums; /* array indexes of active skew buckets */
 
 	int			nbatch;			/* number of batches */
+	int			nbatch_inmemory;	/* max number of in-memory batches */
 	int			curbatch;		/* current batch #; 0 during 1st pass */
 
 	int			nbatch_original;	/* nbatch when we started inner scan */
@@ -331,6 +332,9 @@ typedef struct HashJoinTableData
 	double		partialTuples;	/* # tuples obtained from inner plan by me */
 	double		skewTuples;		/* # tuples inserted into skew tuples */
 
+	BufFile	   *innerOverflowFile;	/* temp file for overflow batch batch */
+	BufFile	   *outerOverflowFile;	/* temp file for overflow batch batch */
+
 	/*
 	 * These arrays are allocated for the life of the hash join, but only if
 	 * nbatch > 1.  A file is opened only when we first write a tuple into it
diff --git a/src/include/executor/nodeHash.h b/src/include/executor/nodeHash.h
index e4eb7bc6359..fc82ea43c3b 100644
--- a/src/include/executor/nodeHash.h
+++ b/src/include/executor/nodeHash.h
@@ -16,6 +16,7 @@
 
 #include "access/parallel.h"
 #include "nodes/execnodes.h"
+#include "storage/buffile.h"
 
 struct SharedHashJoinBatch;
 
@@ -46,6 +47,11 @@ extern void ExecHashGetBucketAndBatch(HashJoinTable hashtable,
 									  uint32 hashvalue,
 									  int *bucketno,
 									  int *batchno);
+extern int ExecHashGetBatchIndex(HashJoinTable hashtable, int batchno);
+extern BufFile **ExecHashGetBatchFile(HashJoinTable hashtable, int batchno,
+					 BufFile **files, BufFile **overflow);
+extern void ExecHashSwitchToNextBatchSlice(HashJoinTable hashtable);
+extern int ExecHashSwitchToNextBatch(HashJoinTable hashtable);
 extern bool ExecScanHashBucket(HashJoinState *hjstate, ExprContext *econtext);
 extern bool ExecParallelScanHashBucket(HashJoinState *hjstate, ExprContext *econtext);
 extern void ExecPrepHashTableForUnmatched(HashJoinState *hjstate);
@@ -62,6 +68,7 @@ extern void ExecChooseHashTableSize(double ntuples, int tupwidth, bool useskew,
 									size_t *space_allowed,
 									int *numbuckets,
 									int *numbatches,
+									int *numbatches_inmemory,
 									int *num_skew_mcvs);
 extern int	ExecHashGetSkewBucket(HashJoinTable hashtable, uint32 hashvalue);
 extern void ExecHashEstimate(HashState *node, ParallelContext *pcxt);
diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h
index 1590b643920..eb8eaea89ab 100644
--- a/src/include/nodes/execnodes.h
+++ b/src/include/nodes/execnodes.h
@@ -2755,6 +2755,7 @@ typedef struct HashInstrumentation
 	int			nbuckets_original;	/* planned number of buckets */
 	int			nbatch;			/* number of batches at end of execution */
 	int			nbatch_original;	/* planned number of batches */
+	int			nbatch_inmemory;	/* number of batches kept in memory */
 	Size		space_peak;		/* peak memory usage in bytes */
 } HashInstrumentation;
 
-- 
2.47.1

