diff --git a/src/backend/utils/mmgr/mcxt.c b/src/backend/utils/mmgr/mcxt.c index 64bcc7ef32..3d80abbfae 100644 --- a/src/backend/utils/mmgr/mcxt.c +++ b/src/backend/utils/mmgr/mcxt.c @@ -482,6 +482,15 @@ MemoryContextAllowInCriticalSection(MemoryContext context, bool allow) MemoryContext GetMemoryChunkContext(void *pointer) { + /* + * Try to detect bogus pointers handed to us, poorly though we can. + * Presumably, a pointer that isn't MAXALIGNED isn't pointing at an + * allocated chunk. + */ + Assert(pointer != NULL); + Assert(pointer == (void *) MAXALIGN(pointer)); + /* adding further Asserts here? See pre-checks in MemoryContextContains */ + return MCXT_METHOD(pointer, get_chunk_context) (pointer); } @@ -809,11 +818,10 @@ MemoryContextCheck(MemoryContext context) * Detect whether an allocated chunk of memory belongs to a given * context or not. * - * Caution: this test is reliable as long as 'pointer' does point to - * a chunk of memory allocated from *some* context. If 'pointer' points - * at memory obtained in some other way, there is a small chance of a - * false-positive result, since the bits right before it might look like - * a valid chunk header by chance. + * Caution: 'pointer' must point to a pointer which was allocated by a + * MemoryContext. It's not safe or valid to use this function on arbitrary + * pointers as obtaining the MemoryContext which 'pointer' belongs to requires + * possibly several pointer dereferences. */ bool MemoryContextContains(MemoryContext context, void *pointer) @@ -821,9 +829,8 @@ MemoryContextContains(MemoryContext context, void *pointer) MemoryContext ptr_context; /* - * NB: Can't use GetMemoryChunkContext() here - that performs assertions - * that aren't acceptable here since we might be passed memory not - * allocated by any memory context. + * NB: We must perform run-time checks here which GetMemoryChunkContext() + * does as assertions before calling GetMemoryChunkContext(). * * Try to detect bogus pointers handed to us, poorly though we can. * Presumably, a pointer that isn't MAXALIGNED isn't pointing at an @@ -835,7 +842,7 @@ MemoryContextContains(MemoryContext context, void *pointer) /* * OK, it's probably safe to look at the context. */ - ptr_context = *(MemoryContext *) (((char *) pointer) - sizeof(void *)); + ptr_context = GetMemoryChunkContext(pointer); return ptr_context == context; } @@ -953,6 +960,8 @@ MemoryContextAlloc(MemoryContext context, Size size) VALGRIND_MEMPOOL_ALLOC(context, ret, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -991,6 +1000,8 @@ MemoryContextAllocZero(MemoryContext context, Size size) MemSetAligned(ret, 0, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1029,6 +1040,8 @@ MemoryContextAllocZeroAligned(MemoryContext context, Size size) MemSetLoop(ret, 0, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1070,6 +1083,8 @@ MemoryContextAllocExtended(MemoryContext context, Size size, int flags) if ((flags & MCXT_ALLOC_ZERO) != 0) MemSetAligned(ret, 0, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1153,6 +1168,8 @@ palloc(Size size) VALGRIND_MEMPOOL_ALLOC(context, ret, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1186,6 +1203,8 @@ palloc0(Size size) MemSetAligned(ret, 0, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1225,6 +1244,8 @@ palloc_extended(Size size, int flags) if ((flags & MCXT_ALLOC_ZERO) != 0) MemSetAligned(ret, 0, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1278,6 +1299,8 @@ repalloc(void *pointer, Size size) VALGRIND_MEMPOOL_CHANGE(context, pointer, ret, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1313,6 +1336,8 @@ MemoryContextAllocHuge(MemoryContext context, Size size) VALGRIND_MEMPOOL_ALLOC(context, ret, size); + Assert(MemoryContextContains(context, ret)); + return ret; } @@ -1352,6 +1377,8 @@ repalloc_huge(void *pointer, Size size) VALGRIND_MEMPOOL_CHANGE(context, pointer, ret, size); + Assert(MemoryContextContains(context, ret)); + return ret; }