#!/usr/bin/env python3
"""
Run prefetch regression tests.

Usage:
    python run_regression_test.py [options]

Examples:
    python run_regression_test.py --iterations=5
    python run_regression_test.py --evict=off,pg --workers=0,2
    python run_regression_test.py --columns=sequential,random --reset
"""

import argparse
import subprocess
import itertools
import random
import os
import psycopg
import pandas as pd
from math import gcd


def parse_args():
    p = argparse.ArgumentParser(description='Run prefetch regression tests')
    p.add_argument('--iterations', '-n', type=int, default=10,
                   help='Number of test iterations (default: 10)')
    p.add_argument('--columns', '-c', type=str, default='sequential,periodic,random',
                   help='Columns to test (default: sequential,periodic,random)')
    p.add_argument('--workers', '-w', type=str, default='0,2',
                   help='Worker counts to test (default: 0,2)')
    p.add_argument('--evict', '-e', type=str, default='off',
                   help='Evict modes: off,pg,os or "all" (default: off)')
    p.add_argument('--rows', '-r', type=int, default=100000,
                   help='Number of rows in test table (default: 100000)')
    p.add_argument('--payload-size', '-ps', type=int, default=50,
                   help='Size of payload in test table (default: 50)')
    p.add_argument('--reset', action='store_true',
                   help='Reset tables before running')
    p.add_argument('--dbname', '--db', '-d', type=str, default='postgres',
                   help='Database name (default: postgres)')
    p.add_argument('--host', '-H', type=str, default='/tmp',
                   help='Database host (default: /tmp)')
    p.add_argument('--port', '-p', type=int, default=None,
                   help='Database port (default: use socket)')
    p.add_argument('--ntables', '-t', type=int, default=1,
                   help='Number of tables to join (default: 1)')
    return p.parse_args()


def setup_tables(cur, num_rows, payload_size, ntables=1, reset=False):
    """Create tables and populate test data."""
    
    if reset:
        cur.execute('DROP TABLE IF EXISTS prefetch_test_results')
        for i in range(1, 100):  # Drop up to 100 tables
            cur.execute(f'DROP TABLE IF EXISTS prefetch_test_data_{i}')
        cur.execute('DROP TABLE IF EXISTS prefetch_test_data')
    
    cur.execute('''
        CREATE EXTENSION IF NOT EXISTS pg_buffercache
    ''');
    cur.execute('''
        CREATE TABLE IF NOT EXISTS prefetch_test_results (
            id SERIAL PRIMARY KEY,
            run_timestamp TIMESTAMPTZ DEFAULT now(),
            io_method TEXT NOT NULL,
            num_workers INT NOT NULL DEFAULT 0,
            prefetch_enabled BOOLEAN NOT NULL,
            evict_mode TEXT NOT NULL DEFAULT 'off',
            column_name TEXT NOT NULL,
            iteration INT,
            execution_time_ms NUMERIC,
            rows_returned BIGINT,
            blks_hit BIGINT,
            blks_read BIGINT
        )
    ''')

    # Create multiple tables for join tests
    for t in range(1, ntables + 1):
        table_name = f'prefetch_test_data_{t}'
        cur.execute(f'''
            CREATE TABLE IF NOT EXISTS {table_name} (
                sequential INT,
                periodic INT,
                random INT,
                payload TEXT
            )
        ''')

        cur.execute(f"SELECT count(*) FROM {table_name}")
        row_count = cur.fetchone()[0]

        if row_count == 0:
            print(f"Populating {table_name} ({num_rows} rows)...")
            r = min(10000, num_rows // 5)
            while gcd(r, num_rows) != 1:
                r += 1
            cur.execute(f'''
                INSERT INTO {table_name} (sequential, periodic, random, payload)
                SELECT 
                    i,
                    ((i * {r}::bigint) % {num_rows} + 1)::int,
                    row_number() OVER (ORDER BY random()),
                    (SELECT string_agg((i*j)::text, '+') FROM generate_series(1, {payload_size}) j)
                FROM generate_series(1, {num_rows}) i
                ORDER BY i;
            ''')
            cur.execute(f"CREATE UNIQUE INDEX IF NOT EXISTS idx_{table_name}_sequential ON {table_name}(sequential)")
            cur.execute(f"CREATE UNIQUE INDEX IF NOT EXISTS idx_{table_name}_periodic ON {table_name}(periodic)")
            cur.execute(f"CREATE UNIQUE INDEX IF NOT EXISTS idx_{table_name}_random ON {table_name}(random)")
            print(f"{table_name} populated.")
        else:
            print(f"{table_name} exists: {row_count} rows")


def purge_os_cache():
    script = os.path.join(os.path.dirname(__file__), 'drop_cache')
    # specially at the first time it can take a while to sync
    subprocess.run([script], capture_output=True, timeout=20)


def evict_pg_buffers(cur):
    """Evict PostgreSQL shared buffers for test tables."""
    cur.execute('''
        SELECT count(pg_buffercache_evict(bufferid))
        FROM pg_buffercache b 
        JOIN pg_class c ON b.relfilenode = pg_relation_filenode(c.oid)
        WHERE c.relname LIKE 'prefetch_test_data%'
    ''').fetchall()


def apply_eviction(cur, evict_mode):
    """Apply eviction based on mode."""
    if evict_mode == 'pg':
        evict_pg_buffers(cur)
    elif evict_mode == 'os':
        evict_pg_buffers(cur)
        purge_os_cache()


def run_test(cur, column_name, prefetch_enabled, num_workers, evict_mode, iteration, ntables=1):
    """Run a single test and record results."""
    
    cur.execute(f"SET enable_indexscan_prefetch = {'on' if prefetch_enabled else 'off'}")
    cur.execute(f"SET max_parallel_workers_per_gather = {num_workers}")
    cur.execute("SET enable_bitmapscan = off")
    cur.execute("SET enable_seqscan = off")
    cur.execute("SET enable_indexonlyscan = off")
    cur.execute("SET enable_sort = off")
    
    apply_eviction(cur, evict_mode)
    
    # Build query - single table or join
    if ntables == 1:
        query = f'''
            SELECT length(t1.payload) FROM prefetch_test_data_1 t1
            ORDER BY t1.{column_name}
        '''
    else:
        # Build JOIN query on the column
        tables = [f'prefetch_test_data_{i} t{i}' for i in range(1, ntables + 1)]
        joins = [f't1.{column_name} = t{i}.{column_name}' for i in range(2, ntables + 1)]
        payloads = [f't{i}.payload' for i in range(1, ntables + 1)]
        
        query = f'''
            SELECT {' || '.join([f'length({p})::text' for p in payloads])}
            FROM {tables[0]}
            {' '.join([f'JOIN {tables[i]} ON {joins[i-1]}' for i in range(1, ntables)])}
            ORDER BY t1.{column_name}
        '''
    
    cur.execute(f'EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {query}')
    
    result = cur.fetchone()[0]
    plan = result[0]['Plan']
    exec_time = result[0]['Execution Time']
    blks_hit = plan.get('Shared Hit Blocks', 0)
    blks_read = plan.get('Shared Read Blocks', 0)
    rows = plan.get('Actual Rows', 0)
    
    cur.execute("SHOW io_method")
    io_method = cur.fetchone()[0]
    
    cur.execute('''
        INSERT INTO prefetch_test_results 
        (io_method, num_workers, prefetch_enabled, evict_mode, column_name, 
         iteration, execution_time_ms, rows_returned, blks_hit, blks_read)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
    ''', (io_method, num_workers, prefetch_enabled, evict_mode, 
          column_name, iteration, exec_time, rows, blks_hit, blks_read))
    
    pf = 'Y' if prefetch_enabled else 'N'
    print(f"  [{column_name}] pf={pf} w={num_workers} evict={evict_mode}: {exec_time:.1f}ms")


def print_summary(cur):
    """Print summary of results."""
    df = pd.read_sql('''
        WITH avgs AS (
            SELECT column_name, io_method, num_workers, 
                   prefetch_enabled AS prefetch, evict_mode AS evict,
                   count(*) AS n, avg(execution_time_ms) AS ms
            FROM prefetch_test_results
            GROUP BY column_name, io_method, num_workers, prefetch_enabled, evict_mode
        )
        SELECT a.column_name, a.io_method, a.num_workers, a.evict, a.n,
               round(a.ms::numeric, 1) AS off_ms, 
               round(b.ms::numeric, 1) AS on_ms,
               round(((b.ms - a.ms) / NULLIF(b.ms + a.ms, 0) * 100)::numeric, 1) AS effect_pct
        FROM avgs a JOIN avgs b USING (column_name, io_method, num_workers, evict)
        WHERE NOT a.prefetch AND b.prefetch
        ORDER BY column_name, io_method, num_workers, evict
    ''', cur.connection)
    
    print("\n" + "=" * 80)
    print("SUMMARY: Prefetch Effect (positive = slower)")
    print("=" * 80)
    print(df.to_string(index=False))


def main():
    args = parse_args()
    
    # Parse list arguments
    columns = [c.strip() for c in args.columns.split(',')]
    workers = [int(w.strip()) for w in args.workers.split(',')]
    evict_modes = ['off', 'pg', 'os'] if args.evict == 'all' else [e.strip() for e in args.evict.split(',')]
    
    print(f"Config: iterations={args.iterations}, columns={columns}, workers={workers}, evict={evict_modes}, ntables={args.ntables}")
    
    # Connect
    connstr = f"dbname={args.dbname} host={args.host}"
    if args.port:
        connstr += f" port={args.port}"
    conn = psycopg.connect(connstr)
    conn.autocommit = True
    cur = conn.cursor()
    
    # Setup
    setup_tables(cur, args.rows, args.payload_size, args.ntables, args.reset)
    
    # Show io_method
    cur.execute("SHOW io_method")
    io_method = cur.fetchone()[0]
    print(f"\nio_method = {io_method}\n")
    
    # Run tests
    prefetch_opts = [False, True]
    
    for i in range(1, args.iterations + 1):
        print(f"Iteration {i}/{args.iterations}")
        
        configs = list(itertools.product(columns, workers, evict_modes, prefetch_opts))
        # random.shuffle(configs)
        
        for col, w, evict, pf in configs:
            run_test(cur, col, pf, w, evict, i, args.ntables)
        
        print()
    
    # Summary
    print_summary(cur)
    conn.close()


if __name__ == "__main__":
    main()
