From c82e188e3e31277ced29e102532682630f90dfc6 Mon Sep 17 00:00:00 2001
From: Jacob Champion <jacob.champion@enterprisedb.com>
Date: Tue, 16 Dec 2025 09:30:55 +0100
Subject: [PATCH v9 4/5] WIP: pytest: Add some SSL client tests

This is a sample client-only test suite. It tests some handshake
failures against a mock server, as well as a full SSL handshake + empty
query + response.

pyca/cryptography is added as a new package dependency. Certificates for
testing are generated on the fly.

The mock design is threaded: the server socket is listening on a
background thread, and the test provides the server logic via a
callback. There is some additional work still needed to make this
production-ready; see the notes for _TCPServer.background(). (Currently,
an exception in the wrong place could result in a hang-until-timeout
rather than an immediate failure.)

TODOs:
- local_server and tcp_server_class are nearly identical and should
  share code.
- fix exception-related timeouts for .background()
- figure out the proper use of "session" vs "module" scope
- ensure that pq.libpq unwinds (to close connections) before tcp_server;
  see comment in test_server_with_ssl_disabled()
---
 .cirrus.tasks.yml               |   2 +
 pyproject.toml                  |   8 +
 src/test/ssl/Makefile           |   2 +
 src/test/ssl/meson.build        |   6 +
 src/test/ssl/pyt/conftest.py    | 128 +++++++++++++++
 src/test/ssl/pyt/test_client.py | 278 ++++++++++++++++++++++++++++++++
 6 files changed, 424 insertions(+)
 create mode 100644 src/test/ssl/pyt/conftest.py
 create mode 100644 src/test/ssl/pyt/test_client.py

diff --git a/.cirrus.tasks.yml b/.cirrus.tasks.yml
index 1b0deae8d87..17fd7e0c8c3 100644
--- a/.cirrus.tasks.yml
+++ b/.cirrus.tasks.yml
@@ -645,6 +645,7 @@ task:
     CIRRUS_WORKING_DIR: ${HOME}/pgsql/
     CCACHE_DIR: ${HOME}/ccache
     MACPORTS_CACHE: ${HOME}/macports-cache
+    PYTEST_DEBUG_TEMPROOT: /tmp  # default is too long for UNIX sockets on Mac
 
     MESON_FEATURES: >-
       -Dbonjour=enabled
@@ -665,6 +666,7 @@ task:
       p5.34-io-tty
       p5.34-ipc-run
       python312
+      py312-cryptography
       py312-packaging
       py312-pytest
       tcl
diff --git a/pyproject.toml b/pyproject.toml
index 4628d2274e0..00c8ae88583 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,14 @@ dependencies = [
     # Any other dependencies are effectively optional (added below). We import
     # these libraries using pytest.importorskip(). So tests will be skipped if
     # they are not available.
+
+    # Notes on the cryptography package:
+    # - 3.3.2 is shipped on Debian bullseye.
+    # - 3.4.x drops support for Python 2, making it a version of note for older LTS
+    #   distros.
+    # - 35.x switched versioning schemes and moved to Rust parsing.
+    # - 40.x is the last version supporting Python 3.6.
+    "cryptography >= 3.3.2",
 ]
 
 [tool.pytest.ini_options]
diff --git a/src/test/ssl/Makefile b/src/test/ssl/Makefile
index aa062945fb9..287729ad9fb 100644
--- a/src/test/ssl/Makefile
+++ b/src/test/ssl/Makefile
@@ -30,6 +30,8 @@ clean distclean:
 # Doesn't depend on sslfiles because we don't rebuild them by default
 check:
 	$(prove_check)
+	# XXX these suites should run independently, not serially
+	$(pytest_check)
 
 installcheck:
 	$(prove_installcheck)
diff --git a/src/test/ssl/meson.build b/src/test/ssl/meson.build
index 9e5bdbb6136..6ec274d8165 100644
--- a/src/test/ssl/meson.build
+++ b/src/test/ssl/meson.build
@@ -15,4 +15,10 @@ tests += {
       't/003_sslinfo.pl',
     ],
   },
+  'pytest': {
+    'tests': [
+      'pyt/test_client.py',
+      'pyt/test_server.py',
+    ],
+  },
 }
diff --git a/src/test/ssl/pyt/conftest.py b/src/test/ssl/pyt/conftest.py
new file mode 100644
index 00000000000..870f738ac44
--- /dev/null
+++ b/src/test/ssl/pyt/conftest.py
@@ -0,0 +1,128 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+import datetime
+import re
+import subprocess
+import tempfile
+from collections import namedtuple
+
+import pytest
+
+
+@pytest.fixture(scope="session")
+def cryptography():
+    return pytest.importorskip("cryptography", "3.3.2")
+
+
+Cert = namedtuple("Cert", "cert, certpath, key, keypath")
+
+
+@pytest.fixture(scope="session")
+def certs(cryptography, tmp_path_factory):
+    """
+    Caches commonly used certificates at the session level, and provides a way
+    to create new ones.
+
+    - certs.ca: the root CA certificate
+
+    - certs.server: the "standard" server certficate, signed by certs.ca
+
+    - certs.server_host: the hostname of the certs.server certificate
+
+    - certs.new(): creates a custom certificate, signed by certs.ca
+    """
+
+    from cryptography import x509
+    from cryptography.hazmat.primitives import hashes, serialization
+    from cryptography.hazmat.primitives.asymmetric import rsa
+    from cryptography.x509.oid import NameOID
+
+    tmpdir = tmp_path_factory.mktemp("test-certs")
+
+    class _Certs:
+        def __init__(self):
+            self.ca = self.new(
+                x509.Name(
+                    [x509.NameAttribute(NameOID.COMMON_NAME, "PG pytest CA")],
+                ),
+                ca=True,
+            )
+
+            self.server_host = "example.org"
+            self.server = self.new(
+                x509.Name(
+                    [x509.NameAttribute(NameOID.COMMON_NAME, self.server_host)],
+                )
+            )
+
+        def new(self, subject: x509.Name, *, ca=False) -> Cert:
+            """
+            Creates and signs a new Cert with the given subject name. If ca is
+            True, the certificate will be self-signed; otherwise the certificate
+            is signed by self.ca.
+            """
+            key = rsa.generate_private_key(
+                public_exponent=65537,
+                key_size=2048,
+            )
+
+            builder = x509.CertificateBuilder()
+            now = datetime.datetime.now(datetime.timezone.utc)
+
+            builder = (
+                builder.subject_name(subject)
+                .public_key(key.public_key())
+                .serial_number(x509.random_serial_number())
+                .not_valid_before(now)
+                .not_valid_after(now + datetime.timedelta(hours=1))
+            )
+
+            if ca:
+                builder = builder.issuer_name(subject)
+            else:
+                builder = builder.issuer_name(self.ca.cert.subject)
+
+            builder = builder.add_extension(
+                x509.BasicConstraints(ca=ca, path_length=None),
+                critical=True,
+            )
+
+            cert = builder.sign(
+                private_key=key if ca else self.ca.key,
+                algorithm=hashes.SHA256(),
+            )
+
+            # Dump the certificate and key to file.
+            keypath = self._tofile(
+                key.private_bytes(
+                    serialization.Encoding.PEM,
+                    serialization.PrivateFormat.PKCS8,
+                    serialization.NoEncryption(),
+                ),
+                suffix=".key",
+            )
+            certpath = self._tofile(
+                cert.public_bytes(serialization.Encoding.PEM),
+                suffix="-ca.crt" if ca else ".crt",
+            )
+
+            return Cert(
+                cert=cert,
+                certpath=certpath,
+                key=key,
+                keypath=keypath,
+            )
+
+        def _tofile(self, data: bytes, *, suffix) -> str:
+            """
+            Dumps data to a file on disk with the requested suffix and returns
+            the path. The file is located somewhere in pytest's temporary
+            directory root.
+            """
+            f = tempfile.NamedTemporaryFile(suffix=suffix, dir=tmpdir, delete=False)
+            with f:
+                f.write(data)
+
+            return f.name
+
+    return _Certs()
diff --git a/src/test/ssl/pyt/test_client.py b/src/test/ssl/pyt/test_client.py
new file mode 100644
index 00000000000..556bad33bf8
--- /dev/null
+++ b/src/test/ssl/pyt/test_client.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+import contextlib
+import ctypes
+import socket
+import ssl
+import struct
+import threading
+from typing import Callable
+
+import pytest
+
+import pypg
+from libpq import LibpqError, ExecStatus
+
+# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl.
+pytestmark = pypg.require_test_extras("ssl")
+
+
+@pytest.fixture(scope="session", autouse=True)
+def skip_if_no_ssl_support(libpq_handle):
+    """Skips tests if SSL support is not configured."""
+
+    # Declare PQsslAttribute().
+    PQsslAttribute = libpq_handle.PQsslAttribute
+    PQsslAttribute.restype = ctypes.c_char_p
+    PQsslAttribute.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
+
+    if not PQsslAttribute(None, b"library"):
+        pytest.skip("requires SSL support to be configured")
+
+
+#
+# Test Fixtures
+#
+
+
+@pytest.fixture
+def tcp_server_class(remaining_timeout):
+    """
+    Metafixture to combine related logic for tcp_server and ssl_server.
+
+    TODO: combine with test_libpq.local_server
+    """
+
+    class _TCPServer(contextlib.ExitStack):
+        """
+        Implementation class for tcp_server. See .background() for the primary
+        entry point for tests. Postgres clients may connect to this server via
+        **tcp_server.conninfo.
+
+        _TCPServer derives from contextlib.ExitStack to provide easy cleanup of
+        associated resources; see the documentation for that class for a full
+        explanation.
+        """
+
+        def __init__(self):
+            super().__init__()
+
+            self._thread = None
+            self._thread_exc = None
+            self._listener = self.enter_context(
+                socket.socket(socket.AF_INET, socket.SOCK_STREAM),
+            )
+
+            self._bind_and_listen()
+            sockname = self._listener.getsockname()
+            self.conninfo = dict(
+                hostaddr=sockname[0],
+                port=sockname[1],
+            )
+
+        def _bind_and_listen(self):
+            """
+            Does the actual work of binding the socket and listening for
+            connections.
+
+            The listen backlog is currently hardcoded to one.
+            """
+            self._listener.bind(("127.0.0.1", 0))
+            self._listener.listen(1)
+
+        def background(self, fn: Callable[[socket.socket], None]) -> None:
+            """
+            Accepts a client connection on a background thread and passes it to
+            the provided callback. Any exceptions raised from the callback will
+            be re-raised on the main thread during fixture teardown.
+
+            Blocking operations on the connected socket default to using the
+            remaining_timeout(), though this can be changed by the test via the
+            socket's .settimeout().
+            """
+
+            def _bg():
+                try:
+                    self._listener.settimeout(remaining_timeout())
+                    sock, _ = self._listener.accept()
+
+                    with sock:
+                        sock.settimeout(remaining_timeout())
+                        fn(sock)
+
+                except Exception as e:
+                    # Save the exception for re-raising on the main thread.
+                    self._thread_exc = e
+
+            # TODO: rather than using callback(), consider explicitly signaling
+            # the fn() implementation to stop early if we get an exception.
+            # Otherwise we'll hang until the end of the timeout.
+            self._thread = threading.Thread(target=_bg)
+            self.callback(self._join)
+
+            self._thread.start()
+
+        def _join(self):
+            """
+            Waits for the background thread to finish and raises any thrown
+            exception. This is called during fixture teardown.
+            """
+            # Give a little bit of wiggle room on the join timeout, since we're
+            # racing against the test's own use of remaining_timeout(). (It's
+            # preferable to let tests report timeouts; the stack traces will
+            # help with debugging.)
+            self._thread.join(remaining_timeout() + 1)
+            if self._thread.is_alive():
+                raise TimeoutError("background thread is still running after timeout")
+
+            if self._thread_exc is not None:
+                raise self._thread_exc
+
+    return _TCPServer
+
+
+@pytest.fixture
+def tcp_server(tcp_server_class):
+    """
+    Opens up a local TCP socket for mocking a Postgres server on a background
+    thread. See the _TCPServer API for usage.
+    """
+    with tcp_server_class() as s:
+        yield s
+
+
+@pytest.fixture
+def ssl_server(tcp_server_class, certs):
+    """
+    Like tcp_server, but with an additional .background_ssl() method which will
+    perform a SSLRequest handshake on the socket before handing the connection
+    to the test callback.
+
+    This server uses certs.server as its identity.
+    """
+
+    class _SSLServer(tcp_server_class):
+        def __init__(self):
+            super().__init__()
+
+            self.conninfo["host"] = certs.server_host
+
+            self._ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+            self._ctx.load_cert_chain(certs.server.certpath, certs.server.keypath)
+
+        def background_ssl(self, fn: Callable[[ssl.SSLSocket], None]) -> None:
+            """
+            Invokes a server callback as with .background(), but an SSLRequest
+            handshake is performed first, and the socket provided to the
+            callback has been wrapped in an OpenSSL layer.
+            """
+
+            def handshake(s: socket.socket):
+                pktlen = struct.unpack("!I", s.recv(4))[0]
+
+                # Make sure we get an SSLRequest.
+                version = struct.unpack("!HH", s.recv(4))
+                assert version == (1234, 5679)
+                assert pktlen == 8
+
+                # Accept the SSLRequest.
+                s.send(b"S")
+
+                with self._ctx.wrap_socket(s, server_side=True) as wrapped:
+                    fn(wrapped)
+
+            self.background(handshake)
+
+    with _SSLServer() as s:
+        yield s
+
+
+#
+# Tests
+#
+
+
+@pytest.mark.parametrize("sslmode", ("require", "verify-ca", "verify-full"))
+def test_server_with_ssl_disabled(connect, tcp_server, certs, sslmode):
+    """
+    Make sure client refuses to talk to non-SSL servers with stricter
+    sslmodes.
+    """
+
+    def refuse_ssl(s: socket.socket):
+        pktlen = struct.unpack("!I", s.recv(4))[0]
+
+        # Make sure we get an SSLRequest.
+        version = struct.unpack("!HH", s.recv(4))
+        assert version == (1234, 5679)
+        assert pktlen == 8
+
+        # Refuse the SSLRequest.
+        s.send(b"N")
+
+        # Wait for the client to close the connection.
+        assert not s.recv(1), "client sent unexpected data"
+
+    tcp_server.background(refuse_ssl)
+
+    with pytest.raises(LibpqError, match="server does not support SSL"):
+        connect(
+            **tcp_server.conninfo,
+            sslrootcert=certs.ca.certpath,
+            sslmode=sslmode,
+        )
+
+
+def test_verify_full_connection(connect, ssl_server, certs):
+    """Completes a verify-full connection and empty query."""
+
+    def handle_empty_query(s: ssl.SSLSocket):
+        pktlen = struct.unpack("!I", s.recv(4))[0]
+
+        # Check the startup packet version, then discard the remainder.
+        version = struct.unpack("!HH", s.recv(4))
+        assert version == (3, 0)
+        s.recv(pktlen - 8)
+
+        # Send the required litany of server messages.
+        s.send(struct.pack("!cII", b"R", 8, 0))  # AuthenticationOK
+
+        # ParameterStatus: client_encoding
+        key = b"client_encoding\0"
+        val = b"UTF-8\0"
+        s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val)
+
+        # ParameterStatus: DateStyle
+        key = b"DateStyle\0"
+        val = b"ISO, MDY\0"
+        s.send(struct.pack("!cI", b"S", 4 + len(key) + len(val)) + key + val)
+
+        s.send(struct.pack("!cIII", b"K", 12, 1234, 1234))  # BackendKeyData
+        s.send(struct.pack("!cIc", b"Z", 5, b"I"))  # ReadyForQuery
+
+        # Expect an empty query.
+        pkttype = s.recv(1)
+        assert pkttype == b"Q"
+        pktlen = struct.unpack("!I", s.recv(4))[0]
+        assert s.recv(pktlen - 4) == b"\0"
+
+        # Send an EmptyQueryResponse+ReadyForQuery.
+        s.send(struct.pack("!cI", b"I", 4))
+        s.send(struct.pack("!cIc", b"Z", 5, b"I"))
+
+        # libpq should terminate and close the connection.
+        assert s.recv(1) == b"X"
+        pktlen = struct.unpack("!I", s.recv(4))[0]
+        assert pktlen == 4
+
+        assert not s.recv(1), "client sent unexpected data"
+
+    ssl_server.background_ssl(handle_empty_query)
+
+    conn = connect(
+        **ssl_server.conninfo,
+        sslrootcert=certs.ca.certpath,
+        sslmode="verify-full",
+    )
+    with conn:
+        assert conn.exec("").status() == ExecStatus.PGRES_EMPTY_QUERY
-- 
2.52.0

