/*-------------------------------------------------------------------------
 *
 * snapbuild_bench.c
 *    Microbenchmark for SnapBuildPurgeOlderTxn optimization
 *
 * This extension provides a unit-level benchmark for the committed.xip
 * purging logic in SnapBuildPurgeOlderTxn. It compares:
 *   - OLD method: allocate workspace + memcpy (lines 874-897 in snapbuild.c)
 *   - NEW method: in-place two-pointer compaction (from patch)
 *
 * Uses actual PostgreSQL data structures (TransactionId arrays) and
 * functions (NormalTransactionIdPrecedes) to measure real-world performance.
 *
 *-------------------------------------------------------------------------
 */

#include "postgres.h"
#include "fmgr.h"
#include "funcapi.h"
#include "access/transam.h"
#include "utils/builtins.h"
#include "utils/memutils.h"
#include "portability/instr_time.h"
#include "common/pg_prng.h"
#include <math.h>
#include <string.h>

PG_MODULE_MAGIC;

/*
 * Mimics the OLD purge logic from snapbuild.c lines 874-897 (pre-patch)
 * Allocates a temporary workspace, copies survivors, then copies back.
 */
static int
purge_old_method(TransactionId *xip, int xcnt, TransactionId xmin,
				 MemoryContext context, int64 *bytes_read_out, int64 *bytes_written_out)
{
	TransactionId *workspace;
	int			surviving_xids = 0;
	int			off;

	/* Allocate temporary workspace - this is the key cost we're measuring */
	workspace = MemoryContextAlloc(context, xcnt * sizeof(TransactionId));

	/* Copy xids that still are interesting to workspace */
	for (off = 0; off < xcnt; off++)
	{
		if (NormalTransactionIdPrecedes(xip[off], xmin))
			;					/* remove */
		else
			workspace[surviving_xids++] = xip[off];
	}

	/* Copy workspace back to persistent state */
	memcpy(xip, workspace, surviving_xids * sizeof(TransactionId));

	pfree(workspace);

	/*
	 * Memory traffic accounting (fair comparison):
	 * READS:
	 *  - Read entire array during scan: xcnt * sizeof(TransactionId)
	 *  - Read workspace during copy back: surviving_xids * sizeof(TransactionId)
	 * WRITES:
	 *  - Write to workspace: surviving_xids * sizeof(TransactionId)
	 *  - Write back to original: surviving_xids * sizeof(TransactionId)
	 */
	*bytes_read_out = (int64)(xcnt + surviving_xids) * sizeof(TransactionId);
	*bytes_written_out = (int64)surviving_xids * 2 * sizeof(TransactionId);

	return surviving_xids;
}

/*
 * Implements the NEW in-place compaction from the patch (lines 38-49)
 * Uses a two-pointer algorithm: read pointer 'off', write pointer 'surviving_xids'
 */
static int
purge_new_method(TransactionId *xip, int xcnt, TransactionId xmin,
				 int64 *bytes_read_out, int64 *bytes_written_out)
{
	int			off;
	int			surviving_xids = 0;

	/* Use in-place compaction to remove xids < xmin */
	for (off = 0; off < xcnt; off++)
	{
		if (!NormalTransactionIdPrecedes(xip[off], xmin))
			xip[surviving_xids++] = xip[off];
	}

	/*
	 * Memory traffic accounting (fair comparison):
	 * READS:
	 *  - Read entire array during scan: xcnt * sizeof(TransactionId)
	 * WRITES:
	 *  - Write survivors in-place: surviving_xids * sizeof(TransactionId)
	 */
	*bytes_read_out = (int64)xcnt * sizeof(TransactionId);
	*bytes_written_out = (int64)surviving_xids * sizeof(TransactionId);

	return surviving_xids;
}

/*
 * Generate test input array with controlled sparsity
 *
 * xcnt: total number of XIDs
 * keep_ratio: fraction to keep (0.0 = remove all, 1.0 = keep all)
 * distribution: 
 *   - 'contiguous': all keepers at the end (best case for branch predictor)
 *   - 'scattered': Fisher-Yates shuffle (worst case for branch predictor)
 * xmin_out: output parameter for the xmin threshold
 * seed: random seed for reproducibility
 */
static void
generate_test_input(TransactionId *xip, int xcnt, double keep_ratio,
					const char *distribution, TransactionId *xmin_out,
					unsigned int seed)
{
	int			survivors_target = (int)(xcnt * keep_ratio + 0.5);
	int			i;
	pg_prng_state rng;

	/* Validate distribution parameter */
	if (pg_strcasecmp(distribution, "contiguous") != 0 &&
		pg_strcasecmp(distribution, "scattered") != 0)
		ereport(ERROR,
				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
				 errmsg("distribution must be 'contiguous' or 'scattered'")));

	/* Initialize PostgreSQL's PRNG (process-local, doesn't affect global state) */
	pg_prng_seed(&rng, seed);

	if (pg_strcasecmp(distribution, "contiguous") == 0)
	{
		/*
		 * Contiguous: Put all keepers at the end
		 * XIDs 1..(xcnt - survivors_target) will be removed
		 * XIDs (xcnt - survivors_target + 1)..xcnt will survive
		 */
		for (i = 0; i < xcnt; i++)
			xip[i] = (TransactionId)(i + 1);
		
		*xmin_out = (TransactionId)(xcnt - survivors_target + 1);
	}
	else /* scattered */
	{
		/*
		 * Scattered: Mix removed and kept XIDs throughout the array
		 * This forces the branch predictor to mispredict frequently
		 * 
		 * Strategy: Generate XIDs, then mark ~keep_ratio of them as >= xmin
		 * by shuffling or using a pattern
		 */
		int removals_target = xcnt - survivors_target;
		
		/* Generate sequential XIDs */
		for (i = 0; i < xcnt; i++)
			xip[i] = (TransactionId)(i + 1);

		/*
		 * Set xmin to split the range roughly at keep_ratio
		 * But then shuffle the array to scatter survivors and victims
		 */
		*xmin_out = (TransactionId)(removals_target + 1);

		/* Fisher-Yates shuffle to scatter XIDs */
		for (i = xcnt - 1; i > 0; i--)
		{
			uint32		j = pg_prng_uint32(&rng) % (uint32)(i + 1);
			TransactionId tmp = xip[i];
			xip[i] = xip[j];
			xip[j] = tmp;
		}
	}
}

/*
 * Comparator for qsort (used to compute median and p95)
 */
static int
compare_doubles(const void *a, const void *b)
{
	double		da = *(const double *)a;
	double		db = *(const double *)b;

	if (da < db)
		return -1;
	if (da > db)
		return 1;
	return 0;
}

/*
 * Compute percentile from sorted array
 */
static double
percentile_double(const double *sorted_data, int n, double p)
{
	double		idx;
	int			lo, hi;

	if (n == 0)
		return 0.0;
	if (n == 1)
		return sorted_data[0];

	/* p=0.5 -> median, p=0.95 -> P95 */
	idx = p * (n - 1);
	lo = (int)floor(idx);
	hi = (int)ceil(idx);

	if (lo == hi)
		return sorted_data[lo];

	/* Linear interpolation */
	return sorted_data[lo] + (sorted_data[hi] - sorted_data[lo]) * (idx - lo);
}

/*
 * t-critical value for 95% confidence interval
 * Uses Student's t-distribution with df = degrees of freedom
 */
static double
tcrit_95(int df)
{
	/* For df >= 30, use normal approximation (z = 1.96) */
	if (df >= 30)
		return 1.96;
	
	/* Lookup table for df = 1..29 (two-tailed, 95% confidence) */
	static const double table[] = {
		/* df: 1..30 -> index df-1 */
		12.706, 4.303, 3.182, 2.776, 2.571, 2.447, 2.365, 2.306, 2.262, 2.228,
		2.201, 2.179, 2.160, 2.145, 2.131, 2.120, 2.110, 2.101, 2.093, 2.086,
		2.080, 2.074, 2.069, 2.064, 2.060, 2.056, 2.052, 2.048, 2.045, 2.042
	};
	
	if (df < 1)
		df = 1;
	
	return table[df - 1];
}

/*
 * SQL-callable function: bench_purge(method, xcnt, keep_ratio, reps, distribution)
 */
PG_FUNCTION_INFO_V1(bench_purge_sql);
Datum
bench_purge_sql(PG_FUNCTION_ARGS)
{
	text	   *method_text = PG_GETARG_TEXT_PP(0);
	int64		xcnt = PG_GETARG_INT64(1);
	double		keep_ratio = PG_GETARG_FLOAT8(2);
	int			reps = PG_GETARG_INT32(3);
	text	   *distribution_text = PG_GETARG_TEXT_PP(4);

	char	   *method = text_to_cstring(method_text);
	char	   *distribution = text_to_cstring(distribution_text);

	TupleDesc	tupdesc;
	Datum		values[14];
	bool		nulls[14];
	HeapTuple	tuple;

	TransactionId *xip_master;
	TransactionId *xip_work;
	TransactionId xmin;
	double	   *timings;
	int			i;
	int			last_survivors = -1;
	int64		total_bytes_read = 0;
	int64		total_bytes_written = 0;
	MemoryContext bench_context;
	MemoryContext oldcontext;

	double		mean_ns, median_ns, p95_ns, stdev, stderr_val, ci_margin;
	double		sum_sq_diff;
	double		t_crit;

	/* Validate inputs */
	if (xcnt <= 0 || xcnt > 100000)  /* cap at 100k for realistic committed.xip sizes */
		ereport(ERROR,
				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
				 errmsg("xcnt must be between 1 and 100000")));

	if (keep_ratio < 0.0 || keep_ratio > 1.0)
		ereport(ERROR,
				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
				 errmsg("keep_ratio must be between 0.0 and 1.0")));

	if (reps <= 0 || reps > 10000)
		ereport(ERROR,
				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
				 errmsg("reps must be between 1 and 10000")));

	if (pg_strcasecmp(method, "workspace") != 0 &&
		pg_strcasecmp(method, "inplace") != 0)
		ereport(ERROR,
				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
				 errmsg("method must be 'workspace' or 'inplace'")));

	/* Create a memory context for benchmark allocations */
	bench_context = AllocSetContextCreate(CurrentMemoryContext,
										  "SnapBuild Bench Context",
										  ALLOCSET_DEFAULT_SIZES);
	oldcontext = MemoryContextSwitchTo(bench_context);

	/* Allocate master input array (reused across all reps) */
	xip_master = (TransactionId *) palloc(xcnt * sizeof(TransactionId));
	
	/* Generate input once with fixed seed for reproducibility */
	generate_test_input(xip_master, (int)xcnt, keep_ratio, distribution, &xmin, 42);

	/* Allocate array to store timings from each rep */
	timings = (double *) palloc(reps * sizeof(double));

	/* Run benchmark repetitions */
	for (i = 0; i < reps; i++)
	{
		instr_time	start, end;
		int			survivors;
		int64		bytes_read, bytes_written;

		/* Create a working copy so each rep starts with identical input */
		xip_work = (TransactionId *) palloc(xcnt * sizeof(TransactionId));
		memcpy(xip_work, xip_master, xcnt * sizeof(TransactionId));

		/* Time the purge operation */
		INSTR_TIME_SET_CURRENT(start);

		if (pg_strcasecmp(method, "workspace") == 0)
			survivors = purge_old_method(xip_work, (int)xcnt, xmin, bench_context, &bytes_read, &bytes_written);
		else /* inplace */
			survivors = purge_new_method(xip_work, (int)xcnt, xmin, &bytes_read, &bytes_written);

		INSTR_TIME_SET_CURRENT(end);
		INSTR_TIME_SUBTRACT(end, start);

		/* Record timing in nanoseconds */
		timings[i] = INSTR_TIME_GET_DOUBLE(end) * 1e9;

		/* Verify consistency across reps */
		if (last_survivors == -1)
			last_survivors = survivors;
		else if (survivors != last_survivors)
			ereport(ERROR,
					(errcode(ERRCODE_INTERNAL_ERROR),
					 errmsg("inconsistent survivor count across reps: %d vs %d",
							survivors, last_survivors)));

		total_bytes_read += bytes_read;
		total_bytes_written += bytes_written;
		pfree(xip_work);
	}

	/* Compute statistics */
	qsort(timings, reps, sizeof(double), compare_doubles);

	mean_ns = 0.0;
	for (i = 0; i < reps; i++)
		mean_ns += timings[i];
	mean_ns /= reps;

	median_ns = percentile_double(timings, reps, 0.5);
	p95_ns = percentile_double(timings, reps, 0.95);

	/* Standard deviation and 95% confidence interval */
	sum_sq_diff = 0.0;
	for (i = 0; i < reps; i++)
		sum_sq_diff += (timings[i] - mean_ns) * (timings[i] - mean_ns);

	stdev = (reps > 1) ? sqrt(sum_sq_diff / (reps - 1)) : 0.0;
	stderr_val = (reps > 0) ? stdev / sqrt(reps) : 0.0;

	/* t-critical value for 95% CI (exact for df = reps - 1) */
	t_crit = tcrit_95(reps - 1);
	ci_margin = t_crit * stderr_val;

	/* Build result tuple */
	if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE)
		ereport(ERROR,
				(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
				 errmsg("function returning record called in context that cannot accept type record")));

	BlessTupleDesc(tupdesc);

	memset(nulls, 0, sizeof(nulls));

	values[0] = CStringGetTextDatum(method);
	values[1] = Int64GetDatum(xcnt);
	values[2] = Float8GetDatum(keep_ratio);
	values[3] = CStringGetTextDatum(distribution);
	values[4] = Int32GetDatum(reps);
	values[5] = Float8GetDatum(mean_ns);
	values[6] = Float8GetDatum(median_ns);
	values[7] = Float8GetDatum(p95_ns);
	values[8] = Float8GetDatum(mean_ns - ci_margin);
	values[9] = Float8GetDatum(mean_ns + ci_margin);
	values[10] = Int64GetDatum(last_survivors);
	values[11] = Int64GetDatum(total_bytes_read / reps);
	values[12] = Int64GetDatum(total_bytes_written / reps);
	values[13] = Int64GetDatum((total_bytes_read + total_bytes_written) / reps);

	tuple = heap_form_tuple(tupdesc, values, nulls);

	/* Clean up */
	MemoryContextSwitchTo(oldcontext);
	MemoryContextDelete(bench_context);

	PG_RETURN_DATUM(HeapTupleGetDatum(tuple));
}

