#!/usr/bin/python3
#
# Reproducer for the libpq TLS buffering bug
#
# Subject: libpq: Process buffered SSL read bytes to support records >8kB on async API
# https://www.postgresql.org/message-id/2039ac58-d3e0-434b-ac1a-2a987f3b4cb1%40greiz-reinsdorf.de
#
# This test program has two parts:
#
# - A psycopg2 client program that runs a query against a running server
#
# - TLS proxy which sits between the client and the running server. It
#   buffers into larger chunks, and then splits them again into TLS frames
#   so that the TLS frames don't align with the Postgres protocol messages
#
# Usage:
#
# You must have a PostgreSQL server running on localhost port 5432,
# and accepting connections on 'postgres' database without a password.
#
import select
import socket
import ssl
import tempfile
import time
import threading
import psycopg2
import psycopg2.extras
import psycopg2.extensions

# Where is the PostgreSQL server running?
dest_host = 'localhost'
dest_port = 5432

#
# TLS Proxy
#
# This listens for an incoming TLS connection on given
# 'server_socket', which has already been bound to a port. When a
# connection is received, open a connection to the server listening at
# dest_host:dest_port. The connection from the proxy to the server is
# not encrypted.
#
# To tickle the bug, this buffers all data coming from the server into
# larger chunks, and splits them again at 12 kB. If no data is
# received from the server for 1 s, any buffered data is sent even if
# it's not reached 12 kB yet.
def tls_proxy(server_socket, dest_host, dest_port):
    print(f"Forwarding to {dest_host}:{dest_port}")

    server_socket.listen(1)

    raw_client_socket, addr = server_socket.accept()
    print(f"Connection from {addr}")

    # Create SSL context for the incoming connection
    tmp_cert_file = tempfile.NamedTemporaryFile()
    with open(tmp_cert_file.name, 'wb') as f:
        f.write(b"""
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDkpRy8folzDqSk
rV3NbIvIS9E4ZKiJmcnvHVCysPVwXB7iQiFNif2gJLqsV8KLsJxdM7QxE/dh2vug
sazXZAsCNBDQZD756hRtRvS+2OtCdjZ3WrPvVzQwG1iDHHIRISC9mrvQAvbdFtN1
nxILdIemPZtfrR3uzq0A+uHIGwoHih3+msrozQvF5IgHnz6J7JBZoabBFS8HGp9I
+SWrlwmj4LFVX0Gfx1x/Bf+8tJe4/lSgi6fS6qoBq9AtAb4+hyuw+VLyGlBbvE+C
DzEVjCJ9yXKguF4yNEy+VQp5IXrmiF9JAS/rg9/qGqTe9h89RTCP3ZIWKvGGvpoN
AhQEIHfnAgMBAAECggEAC+13Hc2GExao98mZfUgT1LDr1a1z7gBep0GCvlz6jiV9
cPUyh+PSc3e8YJWQGbnEC8f7jEc2iQgjnKLthCunuAeKTFdmbPgin6lrcXH0d+Tz
Jhp5DNvM5lOO3wFtYT+2qwghkxI3Lu/BZpST8ZK0ft4eNwlbaI6nOh1cXj7W25oQ
1EvomnUwUlc6ADudpg5rQh1HbiICk6uSE0A530Qo8D4FpiYoVZwt8l4vcIyGoyiU
kITL3KK17yZ+zWI8bqZT3iSzzcmiKxz6qLUomcfkwUAwFvgZVYmP2c5fTcRUMl5c
t0zuyJlus6vvovIvTwkCqs6LaY//fEO2xcYIK9UbfQKBgQD8c4ss1QDO1McOdG+3
Wd8Aj0t9Iu2ziknjwFDs2YbBOF6FW0WzItrKurxfKNPhN73qJPD6GmFiujRAR/Ui
xbVND8eGx9ocAmvUQyst8NK56OmKNpuKpNGnL5s2TeVJf5iXdaeBs/i+hkwiXe9C
+Ufdc5QAfVDD+B1m3146p3CcBQKBgQDn2+aA9ZMYv3qfypmg5VfhfmTF712HkSvn
nRoBH5YW5QQCG0kuXDz8TdTBSEonbyCDXD182szzn9epKb+IZhScXRYlYnChN6uG
3OmHvzAzsYUw5f+zWA3XfKnL3y9qkfaAOo9zONzWs9FB2sITv0eKNd4mf8Y+ddE1
EE1OHyCz+wKBgQDk5GpS+snhvmDBRWcpag3ctw/t5OZ6vC7kljGJnm0k8dQZu7jF
hBu2Znt3GFCLynuiOV5YleSonEXV5qhnn7UTqvPwy3GBpdxYt5IF9G1L7Nca3wpG
OcxxdqOXKCd1bYBQC3gWDLTDIocTPfI62kSDkFCn5Pd+x475ABuyuLBMdQKBgA4Z
J/X1eMFLe1hWCGtpJqPWfKgwet5wbFwECH3C/uxbdpfuMs/32dl5nhM2oxOsxSxX
ooGCCG5T7NgjarsPgfdUDbGuP6z95pcnvad8b6DlDXVAtwCfvQ+6S9TSuF5hi7yW
Uvytm3gOrQ21EJIE0oPL7Lsoj9Ric5snZ5v1dpabAoGAY5Akg6JCQhmefQtfGoQN
PDNIZGoc9MvAlhP+9Yg6qMSPapjN0BMt4KDcG7nF8fNAnIlAgzE7VVhyn3uHXr9g
q3MnT3ZZ/bOcmQSgNiUY/BYP2G8Wz47aC50RAsdX28/vF8HomD16bHwaoWN2M343
PXw+1B4WzaAKSS0JTRqfaxM=
-----END PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIDjzCCAnegAwIBAgIUKRYbc0eN7+LEdjc+X8QX3a+15EYwDQYJKoZIhvcNAQEL
BQAwVzELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5
MRUwEwYDVQQKDAxPcmdhbml6YXRpb24xEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0y
NjAzMjUxMTQ2MjhaFw0yNzAzMjUxMTQ2MjhaMFcxCzAJBgNVBAYTAlVTMQ4wDAYD
VQQIDAVTdGF0ZTENMAsGA1UEBwwEQ2l0eTEVMBMGA1UECgwMT3JnYW5pemF0aW9u
MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK
AoIBAQDkpRy8folzDqSkrV3NbIvIS9E4ZKiJmcnvHVCysPVwXB7iQiFNif2gJLqs
V8KLsJxdM7QxE/dh2vugsazXZAsCNBDQZD756hRtRvS+2OtCdjZ3WrPvVzQwG1iD
HHIRISC9mrvQAvbdFtN1nxILdIemPZtfrR3uzq0A+uHIGwoHih3+msrozQvF5IgH
nz6J7JBZoabBFS8HGp9I+SWrlwmj4LFVX0Gfx1x/Bf+8tJe4/lSgi6fS6qoBq9At
Ab4+hyuw+VLyGlBbvE+CDzEVjCJ9yXKguF4yNEy+VQp5IXrmiF9JAS/rg9/qGqTe
9h89RTCP3ZIWKvGGvpoNAhQEIHfnAgMBAAGjUzBRMB0GA1UdDgQWBBQ/KejUOncs
pm5k/wkn24uD4mbvqjAfBgNVHSMEGDAWgBQ/KejUOncspm5k/wkn24uD4mbvqjAP
BgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA9GS+ZJqgidut9ZCBH
+GC/WlFZ5ueQFZWiaPGdylRdR3hDqdQy/s2pfshJPs8kmMqMGOJr2lI1GeZOYOQC
vgn3m2YZabtABzuLGpZxJXGByXa07AIob4DGgpoDY1hNEf8F7jjhnl7M/64R2blE
tA5hCUsKuF7qeGVmXKlkfMLbJvwbncXGMY4wComF4crrF4v9K+TpSLTTZpvaft63
ucwSU70izWOzt/QoiStpeOQ9aQ1p4Yf4wALd4LN5zvQklJTr98/30odGZj+vfgjg
p3AshIu+zB2GgUVMCsRbQcKuWwD5a28ZavqFWEu382v4tatAzh1HxnII52FKCepI
wRBr
-----END CERTIFICATE-----
""")

    print(f"cert file {tmp_cert_file.name}")
    listen_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
    listen_context.load_cert_chain(tmp_cert_file.name)
    listen_context.set_alpn_protocols(["postgresql"])

    ssl_client_socket = listen_context.wrap_socket(raw_client_socket, server_side=True)

    tmp_cert_file.close()

    try:
        # Connect to the destination
        dest_socket = socket.create_connection((dest_host, dest_port))

        def forward_data(src_socket, dst_socket, direction):
            """Forward data from source to destination socket"""
            try:
                while True:
                    data = src_socket.recv(100 * 1024)
                    if not data:
                        print(f"[{direction}] exiting")
                        break
                    print(f"[{direction}] Received and forwarded {len(data)} bytes")
                    dst_socket.send(data)
            finally:
                try:
                    src_socket.close()
                    dst_socket.close()
                except:
                    pass

        def forward_data_buffered(src_socket, dst_socket, direction):
            """Forward data from source to destination socket"""
            try:
                src_socket.setblocking(False)
                buf = bytearray()
                while True:
                    # Set a timeout for recv operations
                    if len(buf) > 0:
                        timeout = 1
                    else:
                        timeout = None
                    r, w, e = select.select([src_socket], [], [], timeout)
                    if src_socket in r:
                        data = src_socket.recv(100 * 1024)
                        print(f"[{direction}] Received {len(data)} bytes")
                        buf.extend(data)
                    else:
                        while len(buf) > 0:
                            cutpoint = min(12*1024, len(buf))
                            dst_socket.send(buf[:cutpoint])
                            buf = buf[cutpoint:]
                            print(f"[{direction}] Timeout, Forwarded {cutpoint} bytes")
            finally:
                try:
                    src_socket.close()
                    dst_socket.close()
                except:
                    pass

        # Start forwarding threads
        thread1 = threading.Thread(target=forward_data, args=(ssl_client_socket, dest_socket, "CLIENT->SERVER"))
        thread1.daemon = True
        thread2 = threading.Thread(target=forward_data_buffered, args=(dest_socket, ssl_client_socket, "DEST->SERVER"))
        thread2.daemon = True
        thread1.start()
        thread2.start()
        thread1.join()
        thread2.join()

    except Exception as e:
        print(f"Error handling client: {e}")
    finally:
        try:
            ssl_client_socket.close()
        except:
            pass

# Create a listener socket for the TLS proxy, and start thread to await for connection
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind(('localhost', 0))
listen_port = server_socket.getsockname()[1]

print(f"Listening on port {listen_port}")

client_thread = threading.Thread(
    target=tls_proxy,
    args=(server_socket, dest_host, dest_port)
)
client_thread.daemon = True
client_thread.start()

# wait a little so that the thread has reached listen()
time.sleep(1)


#
# Run the client
#
aconn = psycopg2.connect(f"host='localhost' port={listen_port} dbname='postgres' sslnegotiation='direct' sslmode='require'", async_=True)
psycopg2.extras.wait_select(aconn)

acurs = aconn.cursor()

acurs.execute("""
SELECT repeat('x', 32000)
union all
select repeat('y', 800)
""")

while True:
    print("polling")
    state = aconn.poll()
    if state == psycopg2.extensions.POLL_OK:
        break
    elif state == psycopg2.extensions.POLL_WRITE:
        select.select([], [aconn.fileno()], [])
    elif state == psycopg2.extensions.POLL_READ:
        print("read ready")
        # If you hit the libpq bug, this will time out
        r, w, e = select.select([aconn.fileno()], [], [], 5)
        if r == [] and w == [] and e == []:
            print("FAIL: the libpq bug was triggered, select timed out")
            exit(1)
    else:
        raise psycopg2.OperationalError("poll() returned %s" % state)

result = acurs.fetchall()

print(f"Success, bug not triggered ({len(result)} rows)")
