#!/usr/bin/env python3
"""Test read stream distance oscillation bug using regex-like patterns.

Examples:
    --pattern "(mh)+"      # alternating miss/hit - BUG: distance stuck at 1-2
    --pattern "m(mh)+"     # miss first, then alternating - distance ramps up
    --pattern "m+"         # all misses - prefetch works great
    --pattern "h+"         # all hits - no I/O needed
    --pattern "m{2}(mh)*"  # 2 misses, then alternating
    
    --plot-distance dist.png   # plot the prefetch distances during scan
"""

import argparse, subprocess, os, warnings, statistics, psycopg

def purge_os_cache():
    subprocess.run([os.path.join(os.path.dirname(__file__), 'drop_cache')], capture_output=True)

def expand_pattern(pattern, length):
    """Expand regex pattern to 'mh' string using Python's regex parser.
    
    Examples:
        m+       -> mmmmm...
        (hm)+    -> hmhmhm...
        m(hm)+   -> mhmhmhm...
        m{3}(hm)*-> mmmhmhmhm...
        h{5}m+   -> hhhhhmmmmm...
    """
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", DeprecationWarning)
        import sre_parse
        
        def generate(parsed, remaining):
            result = []
            for op, av in parsed:
                if remaining <= 0:
                    break
                if op == sre_parse.LITERAL:
                    c = chr(av)
                    if c in 'mh':
                        result.append(c)
                        remaining -= 1
                elif op == sre_parse.SUBPATTERN:
                    sub = generate(av[-1], remaining)
                    result.append(sub)
                    remaining -= len(sub)
                elif op in (sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT):
                    lo, hi, subpat = av
                    hi = min(hi, remaining) if hi != sre_parse.MAXREPEAT else remaining
                    for _ in range(max(lo, hi)):
                        if remaining <= 0:
                            break
                        sub = generate(subpat, remaining)
                        if not sub:
                            break
                        result.append(sub)
                        remaining -= len(sub)
            return ''.join(result)
        
        try:
            parsed = sre_parse.parse(pattern)
            seq = generate(list(parsed), length)
            # Pad with last char if too short
            if len(seq) < length and seq:
                seq += seq[-1] * (length - len(seq))
            return seq[:length]
        except Exception:
            # Fallback: repeat literal chars
            chars = ''.join(c for c in pattern if c in 'mh') or 'm'
            return (chars * (length // len(chars) + 1))[:length]

def run_one_sample(cur, seq, pf, collect_distances=False):
    """Run one sample and return execution time (and optionally distances)."""
    # Evict all and clear OS cache
    cur.execute("SELECT pg_buffercache_evict_relation('distance_test')")
    
    # Warm specific pages into PG buffer - pages where pattern has 'h'
    if 'h' in seq:
        cur.execute("""
            SELECT id FROM distance_test 
            WHERE id = ANY(ARRAY(
                SELECT i FROM generate_series(0, length(%s)-1) i 
                WHERE substr(%s, i+1, 1) = 'h'
            ))
        """, (seq, seq))
    
    # buffers not held by PG will have to be read from disk
    purge_os_cache()
    
    if collect_distances:
        cur.execute("SELECT pg_set_prefetch_tracking('distance_test')")
    
    cur.execute(f"SET enable_indexscan_prefetch={'on' if pf else 'off'}")
    cur.execute("""
    SET enable_seqscan=off; 
    SET enable_indexscan=on; 
    SET enable_bitmapscan=off; 
    SET enable_indexonlyscan=off; 
    SET max_parallel_workers_per_gather=0; 
    SET enable_sort=off;
    """)
    cur.execute("EXPLAIN (ANALYZE,BUFFERS,FORMAT JSON) SELECT filler FROM distance_test ORDER BY id")
    r = cur.fetchone()[0][0]
    # Collect distances if requested
    distances = []
    if collect_distances:
        while True:
            cur.execute(f"SELECT pg_prefetch_distance()")
            d = cur.fetchone()[0]
            if d is None:
                break
            distances.append(d)
    
    return r['Execution Time'], r['Plan'].get('Shared Hit Blocks',0), r['Plan'].get('Shared Read Blocks',0), distances


def plot_distances(distances_off, distances_on, seq, pattern, output_file):
    """Plot the prefetch distances from both runs."""
    import matplotlib.pyplot as plt
    import numpy as np
    
    fig, ax = plt.subplots(figsize=(12, 5))
    
    if distances_off:
        for i, d in enumerate(distances_off):
            ax.plot(d, 'b.-', label='OFF' if i == 0 else None)
    
    if distances_on:
        for i, d in enumerate(distances_on):
            ax.plot(d, 'r.-', label='ON' if i == 0 else None)
    
    ax.set_xlabel('Buffer Return Order')
    ax.set_ylabel('Prefetch Distance')
    ax.set_title(f'Prefetch Distance During Scan - Pattern: {pattern}')
    ax.legend(loc='lower center')
    ax.grid(True, alpha=0.3)
    
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Distance plot saved to: {output_file}")


def main():
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument('--port', '-p', type=int, default=5432)
    p.add_argument('--pages', type=int, default=500)
    p.add_argument('--pattern', default='(mh)+', help='Regex pattern: m=miss, h=hit')
    p.add_argument('--samples', '-n', type=int, default=1, help='Number of samples for avg/stddev')
    p.add_argument('--reset', action='store_true')
    p.add_argument('--plot-distance', metavar='FILE', help='Plot prefetch distances to PNG file')
    args = p.parse_args()

    conn = psycopg.connect(f"dbname=postgres host=/tmp port={args.port}")
    conn.autocommit = True
    cur = conn.cursor()
    cur.execute("CREATE EXTENSION IF NOT EXISTS pg_buffercache")

    # Setup: 1 row per page using CHAR(5000) STORAGE PLAIN
    cur.execute("DROP TABLE IF EXISTS distance_test")
    cur.execute("CREATE TABLE distance_test (id INT PRIMARY KEY, filler CHAR(5000))")
    cur.execute("ALTER TABLE distance_test ALTER COLUMN filler SET STORAGE PLAIN")
    print(f"Creating {args.pages} pages...")
    cur.execute(f"INSERT INTO distance_test (id, filler) SELECT i, 'not null' FROM generate_series(0,{args.pages-1}) i")
    cur.execute("CREATE INDEX IF NOT EXISTS idx_dt ON distance_test(id)")
    
    # Verify that each row is on a different page
    cur.execute("SELECT count(distinct id), count(distinct (ctid::text::point)[0]::int) FROM distance_test");
    ids, ctids = cur.fetchone()
    assert ids == ctids, f"Expected {ids} ids, got {ctids} ctids"
    
    # Generate pattern sequence
    seq = expand_pattern(args.pattern, args.pages)
    
    print(f"Table: {args.pages} pages, samples: {args.samples}")
    print(f"Pattern: {args.pattern} -> {seq[:20]}{'...' if len(seq)>20 else ''}")
    print(f"Cache: {seq.count('h')} hits, {seq.count('m')} misses")

    results = {False: [], True: []}
    all_distances = {False: [], True: []}
    
    # Determine if we need to collect distances
    collect_distances = args.plot_distance is not None
    
    for i in range(args.samples):
        for pf in [False, True]:
            t, h, rd, dists = run_one_sample(cur, seq, pf, collect_distances and i == 0)
            results[pf].append(t)
            all_distances[pf].append(dists)
            
    if args.plot_distance:
        plot_distances(all_distances[False], all_distances[True], seq, args.pattern, args.plot_distance)
    if args.samples > 1:
        for pf in [False, True]:
            times = results[pf]
            avg = statistics.mean(times)
            std = statistics.stdev(times) if len(times) > 1 else 0
            print(f"  Prefetch {'ON' if pf else 'OFF'}: {avg:6.1f}ms ± {std:5.1f}ms  (n={len(times)})")
        
        # Compute speedup
        avg_off = statistics.mean(results[False])
        avg_on = statistics.mean(results[True])
        pct = (avg_on / avg_off - 1) * 100
        print(f"  Effect: {pct:+.1f}% ({'slower' if pct > 0 else 'faster'})")
    
if __name__ == '__main__':
    main()
