#!/usr/bin/env python3
"""
Visualization tool for Tier 1 micro-benchmark results
Generates plots comparing OLD (workspace) vs NEW (inplace) purge methods.

This version fixes the summary and speed-up plots to use keep_ratio≈0.5
rather than picking the best-case scenario (e.g. 0.9).
"""

import csv
import argparse
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# -------------------- Utility Helpers --------------------

def approx(a, b, tol=1e-9):
    return abs(a - b) <= tol

def safe_improvement(old, new):
    return 100 * (old - new) / old if old > 0 else float('nan')

def fmt_size_short(x):
    if x >= 1_000_000:
        return f'{x/1_000_000:.0f}M'
    if x >= 1_000:
        return f'{x/1_000:.0f}k'
    return str(x)

def read_csv(filepath):
    """Read CSV with tolerance for bytes_total / bytes_moved naming."""
    def _to_int(row, *names, default=0):
        for n in names:
            if n in row and row[n] not in (None, ""):
                try:
                    return int(float(row[n]))
                except ValueError:
                    pass
        return default

    data = []
    with open(filepath) as f:
        reader = csv.DictReader(f)
        for row in reader:
            data.append({
                'method': row['method'],
                'xcnt': int(row['xcnt']),
                'keep_ratio': float(row['keep_ratio']),
                'distribution': row['distribution'],
                'mean_ns': float(row['mean_ns']),
                'median_ns': float(row['median_ns']),
                'p95_ns': float(row['p95_ns']),
                'ci_lower_ns': float(row['ci_lower_ns']),
                'ci_upper_ns': float(row['ci_upper_ns']),
                'survivors': int(row['survivors']),
                'bytes_total': _to_int(row, 'bytes_total', 'bytes_moved'),
            })
    return data

def group_by_params(data):
    groups = {}
    for row in data:
        key = (row['xcnt'], row['keep_ratio'], row['distribution'])
        if key not in groups:
            groups[key] = {}
        groups[key][row['method']] = row
    return groups

# -------------------- Plot Functions --------------------

def plot_time_vs_xcnt(data, output_dir):
    """Plot purge time vs array size for multiple keep_ratios."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Purge Time vs Array Size (scattered distribution)', fontsize=16, fontweight='bold')
    keep_ratios = [0.9, 0.5, 0.1, 0.01]

    for idx, keep_ratio in enumerate(keep_ratios):
        ax = axes[idx // 2, idx % 2]
        x_old, y_old, x_new, y_new = [], [], [], []

        for row in data:
            if approx(row['keep_ratio'], keep_ratio) and row['distribution'] == 'scattered':
                xcnt = row['xcnt']
                t = row['mean_ns'] / 1000.0
                if row['method'] == 'workspace':
                    x_old.append(xcnt); y_old.append(t)
                elif row['method'] == 'inplace':
                    x_new.append(xcnt); y_new.append(t)

        if x_old and x_new:
            xo, yo = zip(*sorted(zip(x_old, y_old)))
            xn, yn = zip(*sorted(zip(x_new, y_new)))
            ax.plot(xo, yo, '-o', label='OLD (workspace)', linewidth=2, markersize=7)
            ax.plot(xn, yn, '-^', label='NEW (inplace)', linewidth=2, markersize=7)
            ax.set_xscale('log'); ax.set_yscale('log')
            ax.set_title(f'keep_ratio = {keep_ratio} ({int(keep_ratio*100)}% survivors)')
            ax.set_xlabel('Array Size (xcnt)'); ax.set_ylabel('Time (µs)')
            ax.grid(True, alpha=0.3); ax.legend()
        else:
            ax.text(0.5, 0.5, 'No data', transform=ax.transAxes, ha='center', va='center')

    plt.tight_layout()
    out = output_dir / 'plot1_time_vs_xcnt.png'
    plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
    print(f"✓ Saved: {out}")

def plot_improvement_heatmap(data, output_dir):
    """Plot improvement percentage as a heatmap (scattered distribution)."""
    groups = group_by_params(data)
    
    # Get unique values
    xcnts = sorted(set(row['xcnt'] for row in data))
    keep_ratios = sorted(set(row['keep_ratio'] for row in data))
    
    # Create improvement matrix for scattered distribution - use NaN for missing data
    improvement_matrix = np.full((len(keep_ratios), len(xcnts)), np.nan)
    
    for i, keep_ratio in enumerate(keep_ratios):
        for j, xcnt in enumerate(xcnts):
            key = (xcnt, keep_ratio, 'scattered')
            if key in groups and 'workspace' in groups[key] and 'inplace' in groups[key]:
                old = groups[key]['workspace']['mean_ns']
                new = groups[key]['inplace']['mean_ns']
                improvement = safe_improvement(old, new)
                improvement_matrix[i, j] = improvement
    
    # Check if we have any valid data
    if np.all(np.isnan(improvement_matrix)):
        print("⚠️  No data for heatmap; skipping")
        return
    
    # Use masked array to handle NaN values properly
    masked = np.ma.masked_invalid(improvement_matrix)
    vmax = np.nanmax(improvement_matrix)
    
    # Plot heatmap
    fig, ax = plt.subplots(figsize=(12, 6))
    im = ax.imshow(masked, cmap='RdYlGn', aspect='auto', vmin=0, vmax=max(30, vmax))
    im.cmap.set_bad(color='#eeeeee')  # Gray for missing data
    
    # Set ticks and labels
    ax.set_xticks(np.arange(len(xcnts)))
    ax.set_yticks(np.arange(len(keep_ratios)))
    ax.set_xticklabels([fmt_size_short(x) for x in xcnts])
    ax.set_yticklabels([f'{r:.2f}' for r in keep_ratios])
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Improvement (%)', rotation=270, labelpad=20, fontsize=12)
    
    # Add text annotations only for valid cells
    for i in range(len(keep_ratios)):
        for j in range(len(xcnts)):
            val = improvement_matrix[i, j]
            if not np.isnan(val):
                ax.text(j, i, f'{val:.1f}%', ha="center", va="center", color="black", fontsize=9)
    
    ax.set_xlabel('Array Size (xcnt)', fontsize=12)
    ax.set_ylabel('Keep Ratio (fraction of survivors)', fontsize=12)
    ax.set_title('Improvement: NEW (inplace) vs OLD (workspace) - Scattered Distribution', 
                 fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    out = output_dir / 'plot2_improvement_heatmap.png'
    plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
    print(f"✓ Saved: {out}")

def plot_memory_traffic(data, output_dir):
    """Plot memory traffic comparison at keep_ratio=0.5 (scattered)."""
    groups = group_by_params(data)
    xcnts = sorted(set(row['xcnt'] for row in data if approx(row['keep_ratio'], 0.5)))
    
    old_bytes, new_bytes = [], []
    for xcnt in xcnts:
        key = (xcnt, 0.5, 'scattered')
        if key in groups and 'workspace' in groups[key] and 'inplace' in groups[key]:
            old_bytes.append(groups[key]['workspace']['bytes_total'] / 1024 / 1024)  # MB
            new_bytes.append(groups[key]['inplace']['bytes_total'] / 1024 / 1024)    # MB
    
    if not old_bytes or not new_bytes:
        print("⚠️  No data for memory traffic plot; skipping")
        return
    
    # Create bar plot
    fig, ax = plt.subplots(figsize=(10, 6))
    x = np.arange(len(xcnts))
    width = 0.35
    
    bars1 = ax.bar(x - width/2, old_bytes, width, label='OLD (workspace)', color='#d62728')
    bars2 = ax.bar(x + width/2, new_bytes, width, label='NEW (inplace)', color='#2ca02c')
    
    ax.set_xlabel('Array Size (xcnt)', fontsize=12)
    ax.set_ylabel('Memory Traffic (MB)', fontsize=12)
    ax.set_title('Memory Traffic: OLD vs NEW (keep_ratio=0.5, scattered)', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels([fmt_size_short(xcnt) for xcnt in xcnts])
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars (skip very small values to avoid overlap)
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            if height >= 0.005:  # Label if height >= 5KB for better coverage
                ax.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}',
                        ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    out = output_dir / 'plot3_memory_traffic.png'
    plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
    print(f"✓ Saved: {out}")

def plot_distribution_comparison(data, output_dir):
    """Compare scattered vs contiguous distributions (xcnt=10k, keep_ratio=0.5)."""
    groups = group_by_params(data)
    xcnt, keep_ratio = 10000, 0.5
    
    # Check if data exists for this scenario
    has_data = False
    for dist in ['scattered', 'contiguous']:
        key = (xcnt, keep_ratio, dist)
        if key in groups and ('workspace' in groups[key] or 'inplace' in groups[key]):
            has_data = True
            break
    
    if not has_data:
        print(f"⚠️  No data for distribution comparison (xcnt={xcnt}, keep_ratio={keep_ratio}); skipping")
        return
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    fig.suptitle(f'Distribution Impact (xcnt={fmt_size_short(xcnt)}, keep_ratio={keep_ratio})',
                 fontsize=14, fontweight='bold')
    
    distributions = ['scattered', 'contiguous']
    
    # OLD method comparison
    old_times = []
    for dist in distributions:
        key = (xcnt, keep_ratio, dist)
        if key in groups and 'workspace' in groups[key]:
            old_times.append(groups[key]['workspace']['mean_ns'] / 1000)
    
    x = np.arange(len(distributions))
    bars = ax1.bar(x, old_times, color=['#1f77b4', '#ff7f0e'])
    ax1.set_ylabel('Time (microseconds)', fontsize=11)
    ax1.set_title('OLD Method (workspace)', fontsize=12)
    ax1.set_xticks(x)
    ax1.set_xticklabels(distributions)
    ax1.grid(True, alpha=0.3, axis='y')
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height, f'{height:.0f}',
                 ha='center', va='bottom', fontsize=10)
    
    # NEW method comparison
    new_times = []
    for dist in distributions:
        key = (xcnt, keep_ratio, dist)
        if key in groups and 'inplace' in groups[key]:
            new_times.append(groups[key]['inplace']['mean_ns'] / 1000)
    
    bars = ax2.bar(x, new_times, color=['#1f77b4', '#ff7f0e'])
    ax2.set_ylabel('Time (microseconds)', fontsize=11)
    ax2.set_title('NEW Method (inplace)', fontsize=12)
    ax2.set_xticks(x)
    ax2.set_xticklabels(distributions)
    ax2.grid(True, alpha=0.3, axis='y')
    for bar in bars:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height, f'{height:.0f}',
                 ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    out = output_dir / 'plot4_distribution_comparison.png'
    plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
    print(f"✓ Saved: {out}")

def plot_summary_dashboard(data, output_dir):
    """Create summary dashboard using keep_ratio≈0.5 (balanced scenario)."""
    groups = group_by_params(data)
    kr_target = 0.5  # fixed mid-point scenario
    kr_used = None
    for kr in sorted(set(row['keep_ratio'] for row in data)):
        if approx(kr, kr_target):
            kr_used = kr
            break
    if kr_used is None:
        kr_used = sorted(set(row['keep_ratio'] for row in data))[0]
        print(f"⚠️ keep_ratio=0.5 not found, using closest: {kr_used}")

    fig = plt.figure(figsize=(16, 10))
    gs = gridspec.GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3)
    fig.suptitle('Tier 1 Micro-Benchmark: Summary Dashboard', fontsize=18, fontweight='bold', y=0.98)

    # --- Plot 1: Performance vs Array Size ---
    ax1 = fig.add_subplot(gs[0, :2])
    x_old, y_old, x_new, y_new = [], [], [], []
    for row in data:
        if approx(row['keep_ratio'], kr_used) and row['distribution'] == 'scattered':
            x = row['xcnt']; y = row['mean_ns'] / 1000
            if row['method'] == 'workspace': x_old.append(x); y_old.append(y)
            else: x_new.append(x); y_new.append(y)

    if x_old and x_new:
        xo, yo = zip(*sorted(zip(x_old, y_old)))
        xn, yn = zip(*sorted(zip(x_new, y_new)))
        ax1.plot(xo, yo, '-o', label='OLD (workspace)', linewidth=2, markersize=7)
        ax1.plot(xn, yn, '-^', label='NEW (inplace)', linewidth=2, markersize=7)
        ax1.set_xscale('log'); ax1.set_yscale('log')
        ax1.set_xlabel('Array Size (xcnt)'); ax1.set_ylabel('Time (µs)')
        ax1.set_title(f'Performance vs Array Size (keep_ratio≈{kr_used})', fontsize=12, fontweight='bold')
        ax1.legend(); ax1.grid(True, alpha=0.3)
    else:
        ax1.text(0.5, 0.5, 'No comparable data', ha='center', va='center', transform=ax1.transAxes)

    # --- Plot 2: Speed-up by Size ---
    ax2 = fig.add_subplot(gs[0, 2])
    x_labels, improvements = [], []
    for (xcnt, kr, dist), methods in groups.items():
        if dist == 'scattered' and approx(kr, kr_used) and 'workspace' in methods and 'inplace' in methods:
            old, new = methods['workspace']['mean_ns'], methods['inplace']['mean_ns']
            imp = safe_improvement(old, new)
            if not np.isnan(imp):
                x_labels.append(xcnt); improvements.append(imp)
    if improvements:
        order = np.argsort(x_labels)
        x_labels = np.array(x_labels)[order]; improvements = np.array(improvements)[order]
        ax2.barh(range(len(improvements)), improvements, color='#2ca02c')
        ax2.set_yticks(range(len(improvements)))
        ax2.set_yticklabels([fmt_size_short(x) for x in x_labels])
        ax2.axvline(x=25, color='red', linestyle='--', linewidth=1, label='Target (25%)')
        ax2.legend(fontsize=8)
        ax2.set_xlabel('Improvement (%)')
        ax2.set_title('Speedup by Size', fontsize=11, fontweight='bold')
        ax2.grid(True, alpha=0.3, axis='x')
    else:
        ax2.text(0.5, 0.5, 'No comparable data', transform=ax2.transAxes, ha='center')

    # --- Plot 3: Memory Traffic Comparison ---
    ax3 = fig.add_subplot(gs[1, :])
    xcnts = sorted(set(row['xcnt'] for row in data if approx(row['keep_ratio'], kr_used)))
    old_mb, new_mb = [], []
    for x in xcnts:
        key = (x, kr_used, 'scattered')
        if key in groups and 'workspace' in groups[key] and 'inplace' in groups[key]:
            old_mb.append(groups[key]['workspace']['bytes_total']/1024/1024)
            new_mb.append(groups[key]['inplace']['bytes_total']/1024/1024)
    if old_mb:
        x = np.arange(len(xcnts))
        ax3.bar(x - 0.35/2, old_mb, 0.35, label='OLD')
        ax3.bar(x + 0.35/2, new_mb, 0.35, label='NEW')
        ax3.set_xticks(x); ax3.set_xticklabels([fmt_size_short(v) for v in xcnts])
        ax3.set_xlabel('Array Size'); ax3.set_ylabel('Memory Traffic (MB)')
        ax3.set_title(f'Memory Traffic Comparison (keep_ratio≈{kr_used})', fontsize=12, fontweight='bold')
        ax3.grid(True, axis='y', alpha=0.3); ax3.legend()
    else:
        ax3.text(0.5, 0.5, 'No data', transform=ax3.transAxes, ha='center')

    # --- Plot 4: Summary Statistics ---
    ax4 = fig.add_subplot(gs[2, :]); ax4.axis('off')
    improvements_large = [safe_improvement(groups[k]['workspace']['mean_ns'],
                                           groups[k]['inplace']['mean_ns'])
                          for k in groups if k[1] == kr_used and k[2] == 'scattered'
                          and 'workspace' in groups[k] and 'inplace' in groups[k]]
    avg_imp = np.nanmean(improvements_large) if improvements_large else float('nan')
    min_imp = np.nanmin(improvements_large) if improvements_large else float('nan')
    max_imp = np.nanmax(improvements_large) if improvements_large else float('nan')
    
    # Compute memory traffic reduction averaged across all sizes (not just largest)
    reductions = []
    largest_reduct = None
    if xcnts:
        for x in xcnts:
            key = (x, kr_used, 'scattered')
            if key in groups and 'workspace' in groups[key] and 'inplace' in groups[key]:
                old_b = groups[key]['workspace']['bytes_total']
                new_b = groups[key]['inplace']['bytes_total']
                if old_b > 0:
                    r = 100.0 * (old_b - new_b) / old_b
                    reductions.append(r)
        mem_reduct_avg = float(np.mean(reductions)) if reductions else 0.0
        
        # Also keep the largest-size number for transparency
        x_star = xcnts[-1]
        key = (x_star, kr_used, 'scattered')
        if key in groups and 'workspace' in groups[key] and 'inplace' in groups[key]:
            old_b = groups[key]['workspace']['bytes_total']
            new_b = groups[key]['inplace']['bytes_total']
            largest_reduct = (100.0 * (old_b - new_b) / old_b) if old_b > 0 else 0.0
    else:
        mem_reduct_avg = 0.0
        largest_reduct = None
    
    # Build memory traffic line with both average and largest-size reduction
    mem_line = f"• Memory Traffic:       {mem_reduct_avg:.0f}% avg reduction"
    if largest_reduct is not None and abs(largest_reduct - mem_reduct_avg) > 1.0:
        mem_line += f" ({largest_reduct:.0f}% at largest size)"
    
    check = '[PASS]' if avg_imp >= 25 else '[WARN]'
    text = f"""
KEY FINDINGS (distribution=scattered, keep_ratio≈{kr_used}):
• Average Improvement:  {avg_imp:.1f}%
• Improvement Range:    {min_imp:.1f}% to {max_imp:.1f}%
{mem_line}

SUCCESS CRITERIA:
{check} Target: ≥25% improvement at larger sizes

CONCLUSION:
The NEW (inplace) method consistently reduces time and memory traffic 
under balanced (50%) survivor workloads, making it a robust general optimization.
"""
    ax4.text(0.5, 0.5, text, transform=ax4.transAxes,
             fontsize=11, verticalalignment='center',
             horizontalalignment='center',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3),
             family='monospace')

    out = output_dir / 'plot5_summary_dashboard.png'
    plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
    print(f"✓ Saved: {out}")

# -------------------- Main Entry --------------------

def main():
    parser = argparse.ArgumentParser(description="Visualize Tier 1 micro-benchmark results")
    parser.add_argument('--csv', type=Path, required=True, help='Path to benchmark CSV')
    parser.add_argument('--out', type=Path, default=Path('./plots'), help='Output directory')
    args = parser.parse_args()

    csv_file = args.csv; output_dir = args.out
    if not csv_file.exists():
        print(f"Error: CSV not found: {csv_file}"); sys.exit(1)
    output_dir.mkdir(parents=True, exist_ok=True)

    data = read_csv(csv_file)
    print(f"Loaded {len(data)} data points\n")
    print("Generating visualizations...\n")
    
    plot_time_vs_xcnt(data, output_dir)
    plot_improvement_heatmap(data, output_dir)
    plot_memory_traffic(data, output_dir)
    plot_distribution_comparison(data, output_dir)
    plot_summary_dashboard(data, output_dir)
    
    print("\n" + "="*80)
    print(f"✓ All plots saved to: {output_dir}")
    print()
    print("Generated plots:")
    print("  1. plot1_time_vs_xcnt.png            - Time vs array size for different keep_ratios")
    print("  2. plot2_improvement_heatmap.png     - Improvement heatmap (xcnt × keep_ratio)")
    print("  3. plot3_memory_traffic.png          - Memory traffic comparison")
    print("  4. plot4_distribution_comparison.png - Scattered vs contiguous impact")
    print("  5. plot5_summary_dashboard.png       - Complete dashboard with key metrics")
    print("="*80)

if __name__ == '__main__':
    main()