"""
Send a raw PostgreSQL startup packet with _pq_.command_tag_format.
Then execute INSERT and check the CommandComplete message.
"""
import socket
import struct
import sys

def make_startup_packet(user, database, extra_params=None):
    """Build a v3 startup packet with optional _pq_ parameters."""
    params = []
    params.append(b"user\x00" + user.encode() + b"\x00")
    params.append(b"database\x00" + database.encode() + b"\x00")
    if extra_params:
        for k, v in extra_params.items():
            params.append(k.encode() + b"\x00" + v.encode() + b"\x00")
    params.append(b"\x00")  # terminator

    body = b"".join(params)
    # Version 3.0
    version = struct.pack("!I", 196608)  # 3 << 16
    length = struct.pack("!I", 4 + len(version) + len(body))
    return length + version + body

def read_message(sock):
    """Read one PG protocol message. Returns (type_byte, payload)."""
    hdr = sock.recv(1)
    if not hdr:
        return None, None
    msg_type = hdr.decode("ascii")
    length_bytes = sock.recv(4)
    length = struct.unpack("!I", length_bytes)[0]
    payload = b""
    remaining = length - 4
    while remaining > 0:
        chunk = sock.recv(remaining)
        if not chunk:
            break
        payload += chunk
        remaining -= len(chunk)
    return msg_type, payload

def send_query(sock, sql):
    """Send a simple query."""
    body = sql.encode() + b"\x00"
    msg = b"Q" + struct.pack("!I", 4 + len(body)) + body
    sock.sendall(msg)

def run_test(format_value=None):
    extra = {}
    if format_value:
        extra["_pq_.command_tag_format"] = format_value

    label = format_value or "default"
    print(f"\n=== Testing with _pq_.command_tag_format={label} ===")

    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.connect("/tmp/.s.PGSQL.5432")

    packet = make_startup_packet("postgres", "postgres", extra if extra else None)
    sock.sendall(packet)

    # Read auth + parameter status + ready
    ready = False
    while not ready:
        msg_type, payload = read_message(sock)
        if msg_type is None:
            print("Connection closed!")
            return
        if msg_type == "E":
            # Error
            err = payload.decode("utf-8", errors="replace")
            print(f"  ERROR: {err}")
            sock.close()
            return
        if msg_type == "R":
            # Auth request
            auth_type = struct.unpack("!I", payload[:4])[0]
            if auth_type == 0:
                pass  # AuthOK
        if msg_type == "S":
            # ParameterStatus
            parts = payload.rstrip(b"\x00").split(b"\x00")
            if len(parts) == 2:
                name, value = parts[0].decode(), parts[1].decode()
                if name == "command_tag_format":
                    print(f"  GUC_REPORT: command_tag_format = {value}")
        if msg_type == "Z":
            ready = True

    # Create table
    send_query(sock, "CREATE TABLE IF NOT EXISTS proto_test(id int)")
    while True:
        msg_type, payload = read_message(sock)
        if msg_type == "Z":
            break

    # INSERT
    send_query(sock, "INSERT INTO proto_test VALUES (777)")
    while True:
        msg_type, payload = read_message(sock)
        if msg_type == "C":
            tag = payload.rstrip(b"\x00").decode()
            print(f"  CommandComplete: {tag}")
        if msg_type == "Z":
            break

    # Cleanup
    send_query(sock, "DELETE FROM proto_test WHERE id = 777")
    while True:
        msg_type, payload = read_message(sock)
        if msg_type == "Z":
            break

    # Terminate
    sock.sendall(b"X\x00\x00\x00\x04")
    sock.close()

# Run tests
run_test(None)          # No _pq_ param → legacy
run_test("verbose")     # _pq_ verbose → INSERT tablename N
run_test("fqn")         # _pq_ fqn → INSERT schema.tablename N
run_test("legacy")      # _pq_ legacy → INSERT 0 N
run_test("modern")      # _pq_ modern → should fail (removed)

