From 5e037b9cf644e23fd9e9806a0b72690ddb867f75 Mon Sep 17 00:00:00 2001
From: Jelte Fennema-Nio <postgres@jeltef.nl>
Date: Thu, 23 Oct 2025 14:31:52 +0200
Subject: [PATCH v2 3/3] Add pytest based tests for GoAway message

I used this patchset as a trial for the new pytest suite that Jacob is
trying to introduce. Feel free to look at it, but I'd say don't review
this test in detail until we have the pytest changes merged or at least
in a more agreed upon state. This patch is built on top of that v3
patchset. This test is not applied by cfbot.

Testing this any other way is actually quite difficult with the
infrastructure we currently have (of course I could change that, but I'd
much rather spend that energy/time on making the pytest test suite a
thing):
- pgregress and perl tests don't work because we need to call a new
  libpq function that is not exposed in psql (I guess I could expose it
  with some \goawayreceived command, but it doesn't seem very useful).
- libpq_pipeline cannot test this because it would need to restart
  the Postgres server and all it has
---
 src/test/pytest/libpq.py         | 18 ++++++++++++
 src/test/pytest/meson.build      |  1 +
 src/test/pytest/pypg/fixtures.py | 33 +++++++++++++++++++++
 src/test/pytest/pypg/server.py   | 50 ++++++++++++++++++++++++++++++++
 4 files changed, 102 insertions(+)

diff --git a/src/test/pytest/libpq.py b/src/test/pytest/libpq.py
index b851a117b66..5536b605c16 100644
--- a/src/test/pytest/libpq.py
+++ b/src/test/pytest/libpq.py
@@ -133,6 +133,12 @@ def load_libpq_handle(libdir):
     lib.PQftype.restype = ctypes.c_uint
     lib.PQftype.argtypes = [_PGresult_p, ctypes.c_int]
 
+    lib.PQgoAwayReceived.restype = ctypes.c_int
+    lib.PQgoAwayReceived.argtypes = [_PGconn_p]
+
+    lib.PQconsumeInput.restype = ctypes.c_int
+    lib.PQconsumeInput.argtypes = [_PGconn_p]
+
     return lib
 
 
@@ -340,6 +346,18 @@ class PGconn(contextlib.AbstractContextManager):
             error_msg = res.error_message() or f"Unexpected status: {status}"
             raise LibpqError(f"Query failed: {error_msg}\nQuery: {query}")
 
+    def consume_input(self) -> bool:
+        """
+        Consumes any available input from the server. Returns True on success.
+        """
+        return bool(self._lib.PQconsumeInput(self._handle))
+
+    def goaway_received(self) -> bool:
+        """
+        Returns True if a GoAway message was received from the server.
+        """
+        return bool(self._lib.PQgoAwayReceived(self._handle))
+
 
 def connstr(opts: Dict[str, Any]) -> str:
     """
diff --git a/src/test/pytest/meson.build b/src/test/pytest/meson.build
index f53193e8686..3c8518243d9 100644
--- a/src/test/pytest/meson.build
+++ b/src/test/pytest/meson.build
@@ -12,6 +12,7 @@ tests += {
     'tests': [
       'pyt/test_something.py',
       'pyt/test_libpq.py',
+      'pyt/test_goaway.py',
     ],
   },
 }
diff --git a/src/test/pytest/pypg/fixtures.py b/src/test/pytest/pypg/fixtures.py
index cf22c8ec436..ba46f048beb 100644
--- a/src/test/pytest/pypg/fixtures.py
+++ b/src/test/pytest/pypg/fixtures.py
@@ -30,6 +30,30 @@ def remaining_timeout():
     return lambda: max(deadline - time.monotonic(), 0)
 
 
+@pytest.fixture
+def wait_until(remaining_timeout):
+    def wait_until(error_message="Did not complete in time", timeout=None, interval=1):
+        """
+        Loop until the timeout is reached. If the timeout is reached, raise an
+        exception with the given error message.
+        """
+        if timeout is None:
+            timeout = remaining_timeout()
+
+        end = time.time() + timeout
+        print_progress = timeout / 10 > 4
+        last_printed_progress = 0
+        while time.time() < end:
+            if print_progress and time.time() - last_printed_progress > 4:
+                last_printed_progress = time.time()
+                print(f"{error_message} - will retry")
+            yield
+            time.sleep(interval)
+        raise TimeoutError(error_message)
+
+    return wait_until
+
+
 @pytest.fixture(scope="session")
 def libpq_handle(libdir):
     """
@@ -149,6 +173,15 @@ def pg_server_module(pg_server_global):
         yield s
 
 
+@pytest.fixture(autouse=True, scope="function")
+def ensure_server_running(pg_server_global):
+    """
+    Autouse fixture that ensures the server is running before each test.
+    If a test shuts down the server, this will restart it for the next test.
+    """
+    pg_server_global.ensure_running()
+
+
 @pytest.fixture
 def pg(pg_server_module, remaining_timeout):
     """
diff --git a/src/test/pytest/pypg/server.py b/src/test/pytest/pypg/server.py
index d6675cde93d..f09651c089e 100644
--- a/src/test/pytest/pypg/server.py
+++ b/src/test/pytest/pypg/server.py
@@ -332,6 +332,56 @@ class PostgresServer:
             # Server may have already been stopped
             pass
 
+    def ensure_running(self):
+        """
+        Ensure that the PostgreSQL server is running and accepting connections.
+
+        If the server is not running, it will be restarted. This method waits
+        for any in-progress shutdown to complete before attempting to restart.
+        """
+        pid_file = os.path.join(self.datadir, "postmaster.pid")
+
+        # Wait for any in-progress shutdown to complete
+        socket_pattern = os.path.join(self.sockdir, f".s.PGSQL.{self.port}*")
+        for _ in range(100):  # Wait up to 10 seconds
+            # Server is fully down when both PID file and sockets are gone
+            if not os.path.exists(pid_file) and len(glob.glob(socket_pattern)) == 0:
+                break
+            # Server is running if PID exists and we can connect
+            if os.path.exists(pid_file):
+                # Use pg_isready to check if server is accepting connections
+                try:
+                    pg_isready = os.path.join(self._bindir, "pg_isready")
+                    run(
+                        pg_isready,
+                        "-h",
+                        self.sockdir,
+                        "-p",
+                        self.port,
+                        stdout=subprocess.DEVNULL,
+                        stderr=subprocess.DEVNULL,
+                        timeout=1,
+                    )
+                    # Server is up and ready
+                    break
+                except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
+                    # Server is not ready yet, keep waiting
+                    pass
+            time.sleep(0.1)
+
+        # Now check if server needs to be started
+        if not os.path.exists(pid_file):
+            # Restart the server and wait for it to be ready
+            run(
+                self._pg_ctl,
+                "-D",
+                self.datadir,
+                "-l",
+                self._log,
+                "-w",
+                "start",
+            )
+
     def cleanup(self):
         """Run all registered cleanup callbacks."""
         self._cleanup_stack.close()
-- 
2.51.1

