#!/usr/bin/env python3

import os
from os import listdir
from os.path import isfile, join
import shutil
import re
import testgres
import subprocess
import statistics

res_names = {'btree_index': ['in_memory', 'on_disk'],
             'cluster': ['in_memory_pkey', 'in_memory_index',
                         'on_disk_pkey', 'on_disk_index'],
             'gist_index': ['in_memory', 'on_disk'],
             'hash_index': ['on_disk'],
             'tuplesort': ['in_memory', 'on_disk']}
result = {}

dir = os.getcwd()
test_dir = join(dir, "test")
out_dir = join(dir, "out")
result_dir = join(dir, "result")
test_files = [f for f in sorted(listdir(test_dir)) if isfile(join(test_dir, f)) and f.endswith(".sql")]

# test_files = ["cluster_test_1.sql","cluster_test_2.sql","cluster_test_3.sql"]
# test_files = ["hash_index_test_1.sql","hash_index_test_2.sql","hash_index_test_3.sql"]

shutil.rmtree(out_dir, ignore_errors=True)
os.mkdir(out_dir)
shutil.rmtree(result_dir, ignore_errors=True)
os.mkdir(result_dir)

testgres.configure_testgres(cache_pg_config = False, cache_initdb = False)

def run_psql(filename, node):
    # Set default arguments
    dbname = testgres.defaults.default_dbname()
    username = testgres.defaults.default_username()

    psql_params = [
        testgres.get_bin_path("psql"),
        "-p", str(node.port),
        "-h", node.host,
        "-U", username,
        "-X",  # no .psqlrc
        "-f", filename,
        "-a",
        dbname
    ]  # yapf: disable

    # start psql process
    process = subprocess.Popen(
        psql_params,
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT)

    # wait until it finishes and get stdout and stderr
    out, _ = process.communicate(input=None)
    return process.returncode, out

def run_tests(prefix, times=10):
    with testgres.get_new_node() as node:
        node.init()
        node.start()
        result[prefix] = {}
        for tf in test_files:
            pre, _ = os.path.splitext(tf)
            parts = pre.split('_')
            name = '_'.join(parts[0:-2])
            col_num = parts[-1]
            out_file = join(out_dir, '_'.join([prefix, pre]) + '.out')

            if not name in result[prefix]:
                result[prefix][name] = {}
            result[prefix][name][col_num] = {}

            vals={key: [] for key in res_names[name]}
            for i in range(times):
                print(prefix + ' ' + pre + ' (' + str(i + 1) + '/' + str(times) + '):', end='', flush=True)
                _, out = run_psql(join(test_dir, tf), node)
                with open(out_file, 'w') as f:
                    f.write(out.decode("utf-8"))

                j = 0
                for line in out.decode("utf-8").splitlines():
                    if name == 'tuplesort':
                        m = re.match(r'.*Sort \(actual time=.*\.\.(.*?) ', line)
                    else:
                        m = re.match(r'^Time: (.*)', line)

                    if m:
                        if name == 'tuplesort':
                            num = float(m.group(1))
                        else:
                            num = float(m.group(1).split(' ')[0].replace(',','.'))
                        vals[res_names[name][j]] += [num]
                        j += 1
                print(' Done', flush=True)
            for key in res_names[name]:
                result[prefix][name][col_num][key] = {}
                vals_file = os.path.join(result_dir, '_'.join([prefix, pre, key, 'vals.out']))
                mean_file = os.path.join(result_dir, '_'.join([prefix, pre, key, 'mean.out']))

                with open(vals_file, 'w') as vf:
                    vf.write("\n".join(sorted(["%.2f" % round(x, 2) for x in vals[key]])))
                with open(mean_file, 'w') as mf:
                    result[prefix][name][col_num][key]['mean'] = "%.2f" % statistics.mean(vals[key])
                    result[prefix][name][col_num][key]['stddev'] = "%.2f" % statistics.stdev(vals[key])
                    mf.write("mean: %s\n" % result[prefix][name][col_num][key]['mean'])
                    mf.write("stddev: %s\n" % result[prefix][name][col_num][key]['stddev'])
                    with open(join(test_dir, tf)) as f:
                        mf.write("test source:\n")
                        mf.write("".join(f.readlines()))
                        mf.write("\n")
            print('')
        node.stop()

os.environ['PATH'] = os.pathsep.join( [ join(dir, "pg16", "bin"), os.environ['PATH'] ] )
print(testgres.get_bin_path("postgres"))
run_tests("no_patch")

os.environ['PATH'] = os.pathsep.join( [ join(dir, "pg16_patch", "bin"), os.environ['PATH'] ] )
print(testgres.get_bin_path("postgres"))
run_tests("patch")

# result = {'no_patch': {'cluster': {'1': {'in_memory_pkey': {'mean': '166.59', 'stddev': '0.32'}, 'in_memory_index': {'mean': '167.37', 'stddev': '1.70'}, 'on_disk_pkey': {'mean': '389.10', 'stddev': '6.08'}, 'on_disk_index': {'mean': '396.86', 'stddev': '4.26'}}, '3': {'in_memory_pkey': {'mean': '37.94', 'stddev': '0.25'}, 'in_memory_index': {'mean': '36.58', 'stddev': '1.84'}, 'on_disk_pkey': {'mean': '63.08', 'stddev': '0.01'}, 'on_disk_index': {'mean': '58.66', 'stddev': '1.07'}}}}, 'patch': {'cluster': {'1': {'in_memory_pkey': {'mean': '167.22', 'stddev': '0.93'}, 'in_memory_index': {'mean': '165.10', 'stddev': '0.58'}, 'on_disk_pkey': {'mean': '390.14', 'stddev': '8.36'}, 'on_disk_index': {'mean': '386.18', 'stddev': '6.95'}}, '3': {'in_memory_pkey': {'mean': '36.50', 'stddev': '0.24'}, 'in_memory_index': {'mean': '35.19', 'stddev': '0.05'}, 'on_disk_pkey': {'mean': '60.59', 'stddev': '2.01'}, 'on_disk_index': {'mean': '58.87', 'stddev': '3.20'}}}}}
# print(result)


headers = ["test name", "mean", "stddev", "mean", "stddev"]
max_len = [len(x) for x in headers]
for prefix in result.keys():
    off = 1 if prefix == 'no_patch' else 3
    for name in result[prefix].keys():
        for col_num in result[prefix][name].keys():
            for key in result[prefix][name][col_num].keys():
                row_title = name + ' ' + col_num + ' ' + key
                if len(row_title) > max_len[0]:
                    max_len[0] = len(row_title)
                if len(result[prefix][name][col_num][key]['mean']) > max_len[off + 0]:
                    max_len[off + 0] = len(result[prefix][name][col_num][key]['mean'])
                if len(result[prefix][name][col_num][key]['stddev']) > max_len[off + 1]:
                    max_len[off + 1] = len(result[prefix][name][col_num][key]['stddev'])

top = ["".center(max_len[0]), "no patch".center(max_len[1] + max_len[2] + 3),
                              "patch".center(max_len[3] + max_len[4] + 3)]
with open(os.path.join(result_dir, 'final_table.txt'), 'w') as f:
    f.write("| %s | %s | %s |\n" % tuple(top))
    f.write("| %s | %s | %s | %s | %s |\n" % tuple(header.center(len) for header, len in zip(headers, max_len)))
    f.write("|-%s-+-%s-+-%s-+-%s-+-%s-|\n" % tuple('-'*x for x in max_len))
    for name in result[prefix].keys():
        for col_num in result[prefix][name].keys():
            for key in result[prefix][name][col_num].keys():
                row = [(name + ' ' + col_num + ' ' + key).ljust(max_len[0]),
                    result['no_patch'][name][col_num][key]['mean'].center(max_len[1]),
                    result['no_patch'][name][col_num][key]['stddev'].center(max_len[2]),
                    result['patch'][name][col_num][key]['mean'].center(max_len[3]),
                    result['patch'][name][col_num][key]['stddev'].center(max_len[4])]
                f.write("| %s | %s | %s | %s | %s |\n" % tuple(row))
exit(0)