#!/usr/bin/env python3
"""
Prefetch regression analysis - compute effects and generate forest plot.

Usage: python run_analysis.py [options]
"""

import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import psycopg
from scipy import stats
from matplotlib.lines import Line2D


def parse_args():
    p = argparse.ArgumentParser(description='Analyze prefetch regression test results')
    p.add_argument('--prefix', '-o', type=str, default='',
                   help='Output file prefix')
    p.add_argument('--dbname', '--db', '-d', type=str, default='postgres',
                   help='Database name (default: postgres)')
    p.add_argument('--host', '-H', type=str, default='/tmp',
                   help='Database host (default: /tmp)')
    p.add_argument('--port', '-p', type=int, default=None,
                   help='Database port')
    return p.parse_args()


def load_timing_data(conn):
    """Load raw results from database."""
    return pd.read_sql('''
        SELECT column_name, io_method, num_workers, prefetch_enabled,
               evict_mode, iteration, execution_time_ms
        FROM prefetch_test_results
    ''', conn)


def compute_ci(group):
    """Compute effect with CI for a single configuration group."""
    off = group[~group['prefetch_enabled']]['execution_time_ms']
    on = group[group['prefetch_enabled']]['execution_time_ms']
    assert len(off) == len(on), "Off and on should have the same length"
    if len(off) < 2 or len(on) < 2:
        return None
    
    result = stats.ttest_ind(on, off, equal_var=False)
    ci = result.confidence_interval(0.95)
    
    return pd.Series({
        'n': min(len(off), len(on)),
        'off_ms': off.mean(),
        'on_ms': on.mean(),
        'pct_change': (on.mean() / off.mean() - 1) * 100,
        'ci_low': ci.low / off.mean() * 100,
        'ci_high': ci.high / off.mean() * 100,
        'p_value': result.pvalue,
        'significant': result.pvalue < 0.05
    })


def compute_effects_with_ci(df):
    """Compute prefetch effect with confidence intervals using Welch's t-test."""
    group_keys = ['num_workers', 'column_name', 'io_method', 'evict_mode']
    results = df.groupby(group_keys).apply(compute_ci).dropna().reset_index()
    results = results.rename(columns={'evict_mode': 'evict'})
    return results.sort_values(['column_name', 'io_method', 'num_workers', 'evict'])


def plot_forest(df, prefix=''):
    """Generate forest plot with table labels as yticks."""
    from matplotlib.lines import Line2D
    
    df = df.sort_values(['evict', 'column_name', 'num_workers', 'n']).reset_index(drop=True)
    n = len(df)
    
    if n == 0:
        print("No data to plot!")
        return None
    
    # Format table as aligned text for ytick labels
    table = df[['num_workers',  'column_name','evict']].copy()
    table.columns = ['num_workers', 'column_name', 'eviction']
    table_str = table.to_string(index=False, header=True)
    lines = table_str.split('\n')
    header = lines[0]
    labels = lines[1:] + [header]  # header at top (highest y)
    
    fig, ax = plt.subplots(figsize=(10, max(4, (n + 1) * 0.35)))
    
    # Forest plot
    for i, row in df.iterrows():
        color = 'gray'
        if row['significant']:
            color = 'red' if row['pct_change'] > 0 else 'green'
        
        ax.errorbar(
            row['pct_change'], i,
            xerr=[[row['pct_change'] - row['ci_low']], [row['ci_high'] - row['pct_change']]],
            fmt='s', color=color, markersize=8, capsize=3, linewidth=2
        )
        ax.annotate(f"{row['pct_change']:.1f}%", 
                    (row['ci_high'] + 2, i), fontsize=8, va='center')
    
    ax.axvline(0, color='black', linewidth=1)
    ax.set_yticks(range(n + 1))
    ax.set_yticklabels(labels, fontfamily='monospace', fontsize=9)
    ax.set_xlabel('% Change in Execution Time (95% CI)')
    ax.set_title('Prefetch Effect (Positive = SLOWER)')
    ax.grid(axis='x', alpha=0.3)
    
    # Legend
    legend = [
        Line2D([0], [0], marker='s', color='w', markerfacecolor='green', markersize=10, label='Faster'),
        Line2D([0], [0], marker='s', color='w', markerfacecolor='red', markersize=10, label='Slower'),
        Line2D([0], [0], marker='s', color='w', markerfacecolor='gray', markersize=10, label='Not significant'),
    ]
    ax.legend(handles=legend, loc='lower right')
    
    plt.tight_layout()
    plot_path = f'{prefix}prefetch_forest.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Saved: {plot_path}")
    return plot_path


def analyze_factors(results):
    """Analyze the effect of each factor on prefetch performance."""
    print("\n" + "=" * 50)
    print("AVERAGE EFFECT BY FACTOR")
    print("=" * 50)
    
    factors = ['evict_mode', 'column_name', 'num_workers', 'prefetch_enabled']
    
    rows = []
    for factor in factors:
        for state, mean in results.groupby(factor)['execution_time_ms'].mean().sort_index().items():
            rows.append({'factor': factor, 'state': str(state), 'mean_ms': round(mean, 1)})
    
    df = pd.DataFrame(rows).set_index(['factor', 'state'])
    print(df.to_string())

def main():
    args = parse_args()
    prefix = args.prefix + "_" if args.prefix else ""
    
    # Connect
    connstr = f"dbname={args.dbname} host={args.host}"
    if args.port:
        connstr += f" port={args.port}"
    conn = psycopg.connect(connstr)
    
    # Load data
    print("Loading data...")
    df = load_timing_data(conn)
    conn.close()
    
    print(f"Loaded {len(df)} rows\n")
    
    if len(df) == 0:
        print("No data found!")
        return
    
    # Compute effects with CI
    results = compute_effects_with_ci(df)
    
    # Print summary
    print("=" * 90)
    print("PREFETCH EFFECT (positive = slower)")
    print("=" * 90)
    display = results.copy()
    for k in ['off_ms', 'on_ms', 'ci_low', 'ci_high']:
        display[k] = display[k].round(1)
    print(display[['column_name', 'io_method', 'num_workers', 'evict', 'n', 
                   'off_ms', 'on_ms', 'pct_change', 'ci_low', 'ci_high']].to_string(index=False))
    
    # Factor analysis
    analyze_factors(df)
    
    # Generate plot
    plot_forest(results, prefix)
    
    # Save CSV
    csv_path = f'{prefix}prefetch_summary.csv'
    results.to_csv(csv_path, index=False)
    print(f"Saved: {csv_path}")
    
    print("\nDone.")


if __name__ == '__main__':
    main()
