#!/usr/bin/env python3
"""
Compare function profiles between prefetch OFF and ON configurations.
Parses profile files named: profile-<role>,<prefetch>,<distance>.<pid>.txt

Identifies functions with significant runtime differences when prefetch is enabled.
"""

import os
import sys
import re
import glob
import argparse
from collections import defaultdict
from statistics import median, stdev, mean, variance
import math

def parse_profile_filename(filepath):
    """
    Parse profile filename to extract role, prefetch, and distance.
    Format: profile-<role>,<prefetch>,<distance>.<pid>.txt
    Example: profile-backend,on,128.12345.txt
    """
    basename = os.path.basename(filepath)
    match = re.match(r'profile-([^,]+),([^,]+),(-?\d+)\.(\d+)\.txt', basename)
    if match:
        return {
            'role': match.group(1),
            'prefetch': match.group(2),
            'distance': int(match.group(3)),
            'pid': int(match.group(4))
        }
    return None

def parse_profile(filepath):
    """Parse a single profile file, return dict of func -> (calls, total_ms)"""
    funcs = {}
    with open(filepath) as f:
        for line in f:
            # Skip comments and headers
            if line.startswith('#') or line.startswith('-') or line.startswith('Function'):
                continue
            if not line.strip():
                continue
            # Format: function_name   calls   total_ms   avg_us   min_us   max_us   pct%
            parts = line.split()
            if len(parts) >= 4:
                func = parts[0]
                try:
                    calls = int(parts[1])
                    total_ms = float(parts[2])
                    funcs[func] = (calls, total_ms)
                except (ValueError, IndexError):
                    continue
    return funcs

def load_profiles_by_config(directory, role_filter='backend'):
    """
    Load all profiles from a directory, grouped by (prefetch, distance).
    Returns dict of (prefetch, distance) -> {func -> list of (calls, total_ms)}
    """
    configs = defaultdict(lambda: defaultdict(list))
    
    patterns = [
        os.path.join(directory, 'profile-*.txt'),
        os.path.join(directory, '**', 'profile-*.txt')
    ]
    
    files = []
    for pattern in patterns:
        files.extend(glob.glob(pattern, recursive=True))
    
    for filepath in files:
        meta = parse_profile_filename(filepath)
        if meta is None:
            continue
        
        # Filter by role (usually we want 'backend' profiles)
        if role_filter and meta['role'] != role_filter:
            continue
        
        key = (meta['prefetch'], meta['distance'])
        funcs = parse_profile(filepath)
        for func, (calls, total_ms) in funcs.items():
            configs[key][func].append((calls, total_ms))
    
    return configs

def calculate_z_statistic(sample1, sample2):
    """Calculate z-statistic (Welch's t-test) for comparing two samples."""
    if len(sample1) < 2 or len(sample2) < 2:
        return 0
    
    n1, n2 = len(sample1), len(sample2)
    mean1, mean2 = mean(sample1), mean(sample2)
    
    var1 = variance(sample1) if len(sample1) > 1 else 0
    var2 = variance(sample2) if len(sample2) > 1 else 0
    
    se = math.sqrt(var1/n1 + var2/n2) if (var1/n1 + var2/n2) > 0 else 0
    
    if se == 0:
        return 0
    
    return (mean2 - mean1) / se

def aggregate_func_stats(func_data, exclude_outliers=True):
    """Aggregate function stats, optionally excluding outliers."""
    if not func_data:
        return 0, 0, 0, []
    
    times = [t for (c, t) in func_data]
    calls = [c for (c, t) in func_data]
    
    if exclude_outliers and len(times) >= 4:
        # Remove top 3 outliers
        times = sorted(times)[:-3]
        calls = sorted(calls)[:-3]
    
    total_time = sum(times)
    total_calls = sum(calls)
    avg_time = mean(times) if times else 0
    
    return total_calls, total_time, avg_time, times

def compare_configs(base_funcs, test_funcs, base_label, test_label):
    """Compare two configurations and return comparison results."""
    all_functions = set(base_funcs.keys()) | set(test_funcs.keys())
    
    comparisons = []
    for func in all_functions:
        base_calls, base_time, base_avg, base_times = aggregate_func_stats(base_funcs.get(func, []))
        test_calls, test_time, test_avg, test_times = aggregate_func_stats(test_funcs.get(func, []))
        
        # Skip functions with negligible time
        if base_time < 1 and test_time < 1:
            continue
        
        time_diff = test_time - base_time
        if base_time > 0:
            pct_change = (time_diff / base_time) * 100
        elif test_time > 0:
            pct_change = float('inf')
        else:
            pct_change = 0
        
        z_score = calculate_z_statistic(base_times, test_times)
        
        comparisons.append({
            'func': func,
            'base_time': base_time,
            'test_time': test_time,
            'base_calls': base_calls,
            'test_calls': test_calls,
            'time_diff': time_diff,
            'pct_change': pct_change,
            'z_score': z_score,
            'n_base': len(base_times),
            'n_test': len(test_times)
        })
    
    return comparisons

def print_comparison(comparisons, base_label, test_label):
    """Print comparison results."""
    significant_increases = [c for c in comparisons if c['time_diff'] > 0 and abs(c['z_score']) > 1.96]
    significant_increases.sort(key=lambda x: abs(x['z_score']), reverse=True)
    
    significant_decreases = [c for c in comparisons if c['time_diff'] < 0 and abs(c['z_score']) > 1.96]
    significant_decreases.sort(key=lambda x: abs(x['z_score']), reverse=True)
    
    all_increases = [c for c in comparisons if c['time_diff'] > 0]
    all_increases.sort(key=lambda x: x['time_diff'], reverse=True)
    
    all_decreases = [c for c in comparisons if c['time_diff'] < 0]
    all_decreases.sort(key=lambda x: x['time_diff'])
    
    print("=" * 100)
    print(f"COMPARISON: {base_label} vs {test_label}")
    print("=" * 100)
    print()
    
    print(f"{'Function':<40} {base_label:>12} {test_label:>12} {'Diff':>10} {'%':>8} {'z-statistic':>8}")
    print("-" * 100)
    
    print("\n** SIGNIFICANT INCREASES (test is slower, |z| > 1.96) **")
    for c in significant_increases[:20]:
        pct = f"{c['pct_change']:+.1f}" if c['pct_change'] != float('inf') else "NEW"
        print(f"{c['func']:<40} {c['base_time']:>12.1f} {c['test_time']:>12.1f} {c['time_diff']:>+10.1f} {pct:>7}% {c['z_score']:>+8.2f}")
    
    print("\n** SIGNIFICANT DECREASES (test is faster, |z| > 1.96) **")
    for c in significant_decreases[:20]:
        pct = f"{c['pct_change']:+.1f}" if c['pct_change'] != float('inf') else "GONE"
        print(f"{c['func']:<40} {c['base_time']:>12.1f} {c['test_time']:>12.1f} {c['time_diff']:>+10.1f} {pct:>7}% {c['z_score']:>-8.2f}")
    
    # Summary
    total_base = sum(c['base_time'] for c in comparisons)
    total_test = sum(c['test_time'] for c in comparisons)
    
    print()
    print("-" * 100)
    print("SUMMARY")
    print("-" * 100)
    print(f"Total function time ({base_label}): {total_base:,.2f} ms")
    print(f"Total function time ({test_label}): {total_test:,.2f} ms")
    if total_base > 0:
        print(f"Difference: {total_test - total_base:+,.2f} ms ({(total_test - total_base) / total_base * 100:+.1f}%)")
    print()
    print(f"Functions analyzed: {len(comparisons)}")
    print(f"  Significant increases: {len(significant_increases)}")
    print(f"  Significant decreases: {len(significant_decreases)}")
    print()

def main():
    parser = argparse.ArgumentParser(description='Compare profiling results')
    parser.add_argument('directory', nargs='?', default='./profiling',
                       help='Directory containing profile files')
    parser.add_argument('--role', default='backend',
                       help='Filter by backend role (default: backend)')
    parser.add_argument('--base', default='off',
                       help='Base configuration prefetch setting (default: off)')
    parser.add_argument('--base-distance', type=int, default=0,
                       help='Base configuration distance (default: 0)')
    args = parser.parse_args()
    
    if not os.path.exists(args.directory):
        print(f"Error: Directory not found: {args.directory}")
        sys.exit(1)
    
    print(f"Loading profiles from: {args.directory}")
    print(f"Filtering by role: {args.role}")
    print()
    
    configs = load_profiles_by_config(args.directory, args.role)
    
    if not configs:
        print("No profiles found!")
        sys.exit(1)
    
    print("Found configurations:")
    for (prefetch, distance), funcs in sorted(configs.items()):
        n_profiles = len(next(iter(funcs.values()))) if funcs else 0
        print(f"  prefetch={prefetch}, distance={distance}: {len(funcs)} functions, ~{n_profiles} samples")
    print()
    
    # Find baseline (prefetch=off or specified)
    base_key = (args.base, args.base_distance)
    if base_key not in configs:
        # Try to find any 'off' config
        off_configs = [(k, v) for k, v in configs.items() if k[0] == 'off']
        if off_configs:
            base_key = off_configs[0][0]
        else:
            base_key = list(configs.keys())[0]
    
    base_funcs = configs[base_key]
    base_label = f"off,d={base_key[1]}"
    
    # Compare each 'on' config against baseline
    for (prefetch, distance), test_funcs in sorted(configs.items()):
        if (prefetch, distance) == base_key:
            continue
        
        test_label = f"{prefetch},d={distance}"
        comparisons = compare_configs(base_funcs, test_funcs, base_label, test_label)
        print_comparison(comparisons, base_label, test_label)
        print("\n")

if __name__ == "__main__":
    main()
