/*-------------------------------------------------------------------------
*
* Copyright (c) 2011, PostgreSQL Global Development Group
*
*
*-------------------------------------------------------------------------
*/
package org.postgresql.ssl;

import java.io.FileInputStream;
import java.io.IOException;
import java.net.Socket;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.util.Properties;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManagerFactory;


/**
 * Provide an implementation of SSLSocketFactory that allows client authentication
 * @author Marc-André Laverdière (marc-andre@atc.tcs.com / marcandre.laverdiere@tcs.com)
 */
public class CertAuthFactory extends SSLSocketFactory {

	// ---------------------------- Constants used internally 
	
	public final static String CONFIG_KEYSTORE_PATH = "org.postgresql.jdbc.keystore.path";
	public final static String CONFIG_KEYSTORE_PWD = "org.postgresql.jdbc.keystore.password";
	public final static String CONFIG_TRUSTSTORE_PATH = "org.postgresql.jdbc.truststore.path";
	public final static String CONFIG_TRUSTSTORE_PWD = "org.postgresql.jdbc.truststore.password";
	
	public final static String SSL_PROTOCOL_NAME = "SSL";
	public final static String TLS_PROTOCOL_NAME = "TLS";
	public final static String SECURE_RANDOM_NAME = "SHA1PRNG";
	
	public final static String KEYSTORE_TYPE_PKCS12 = "PKCS12";
	public final static String KEYSTORE_TYPE_JKS = "JKS";
	
	//------------------------------- Actual instance data
	protected final SSLSocketFactory _internalFactory;
	protected final boolean isUsingCertificateAuth;

	
	//------------------------------ Constructors
	public CertAuthFactory() throws IOException, GeneralSecurityException {
		this(System.getProperties());
	}
	
	public CertAuthFactory(String ignored) throws IOException, GeneralSecurityException {
		this(System.getProperties());
	}
	
	/**
	 * Builds an SSLContext with the specified trust store and key store.
	 * The parameters read are as follows:
	 * <ul>
	 * <li>	(optional) <code>org.postgresql.jdbc.keystore.path</code> : the path to the keystore containing the client certificate
	 * <li>	(mandatory if previous is set) <code>org.postgresql.jdbc.keystore.password</code> : password for loading the keystore containing the client certificate
	 * <li> (optional) <code>org.postgresql.jdbc.truststore.path</code> : keystore containing the CA certificate(s) that are trusted. 
	 * This should normally be the same CA as the one that signed the certificate of the server you would connect to.
	 * <li> (mandatory if previous is set) <code>org.postgresql.jdbc.truststore.password</code> : the password to open the keystore containing the CA certificate(s)
	 * </ul>
	 * @param props the properties to load the configuration from. Must be non-null.
	 * @return an SSLContext initialized based on those properties.
	 * @throws IOException on any error loading the keystores
	 * @throws GeneralSecurityException on keystore errors or crypto set-up errors
	 */
	public CertAuthFactory(Properties props) throws IOException, GeneralSecurityException {
		if (props == null)
			throw new IllegalArgumentException("Properties is null");
		
		FileInputStream fInKeyStore = null;
		FileInputStream fInTrustStore = null;
		try{
			//Load configuration
			String trustPath = props.getProperty(CONFIG_TRUSTSTORE_PATH);
			String trustPwd = props.getProperty(CONFIG_TRUSTSTORE_PWD);
			String keyPath = props.getProperty(CONFIG_KEYSTORE_PATH);
			String keyPwd = props.getProperty(CONFIG_KEYSTORE_PWD);
			
			KeyManagerFactory managerFactory = null;
			TrustManagerFactory trustFactory = null;
			if (keyPath != null && !"".equals(keyPath)){
				if (keyPwd == null)
					throw new IllegalArgumentException("Keystore password must be specified");
				//Load the Key Managers
				KeyStore ks = loadKeyStore(keyPath, keyPwd);
			    managerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
			    managerFactory.init(ks, keyPwd.toCharArray());
			    isUsingCertificateAuth = true;
			} else 
				isUsingCertificateAuth = false;
			
			if (trustPath != null && !"".equals(trustPath)){
				if (trustPwd == null)
					throw new IllegalArgumentException("Trust store password must be specified");
			    // Load the trust store
				KeyStore trustKs = loadKeyStore(trustPath, trustPwd);
			    trustFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
			    trustFactory.init(trustKs);
			}
		    //Create + Initialize TLS context
		    SSLContext context = SSLContext.getInstance(SSL_PROTOCOL_NAME); //can be TLS too
			context.init(managerFactory.getKeyManagers(), trustFactory.getTrustManagers(), SecureRandom.getInstance(SECURE_RANDOM_NAME));
			_internalFactory = context.getSocketFactory();
		} finally{
			try{
				if (fInKeyStore != null)
					fInKeyStore.close();
				if (fInTrustStore != null)
					fInTrustStore.close();
			} catch (IOException e){
				//ignore it
			}
		} 
	}
	
	//------------------------- SSLSocketFactory API implementation

	@Override
	public Socket createSocket() throws IOException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket();
		return enableClientAuth(sock);
	}
	
	@Override
	public Socket createSocket(InetAddress address, int port, InetAddress localhost, int localPort)
			throws IOException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket(address, port, localhost, localPort);
		return enableClientAuth(sock);
	}
	
	@Override
	public Socket createSocket(InetAddress host, int port) throws IOException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket(host, port);
		return enableClientAuth(sock);
	}
	
	@Override
	public Socket createSocket(Socket s, String host, int port, boolean autoClose)
			throws IOException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket(s, host, port, autoClose);
		return enableClientAuth(sock);
	}

	@Override
	public Socket createSocket(String host, int port) throws IOException, UnknownHostException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket(host, port);
		return enableClientAuth(sock);
	}
	
	@Override
	public Socket createSocket(String host, int port, InetAddress localAddress, int localPort)
			throws IOException, UnknownHostException {
		SSLSocket sock = (SSLSocket) _internalFactory.createSocket(host, port, localAddress, localPort);
		return enableClientAuth(sock);
	}

	@Override
	public String[] getDefaultCipherSuites() {
		return _internalFactory.getDefaultCipherSuites();
	}

	@Override
	public String[] getSupportedCipherSuites() {
		return _internalFactory.getSupportedCipherSuites();
	}
	
	// -------------------------------------- Internal Helper Methods
	/**
	 * Load a keystore at the path specified. It will try both JKS and PKCS12 keystores
	 * @param path the path to the keystore
	 * @param password the keystore password
	 * */
	protected static KeyStore loadKeyStore(String path, String password) throws IOException, GeneralSecurityException{
		if (path == null || "".equals(path)) throw new IllegalArgumentException("Path is empty or null");
		if (password == null) throw new IllegalArgumentException("Password is null");
		
		//first try with JKS
		try{
			return loadKeyStore(path, password, KEYSTORE_TYPE_JKS);
		} catch (IOException e){ //docs say that this is what is loaded if the file format is wrong
			//try loading PKCS instead
			return loadKeyStore(path, password, KEYSTORE_TYPE_PKCS12);
		}
		
	}
	
	/**
	 * Tries to open a keystore of the given type
	 * @param path the path to the keystore
	 * @param password the keystore password
	 * @param type the keystore type
	 * @return a valid keystore
	 * @throws IOException If there is any error loading the keystore
	 * @throws GeneralSecurityException if the certificates cannot be loaded or the type specified is invalid
	 */
	protected static KeyStore loadKeyStore(String path, String password, String type) throws IOException, GeneralSecurityException{
		FileInputStream fIn = null;
		try{
			KeyStore ks = KeyStore.getInstance(type);
			fIn = new FileInputStream(path);
			ks.load(fIn, password.toCharArray());
			return ks;
		} finally{
			if (fIn != null)
				fIn.close();
		}
	}
	
	/**
	 * Enables the client mode and the client authentication (if a certificate was loaded)
	 * on the socket
	 * */
	protected SSLSocket enableClientAuth(SSLSocket sock){
		sock.setNeedClientAuth(isUsingCertificateAuth);
		sock.setUseClientMode(true);
		return sock;
	}
}
