From 9f9dce9a106351e969b5adc6d0744a62a677470b Mon Sep 17 00:00:00 2001
From: Jacob Champion <jacob.champion@enterprisedb.com>
Date: Tue, 16 Dec 2025 09:31:46 +0100
Subject: [PATCH v9 5/5] WIP: pytest: Add some server-side SSL tests

In the same vein as the previous commit, this is a server-only test
suite operating against a mock client. The test itself is a heavily
parameterized check for direct-SSL handshake behavior, using a
combination of "standard" and "custom" certificates via the certs
fixture.

installcheck is currently unsupported, but the architecture has some
extension points that should make it possible later. For now, a new
server is always started for the test session.

TODOs:
- improve remaining_timeout() integration with socket operations; at the
  moment, the timeout resets on every call rather than decrementing
---
 src/test/ssl/pyt/conftest.py    |  50 ++++++++++
 src/test/ssl/pyt/test_server.py | 161 ++++++++++++++++++++++++++++++++
 2 files changed, 211 insertions(+)
 create mode 100644 src/test/ssl/pyt/test_server.py

diff --git a/src/test/ssl/pyt/conftest.py b/src/test/ssl/pyt/conftest.py
index 870f738ac44..d121724800b 100644
--- a/src/test/ssl/pyt/conftest.py
+++ b/src/test/ssl/pyt/conftest.py
@@ -126,3 +126,53 @@ def certs(cryptography, tmp_path_factory):
             return f.name
 
     return _Certs()
+
+
+@pytest.fixture(scope="module", autouse=True)
+def ssl_setup(pg_server_module, certs, datadir):
+    """
+    Sets up required server settings for all tests in this module.
+    """
+    try:
+        with pg_server_module.restarting() as s:
+            s.conf.set(
+                ssl="on",
+                ssl_ca_file=certs.ca.certpath,
+                ssl_cert_file=certs.server.certpath,
+                ssl_key_file=certs.server.keypath,
+            )
+
+            # Reject by default.
+            s.hba.prepend("hostssl all all all reject")
+
+    except subprocess.CalledProcessError:
+        # This is a decent place to skip if the server isn't set up for SSL.
+        logpath = datadir / "postgresql.log"
+        unsupported = re.compile("SSL is not supported")
+
+        with open(logpath, "r") as log:
+            for line in log:
+                if unsupported.search(line):
+                    pytest.skip("the server does not support SSL")
+
+        # Some other error happened.
+        raise
+
+    users = pg_server_module.create_users("ssl")
+    dbs = pg_server_module.create_dbs("ssl")
+
+    return (users, dbs)
+
+
+@pytest.fixture(scope="module")
+def client_cert(ssl_setup, certs):
+    """
+    Creates a Cert for the "ssl" user.
+    """
+    from cryptography import x509
+    from cryptography.x509.oid import NameOID
+
+    users, _ = ssl_setup
+    user = users["ssl"]
+
+    return certs.new(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, user)]))
diff --git a/src/test/ssl/pyt/test_server.py b/src/test/ssl/pyt/test_server.py
new file mode 100644
index 00000000000..d5cb14b6c9a
--- /dev/null
+++ b/src/test/ssl/pyt/test_server.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2025, PostgreSQL Global Development Group
+
+import re
+import socket
+import ssl
+import struct
+
+import pytest
+
+import pypg
+
+# This suite opens up local TCP ports and is hidden behind PG_TEST_EXTRA=ssl.
+pytestmark = pypg.require_test_extras("ssl")
+
+# For use with the `creds` parameter below.
+CLIENT = "client"
+SERVER = "server"
+
+
+# fmt: off
+@pytest.mark.parametrize(
+    "auth_method,                    creds,  expected_error",
+[
+    # Trust allows anything.
+    ("trust",                        None,   None),
+    ("trust",                        CLIENT, None),
+    ("trust",                        SERVER, None),
+
+    # verify-ca allows any CA-signed certificate.
+    ("trust clientcert=verify-ca",   None,   "requires a valid client certificate"),
+    ("trust clientcert=verify-ca",   CLIENT, None),
+    ("trust clientcert=verify-ca",   SERVER, None),
+
+    # cert and verify-full allow only the correct certificate.
+    ("trust clientcert=verify-full", None,   "requires a valid client certificate"),
+    ("trust clientcert=verify-full", CLIENT, None),
+    ("trust clientcert=verify-full", SERVER, "authentication failed for user"),
+    ("cert",                         None,   "requires a valid client certificate"),
+    ("cert",                         CLIENT, None),
+    ("cert",                         SERVER, "authentication failed for user"),
+],
+)
+# fmt: on
+def test_direct_ssl_certificate_authentication(
+    pg,
+    ssl_setup,
+    certs,
+    client_cert,
+    remaining_timeout,
+    # test parameters
+    auth_method,
+    creds,
+    expected_error,
+):
+    """
+    Tests direct SSL connections with various client-certificate/HBA
+    combinations.
+    """
+
+    # Set up the HBA as desired by the test.
+    users, dbs = ssl_setup
+
+    user = users["ssl"]
+    db = dbs["ssl"]
+
+    with pg.reloading() as s:
+        s.hba.prepend(
+            ["hostssl", db, user, "127.0.0.1/32", auth_method],
+            ["hostssl", db, user, "::1/128", auth_method],
+        )
+
+    # Configure the SSL settings for the client.
+    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+    ctx.load_verify_locations(cafile=certs.ca.certpath)
+    ctx.set_alpn_protocols(["postgresql"])  # for direct SSL
+
+    # Load up a client certificate if required by the test.
+    if creds == CLIENT:
+        ctx.load_cert_chain(client_cert.certpath, client_cert.keypath)
+    elif creds == SERVER:
+        # Using a server certificate as the client credential is expected to
+        # work only for clientcert=verify-ca (and `trust`, naturally).
+        ctx.load_cert_chain(certs.server.certpath, certs.server.keypath)
+
+    # Make a direct SSL connection. There's no SSLRequest in the handshake; we
+    # simply wrap a TCP connection with OpenSSL.
+    addr = (pg.hostaddr, pg.port)
+    with socket.create_connection(addr) as s:
+        s.settimeout(remaining_timeout())  # XXX this resets every operation
+
+        with ctx.wrap_socket(s, server_hostname=certs.server_host) as conn:
+            # Build and send the startup packet.
+            startup_options = dict(
+                user=user,
+                database=db,
+                application_name="pytest",
+            )
+
+            payload = b""
+            for k, v in startup_options.items():
+                payload += k.encode() + b"\0"
+                payload += str(v).encode() + b"\0"
+            payload += b"\0"  # null terminator
+
+            pktlen = 4 + 4 + len(payload)
+            conn.send(struct.pack("!IHH", pktlen, 3, 0) + payload)
+
+            if not expected_error:
+                # Expect an AuthenticationOK to come back.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"R"
+                assert pktlen == 8
+
+                authn_result = struct.unpack("!I", conn.recv(4))[0]
+                assert authn_result == 0
+
+                # Read and discard to ReadyForQuery.
+                while True:
+                    pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                    payload = conn.recv(pktlen - 4)
+
+                    if pkttype == b"Z":
+                        assert payload == b"I"
+                        break
+
+                # Send an empty query.
+                conn.send(struct.pack("!cI", b"Q", 5) + b"\0")
+
+                # Expect EmptyQueryResponse+ReadyForQuery.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"I"
+                assert pktlen == 4
+
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"Z"
+
+                payload = conn.recv(pktlen - 4)
+                assert payload == b"I"
+
+            else:
+                # Match the expected authentication error.
+                pkttype, pktlen = struct.unpack("!cI", conn.recv(5))
+                assert pkttype == b"E"
+
+                payload = conn.recv(pktlen - 4)
+                msg = None
+
+                for component in payload.split(b"\0"):
+                    if not component:
+                        break  # end of message
+
+                    key, val = component[:1], component[1:]
+                    if key == b"S":
+                        assert val == b"FATAL"
+                    elif key == b"M":
+                        msg = val.decode()
+
+                assert re.search(expected_error, msg), "server error did not match"
+
+            # Terminate.
+            conn.send(struct.pack("!cI", b"X", 4))
-- 
2.52.0

