#!/usr/bin/env python3

# Shows a corruption in gist indexes, when multiple threads are involved.

CONNECTION_STRING="dbname=duncan" # CHANGE THIS TO SUIT YOU.
# You will also need to do: CREATE EXTENSION btree_gist;


import datetime
import psycopg2.extras
import threading

NUM_THREADS = 32
META_IDS_PER_THREAD = 100
CHECK_EVERY = 100

with psycopg2.connect(CONNECTION_STRING) as conn:
    with conn.cursor() as cursor:
        cursor.execute(
            """
            DROP TABLE IF EXISTS files CASCADE;
            CREATE TABLE files (
              id INTEGER PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
              meta INTEGER NOT NULL
            );
    
            DROP TABLE IF EXISTS sections CASCADE;
            CREATE TABLE sections (
               id INTEGER PRIMARY KEY,
               file INTEGER NOT NULL REFERENCES files,
               valid DATERANGE NOT NULL,
               EXCLUDE USING gist(file WITH =, valid WITH &&)
            );
            """
        )

found_problem = False

class Worker(threading.Thread):
    def __init__(self, lower, upper): # upper not included
        super().__init__()
        self.lower = lower
        self.upper = upper

    def run(self):
        global found_problem

        try:
            check_counter = 0
            metaid = self.lower
    
            with psycopg2.connect(CONNECTION_STRING) as conn:
                with conn.cursor() as cursor:
                    while not found_problem:
                        metaid = metaid + 1
                        if metaid >= self.upper:
                            metaid = self.lower
        
                        cursor.execute(
                            "SELECT sections.id, valid FROM sections"
                            " JOIN files ON sections.file = files.id WHERE meta=%s"
                            " ORDER BY files.id DESC LIMIT 1",
                            (metaid, ))
                        r = cursor.fetchone()
        
                        if not r:
                            cursor.execute("BEGIN")
                            cursor.execute(
                                "INSERT INTO files(meta) VALUES (%s) RETURNING id",
                                (metaid, )
                            )
                            fileid = cursor.fetchone()
                            cursor.execute(
                                "INSERT INTO sections(id,file,valid) VALUES ("
                                " %s, %s, '(,)')",
                                (fileid, fileid, )
                            )
                            cursor.execute("COMMIT")
        
                            cursor.execute(
                                "SELECT sections.id, valid FROM sections"
                                " JOIN files ON sections.file = files.id WHERE meta=%s"
                                " ORDER BY files.id DESC LIMIT 1",
                                (metaid, ))
                            r = cursor.fetchone()
                            if not r:
                                found_problem = True
                                return
        
                        row, valid = r
        
                        cursor.execute("BEGIN")
        
                        if valid.lower:
                            mid = valid.lower + datetime.timedelta(days=1)
                        else:
                            mid = datetime.date(1970, 1, 1)
                        v1 = psycopg2.extras.DateRange(
                            valid.lower, mid, '[)', empty=valid.lower and valid.lower >= mid)
                        v2 = psycopg2.extras.DateRange(
                            mid, valid.upper, '[)', empty=valid.upper and mid >= valid.upper)
        
                        add_1 = not v1.isempty
                        add_2 = not v2.isempty
        
                        # split ranges
                        if add_1 or add_2:
                            cursor.execute(
                                "DELETE FROM sections WHERE id=%s", (row, ))
                            if add_1:
                                cursor.execute(
                                    "INSERT INTO files(meta) VALUES (%s) RETURNING id",
                                    (metaid, )
                                )
                                file1 = cursor.fetchone()
                                cursor.execute(
                                    "INSERT INTO sections(id,file,valid) VALUES ("
                                    " %s, %s, %s::daterange)",
                                    (file1, file1, v1))
                            if add_2:
                                cursor.execute(
                                    "INSERT INTO files(meta) VALUES (%s) RETURNING id",
                                    (metaid, )
                                )
                                file2 = cursor.fetchone()
                                cursor.execute(
                                    "INSERT INTO sections(id,file,valid) VALUES ("
                                    " %s, %s, %s::daterange)",
                                    (file2, file2, v2))
        
                        cursor.execute("COMMIT")
        
                        check_counter = check_counter + 1
                        if check_counter % CHECK_EVERY == 0:
                            cursor.execute(
                                "SELECT id FROM files WHERE meta>=%s AND meta<%s",
                                (self.lower, self.upper))
                            ids = [row[0] for row in cursor]
                            for i in ids:
                                cursor.execute("SELECT * FROM sections WHERE id=%s", (i, ))
                                has1 = cursor.fetchone() is not None 
                                cursor.execute("SELECT * FROM sections WHERE file=%s", (i, ))
                                has2 = cursor.fetchone() is not None 
                                if has1 != has2:
                                    found_problem = True
                                    return
        except:
            found_problem = True
            raise
    

workers = [Worker(j*META_IDS_PER_THREAD, (j+1)*META_IDS_PER_THREAD)  for j in range(0, NUM_THREADS)]
for w in workers:
    w.start()
for w in workers:
    w.join()

with psycopg2.connect(CONNECTION_STRING) as conn:
    with conn.cursor() as cursor:
        cursor.execute("SELECT id FROM files")
        ids = [row[0] for row in cursor]
        for i in ids:
            cursor.execute("SELECT * FROM sections WHERE id=%s", (i, ))
            has1 = cursor.fetchone() is not None 
            cursor.execute("SELECT * FROM sections WHERE file=%s", (i, ))
            has2 = cursor.fetchone() is not None 
            if has1 != has2:
                print("inconsistent: %s" %i)
