package org.postgresql.util;

import org.postgresql.Driver;

import javax.net.ssl.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.Socket;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

/**
 * Creates SSL sockets.
 * If the property org.postgresql.Driver.trustStore is set, it will
 * use that truststore for the socket. The password is read from
 * org.postgresql.Driver.trustStorePassword .
 * If trustandsave is on, it will trust any host and attempt to save its certs in the truststore
 *
 * @author Ulrich Meis
 */
public class PGSSLSocketFactory {

    private static PGSSLSocketFactory instance = new PGSSLSocketFactory();

    private String storeLocation;
    private String storePass;
    private KeyStore ks;

    private String trustStoreProp = "org.postgresql.trustStore";
    private String trustStorePasswordProp = "org.postgresql.trustStorePassword";

    /**
     * loads contents of a trustStore file into ks.
     * If no custom trustStore is supplied, it implements JSSE's fallback mechanism:
     * Check if property is present, check java_home/jre/lib/security/jssecacerts, check
     * java_home/jre/lib/security/cacerts
     *
     * @throws KeyStoreException
     * @throws IOException
     * @throws NoSuchAlgorithmException
     * @throws CertificateException
     */
    private void load() throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException {
        storeLocation = System.getProperty(trustStoreProp);
        storePass = System.getProperty(trustStorePasswordProp);
        if (storeLocation == null) { // Fallback to JSSE properties
            storeLocation = System.getProperty("javax.net.ssl.trustStore");
            storePass = System.getProperty("javax.net.ssl.trustStorePassword");
        }

        if (storeLocation == null) { // Fallback to JSEE standard files
            storePass = System.getProperty("changeit");
            storeLocation = System.getProperty("java.home") +
                    File.separatorChar + "lib" + File.separatorChar + "security" + File.separatorChar + "jssecacerts";
            File f = new File(storeLocation);
            if (!f.exists())
                storeLocation = System.getProperty("java.home") +
                        File.separatorChar + "lib" + File.separatorChar + "security" + File.separatorChar + "cacerts";
        }
        if (storePass == null) storePass = "changeit";
        Driver.info("Using trustStore " + storeLocation);
        ks = KeyStore.getInstance(KeyStore.getDefaultType());
        File keystore = new File(storeLocation);
        if (keystore.exists()) {
            FileInputStream fi = new FileInputStream(storeLocation);
            ks.load(fi, storePass.toCharArray());
            fi.close();
        } else
            ks.load(null, storePass.toCharArray());
    }

    /**
     * stores given certificate in the trustStore
     *
     * @param serverCerts
     * @throws KeyStoreException
     * @throws IOException
     * @throws NoSuchAlgorithmException
     * @throws CertificateException
     */
    private void store(Certificate[] serverCerts) throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException {
        File keystore = new File(storeLocation);
        if (!keystore.exists()) {
            if (!keystore.getParentFile().canWrite()) {
                Driver.debug("Cannot write to directory " + keystore.getParentFile().getAbsolutePath());
                return;
            }
        } else if (!keystore.canWrite()) {
            Driver.debug("Cannot write to keystore " + storeLocation);
            return;
        }
        for (int i = 0; i < serverCerts.length; i++) {
            X509Certificate cert = (X509Certificate) serverCerts[i];
            ks.setCertificateEntry("PGJDBC " + cert.getIssuerDN().toString(), cert);
        }
        FileOutputStream fo = new FileOutputStream(storeLocation);
        ks.store(fo, storePass.toCharArray());
        fo.close();
    }

    /**
     * creates a socket.
     * Will use a custom trustStore if provided. If the property org.postgresql.ssl_trustandsave is defined,
     * it will accept any certificate(s) and save them to the used keystore.
     *
     * @param socket
     * @param host
     * @param port
     * @param autoClose
     * @return
     * @throws IOException
     */
    public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException {
        if ((System.getProperty("org.postgresql.Driver.trustStore") == null) &&
                (System.getProperty("org.postgresql.ssl_trustandsave") == null))
            return ((SSLSocketFactory) SSLSocketFactory.getDefault()).createSocket(socket, host, port, autoClose);

        boolean trustandsave = System.getProperty("org.postgresql.ssl_trustandsave") != null;
        SSLSocketFactory factory;
        try {
            // the truststore is only loaded once per class load. If you want to change the properties/trustStore
            // at runtime, remove the if.
            if (ks == null) load();

            SSLContext sc = SSLContext.getInstance("SSL");

            if (trustandsave) { // trust anyone
                sc.init(null, trustseveryone(), null);
                Driver.debug("trustandsave is on for upcoming connection");
            } else {
                // Initalize KeyManagerFactory and TrustManagerFactory with keystore
                KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
                TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
                tmf.init(ks);
                kmf.init(ks, storePass.toCharArray());
                sc.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
            }

            factory = sc.getSocketFactory();
        } catch (Exception e) {
            e.printStackTrace();
            Driver.debug("Couldn't set up custom SSLSocketFactory due to previous exception," +
                    " using SSLSocketFactory.getDefault()");
            factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
        }

        SSLSocket sslsocket = (SSLSocket) factory.createSocket(socket, host, port, autoClose);

        try {
            if (trustandsave) store(sslsocket.getSession().getPeerCertificates());
        } catch (Exception e) {
            e.printStackTrace();
            Driver.debug("Couldn't save certificate due to previous exception");
        }
        return sslsocket;
    }


    /**
     * Get static instance
     *
     * @return
     */
    public static PGSSLSocketFactory getInstance() {
        return instance;
    }

    /**
     * will trust everyone. Only used if in trustandsave mode
     *
     * @return array suitable to init SSLContext
     * @see SSLContext
     */
    public TrustManager[] trustseveryone() {
        return new TrustManager[]{
            new X509TrustManager() {
                public java.security.cert.X509Certificate[] getAcceptedIssuers() {
                    return null;
                }

                public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
                }

                public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
                }
            }
        };
    }
}
