#!/usr/bin/env python3
"""Generate comparison charts for step*.csv benchmark results."""

import pandas as pd
import matplotlib.pyplot as plt
import glob
import os

RESULTS_DIR = os.path.dirname(os.path.abspath(__file__)) + '/results'

# Find all step*.csv files
csv_files = sorted(glob.glob(f'{RESULTS_DIR}/*.csv'))

# Filter to main steps only
main_steps = [f'patch-{i}' for i in [1, 2, 3, 4, 5]]
csv_files = [f for f in csv_files if any(s in f for s in main_steps)]

print(f"Found {len(csv_files)} step CSV files:")
for f in csv_files:
    print(f"  - {os.path.basename(f)}")

# Load all data as pivoted tables
all_data = []
for csv_file in csv_files:
    df = pd.read_csv(csv_file)
    step_name = os.path.basename(csv_file).replace('.csv', '')
    
    pivot = df.pivot_table(index=['pattern', 'num_buffers'], 
                           columns='benchmark', values='median_ns').reset_index()
    pivot['step'] = step_name
    
    # Compute derived metrics
    if 'read' in pivot.columns and 'resowner' in pivot.columns:
        pivot['read-resowner'] = pivot['read'] - pivot['resowner']
    
    all_data.append(pivot)

data = pd.concat(all_data, ignore_index=True)

# Define nice labels
step_labels = {
    'v4-pg18-baseline': 'PostgreSQL 18 Baseline',
    'patch-1': 'Patch 1: Pg19 Baseline',
    'patch-2': 'Patch 2: Refactoring',
    'patch-3': 'Patch 3: Simple hash',
    'patch-4': 'Patch 4: Comopact entry',
    'patch-5': 'Patch 5: No array',
}

# Color scheme
colors = {k: c for k, c in zip(
    step_labels.keys(), 
    ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
)}

def plot_benchmark(benchmark_name, title, output_name):
    """Plot comparison for a specific benchmark."""
    if benchmark_name not in data.columns:
        print(f"Skipping {benchmark_name}: column not found")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(title, fontsize=14, fontweight='bold')
    
    for idx, pattern in enumerate(['sequential', 'random']):
        ax = axes[idx]
        ax.set_title(f'{pattern.capitalize()} Access Pattern')
        ax.set_xlabel('Number of Buffers')
        ax.set_ylabel('Time (ns)')
        ax.set_xscale('log')
        ax.grid(True, alpha=0.3)
        
        for step in main_steps:
            subset = data[(data['pattern'] == pattern) & (data['step'] == step)]
            if len(subset) > 0 and benchmark_name in subset.columns:
                ax.plot(subset['num_buffers'], subset[benchmark_name], 
                       label=step_labels.get(step, step),
                       color=colors.get(step, 'gray'),
                       linewidth=2, marker='o', markersize=3)
        
        ax.legend(loc='upper left', fontsize=9)
    
    plt.tight_layout()
    output_path = f'{RESULTS_DIR}/{output_name}.svg'
    plt.savefig(output_path, format='svg', bbox_inches='tight')
    plt.savefig(output_path.replace('.svg', '.png'), format='png', dpi=150, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()

# Generate three comparison figures
plot_benchmark('read-resowner', 'Read/Release excluding Resowner', 'compare-read-resowner')
plot_benchmark('read', 'Read/Release Performance Comparison', 'compare-read')
plot_benchmark('resowner', 'ResourceOwner Performance Comparison', 'compare-resowner')
plot_benchmark('pinning', 'Pin/Unpin Performance Comparison', 'compare-pinning')
plot_benchmark('locking', 'Lock/Unlock Performance Comparison', 'compare-locking')

print("\nDone! Generated comparison charts.")
