from time import time
from sys import argv
import psycopg2
# import psycopg2.extras

"""
Usage
Set the input_path, output_path and conn_params in the main guard below.
 
To create the test table
python3 bytea_test.py c

To insert a file in the input_path N times
python3 bytea_test.py w <N>

To read the data from the table and write the first file to output_path
python3 bytea_test.py r
"""


def read_binary(input_path):
    with open(input_path, 'rb') as f:
        return f.read()


def create_table(conn_params):
    conn = psycopg2.connect(**conn_params)
    cursor = conn.cursor()
    
    cursor.execute("""CREATE TABLE IF NOT EXISTS BYTEA_TABLE (data BYTEA);""",)
    
    conn.commit()
    cursor.close()
    conn.close()


def save_file(binary_data, duplicate, conn_params):
    conn = psycopg2.connect(**conn_params)
    cursor = conn.cursor()

    data = psycopg2.Binary(binary_data)
    
    # rows = [(data,) for _ in range(duplicate)]
    # start = time()
    # psycopg2.extras.execute_values(cursor, """INSERT INTO bytea_table (data) VALUES %s;""", rows)
    # print(time()-start)
    
    start = time()
    for _ in range(duplicate):
        cursor.execute("""INSERT INTO bytea_table (data) VALUES (%s);""", (data,))
    print(time()-start)

    conn.commit()
    cursor.close()
    conn.close()
    print(f"file saved {duplicate} times")


def retrieve_file(output_path, conn_params):
    conn = psycopg2.connect(**conn_params)
    cursor = conn.cursor()

    start = time()
    cursor.execute("""SELECT data FROM bytea_table;""")
    print(time()-start)
    
    result = cursor.fetchone()
    if result:
        binary_data = result[0]
        with open(output_path, 'wb') as f:
            f.write(binary_data)
        print(f"File retrieved, saved to '{output_path}'")
    else:
        print(f"ERROR")

    cursor.close()
    conn.close()


if __name__ == '__main__':
    input_path = 'INPUT/PATH/file.jpg'
    output_path = 'OUTPUT/PATCH/retrieved.jpg'
    conn_params = {
        'dbname': '',
        'user': '',
        'password': '',
        'host': '',
        'port': 0
    }

    if len(argv) == 1:
        print("Specify Mode")
    elif argv[1].lower() == "c":
        create_table(conn_params)
    elif argv[1].lower() == "r":
        retrieve_file(output_path, conn_params)
    elif argv[1].lower() == 'w' and int(argv[2]) > 0:
        save_file(read_binary(input_path), int(argv[2]), conn_params)
    else:
        print("Invalid Mode")
