import java.io.File; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.Collections; import static org.junit.Assert.*; /** * Test to get some basic performance numbers for a change to Postgres row formats * to elide trailing nulls. */ public class TestTrailingNull { private static final String BASELINEDIR = ""; private static final String PATCHDIR = ""; private static final String DBNAME = "test"; private static final String TABLENAME = "test"; private static final int NUMROWS = 10000000; public static void main(String[] args) throws Exception { System.out.println("=======BASELINE========="); runTest(BASELINEDIR); System.out.println("=======END: BASELINE========="); System.out.println("=======PATCH========="); runTest(PATCHDIR); System.out.println("=======END: PATCH========="); } public static void runTest(String dirStr) throws Exception { File dir = new File(dirStr); assertTrue(dir.exists()); assertTrue(new File(dir, "start.sh").exists()); assertTrue(new File(dir, "stop.sh").exists()); // boot the server ProcessBuilder pb = new ProcessBuilder("/bin/sh", "start.sh"); pb.directory(dir); Process p = pb.start(); p.waitFor(); Thread.sleep(1000); Class.forName("org.postgresql.Driver"); Connection c = DriverManager.getConnection("jdbc:postgresql:"+DBNAME+"//localhost:5432"); createBaseTable(c, TABLENAME, 1, NUMROWS); System.out.println("base table size: "+getTableSizePretty(c, TABLENAME)); execTest(c, TABLENAME, 20, 5, 5, false); execTest(c, TABLENAME, 20, 5, 5, true); execTest(c, TABLENAME, 20, 5, 1, true); execTest(c, TABLENAME, 20, 50, 1, true); execTest(c, TABLENAME, 20, 50, 20, true); execTest(c, TABLENAME, 20, 50, 40, true); // just get some rough approximation for these runs, they take too long execTest(c, TABLENAME, 4, 100, 1, true); execTest(c, TABLENAME, 4, 800, 1, true); execTest(c, TABLENAME, 4, 800, 50, true); execTest(c, TABLENAME, 4, 800, 100, true); execTest(c, TABLENAME, 4, 800, 200, true); c.close(); // shut down the server pb = new ProcessBuilder("/bin/sh", "stop.sh"); pb.directory(dir); p = pb.start(); p.waitFor(); Thread.sleep(1000); } /** * Create a temporary table and then run loops where we run INSERT SELECT followed by TRUNCATE * @param c connection * @param srcTableName source data in this table. Assumption is that it has a column named 'c1' that we select from * @param iterations number of times to run * @param tmpTableCols number of columns in the temp table * @param numColsSet how many tmp columns we want to set. e.g. if 5 cols and set 2 we'll generate * INSERT INTO tmp SELECT c1, c1 FROM table * @param nullable whether the tmp cols are nullable or not * @throws Exception on error */ private static void execTest(Connection c, String srcTableName, int iterations, int tmpTableCols, int numColsSet, boolean nullable) throws Exception { Statement s = c.createStatement(); try { s.execute("drop table tmptest"); } catch (SQLException ignored) {} ; String createTable = generateCreateTable("tmptest", tmpTableCols, nullable, true); s.execute(createTable); StringBuilder sb = new StringBuilder("insert into tmptest select /* "+numColsSet+" */ "); for (int i = 0; i < numColsSet; i++) { if (i > 0) { sb.append(", "); } sb.append("c1"); } sb.append(" from "+srcTableName); String insert = sb.toString(); String truncate = "truncate table tmptest"; System.out.println(createTable); System.out.println(insert); // get size s.execute(insert); System.out.println("tmp table size: "+getTableSizePretty(c, "tmptest")); s.execute(truncate); // prime a few times if (iterations > 1) { execTestOnce(c, insert, truncate); execTestOnce(c, insert, truncate); } ArrayList times = new ArrayList(); for (int i = 0; i < iterations; i++) { times.add(execTestOnce(c, insert, truncate)); } s.close(); generateStats(times); System.out.println("\n"); try { s.execute("drop table tmptest"); } catch (SQLException ignored) {} ; } private static long execTestOnce(Connection c, String insert, String cleanup) throws Exception { Statement s = c.createStatement(); long start = System.currentTimeMillis(); s.execute(insert); long end = System.currentTimeMillis(); s.execute(cleanup); s.close(); return end - start; } private static void generateStats(ArrayList allTimes) { Collections.sort(allTimes); // remove top 10% and bottom 10% int tenPercent = allTimes.size() / 10; Long[] times = new Long[allTimes.size() - 2 * tenPercent]; System.arraycopy(allTimes.toArray(), tenPercent, times, 0, times.length); long sum = 0; StringBuilder allValsString = new StringBuilder("{"); for (int i = 0; i < times.length; i++) { sum += times[i]; if (i > 0) { allValsString.append(", "); } allValsString.append(times[i]); } allValsString.append("}"); System.out.println("avg time = " + ((float)sum)/times.length+" ms"); System.out.println("median 80% of the times (ms) = "+allValsString.toString()); } private static void createBaseTable(Connection c, String tableName, int numCols, int numRows) throws Exception { Statement s = c.createStatement(); s.execute("select count(*) from information_schema.tables where table_name = '"+tableName+"'"); ResultSet rs = s.getResultSet(); assertTrue(rs.next()); boolean exists = rs.getInt(1) == 1; rs.close(); if (exists) { if (numRows == getRowCount(c, tableName)) { System.out.println(tableName + " is already loaded with "+numRows+" rows"); return; } s.execute("drop table "+tableName); } s.execute(generateCreateTable(tableName, numCols, false, false)); StringBuilder sb = new StringBuilder("create or replace function load_data(count integer) returns void as $$ \n"); sb.append("begin \n"); sb.append("for i in 1..count loop \n"); // insert into test values(i, i, i, i, i); sb.append("insert into "+tableName+" values ("); for (int i = 1; i <= numCols; i++) { if (i > 1) { sb.append(", "); } sb.append("i"); } sb.append(");\n"); sb.append("end loop;\n"); sb.append("end;\n"); sb.append("$$ LANGUAGE plpgsql;"); s.execute(sb.toString()); s.execute("select load_data("+numRows+")"); System.out.println(tableName + " loaded with "+numRows+" rows"); } private static int getRowCount(Connection c, String tableName) throws Exception { Statement s = c.createStatement(); try { s.execute("select count(*) from "+tableName+""); ResultSet rs = s.getResultSet(); assertTrue(rs.next()); int count =rs.getInt(1); rs.close(); return count; } finally { s.close(); } } private static String generateCreateTable(String tableName, int numCols, boolean nullable, boolean tmpTable) { StringBuilder sb = new StringBuilder(); if (tmpTable) { sb.append("create temporary table "+tableName); } else { sb.append("create table "+tableName); } sb.append(" /* "+numCols+" cols */ ("); for (int i = 1; i <= numCols; i++) { if (i > 1) { sb.append(", "); } sb.append("c"+i+" int "+ (nullable ? "null" : "not null")); } sb.append(")"); return sb.toString(); } private static String getTableSizePretty(Connection c, String tableName) throws Exception { Statement s = c.createStatement(); try { s.execute("select pg_size_pretty(pg_total_relation_size('"+tableName+"'))"); ResultSet rs = s.getResultSet(); assertTrue(rs.next()); String pretty = rs.getString(1); rs.close(); return pretty; } finally { s.close(); } } }