#!/usr/bin/env python3
"""
Run prefetch regression tests.

Usage:
    python run_regression_test.py [options]

Examples:
    python run_regression_test.py --iterations=5
    python run_regression_test.py --evict=off,pg --workers=0,2
    python run_regression_test.py --columns=sequential,random --reset
"""

import argparse
import subprocess
import itertools
import random
import os
import psycopg
import pandas as pd
from math import gcd


def parse_args():
    p = argparse.ArgumentParser(description='Run prefetch regression tests')
    p.add_argument('--iterations', '-n', type=int, default=10,
                   help='Number of test iterations (default: 10)')
    p.add_argument('--columns', '-c', type=str, default='sequential,periodic,random',
                   help='Columns to test (default: sequential,periodic,random)')
    p.add_argument('--workers', '-w', type=str, default='0,2',
                   help='Worker counts to test (default: 0,2)')
    p.add_argument('--evict', '-e', type=str, default='off',
                   help='Evict modes: off,pg,os or "all" (default: off)')
    p.add_argument('--rows', '-r', type=int, default=100000,
                   help='Number of rows in test table (default: 100000)')
    p.add_argument('--reset', action='store_true',
                   help='Reset tables before running')
    p.add_argument('--dbname', '--db', '-d', type=str, default='postgres',
                   help='Database name (default: postgres)')
    p.add_argument('--host', '-H', type=str, default='/tmp',
                   help='Database host (default: /tmp)')
    p.add_argument('--port', '-p', type=int, default=None,
                   help='Database port (default: use socket)')
    return p.parse_args()


def setup_tables(cur, num_rows, reset=False):
    """Create tables and populate test data."""
    
    if reset:
        cur.execute('DROP TABLE IF EXISTS prefetch_test_results')
        cur.execute('DROP TABLE IF EXISTS prefetch_test_data')
    
    cur.execute('''
        CREATE EXTENSION IF NOT EXISTS pg_buffercache
    ''');
    cur.execute('''
        CREATE TABLE IF NOT EXISTS prefetch_test_results (
            id SERIAL PRIMARY KEY,
            run_timestamp TIMESTAMPTZ DEFAULT now(),
            io_method TEXT NOT NULL,
            num_workers INT NOT NULL DEFAULT 0,
            prefetch_enabled BOOLEAN NOT NULL,
            evict_mode TEXT NOT NULL DEFAULT 'off',
            column_name TEXT NOT NULL,
            iteration INT,
            execution_time_ms NUMERIC,
            rows_returned BIGINT,
            blks_hit BIGINT,
            blks_read BIGINT
        )
    ''')

    cur.execute('''
        CREATE TABLE IF NOT EXISTS prefetch_test_data (
            id SERIAL PRIMARY KEY,
            sequential INT,
            periodic INT,
            random INT,
            payload TEXT
        )
    ''')

    cur.execute("SELECT count(*) FROM prefetch_test_data")
    row_count = cur.fetchone()[0]

    if row_count == 0:
        print(f"Populating test data ({num_rows} rows)...")
        r = min(10000, num_rows // 5)
        while gcd(r, num_rows) != 1:
            r += 1
        cur.execute(f'''
            INSERT INTO prefetch_test_data (sequential, periodic, random, payload)
            SELECT 
                i,
                ((i * {r}::bigint) % {num_rows} + 1)::int,
                row_number() OVER (ORDER BY random()),
                repeat('x', 200)
            FROM generate_series(1, {num_rows}) i
            ORDER BY i;
        ''')
        cur.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_sequential ON prefetch_test_data(sequential)")
        cur.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_periodic ON prefetch_test_data(periodic)")
        cur.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_random ON prefetch_test_data(random)")
        print("Data populated.")
    else:
        print(f"Test data exists: {row_count} rows")


def purge_os_cache():
    """Purge OS filesystem cache using external script."""
    script = os.path.join(os.path.dirname(__file__), 'purge_cache.sh')
    try:
        subprocess.run([script], capture_output=True, timeout=10)
    except:
        pass


def evict_pg_buffers(cur):
    """Evict PostgreSQL shared buffers for test table."""
    cur.execute('''
        SELECT count(pg_buffercache_evict(bufferid))
        FROM pg_buffercache b 
        JOIN pg_class c ON b.relfilenode = pg_relation_filenode(c.oid)
        WHERE c.relname = 'prefetch_test_data'
    ''').fetchall()


def apply_eviction(cur, evict_mode):
    """Apply eviction based on mode."""
    if evict_mode == 'pg':
        evict_pg_buffers(cur)
    elif evict_mode == 'os':
        evict_pg_buffers(cur)
        purge_os_cache()


def run_test(cur, column_name, prefetch_enabled, num_workers, evict_mode, iteration):
    """Run a single test and record results."""
    
    cur.execute(f"SET enable_indexscan_prefetch = {'on' if prefetch_enabled else 'off'}")
    cur.execute(f"SET max_parallel_workers_per_gather = {num_workers}")
    cur.execute("SET enable_bitmapscan = off")
    cur.execute("SET enable_seqscan = off")
    cur.execute("SET enable_indexonlyscan = off")
    cur.execute("SET enable_sort = off")
    
    apply_eviction(cur, evict_mode)
    
    cur.execute(f'''
        EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) 
        SELECT length(payload) FROM prefetch_test_data 
        ORDER BY {column_name}
    ''')
    
    result = cur.fetchone()[0]
    plan = result[0]['Plan']
    exec_time = result[0]['Execution Time']
    blks_hit = plan.get('Shared Hit Blocks', 0)
    blks_read = plan.get('Shared Read Blocks', 0)
    rows = plan.get('Actual Rows', 0)
    
    cur.execute("SHOW io_method")
    io_method = cur.fetchone()[0]
    
    cur.execute('''
        INSERT INTO prefetch_test_results 
        (io_method, num_workers, prefetch_enabled, evict_mode, column_name, 
         iteration, execution_time_ms, rows_returned, blks_hit, blks_read)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
    ''', (io_method, num_workers, prefetch_enabled, evict_mode, 
          column_name, iteration, exec_time, rows, blks_hit, blks_read))
    
    pf = 'Y' if prefetch_enabled else 'N'
    print(f"  [{column_name}] pf={pf} w={num_workers} evict={evict_mode}: {exec_time:.1f}ms")


def print_summary(cur):
    """Print summary of results."""
    df = pd.read_sql('''
        WITH avgs AS (
            SELECT column_name, io_method, num_workers, 
                   prefetch_enabled AS prefetch, evict_mode AS evict,
                   count(*) AS n, avg(execution_time_ms) AS ms
            FROM prefetch_test_results
            GROUP BY column_name, io_method, num_workers, prefetch_enabled, evict_mode
        )
        SELECT a.column_name, a.io_method, a.num_workers, a.evict, a.n,
               round(a.ms::numeric, 1) AS off_ms, 
               round(b.ms::numeric, 1) AS on_ms,
               round(((b.ms - a.ms) / NULLIF(b.ms + a.ms, 0) * 100)::numeric, 1) AS effect_pct
        FROM avgs a JOIN avgs b USING (column_name, io_method, num_workers, evict)
        WHERE NOT a.prefetch AND b.prefetch
        ORDER BY column_name, io_method, num_workers, evict
    ''', cur.connection)
    
    print("\n" + "=" * 80)
    print("SUMMARY: Prefetch Effect (positive = slower)")
    print("=" * 80)
    print(df.to_string(index=False))


def main():
    args = parse_args()
    
    # Parse list arguments
    columns = [c.strip() for c in args.columns.split(',')]
    workers = [int(w.strip()) for w in args.workers.split(',')]
    evict_modes = ['off', 'pg', 'os'] if args.evict == 'all' else [e.strip() for e in args.evict.split(',')]
    
    print(f"Config: iterations={args.iterations}, columns={columns}, workers={workers}, evict={evict_modes}")
    
    # Connect
    connstr = f"dbname={args.dbname} host={args.host}"
    if args.port:
        connstr += f" port={args.port}"
    conn = psycopg.connect(connstr)
    conn.autocommit = True
    cur = conn.cursor()
    
    # Setup
    setup_tables(cur, args.rows, args.reset)
    
    # Show io_method
    cur.execute("SHOW io_method")
    io_method = cur.fetchone()[0]
    print(f"\nio_method = {io_method}\n")
    
    # Run tests
    prefetch_opts = [False, True]
    
    for i in range(1, args.iterations + 1):
        print(f"Iteration {i}/{args.iterations}")
        
        configs = list(itertools.product(columns, prefetch_opts, workers, evict_modes))
        random.shuffle(configs)
        
        for col, pf, w, evict in configs:
            run_test(cur, col, pf, w, evict, i)
        
        print()
    
    # Summary
    print_summary(cur)
    conn.close()


if __name__ == "__main__":
    main()
