#!/usr/bin/env python3
from util import *

def handle_argument(root):
    command_list = ['bin','startall','stopall','gdata','initdb','runselect','runcount','runselectall', 'increasedata', 'resetparent']
    debug_output('handle_argument start')
    root.arglist = sys.argv
    root.argnum = len(root.arglist)

    if root.argnum < 2:
        error_exit('Wrong call of this tool')
    
    if root.arglist[1] in command_list:
        root.commandkind = root.arglist[1]
    else:
        error_exit('Wrong call of this tool')

    if 'gdata' == root.commandkind:
        if root.argnum == 2:
            root.datascale = 1
        elif root.argnum == 3:
            if not root.arglist[2].isnumeric():
                error_exit('Wrong call of this tool')
            root.datascale = int(root.arglist[2])
    elif 'increasedata' == root.commandkind:
        if root.argnum == 2:
            root.increasescale = 1
        elif root.argnum == 3:
            if not root.arglist[2].isnumeric():
                error_exit('Wrong call of this tool')
            root.increasescale = int(root.arglist[2])
    debug_output('handle_argument end')

class Prepare:
    def __init__(self):
        self.op_install_path='/h1/opt/pg'
        self.cp_install_path='/h1/opt/pg_withpatch'
        self.remote_install_path='/h1/opt/pg'
        self.tarfilename = 'pg_bin.tar'
        self.datadir = '/h1/opt/pg/data'
        self.tt = 'pft'
        self.node_conf = 'node_conf'
        self.opresultfile = 'op.out'
        self.cpresultfile = 'cp.out'
        self.testtimes = 3
        self.remoteport = 5434
        self.dbsseverynode = 1
        self.first_generate_parent_sql = True
        #self.oneGscale = 1400
        self.oneGscale = 14000000
    
    def get_node_tuple(self, nodestr):
        blank_pos = nodestr.find(' ')
        if -1 == blank_pos:
            error_exit('Wrong node str')
        ip = nodestr[0:blank_pos]
        nodes = nodestr[blank_pos+1:]
        nodes = int(nodes)
        return (ip, nodes)
    
    def load_nodes(self):
       	#self.node_list = ['18.217.165.238','3.133.124.205','18.216.188.108','3.135.187.123','3.15.218.204']
        #self.node_list = ['18.217.165.238','3.133.124.205','18.216.188.108','3.135.187.123']
        #self.node_list = ['18.217.165.238','3.133.124.205','18.216.188.108']
        #self.node_list = ['18.217.165.238','3.133.124.205']
        self.node_list = []
        file = open(self.node_conf, 'r')
        for line in file.readlines():
            line = line.strip('\n')
            tuple_item = self.get_node_tuple(line)
            self.node_list.append(tuple_item)
        file.close()
        #print(self.node_list)
        self.node_user = 'ec2-user'
        self.keypath = '/home/ec2-user/highgokey.pem'
        self.check_dir()

    def check_dir(self):
        if not dir_exist(self.op_install_path):
            error_exit('op install path is not exists')
        if not dir_exist(self.cp_install_path):
            error_exit('cp install path is not exists')
    
    def clean_remote(self):
        for item_tuple in self.node_list:
            item = item_tuple[0]
            if dir_remote_exist(self.remote_install_path, self.node_user, item):
                cmd = 'rm -rf %s' % self.remote_install_path
                exec_command_system_remote(cmd,self.node_user, item)
            cmd = 'mkdir -p  %s' % self.remote_install_path
            exec_command_system_remote(cmd,self.node_user, item)
    
    def prepare_bin(self):
        if not file_exist(self.tarfilename):
            recordcwd = getcwd()
            cmd = 'cd %s;tar -czvf %s/%s .' % (self.op_install_path, recordcwd, self.tarfilename)
            exec_comand_system(cmd)
            cmd = 'cd %s' % (recordcwd)
            exec_comand_system(cmd)

        
        for item_tuple in self.node_list:
            item = item_tuple[0]
            cmd = 'scp -r -i ~/highgokey.pem %s %s@%s:%s' % (self.tarfilename, self.node_user, item, self.remote_install_path)
            exec_comand_system(cmd)

            cmd = 'cd %s;tar -xzvf %s' % (self.remote_install_path, self.tarfilename)
            exec_command_system_remote(cmd, self.node_user, item)

            sleep(1)
            cmd = 'rm -rf %s' % (self.datadir)
            exec_command_system_remote(cmd, self.node_user, item, False)

    def initdb_remote(self):
        loop = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                datadir = self.datadir + str(loop) 
                if dir_remote_exist(datadir, self.node_user, item):
                    cmd = 'rm -rf %s' % datadir
                    exec_command_system_remote(cmd,self.node_user, item)
                cmd = 'cd %s/bin; ./initdb %s' % (self.remote_install_path, datadir)
                exec_command_system_remote(cmd, self.node_user, item)
                loop += 1
            loop = 1

    def opg_op(self, kind):
        if kind != 'start' and kind != 'stop':
            error_exit('Wrong kind')
        cmd = 'cd %s/bin; ./pg_ctl %s -D ../data' % (self.op_install_path, kind)
        exec_comand_system(cmd,False, False)
    
    def cpg_op(self,kind):
        if kind != 'start' and kind != 'stop':
            error_exit('Wrong kind')
        cmd = 'cd %s/bin; ./pg_ctl %s -D ../data' % (self.cp_install_path, kind)
        exec_comand_system(cmd,False, False)

    def opg_start(self):
        self.opg_op('stop')
        self.cpg_op('stop')
        self.opg_op('start')

    def cpg_start(self):
        self.opg_op('stop')
        self.cpg_op('stop')
        self.cpg_op('start')
    
    def startall(self):
        loop = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while (loop <= item_tuple[1]):
                datadir = self.datadir + str(loop)
                cmd = 'cd %s/bin; ./pg_ctl start -D %s' % (self.remote_install_path, datadir)
                exec_command_system_remote(cmd, self.node_user, item)
                loop += 1
            loop = 1

    def stopall(self):
        loop = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while (loop <= item_tuple[1]):
                datadir = self.datadir + str(loop)
                cmd = 'cd %s/bin; ./pg_ctl stop -D %s' % (self.remote_install_path, datadir)
                exec_command_system_remote(cmd, self.node_user, item, False)
                loop += 1
            loop = 1
        cmd = 'cd %s/bin; ./pg_ctl stop -D ../data' % (self.op_install_path)
        exec_comand_system(cmd, False, False)
        cmd = 'cd %s/bin; ./pg_ctl stop -D ../data' % (self.cp_install_path)
        exec_comand_system(cmd, False, False)
    
    def general_parent_sql(self):
        sql_list = []
        sql_list.append('CREATE EXTENSION IF NOT EXISTS postgres_fdw;')
        sql_list.append('DROP TABLE IF EXISTS pft CASCADE;')
        loop = 1
        loop2 = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                sql_list.append('DROP SERVER IF EXISTS sv%s CASCADE;' % (loop + loop2 * 10))
                loop +=1
            loop2 += 1
            loop = 1
        
        loop = 1
        loop2 = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                sql_list.append('CREATE SERVER sv%s FOREIGN DATA WRAPPER postgres_fdw OPTIONS (host \'%s\', port \'%s\', dbname \'postgres\');' % (loop + loop2 * 10,item, self.remoteport + loop))
                loop +=1
            loop2 += 1
            loop = 1
        
        loop = 1
        loop2 = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                sql_list.append('CREATE USER MAPPING FOR public SERVER sv%s;' % (loop + loop2 * 10))
                loop +=1
            loop2 += 1
            loop = 1
        
        loop = 1
        loop2 = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                sql_list.append('CREATE FOREIGN TABLE pft%s (a int, b int, c text) SERVER sv%s;' % (loop + loop2 * 10, loop + loop2 * 10))
                loop +=1
            loop2 += 1
            loop = 1

        sql_list.append('CREATE TABLE pft (a int, b int, c text) PARTITION BY LIST(a);')
        
        loop = 1
        loop2 = 1
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while loop <= item_tuple[1]:
                sql_list.append('ALTER TABLE pft ATTACH PARTITION pft%s FOR VALUES IN (%s);' % (loop + loop2 * 10, loop + loop2 * 10) )
                loop +=1
            loop2 += 1
            loop = 1

        if self.first_generate_parent_sql:
            file = open('parent.sql', 'w')
            self.first_generate_parent_sql = False
        else:
            file = open('parent.sql', 'a')
        for item in sql_list:
            file.write(item + '\n')
        file.close()
    
    def general_node_sql(self):
        sql_list = []
        sql_list.append('DROP TABLE IF EXISTS :tablename CASCADE;')
        sql_list.append('CREATE TABLE :tablename(a int, b int, c text);')
        sql_list.append('INSERT INTO :tablename SELECT :list_value, generate_series(1,%s*:scale), md5(random()::text)' % (self.oneGscale))
        file = open('node.sql', 'w')
        for item in sql_list:
            file.write(item + '\n')
        file.close()
        pass
    
    def copy_sql(self):
        self.general_node_sql()
        for item_tuple in self.node_list:
            item = item_tuple[0]
            cmd = 'scp -i ~/highgokey.pem node.sql %s@%s:%s/node.sql' % (self.node_user, item, self.op_install_path)
            exec_comand_system(cmd)
        self.general_parent_sql()
        cmd = 'cp parent.sql %s/data/node.sql' % (self.op_install_path)
        exec_comand_system(cmd)
        cmd = 'cp parent.sql %s/data/node.sql' % (self.cp_install_path)
        exec_comand_system(cmd)
    
    def resetparent(self):
        self.general_parent_sql()
        cmd = 'cp parent.sql %s/data/node.sql' % (self.op_install_path)
        exec_comand_system(cmd)
        cmd = 'cp parent.sql %s/data/node.sql' % (self.cp_install_path)
        exec_comand_system(cmd)
        self.generatedata_thread('cp',0,0)
        self.generatedata_thread('op',0,0)

    def generatedata_thread(self, item, loop, loop2):
        if item == 'op':
            self.opg_start()
            cmd = '%s/bin/psql -d postgres -f %s/data/node.sql' % (self.op_install_path, self.op_install_path)
            output('>>Generate opg tables')
            exec_comand_system(cmd)
        elif item == 'cp':
            self.cpg_start()
            cmd = '%s/bin/psql -d postgres -f %s/data/node.sql' % (self.cp_install_path,self.cp_install_path)
            output('>>Generate cpg tables')
            exec_comand_system(cmd)
        else:
            cmd = '%s/bin/psql -p %s -d postgres -v list_value=%s -v scale=%s -v tablename=pft%s -f %s/node.sql' % \
                    (self.remote_install_path, self.remoteport + loop, loop2*10+loop, self.datascale,loop2*10+loop,self.op_install_path)
            output('>>Generate %s node data' % item)
            exec_command_system_remote(cmd, self.node_user, item, True)

    def generatedata(self):
        thruning = True
        threadlist = []
        loop = 1
        loop2 = 1
        output('>>Generate datas...')
        self.generatedata_thread('cp',0,0)
        self.generatedata_thread('op',0,0)
        self.startall()
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while (loop <= item_tuple[1]):
                mt = myThread(self.generatedata_thread, item, loop, loop2)
                mt.start()
                threadlist.append(mt)
                loop += 1
            loop2 += 1
            loop = 1
        
        while(thruning):
            sleep(2)
            thruning = False
            for mt in threadlist:
                mts = mt.isAlive()
                thruning = thruning or mts
    
    def increasedata_thread(self, item, loop, loop2):
        sql ='INSERT INTO pft%s SELECT %s, generate_series(%s,%s*%s), md5(random()::text)' % \
                (loop+10*loop2, loop+10*loop2, loop+10*loop2,self.oneGscale, self.increasescale)
        cmd = '%s/bin/psql -p %s -d postgres -c \'%s\'' % (self.remote_install_path, self.remoteport + loop,sql)
        output('>>Generate %sG to %s node data' % (self.increasescale,item))
        exec_command_system_remote(cmd, self.node_user, item, True)

    def increasedata(self):
        threadlist = []
        loop = 1
        loop2 = 1
        thruning = True
        for item_tuple in self.node_list:
            item = item_tuple[0]
            while (loop <= item_tuple[1]):
                mt = myThread(self.increasedata_thread, item, loop, loop2)
                mt.start()
                threadlist.append(mt)
                loop += 1
            loop2 += 1
            loop = 1
        
        while(thruning):
            sleep(2)
            thruning = False
            for mt in threadlist:
                mts = mt.isAlive()
                thruning = thruning or mts
        pass
    
    def change_node_pgconf(self):
        local_pg_conf = '%s/%s/postgresql.conf' % (self.op_install_path,'data')
        local_pg_conf_change = '%s/%s/postgresql.conf.change' % (self.op_install_path, 'data')
        local_pg_hba = '%s/%s/pg_hba.conf' % (self.op_install_path,'data')
        local_pg_hba_change = '%s/%s/pg_hba.conf.change' % (self.op_install_path,'data')
        debug_output('change_node_pgconf')
        loop = 1

        for item_tuple in self.node_list:
            item = item_tuple[0]
            while (loop <= item_tuple[1]):
                cmd = 'cp %s %s;cp %s %s' % (local_pg_conf, local_pg_conf_change, local_pg_hba, local_pg_hba_change)
                exec_comand_system(cmd)

                conf_item = 'listen_addresses = \'*\''
                cmd = 'echo "%s" >> %s' % (conf_item, local_pg_conf_change)
                exec_comand_system(cmd,True)

                conf_item = 'port = %s' % (self.remoteport + loop)
                cmd = 'echo "%s" >> %s' % (conf_item, local_pg_conf_change)
                exec_comand_system(cmd,True)

                conf_item = 'host    all             all             all            trust'
                cmd = 'echo "%s" >> %s' % (conf_item, local_pg_hba_change)
                exec_comand_system(cmd,True)

                datadir = self.datadir + str(loop)
                tar_pgconf_file_path = '%s/postgresql.conf' % (datadir)
                cmd = 'scp -i ~/highgokey.pem %s %s@%s:%s' % (local_pg_conf_change, self.node_user, item, tar_pgconf_file_path)
                exec_comand_system(cmd, True)

                tar_pghba_file_path = '%s/pg_hba.conf' % (datadir)
                cmd = 'scp -i ~/highgokey.pem %s %s@%s:%s' % (local_pg_hba_change, self.node_user, item, tar_pghba_file_path)
                exec_comand_system(cmd, True)
                loop += 1
            loop = 1

    def exec_test_count(self):
        sql = 'explain analyse verbose select count(*) from %s' % self.tt
        cmd = '%s/bin/psql -d postgres -c \'%s\'' % (self.testpath,sql)
        output(sql)
        t1 = gettime()
        exec_comand_system(cmd, True)
        t2 = gettime()
        output('>>>Run time %s second' % int(t2-t1))

    def exec_test_select(self):
        sql_base = 'select * from %s where b = 100' % self.tt
        sql = 'explain analyse verbose %s' % sql_base
        cmd = '%s/bin/psql -d postgres -c \'%s\'' % (self.testpath,sql)
        t1 = gettime()
        output(sql)
        exec_comand_system(cmd, True)
        t2 = gettime()
        output('>>>>Run time %s second' % int(t2-t1))
    
    def exec_test_selectall(self):
        sql_base = 'select * from %s' % self.tt
        sql = 'explain analyse verbose %s' % sql_base
        cmd = '%s/bin/psql -d postgres -c \'%s\'' % (self.testpath,sql)
        t1 = gettime()
        output(sql)
        exec_comand_system(cmd, True)
        t2 = gettime()
        output('>>>>Run time %s second' % int(t2-t1))
    
    def runtest(self, kind):
        
        output('>>Test With Patch')
        count = self.testtimes
        self.testpath = self.cp_install_path
        self.cpg_start()
        while count > 0:
            output('>>>The %sth test' % (4 - count) )
            if kind == 1:
                self.exec_test_select()
            elif kind == 2:
                self.exec_test_count()
            elif kind == 3:
                self.exec_test_selectall()
            count -= 1
        
        output('>>Test Without Patch')
        count = self.testtimes
        self.testpath = self.op_install_path
        self.opg_start()
        while count > 0:
            output('>>>The %sth test' % (4 - count) )
            if kind == 1:
                self.exec_test_select()
            elif kind == 2:
                self.exec_test_count()
            elif kind == 3:
                self.exec_test_selectall()
            count -= 1
    



p = Prepare()

handle_argument(p)
p.load_nodes()

if p.commandkind == 'bin':
    p.clean_remote()
    p.prepare_bin()
elif p.commandkind == 'initdb':
    p.initdb_remote()
    p.change_node_pgconf()
    p.copy_sql()
elif p.commandkind == 'gdata':
    p.generatedata()
elif p.commandkind == 'startall':
    p.startall()
elif p.commandkind == 'stopall':
    p.stopall()
elif p.commandkind == 'runselect':
    p.runtest(1)
elif p.commandkind == 'runcount':
    p.runtest(2)
elif p.commandkind == 'runselectall':
    p.runtest(3)
elif p.commandkind == 'increasedata':
    p.increasedata()
elif p.commandkind == 'resetparent':
    p.resetparent()
