import socket
import struct

HOST = "127.0.0.1"
PORT = 5432
USER = "postgres"
PASSWORD = "yourpassword"
DATABASE = "postgres"

def send_msg(sock, msg_type, payload):
    sock.sendall(msg_type.encode("ascii") + struct.pack("!I", len(payload) + 4) + payload)

def recv_msg(sock):
    msg_type = sock.recv(1)
    length = struct.unpack("!I", sock.recv(4))[0] - 4
    payload = sock.recv(length)
    return msg_type, payload

def read_cstring(data, offset):
    end = data.index(0, offset)
    return data[offset:end].decode(), end + 1

def startup(sock):
    params = {"user": USER, "database": DATABASE, "client_encoding": "UTF8"}
    body = b"".join(k.encode() + b"\x00" + v.encode() + b"\x00" for k, v in params.items()) + b"\x00"
    packet = struct.pack("!I", 196608) + body
    sock.sendall(struct.pack("!I", len(packet) + 4) + packet)

def authenticate(sock):
    while True:
        msg_type, data = recv_msg(sock)
        if msg_type == b"R":
            code = struct.unpack("!I", data[:4])[0]
            if code == 0:
                print("Authentication OK")
            elif code == 3:
                send_msg(sock, "p", PASSWORD.encode() + b"\x00")
                continue
            else:
                raise Exception(f"Unsupported auth method: {code}")
        elif msg_type in (b"K", b"S"):
            continue
        elif msg_type == b"Z":
            break

def parse_row_description(payload):
    col_count = struct.unpack("!H", payload[:2])[0]
    offset = 2
    columns = []
    for _ in range(col_count):
        name, offset = read_cstring(payload, offset)
        offset += 18
        columns.append(name)
    return columns

def parse_data_row(payload):
    field_count = struct.unpack("!H", payload[:2])[0]
    offset = 2
    fields = []
    for _ in range(field_count):
        (length,) = struct.unpack("!I", payload[offset:offset+4])
        offset += 4
        if length == 0xFFFFFFFF:
            fields.append(None)
        else:
            fields.append(payload[offset:offset+length].decode())
            offset += length
    return fields

def extended_query(sock, sql):
    # 1️⃣ Parse: unnamed statement
    payload = b"\x00" + sql.encode() + b"\x00" + struct.pack("!H", 0)  # no params
    send_msg(sock, "P", payload)

    # 2️⃣ Bind: unnamed portal + statement
    payload = b"\x00\x00" + struct.pack("!H", 0) + struct.pack("!H", 0) + struct.pack("!H", 0)
    send_msg(sock, "B", payload)

    # 3️⃣ Describe (portal)
    send_msg(sock, "D", b"P\x00")

    # 4️⃣ Execute
    send_msg(sock, "E", b"\x00" + struct.pack("!I", 7))  # fetch all rows
    print("back-to-back execute ......");
    send_msg(sock, "E", b"\x00" + struct.pack("!I", 1))  # fetch all rows

    # 3️⃣ Describe (portal
    #send_msg(sock, "D", b"P\x00")

    # 5️⃣ Sync
    send_msg(sock, "S", b"")

    # Receive responses
    columns = []
    rows = []
    while True:
        msg_type, data = recv_msg(sock)
        if msg_type == b"1":  # ParseComplete
            continue
        elif msg_type == b"2":  # BindComplete
            continue
        elif msg_type == b"T":  # RowDescription
            columns = parse_row_description(data)
        elif msg_type == b"D":  # DataRow
            rows.append(parse_data_row(data))
        elif msg_type == b"C":  # CommandComplete
            continue
        elif msg_type == b"Z":  # ReadyForQuery
            break
        elif msg_type == b"E":  # ErrorResponse
            print("Error:", data)
            break

    for row in rows:
        print(dict(zip(columns, row)))

def main():
    with socket.create_connection((HOST, PORT)) as sock:
        startup(sock)
        authenticate(sock)
        extended_query(sock, "SELECT * from foo;")

if __name__ == "__main__":
    main()

