#!/usr/bin/env python3
"""
Standalone PostgreSQL Index Prefetch Benchmark Tool

Benchmarks different indexscan_prefetch_distance values and stores results
in a PostgreSQL table for analysis.

Usage:
    # Run benchmark and store in database
    python prefetch_benchmark.py --table prefetch_size_medium_50mb --runs 2
    
    # Run with specific distances
    python prefetch_benchmark.py --table prefetch_size_xlarge_500mb --distances "1,4,16,64,0"
    
    # Query stored results
    python prefetch_benchmark.py --query
    
    # Export results to CSV
    python prefetch_benchmark.py --export results.csv

Requirements:
    pip install psycopg[binary]
"""

import argparse
import json
import os
import sys
import threading
import time
from collections import defaultdict
from datetime import datetime

try:
    import psycopg
except ImportError:
    print("Error: psycopg not found. Install with: pip install psycopg[binary]")
    sys.exit(1)

# Default configuration
DEFAULT_CONN = "dbname=postgres host=/tmp"
# -1 = prefetch OFF, 0 = unlimited, positive = specific distance cap
DEFAULT_DISTANCES = [-1, 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
RESULTS_TABLE = "prefetch_benchmark_results"
WAIT_SAMPLES_TABLE = "prefetch_benchmark_waits"


def distance_label(distance):
    """Convert distance value to human-readable label."""
    if distance < 0:
        return "off"
    elif distance == 0:
        return "unlimited"
    else:
        return str(distance)


def get_connection(conn_str):
    """Get a database connection."""
    return psycopg.connect(conn_str)


def init_results_tables(conn_str):
    """Create results tables if they don't exist."""
    conn = get_connection(conn_str)
    conn.autocommit = True
    cur = conn.cursor()
    
    # Main results table
    cur.execute(f"""
        CREATE TABLE IF NOT EXISTS {RESULTS_TABLE} (
            id SERIAL PRIMARY KEY,
            run_id TEXT NOT NULL,
            run_timestamp TIMESTAMPTZ DEFAULT NOW(),
            test_table TEXT NOT NULL,
            table_size_bytes BIGINT,
            prefetch_distance INT NOT NULL,
            io_method TEXT,
            io_workers INT,
            shared_buffers TEXT,
            effective_io_concurrency INT,
            cold_cache BOOLEAN DEFAULT TRUE,
            elapsed_seconds FLOAT NOT NULL,
            aio_wait_pct FLOAT,
            total_samples INT,
            notes TEXT
        )
    """)
    
    # Wait event samples table
    cur.execute(f"""
        CREATE TABLE IF NOT EXISTS {WAIT_SAMPLES_TABLE} (
            id SERIAL PRIMARY KEY,
            result_id INT REFERENCES {RESULTS_TABLE}(id) ON DELETE CASCADE,
            wait_event_type TEXT,
            wait_event TEXT,
            sample_count INT
        )
    """)
    
    # Create indexes for common queries
    cur.execute(f"""
        CREATE INDEX IF NOT EXISTS idx_results_run_id ON {RESULTS_TABLE}(run_id);
        CREATE INDEX IF NOT EXISTS idx_results_table ON {RESULTS_TABLE}(test_table);
        CREATE INDEX IF NOT EXISTS idx_results_distance ON {RESULTS_TABLE}(prefetch_distance);
    """)
    
    conn.close()
    print(f"Initialized tables: {RESULTS_TABLE}, {WAIT_SAMPLES_TABLE}")


def get_pg_settings(conn_str):
    """Get relevant PostgreSQL settings."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    settings = {}
    params = ['io_method', 'shared_buffers', 'io_workers', 
              'effective_io_concurrency', 'max_parallel_workers_per_gather']
    
    for param in params:
        try:
            cur.execute(f"SHOW {param}")
            settings[param] = cur.fetchone()[0]
        except Exception:
            settings[param] = None
    
    conn.close()
    return settings


def get_table_size(conn_str, table_name):
    """Get table size in bytes."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    try:
        cur.execute("SELECT pg_relation_size(%s)", (table_name,))
        size = cur.fetchone()[0]
    except Exception:
        size = None
    
    conn.close()
    return size


def purge_os_cache():
    """Purge OS filesystem cache (macOS only, requires sudo access)."""
    import subprocess
    try:
        subprocess.run(['sudo', '/usr/sbin/purge'], check=True, capture_output=True)
        return True
    except Exception:
        return False


def evict_table_buffers(conn_str, table_name, purge_os=False):
    """Evict table's buffers from shared_buffers for cold cache test."""
    # Optionally purge OS cache first
    if purge_os:
        purge_os_cache()
    
    conn = get_connection(conn_str)
    conn.autocommit = True
    cur = conn.cursor()
    
    cur.execute(f"""
        SELECT count(pg_buffercache_evict(bufferid)) 
        FROM pg_buffercache b 
        JOIN pg_class c ON b.relfilenode = pg_relation_filenode(c.oid) 
        WHERE c.relname = %s
    """, (table_name,))
    evicted = cur.fetchone()[0]
    
    conn.close()
    return evicted


def run_single_benchmark(conn_str, table_name, distance, cold_cache=True, purge_os=False):
    """Run a single benchmark with specified prefetch distance."""
    result = {
        'elapsed': 0,
        'success': False,
        'waits': {},
        'samples': 0,
        'error': None
    }
    
    # Evict buffers for cold cache test
    if cold_cache:
        evict_table_buffers(conn_str, table_name, purge_os=purge_os)
    
    try:
        conn = get_connection(conn_str)
        conn.autocommit = True
        cur = conn.cursor()
        
        # Configure session
        cur.execute("SET max_parallel_workers_per_gather = 0")
        cur.execute("SET enable_bitmapscan = off")
        cur.execute("SET enable_seqscan = off")
        cur.execute("SET enable_indexonlyscan = off")
        
        # Configure prefetching: -1 = off, 0 = unlimited, positive = cap
        if distance < 0:
            cur.execute("SET enable_indexscan_prefetch = off")
        else:
            cur.execute("SET enable_indexscan_prefetch = on")
            cur.execute(f"SET indexscan_prefetch_distance = {distance}")
        
        
        # Run the query
        start = time.time()
        cur.execute(f"SELECT length(payload) FROM {table_name} ORDER BY sequential")
        cur.fetchall()
        result['elapsed'] = time.time() - start
        result['success'] = True
        
        conn.close()
        
    except Exception as e:
        result['error'] = str(e)
    
    return result


def store_result(conn_str, run_id, table_name, distance, result, settings, table_size, cold_cache):
    """Store a single benchmark result in the database."""
    conn = get_connection(conn_str)
    conn.autocommit = True
    cur = conn.cursor()
    

    # Insert main result
    cur.execute(f"""
        INSERT INTO {RESULTS_TABLE} 
        (run_id, test_table, table_size_bytes, prefetch_distance, 
         io_method, io_workers, shared_buffers, effective_io_concurrency,
         cold_cache, elapsed_seconds)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        RETURNING id
    """, (
        run_id, table_name, table_size, distance,
        settings.get('io_method'), 
        int(settings.get('io_workers', 0)) if settings.get('io_workers') else None,
        settings.get('shared_buffers'),
        int(settings.get('effective_io_concurrency', 0)) if settings.get('effective_io_concurrency') else None,
        cold_cache, result['elapsed']
    ))
    
    result_id = cur.fetchone()[0]
    
    # Insert wait event samples
    for (wait_type, wait_event), count in result['waits'].items():
        cur.execute(f"""
            INSERT INTO {WAIT_SAMPLES_TABLE} 
            (result_id, wait_event_type, wait_event, sample_count)
            VALUES (%s, %s, %s, %s)
        """, (result_id, wait_type, wait_event, count))
    
    conn.close()
    return result_id


def run_benchmark(conn_str, table_name, distances, runs=2, cold_cache=True, purge_os=False):
    """Run full benchmark suite."""
    run_id = datetime.now().strftime('%Y%m%d_%H%M%S')
    settings = get_pg_settings(conn_str)
    table_size = get_table_size(conn_str, table_name)
    
    print(f"\n{'='*70}")
    print(f"PostgreSQL Index Prefetch Benchmark")
    print(f"{'='*70}")
    print(f"Run ID: {run_id}")
    print(f"Table: {table_name} ({table_size / 1024 / 1024:.1f} MB)" if table_size else f"Table: {table_name}")
    print(f"Settings: io_method={settings.get('io_method')}, io_workers={settings.get('io_workers')}")
    print(f"Distances: {distances}")
    print(f"Runs per distance: {runs}")
    print(f"Cold cache: {cold_cache}, OS purge: {purge_os}")
    print(f"{'='*70}")
    
    results = []
    
    for distance in distances:
        dist_label = distance_label(distance)
        print(f"\nTesting prefetch={dist_label}...")
        
        times = []
        all_waits = defaultdict(int)
        total_samples = 0
        for run in range(runs):
            print(f"  Run {run + 1}/{runs}...", end=" ", flush=True)
            
            result = run_single_benchmark(conn_str, table_name, distance, cold_cache, purge_os)
            
            if result['success']:
                times.append(result['elapsed'])
                # Aggregate wait events
                for k, v in result.get('waits', {}).items():
                    all_waits[k] += v
                total_samples += result.get('samples', 0)
                store_result(conn_str, run_id, table_name, distance, 
                           result, settings, table_size, cold_cache)
                print(f"done ({result['elapsed']:.2f}s)")
            else:
                print(f"ERROR: {result['error']}")
        
        if times:
            results.append({
                'distance': distance,
                'avg_time': sum(times) / len(times),
                'min_time': min(times),
                'max_time': max(times),
                'waits': dict(all_waits),
                'total_samples': total_samples
            })
    
    return run_id, results


def print_results_table(results):
    """Print results summary."""
    if not results:
        return
    
    print(f"\n{'='*70}")
    print("RESULTS SUMMARY")
    print(f"{'='*70}")
    print(f"{'Prefetch':<12} {'Avg Time':<12} {'Min Time':<12} {'Max Time':<12}")
    print(f"{'-'*48}")
    
    for r in results:
        dist_label = distance_label(r['distance'])
        print(f"{dist_label:<12} {r['avg_time']:.2f}s{'':<6} {r['min_time']:.2f}s{'':<6} {r['max_time']:.2f}s")
    
    best = min(results, key=lambda x: x['avg_time'])
    best_label = distance_label(best['distance'])
    print(f"\nOPTIMAL: prefetch={best_label} ({best['avg_time']:.2f}s)")
    
    # Print wait event breakdown for each configuration
    print(f"\n{'='*70}")
    print("WAIT EVENT BREAKDOWN (estimated time)")
    print(f"{'='*70}")
    
    for r in results:
        dist_label = distance_label(r['distance'])
        total = r.get('total_samples', 0)
        waits = r.get('waits', {})
        elapsed = r.get('avg_time', 0)
        
        if total == 0:
            continue
        
        print(f"\nprefetch={dist_label} (elapsed: {elapsed:.2f}s):")
        
        # Group by wait type and sort by count
        by_type = defaultdict(int)
        for (wtype, wevent), count in waits.items():
            by_type[(wtype, wevent)] = count
        
        sorted_waits = sorted(by_type.items(), key=lambda x: -x[1])
        
        for (wtype, wevent), count in sorted_waits[:10]:  # Top 10
            pct = count / total * 100
            est_time = elapsed * pct / 100
            bar_len = int(pct / 2)
            bar = '#' * bar_len
            label = f"{wtype}/{wevent}" if wtype != wevent else wtype
            print(f"  {label:<30} {est_time:5.2f}s ({pct:4.1f}%) |{bar}")


def query_results(conn_str, run_id=None, table_name=None, limit=50):
    """Query and display stored results."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    query = f"""
        SELECT run_id, test_table, prefetch_distance, io_method, io_workers,
               AVG(elapsed_seconds) as avg_time,
               MIN(elapsed_seconds) as min_time,
               MAX(elapsed_seconds) as max_time,
               AVG(aio_wait_pct) as avg_aio_pct,
               COUNT(*) as runs
        FROM {RESULTS_TABLE}
        WHERE 1=1
    """
    params = []
    
    if run_id:
        query += " AND run_id = %s"
        params.append(run_id)
    if table_name:
        query += " AND test_table = %s"
        params.append(table_name)
    
    query += """
        GROUP BY run_id, test_table, prefetch_distance, io_method, io_workers
        ORDER BY run_id DESC, prefetch_distance
        LIMIT %s
    """
    params.append(limit)
    
    cur.execute(query, params)
    rows = cur.fetchall()
    conn.close()
    
    if not rows:
        print("No results found.")
        return
    
    print(f"\n{'='*100}")
    print("STORED BENCHMARK RESULTS")
    print(f"{'='*100}")
    print(f"{'Run ID':<18} {'Table':<25} {'Prefetch':<10} {'io_method':<10} {'workers':<8} "
          f"{'Avg(s)':<8} {'AIO%':<8} {'Runs':<5}")
    print(f"{'-'*100}")
    
    for row in rows:
        run_id, table, dist, io_method, workers, avg_t, min_t, max_t, aio_pct, runs = row
        dist_label = distance_label(dist)
        print(f"{run_id:<18} {table[:24]:<25} {dist_label:<10} {io_method or 'N/A':<10} "
              f"{workers or 'N/A':<8} {avg_t:.2f}{'':<4} {aio_pct:.1f}%{'':<4} {runs}")


def export_results(conn_str, output_file, run_id=None):
    """Export results to CSV."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    query = f"""
        SELECT r.run_id, r.run_timestamp, r.test_table, r.table_size_bytes,
               r.prefetch_distance, r.io_method, r.io_workers, r.shared_buffers,
               r.effective_io_concurrency, r.cold_cache, r.elapsed_seconds,
               r.aio_wait_pct, r.total_samples
        FROM {RESULTS_TABLE} r
    """
    params = []
    
    if run_id:
        query += " WHERE r.run_id = %s"
        params.append(run_id)
    
    query += " ORDER BY r.run_id, r.prefetch_distance"
    
    cur.execute(query, params)
    rows = cur.fetchall()
    columns = [desc[0] for desc in cur.description]
    conn.close()
    
    with open(output_file, 'w') as f:
        f.write(','.join(columns) + '\n')
        for row in rows:
            f.write(','.join(str(v) if v is not None else '' for v in row) + '\n')
    
    print(f"Exported {len(rows)} results to {output_file}")


def export_json(conn_str, output_file, run_id=None):
    """Export results to JSON."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    query = f"""
        SELECT r.*, 
               json_agg(json_build_object(
                   'wait_type', w.wait_event_type,
                   'wait_event', w.wait_event,
                   'count', w.sample_count
               )) FILTER (WHERE w.id IS NOT NULL) as wait_events
        FROM {RESULTS_TABLE} r
        LEFT JOIN {WAIT_SAMPLES_TABLE} w ON r.id = w.result_id
    """
    params = []
    
    if run_id:
        query += " WHERE r.run_id = %s"
        params.append(run_id)
    
    query += " GROUP BY r.id ORDER BY r.run_id, r.prefetch_distance"
    
    cur.execute(query, params)
    rows = cur.fetchall()
    columns = [desc[0] for desc in cur.description]
    conn.close()
    
    results = []
    for row in rows:
        result = dict(zip(columns, row))
        # Convert datetime to string
        if result.get('run_timestamp'):
            result['run_timestamp'] = str(result['run_timestamp'])
        results.append(result)
    
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Exported {len(results)} results to {output_file}")


def list_tables(conn_str):
    """List available test tables."""
    conn = get_connection(conn_str)
    cur = conn.cursor()
    
    cur.execute("""
        SELECT relname, pg_size_pretty(pg_relation_size(oid)) as size,
               pg_relation_size(oid) as size_bytes
        FROM pg_class 
        WHERE relname LIKE 'prefetch%' AND relkind = 'r'
        ORDER BY pg_relation_size(oid) DESC
    """)
    
    rows = cur.fetchall()
    conn.close()
    
    if not rows:
        print("No prefetch test tables found.")
        return
    
    print(f"\n{'='*50}")
    print("Available Test Tables")
    print(f"{'='*50}")
    print(f"{'Table Name':<30} {'Size':<15}")
    print(f"{'-'*45}")
    
    for name, size, _ in rows:
        print(f"{name:<30} {size:<15}")


def create_test_table(conn_str, table_name, rows=100000, payload_size=500):
    """Create a test table if it doesn't exist."""
    conn = get_connection(conn_str)
    conn.autocommit = True
    cur = conn.cursor()
    
    # Check if table exists
    cur.execute("SELECT EXISTS (SELECT 1 FROM pg_class WHERE relname = %s)", (table_name,))
    if cur.fetchone()[0]:
        print(f"Table {table_name} already exists.")
        conn.close()
        return
    
    print(f"Creating table {table_name} with {rows} rows, {payload_size}B payload...")
    
    cur.execute(f"""
        CREATE TABLE {table_name} (
            id SERIAL PRIMARY KEY,
            sequential INT NOT NULL,
            payload TEXT NOT NULL
        )
    """)
    
    # Insert data in batches
    batch_size = 10000
    for i in range(0, rows, batch_size):
        cur.execute(f"""
            INSERT INTO {table_name} (sequential, payload)
            SELECT g, repeat('x', {payload_size})
            FROM generate_series({i + 1}, {min(i + batch_size, rows)}) g
        """)
        print(f"  Inserted {min(i + batch_size, rows)}/{rows} rows...")
    
    # Create index
    cur.execute(f"CREATE INDEX ON {table_name} (sequential)")
    cur.execute(f"ANALYZE {table_name}")
    
    conn.close()
    print(f"Created table {table_name}")


def main():
    parser = argparse.ArgumentParser(
        description='PostgreSQL Index Prefetch Benchmark Tool',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run benchmark (includes prefetch off, unlimited, and various caps)
  python prefetch_benchmark.py --table prefetch_size_medium_50mb --runs 2

  # Run with specific settings (-1=off, 0=unlimited, positive=cap)
  python prefetch_benchmark.py --table prefetch_size_xlarge_500mb --distances "-1,0,16,64,128"

  # Query stored results
  python prefetch_benchmark.py --query

  # Query specific run
  python prefetch_benchmark.py --query --run-id 20240101_120000

  # Export to CSV
  python prefetch_benchmark.py --export results.csv

  # Export to JSON
  python prefetch_benchmark.py --export-json results.json

  # List available test tables
  python prefetch_benchmark.py --list-tables

  # Create a test table
  python prefetch_benchmark.py --create-table my_test --rows 50000 --payload-size 1000
        """
    )
    
    # Connection
    parser.add_argument('--conn', default=DEFAULT_CONN,
                        help=f'Connection string (default: {DEFAULT_CONN})')
    
    # Benchmark options
    parser.add_argument('--table', help='Table name to benchmark')
    parser.add_argument('--distances', type=str,
                        help='Comma-separated list of distances (-1=off, 0=unlimited, N=cap)')
    parser.add_argument('--runs', type=int, default=2,
                        help='Number of runs per distance (default: 2)')
    parser.add_argument('--warm-cache', action='store_true',
                        help='Run with warm cache (default: cold cache)')
    parser.add_argument('--purge-os', action='store_true',
                        help='Also purge OS filesystem cache (requires sudo purge)')
    
    # Query/Export options
    parser.add_argument('--query', action='store_true',
                        help='Query and display stored results')
    parser.add_argument('--run-id', help='Filter by specific run ID')
    parser.add_argument('--export', metavar='FILE',
                        help='Export results to CSV file')
    parser.add_argument('--export-json', metavar='FILE',
                        help='Export results to JSON file')
    
    # Utility options
    parser.add_argument('--list-tables', action='store_true',
                        help='List available test tables')
    parser.add_argument('--ensure-table', action='store_true',
                        help='Create test table if missing before running benchmark')
    parser.add_argument('--rows', type=int, default=100000,
                        help='Rows for test table (default: 100000)')
    parser.add_argument('--payload-size', type=int, default=500,
                        help='Payload size for test table (default: 500)')
    
    args = parser.parse_args()
    
    # Handle utility commands
    if args.list_tables:
        list_tables(args.conn)
        return
    
    if args.query:
        query_results(args.conn, args.run_id, args.table)
        return
    
    if args.export:
        export_results(args.conn, args.export, args.run_id)
        return
    
    if args.export_json:
        export_json(args.conn, args.export_json, args.run_id)
        return
    
    # Run benchmark
    if not args.table:
        parser.print_help()
        print("\nError: --table is required for benchmarking")
        sys.exit(1)
    
    # Initialize results tables
    init_results_tables(args.conn)
    
    # Ensure test table exists if requested
    if args.ensure_table:
        create_test_table(args.conn, args.table, args.rows, args.payload_size)
    
    # Parse distances
    distances = DEFAULT_DISTANCES
    if args.distances:
        distances = [int(d.strip()) for d in args.distances.split(',')]
    
    # Run benchmark
    run_id, results = run_benchmark(
        args.conn, 
        args.table, 
        distances, 
        args.runs,
        cold_cache=not args.warm_cache,
        purge_os=args.purge_os
    )
    
    print_results_table(results)
    print(f"\nResults stored in database with run_id: {run_id}")
    print(f"Query with: python {sys.argv[0]} --query --run-id {run_id}")


if __name__ == "__main__":
    main()
