From 3bfae6e6f563acfe63f2ae44feb5799f4e49324b Mon Sep 17 00:00:00 2001
From: Michael Paquier <michael@paquier.xyz>
Date: Mon, 5 Jul 2021 17:16:15 +0900
Subject: [PATCH v4] Generalize SASL exchange code for the backend and the
 frontend

---
 src/include/libpq/auth.h             |   2 +
 src/include/libpq/sasl.h             | 136 +++++++++++++++++++
 src/include/libpq/scram.h            |  13 +-
 src/backend/libpq/Makefile           |   1 +
 src/backend/libpq/auth-sasl.c        | 196 +++++++++++++++++++++++++++
 src/backend/libpq/auth-scram.c       |  51 ++++---
 src/backend/libpq/auth.c             | 167 +----------------------
 src/interfaces/libpq/fe-auth-sasl.h  | 130 ++++++++++++++++++
 src/interfaces/libpq/fe-auth-scram.c |  40 ++++--
 src/interfaces/libpq/fe-auth.c       |  23 +++-
 src/interfaces/libpq/fe-auth.h       |  11 +-
 src/interfaces/libpq/fe-connect.c    |   6 +-
 src/interfaces/libpq/libpq-int.h     |   2 +
 src/tools/pgindent/typedefs.list     |   2 +
 14 files changed, 558 insertions(+), 222 deletions(-)
 create mode 100644 src/include/libpq/sasl.h
 create mode 100644 src/backend/libpq/auth-sasl.c
 create mode 100644 src/interfaces/libpq/fe-auth-sasl.h

diff --git a/src/include/libpq/auth.h b/src/include/libpq/auth.h
index 3610fae3ff..3d6734f253 100644
--- a/src/include/libpq/auth.h
+++ b/src/include/libpq/auth.h
@@ -21,6 +21,8 @@ extern bool pg_krb_caseins_users;
 extern char *pg_krb_realm;
 
 extern void ClientAuthentication(Port *port);
+extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
+							int extralen);
 
 /* Hook for plugins to get control in ClientAuthentication() */
 typedef void (*ClientAuthentication_hook_type) (Port *, int);
diff --git a/src/include/libpq/sasl.h b/src/include/libpq/sasl.h
new file mode 100644
index 0000000000..f119a62d68
--- /dev/null
+++ b/src/include/libpq/sasl.h
@@ -0,0 +1,136 @@
+/*-------------------------------------------------------------------------
+ *
+ * sasl.h
+ *	  Defines the SASL mechanism interface for the backend.
+ *
+ * Each SASL mechanism defines a frontend and a backend callback structure.
+ *
+ * See src/interfaces/libpq/fe-auth-sasl.h for the frontend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/include/libpq/sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#ifndef PG_SASL_H
+#define PG_SASL_H
+
+#include "lib/stringinfo.h"
+#include "libpq/libpq-be.h"
+
+/* Status codes for message exchange */
+#define PG_SASL_EXCHANGE_CONTINUE		0
+#define PG_SASL_EXCHANGE_SUCCESS		1
+#define PG_SASL_EXCHANGE_FAILURE		2
+
+/*
+ * Backend SASL mechanism callbacks.
+ *
+ * To implement a backend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations.  Then pass the mechanism to
+ * CheckSASLAuth() during ClientAuthentication(), once the server has decided
+ * which authentication method to use.
+ */
+typedef struct pg_be_sasl_mech
+{
+	/*---------
+	 * get_mechanisms()
+	 *
+	 * Retrieves the list of SASL mechanism names supported by this
+	 * implementation.
+	 *
+	 * Input parameters:
+	 *
+	 *	port: The client Port
+	 *
+	 * Output parameters:
+	 *
+	 *	buf:  A StringInfo buffer that the callback should populate with
+	 *		  supported mechanism names.  The names are appended into this
+	 *		  StringInfo, separated by '\0' bytes.
+	 *---------
+	 */
+	void		(*get_mechanisms) (Port *port, StringInfo buf);
+
+	/*---------
+	 * init()
+	 *
+	 * Initializes mechanism-specific state for a connection. This callback
+	 * must return a pointer to its allocated state, which will be passed
+	 * as-is as the first argument to the other callbacks.
+	 *
+	 * Input paramters:
+	 *
+	 *	port:        The client Port.
+	 *
+	 *	mech:        The actual mechanism name in use by the client.
+	 *
+	 *	shadow_pass: The stored secret for the role being authenticated, or
+	 *				 NULL if one does not exist.  Mechanisms that do not use
+	 *				 shadow entries may ignore this parameter.  If a
+	 *				 mechanism uses shadow entries but shadow_pass is NULL,
+	 *				 the implementation must continue the exchange as if the
+	 *				 user existed and the password did not match, to avoid
+	 *				 disclosing valid user names.
+	 *---------
+	 */
+	void	   *(*init) (Port *port, const char *mech, const char *shadow_pass);
+
+	/*---------
+	 * exchange()
+	 *
+	 * Produces a server challenge to be sent to the client.  The callback
+	 * must return one of the PG_SASL_EXCHANGE_* values, depending on
+	 * whether the exchange continues, has finished successfully, or has
+	 * failed.
+	 *
+	 * Input parameters:
+	 *
+	 *	state:	  The opaque mechanism state returned by init()
+	 *
+	 *	input:	  The response data sent by the client, or NULL if the
+	 *			  mechanism is client-first but the client did not send an
+	 *			  initial response.  (This can only happen during the first
+	 *			  message from the client.)  This is guaranteed to be
+	 *			  null-terminated for safety, but SASL allows embedded
+	 *			  nulls in responses, so mechanisms must be careful to
+	 *            check inputlen.
+	 *
+	 *	inputlen: The length of the challenge data sent by the server, or
+	 *			  -1 if the client did not send an initial response
+	 *
+	 * Output parameters, to be set by the callback function:
+	 *
+	 *	output:    A palloc'd buffer containing either the server's next
+	 *			   challenge (if PG_SASL_EXCHANGE_CONTINUE is returned) or
+	 *			   the server's outcome data (if PG_SASL_EXCHANGE_SUCCESS is
+	 *			   returned and the mechanism requires data to be sent during
+	 *			   a successful outcome).  The callback should set this to
+	 *			   NULL if the exchange is over and no output should be sent,
+	 *			   which should correspond to either PG_SASL_EXCHANGE_FAILURE
+	 *			   or a PG_SASL_EXCHANGE_SUCCESS with no outcome data.
+	 *
+	 *  outputlen: The length of the challenge data.  Ignored if *output is
+	 *			   NULL.
+	 *
+	 *	logdetail: Set to an optional DETAIL message to be printed to the
+	 *			   server log, to disambiguate failure modes.  (The client
+	 *			   will only ever see the same generic authentication
+	 *			   failure message.) Ignored if the exchange is completed
+	 *			   with PG_SASL_EXCHANGE_SUCCESS.
+	 *---------
+	 */
+	int			(*exchange) (void *state,
+							 const char *input, int inputlen,
+							 char **output, int *outputlen,
+							 char **logdetail);
+} pg_be_sasl_mech;
+
+/* Common implementation for auth.c */
+extern int	CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port,
+						  char *shadow_pass, char **logdetail);
+
+#endif							/* PG_SASL_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index 2c879150da..9e4540bde3 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -15,17 +15,10 @@
 
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
+#include "libpq/sasl.h"
 
-/* Status codes for message exchange */
-#define SASL_EXCHANGE_CONTINUE		0
-#define SASL_EXCHANGE_SUCCESS		1
-#define SASL_EXCHANGE_FAILURE		2
-
-/* Routines dedicated to authentication */
-extern void pg_be_scram_get_mechanisms(Port *port, StringInfo buf);
-extern void *pg_be_scram_init(Port *port, const char *selected_mech, const char *shadow_pass);
-extern int	pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-								 char **output, int *outputlen, char **logdetail);
+/* Implementation */
+extern const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
index 8d1d16b0fc..6d385fd6a4 100644
--- a/src/backend/libpq/Makefile
+++ b/src/backend/libpq/Makefile
@@ -15,6 +15,7 @@ include $(top_builddir)/src/Makefile.global
 # be-fsstubs is here for historical reasons, probably belongs elsewhere
 
 OBJS = \
+	auth-sasl.o \
 	auth-scram.o \
 	auth.o \
 	be-fsstubs.o \
diff --git a/src/backend/libpq/auth-sasl.c b/src/backend/libpq/auth-sasl.c
new file mode 100644
index 0000000000..ed04c3b5b0
--- /dev/null
+++ b/src/backend/libpq/auth-sasl.c
@@ -0,0 +1,196 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-sasl.c
+ *	  Routines to handle authentication via SASL
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ *	  src/backend/libpq/auth-sasl.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/auth.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/sasl.h"
+
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH	1024
+
+/*
+ * Perform a SASL exchange with a libpq client, using a specific mechanism
+ * implementation.
+ *
+ * shadow_pass is an optional pointer to the stored secret of the role
+ * authenticated, from pg_authid.rolpassword.  For mechanisms that use
+ * shadowed passwords, a NULL pointer here means that an entry could not
+ * be found for the role (or the user does not exist), and the mechanism
+ * should fail the authentication exchange.
+ *
+ * Mechanisms must take care not to reveal to the client that a user entry
+ * does not exist; ideally, the external failure mode is identical to that
+ * of an incorrect password.  Mechanisms may instead use the logdetail
+ * output parameter to internally differentiate between failure cases and
+ * assist debugging by the server admin.
+ *
+ * A mechanism is not required to utilize a shadow entry, or even a password
+ * system at all; for these cases, shadow_pass may be ignored and the caller
+ * should just pass NULL.
+ */
+int
+CheckSASLAuth(const pg_be_sasl_mech *mech, Port *port, char *shadow_pass,
+			  char **logdetail)
+{
+	StringInfoData sasl_mechs;
+	int			mtype;
+	StringInfoData buf;
+	void	   *opaq = NULL;
+	char	   *output = NULL;
+	int			outputlen = 0;
+	const char *input;
+	int			inputlen;
+	int			result;
+	bool		initial;
+
+	/*
+	 * Send the SASL authentication request to user.  It includes the list of
+	 * authentication mechanisms that are supported.
+	 */
+	initStringInfo(&sasl_mechs);
+
+	mech->get_mechanisms(port, &sasl_mechs);
+	/* Put another '\0' to mark that list is finished. */
+	appendStringInfoChar(&sasl_mechs, '\0');
+
+	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
+	pfree(sasl_mechs.data);
+
+	/*
+	 * Loop through SASL message exchange.  This exchange can consist of
+	 * multiple messages sent in both directions.  First message is always
+	 * from the client.  All messages from client to server are password
+	 * packets (type 'p').
+	 */
+	initial = true;
+	do
+	{
+		pq_startmsgread();
+		mtype = pq_getbyte();
+		if (mtype != 'p')
+		{
+			/* Only log error if client didn't disconnect. */
+			if (mtype != EOF)
+			{
+				ereport(ERROR,
+						(errcode(ERRCODE_PROTOCOL_VIOLATION),
+						 errmsg("expected SASL response, got message type %d",
+								mtype)));
+			}
+			else
+				return STATUS_EOF;
+		}
+
+		/* Get the actual SASL message */
+		initStringInfo(&buf);
+		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+		{
+			/* EOF - pq_getmessage already logged error */
+			pfree(buf.data);
+			return STATUS_ERROR;
+		}
+
+		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
+
+		/*
+		 * The first SASLInitialResponse message is different from the others.
+		 * It indicates which SASL mechanism the client selected, and contains
+		 * an optional Initial Client Response payload.  The subsequent
+		 * SASLResponse messages contain just the SASL payload.
+		 */
+		if (initial)
+		{
+			const char *selected_mech;
+
+			selected_mech = pq_getmsgrawstring(&buf);
+
+			/*
+			 * Initialize the status tracker for message exchanges.
+			 *
+			 * If the user doesn't exist, or doesn't have a valid password, or
+			 * it's expired, we still go through the motions of SASL
+			 * authentication, but tell the authentication method that the
+			 * authentication is "doomed". That is, it's going to fail, no
+			 * matter what.
+			 *
+			 * This is because we don't want to reveal to an attacker what
+			 * usernames are valid, nor which users have a valid password.
+			 */
+			opaq = mech->init(port, selected_mech, shadow_pass);
+
+			inputlen = pq_getmsgint(&buf, 4);
+			if (inputlen == -1)
+				input = NULL;
+			else
+				input = pq_getmsgbytes(&buf, inputlen);
+
+			initial = false;
+		}
+		else
+		{
+			inputlen = buf.len;
+			input = pq_getmsgbytes(&buf, buf.len);
+		}
+		pq_getmsgend(&buf);
+
+		/*
+		 * The StringInfo guarantees that there's a \0 byte after the
+		 * response.
+		 */
+		Assert(input == NULL || input[inputlen] == '\0');
+
+		/*
+		 * Hand the incoming message to the mechanism implementation.
+		 */
+		result = mech->exchange(opaq, input, inputlen,
+								&output, &outputlen,
+								logdetail);
+
+		/* input buffer no longer used */
+		pfree(buf.data);
+
+		if (output)
+		{
+			/*
+			 * Negotiation generated data to be sent to the client.
+			 */
+			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
+
+			/* TODO: PG_SASL_EXCHANGE_FAILURE with output is forbidden in SASL */
+			if (result == PG_SASL_EXCHANGE_SUCCESS)
+				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
+			else
+				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
+
+			pfree(output);
+		}
+	} while (result == PG_SASL_EXCHANGE_CONTINUE);
+
+	/* Oops, Something bad happened */
+	if (result != PG_SASL_EXCHANGE_SUCCESS)
+	{
+		return STATUS_ERROR;
+	}
+
+	return STATUS_OK;
+}
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index f9e1026a12..9df8f17837 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -101,11 +101,25 @@
 #include "common/sha2.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "utils/builtins.h"
 #include "utils/timestamp.h"
 
+static void scram_get_mechanisms(Port *port, StringInfo buf);
+static void *scram_init(Port *port, const char *selected_mech,
+						const char *shadow_pass);
+static int	scram_exchange(void *opaq, const char *input, int inputlen,
+						   char **output, int *outputlen, char **logdetail);
+
+/* Mechanism declaration */
+const pg_be_sasl_mech pg_be_scram_mech = {
+	scram_get_mechanisms,
+	scram_init,
+	scram_exchange
+};
+
 /*
  * Status data for a SCRAM authentication exchange.  This should be kept
  * internal to this file.
@@ -170,16 +184,14 @@ static char *sanitize_str(const char *s);
 static char *scram_mock_salt(const char *username);
 
 /*
- * pg_be_scram_get_mechanisms
- *
  * Get a list of SASL mechanisms that this module supports.
  *
  * For the convenience of building the FE/BE packet that lists the
  * mechanisms, the names are appended to the given StringInfo buffer,
  * separated by '\0' bytes.
  */
-void
-pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+static void
+scram_get_mechanisms(Port *port, StringInfo buf)
 {
 	/*
 	 * Advertise the mechanisms in decreasing order of importance.  So the
@@ -199,15 +211,13 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
 }
 
 /*
- * pg_be_scram_init
- *
  * Initialize a new SCRAM authentication exchange status tracker.  This
  * needs to be called before doing any exchange.  It will be filled later
  * after the beginning of the exchange with authentication information.
  *
  * 'selected_mech' identifies the SASL mechanism that the client selected.
  * It should be one of the mechanisms that we support, as returned by
- * pg_be_scram_get_mechanisms().
+ * scram_get_mechanisms().
  *
  * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
  * The username was provided by the client in the startup message, and is
@@ -215,10 +225,8 @@ pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
  * an authentication exchange, but it will fail, as if an incorrect password
  * was given.
  */
-void *
-pg_be_scram_init(Port *port,
-				 const char *selected_mech,
-				 const char *shadow_pass)
+static void *
+scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 {
 	scram_state *state;
 	bool		got_secret;
@@ -325,9 +333,9 @@ pg_be_scram_init(Port *port,
  * string at *logdetail that will be sent to the postmaster log (but not
  * the client).
  */
-int
-pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
-					 char **output, int *outputlen, char **logdetail)
+static int
+scram_exchange(void *opaq, const char *input, int inputlen,
+			   char **output, int *outputlen, char **logdetail)
 {
 	scram_state *state = (scram_state *) opaq;
 	int			result;
@@ -346,7 +354,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 
 		*output = pstrdup("");
 		*outputlen = 0;
-		return SASL_EXCHANGE_CONTINUE;
+		return PG_SASL_EXCHANGE_CONTINUE;
 	}
 
 	/*
@@ -379,7 +387,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_first_message(state);
 
 			state->state = SCRAM_AUTH_SALT_SENT;
-			result = SASL_EXCHANGE_CONTINUE;
+			result = PG_SASL_EXCHANGE_CONTINUE;
 			break;
 
 		case SCRAM_AUTH_SALT_SENT:
@@ -408,7 +416,8 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 * erroring out in an application-specific way.  We choose to do
 			 * the latter, so that the error message for invalid password is
 			 * the same for all authentication methods.  The caller will call
-			 * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
+			 * ereport(), when we return PG_SASL_EXCHANGE_FAILURE with no
+			 * output.
 			 *
 			 * NB: the order of these checks is intentional.  We calculate the
 			 * client proof even in a mock authentication, even though it's
@@ -417,7 +426,7 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			 */
 			if (!verify_client_proof(state) || state->doomed)
 			{
-				result = SASL_EXCHANGE_FAILURE;
+				result = PG_SASL_EXCHANGE_FAILURE;
 				break;
 			}
 
@@ -425,16 +434,16 @@ pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
 			*output = build_server_final_message(state);
 
 			/* Success! */
-			result = SASL_EXCHANGE_SUCCESS;
+			result = PG_SASL_EXCHANGE_SUCCESS;
 			state->state = SCRAM_AUTH_FINISHED;
 			break;
 
 		default:
 			elog(ERROR, "invalid SCRAM exchange state");
-			result = SASL_EXCHANGE_FAILURE;
+			result = PG_SASL_EXCHANGE_FAILURE;
 	}
 
-	if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
+	if (result == PG_SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
 		*logdetail = state->logdetail;
 
 	if (*output)
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
index 967b5ef73c..8cc23ef7fb 100644
--- a/src/backend/libpq/auth.c
+++ b/src/backend/libpq/auth.c
@@ -26,11 +26,11 @@
 #include "commands/user.h"
 #include "common/ip.h"
 #include "common/md5.h"
-#include "common/scram-common.h"
 #include "libpq/auth.h"
 #include "libpq/crypt.h"
 #include "libpq/libpq.h"
 #include "libpq/pqformat.h"
+#include "libpq/sasl.h"
 #include "libpq/scram.h"
 #include "miscadmin.h"
 #include "port/pg_bswap.h"
@@ -45,8 +45,6 @@
  * Global authentication functions
  *----------------------------------------------------------------
  */
-static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
-							int extralen);
 static void auth_failed(Port *port, int status, char *logdetail);
 static char *recv_password_packet(Port *port);
 static void set_authn_id(Port *port, const char *id);
@@ -60,7 +58,6 @@ static int	CheckPasswordAuth(Port *port, char **logdetail);
 static int	CheckPWChallengeAuth(Port *port, char **logdetail);
 
 static int	CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail);
-static int	CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail);
 
 
 /*----------------------------------------------------------------
@@ -224,14 +221,6 @@ static int	PerformRadiusTransaction(const char *server, const char *secret, cons
  */
 #define PG_MAX_AUTH_TOKEN_LENGTH	65535
 
-/*
- * Maximum accepted size of SASL messages.
- *
- * The messages that the server or libpq generate are much smaller than this,
- * but have some headroom.
- */
-#define PG_MAX_SASL_MESSAGE_LENGTH	1024
-
 /*----------------------------------------------------------------
  * Global authentication functions
  *----------------------------------------------------------------
@@ -668,7 +657,7 @@ ClientAuthentication(Port *port)
 /*
  * Send an authentication request packet to the frontend.
  */
-static void
+void
 sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen)
 {
 	StringInfoData buf;
@@ -848,12 +837,14 @@ CheckPWChallengeAuth(Port *port, char **logdetail)
 	 * SCRAM secret, we must do SCRAM authentication.
 	 *
 	 * If MD5 authentication is not allowed, always use SCRAM.  If the user
-	 * had an MD5 password, CheckSCRAMAuth() will fail.
+	 * had an MD5 password, CheckSASLAuth() with the SCRAM mechanism will
+	 * fail.
 	 */
 	if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5)
 		auth_result = CheckMD5Auth(port, shadow_pass, logdetail);
 	else
-		auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail);
+		auth_result = CheckSASLAuth(&pg_be_scram_mech, port, shadow_pass,
+									logdetail);
 
 	if (shadow_pass)
 		pfree(shadow_pass);
@@ -911,152 +902,6 @@ CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
 	return result;
 }
 
-static int
-CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
-{
-	StringInfoData sasl_mechs;
-	int			mtype;
-	StringInfoData buf;
-	void	   *scram_opaq = NULL;
-	char	   *output = NULL;
-	int			outputlen = 0;
-	const char *input;
-	int			inputlen;
-	int			result;
-	bool		initial;
-
-	/*
-	 * Send the SASL authentication request to user.  It includes the list of
-	 * authentication mechanisms that are supported.
-	 */
-	initStringInfo(&sasl_mechs);
-
-	pg_be_scram_get_mechanisms(port, &sasl_mechs);
-	/* Put another '\0' to mark that list is finished. */
-	appendStringInfoChar(&sasl_mechs, '\0');
-
-	sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
-	pfree(sasl_mechs.data);
-
-	/*
-	 * Loop through SASL message exchange.  This exchange can consist of
-	 * multiple messages sent in both directions.  First message is always
-	 * from the client.  All messages from client to server are password
-	 * packets (type 'p').
-	 */
-	initial = true;
-	do
-	{
-		pq_startmsgread();
-		mtype = pq_getbyte();
-		if (mtype != 'p')
-		{
-			/* Only log error if client didn't disconnect. */
-			if (mtype != EOF)
-			{
-				ereport(ERROR,
-						(errcode(ERRCODE_PROTOCOL_VIOLATION),
-						 errmsg("expected SASL response, got message type %d",
-								mtype)));
-			}
-			else
-				return STATUS_EOF;
-		}
-
-		/* Get the actual SASL message */
-		initStringInfo(&buf);
-		if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
-		{
-			/* EOF - pq_getmessage already logged error */
-			pfree(buf.data);
-			return STATUS_ERROR;
-		}
-
-		elog(DEBUG4, "processing received SASL response of length %d", buf.len);
-
-		/*
-		 * The first SASLInitialResponse message is different from the others.
-		 * It indicates which SASL mechanism the client selected, and contains
-		 * an optional Initial Client Response payload.  The subsequent
-		 * SASLResponse messages contain just the SASL payload.
-		 */
-		if (initial)
-		{
-			const char *selected_mech;
-
-			selected_mech = pq_getmsgrawstring(&buf);
-
-			/*
-			 * Initialize the status tracker for message exchanges.
-			 *
-			 * If the user doesn't exist, or doesn't have a valid password, or
-			 * it's expired, we still go through the motions of SASL
-			 * authentication, but tell the authentication method that the
-			 * authentication is "doomed". That is, it's going to fail, no
-			 * matter what.
-			 *
-			 * This is because we don't want to reveal to an attacker what
-			 * usernames are valid, nor which users have a valid password.
-			 */
-			scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
-
-			inputlen = pq_getmsgint(&buf, 4);
-			if (inputlen == -1)
-				input = NULL;
-			else
-				input = pq_getmsgbytes(&buf, inputlen);
-
-			initial = false;
-		}
-		else
-		{
-			inputlen = buf.len;
-			input = pq_getmsgbytes(&buf, buf.len);
-		}
-		pq_getmsgend(&buf);
-
-		/*
-		 * The StringInfo guarantees that there's a \0 byte after the
-		 * response.
-		 */
-		Assert(input == NULL || input[inputlen] == '\0');
-
-		/*
-		 * we pass 'logdetail' as NULL when doing a mock authentication,
-		 * because we should already have a better error message in that case
-		 */
-		result = pg_be_scram_exchange(scram_opaq, input, inputlen,
-									  &output, &outputlen,
-									  logdetail);
-
-		/* input buffer no longer used */
-		pfree(buf.data);
-
-		if (output)
-		{
-			/*
-			 * Negotiation generated data to be sent to the client.
-			 */
-			elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
-
-			if (result == SASL_EXCHANGE_SUCCESS)
-				sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
-			else
-				sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
-
-			pfree(output);
-		}
-	} while (result == SASL_EXCHANGE_CONTINUE);
-
-	/* Oops, Something bad happened */
-	if (result != SASL_EXCHANGE_SUCCESS)
-	{
-		return STATUS_ERROR;
-	}
-
-	return STATUS_OK;
-}
-
 
 /*----------------------------------------------------------------
  * GSSAPI authentication system
diff --git a/src/interfaces/libpq/fe-auth-sasl.h b/src/interfaces/libpq/fe-auth-sasl.h
new file mode 100644
index 0000000000..c8ba3bc7cc
--- /dev/null
+++ b/src/interfaces/libpq/fe-auth-sasl.h
@@ -0,0 +1,130 @@
+/*-------------------------------------------------------------------------
+ *
+ * fe-auth-sasl.h
+ *	  Defines the SASL mechanism interface for libpq.
+ *
+ * Each SASL mechanism defines a frontend and a backend callback structure.
+ * This is not part of the public API for applications.
+ *
+ * See src/include/libpq/sasl.h for the backend counterpart.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/interfaces/libpq/fe-auth-sasl.h
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#ifndef FE_AUTH_SASL_H
+#define FE_AUTH_SASL_H
+
+#include "libpq-fe.h"
+
+/*
+ * Frontend SASL mechanism callbacks.
+ *
+ * To implement a frontend mechanism, declare a pg_be_sasl_mech struct with
+ * appropriate callback implementations, then hook it into conn->sasl during
+ * pg_SASL_init()'s mechanism negotiation.
+ */
+typedef struct pg_fe_sasl_mech
+{
+	/*-------
+	 * init()
+	 *
+	 * Initializes mechanism-specific state for a connection.  This
+	 * callback must return a pointer to its allocated state, which will
+	 * be passed as-is as the first argument to the other callbacks.
+	 * free() is called to release any state resources.
+	 *
+	 * If state allocation fails, the implementation should return NULL to
+	 * fail the authentication exchange.
+	 *
+	 * Input parameters:
+	 *
+	 *   conn:     The connection to the server
+	 *
+	 *   password: The user's supplied password for the current connection
+	 *
+	 *   mech:     The mechanism name in use, for implementations that may
+	 *			   advertise more than one name (such as *-PLUS variants).
+	 *-------
+	 */
+	void	   *(*init) (PGconn *conn, const char *password, const char *mech);
+
+	/*--------
+	 * exchange()
+	 *
+	 * Produces a client response to a server challenge.  As a special case
+	 * for client-first SASL mechanisms, exchange() is called with a NULL
+	 * server response once at the start of the authentication exchange to
+	 * generate an initial response.
+	 *
+	 * Input parameters:
+	 *
+	 *	state:	   The opaque mechanism state returned by init()
+	 *
+	 *	input:	   The challenge data sent by the server, or NULL when
+	 *			   generating a client-first initial response (that is, when
+	 *			   the server expects the client to send a message to start
+	 *			   the exchange).  This is guaranteed to be null-terminated
+	 *			   for safety, but SASL allows embedded nulls in challenges,
+	 *			   so mechanisms must be careful to check inputlen.
+	 *
+	 *	inputlen:  The length of the challenge data sent by the server, or -1
+	 *             during client-first initial response generation.
+	 *
+	 * Output parameters, to be set by the callback function:
+	 *
+	 *	output:	   A malloc'd buffer containing the client's response to
+	 *			   the server, or NULL if the exchange should be aborted.
+	 *			   (*success should be set to false in the latter case.)
+	 *
+	 *	outputlen: The length of the client response buffer, or zero if no
+	 *			   data should be sent due to an exchange failure
+	 *
+	 *	done:      Set to true if the SASL exchange should not continue,
+	 *			   because the exchange is either complete or failed
+	 *
+	 *	success:   set to true if the SASL exchange completed successfully.
+	 *			   Ignored if *done is false.
+	 *--------
+	 */
+	void		(*exchange) (void *state, char *input, int inputlen,
+							 char **output, int *outputlen,
+							 bool *done, bool *success);
+
+	/*--------
+	 * channel_bound()
+	 *
+	 * Returns true if the connection has an established channel binding.  A
+	 * mechanism implementation must ensure that a SASL exchange has actually
+	 * been completed, in addition to checking that channel binding is in use.
+	 *
+	 * Mechanisms that do not implement channel binding may simply return
+	 * false.
+	 *
+	 * Input parameters:
+	 *
+	 *	state:    The opaque mechanism state returned by init()
+	 *--------
+	 */
+	bool		(*channel_bound) (void *state);
+
+	/*--------
+	 * free()
+	 *
+	 * Frees the state allocated by init(). This is called when the connection
+	 * is dropped, not when the exchange is completed.
+	 *
+	 * Input parameters:
+	 *
+	 *   state:    The opaque mechanism state returned by init()
+	 *--------
+	 */
+	void		(*free) (void *state);
+
+} pg_fe_sasl_mech;
+
+#endif							/* FE_AUTH_SASL_H */
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index 5881386e37..4337e89ce9 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -21,6 +21,22 @@
 #include "fe-auth.h"
 
 
+/* The exported SCRAM callback mechanism. */
+static void *scram_init(PGconn *conn, const char *password,
+						const char *sasl_mechanism);
+static void scram_exchange(void *opaq, char *input, int inputlen,
+						   char **output, int *outputlen,
+						   bool *done, bool *success);
+static bool scram_channel_bound(void *opaq);
+static void scram_free(void *opaq);
+
+const pg_fe_sasl_mech pg_scram_mech = {
+	scram_init,
+	scram_exchange,
+	scram_channel_bound,
+	scram_free
+};
+
 /*
  * Status of exchange messages used for SCRAM authentication via the
  * SASL protocol.
@@ -72,10 +88,10 @@ static bool calculate_client_proof(fe_scram_state *state,
 /*
  * Initialize SCRAM exchange status.
  */
-void *
-pg_fe_scram_init(PGconn *conn,
-				 const char *password,
-				 const char *sasl_mechanism)
+static void *
+scram_init(PGconn *conn,
+		   const char *password,
+		   const char *sasl_mechanism)
 {
 	fe_scram_state *state;
 	char	   *prep_password;
@@ -128,8 +144,8 @@ pg_fe_scram_init(PGconn *conn,
  * Note that the caller must also ensure that the exchange was actually
  * successful.
  */
-bool
-pg_fe_scram_channel_bound(void *opaq)
+static bool
+scram_channel_bound(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -152,8 +168,8 @@ pg_fe_scram_channel_bound(void *opaq)
 /*
  * Free SCRAM exchange status
  */
-void
-pg_fe_scram_free(void *opaq)
+static void
+scram_free(void *opaq)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 
@@ -188,10 +204,10 @@ pg_fe_scram_free(void *opaq)
 /*
  * Exchange a SCRAM message with backend.
  */
-void
-pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-					 char **output, int *outputlen,
-					 bool *done, bool *success)
+static void
+scram_exchange(void *opaq, char *input, int inputlen,
+			   char **output, int *outputlen,
+			   bool *done, bool *success)
 {
 	fe_scram_state *state = (fe_scram_state *) opaq;
 	PGconn	   *conn = state->conn;
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index e8062647e6..65462f912a 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -41,6 +41,7 @@
 #include "common/md5.h"
 #include "common/scram-common.h"
 #include "fe-auth.h"
+#include "fe-auth-sasl.h"
 #include "libpq-fe.h"
 
 #ifdef ENABLE_GSS
@@ -482,7 +483,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 				 * channel_binding is not disabled.
 				 */
 				if (conn->channel_binding[0] != 'd')	/* disable */
+				{
 					selected_mechanism = SCRAM_SHA_256_PLUS_NAME;
+					conn->sasl = &pg_scram_mech;
+				}
 #else
 				/*
 				 * The client does not support channel binding.  If it is
@@ -516,7 +520,10 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		}
 		else if (strcmp(mechanism_buf.data, SCRAM_SHA_256_NAME) == 0 &&
 				 !selected_mechanism)
+		{
 			selected_mechanism = SCRAM_SHA_256_NAME;
+			conn->sasl = &pg_scram_mech;
+		}
 	}
 
 	if (!selected_mechanism)
@@ -555,20 +562,22 @@ pg_SASL_init(PGconn *conn, int payloadlen)
 		goto error;
 	}
 
+	Assert(conn->sasl);
+
 	/*
 	 * Initialize the SASL state information with all the information gathered
 	 * during the initial exchange.
 	 *
 	 * Note: Only tls-unique is supported for the moment.
 	 */
-	conn->sasl_state = pg_fe_scram_init(conn,
+	conn->sasl_state = conn->sasl->init(conn,
 										password,
 										selected_mechanism);
 	if (!conn->sasl_state)
 		goto oom_error;
 
 	/* Get the mechanism-specific Initial Client Response, if any */
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 NULL, -1,
 						 &initialresponse, &initialresponselen,
 						 &done, &success);
@@ -649,7 +658,7 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 	/* For safety and convenience, ensure the buffer is NULL-terminated. */
 	challenge[payloadlen] = '\0';
 
-	pg_fe_scram_exchange(conn->sasl_state,
+	conn->sasl->exchange(conn->sasl_state,
 						 challenge, payloadlen,
 						 &output, &outputlen,
 						 &done, &success);
@@ -664,6 +673,12 @@ pg_SASL_continue(PGconn *conn, int payloadlen, bool final)
 							 libpq_gettext("AuthenticationSASLFinal received from server, but SASL authentication was not completed\n"));
 		return STATUS_ERROR;
 	}
+
+	/*
+	 * TODO SASL requires us to accomodate zero-length responses. TODO is it
+	 * legal for a client not to send a response to a server challenge, if the
+	 * exchange isn't being aborted?
+	 */
 	if (outputlen != 0)
 	{
 		/*
@@ -830,7 +845,7 @@ check_expected_areq(AuthRequest areq, PGconn *conn)
 			case AUTH_REQ_SASL_FIN:
 				break;
 			case AUTH_REQ_OK:
-				if (!pg_fe_scram_channel_bound(conn->sasl_state))
+				if (!conn->sasl || !conn->sasl->channel_bound(conn->sasl_state))
 				{
 					appendPQExpBufferStr(&conn->errorMessage,
 										 libpq_gettext("channel binding required, but server authenticated client without channel binding\n"));
diff --git a/src/interfaces/libpq/fe-auth.h b/src/interfaces/libpq/fe-auth.h
index 7877dcbd09..63927480ee 100644
--- a/src/interfaces/libpq/fe-auth.h
+++ b/src/interfaces/libpq/fe-auth.h
@@ -22,15 +22,8 @@
 extern int	pg_fe_sendauth(AuthRequest areq, int payloadlen, PGconn *conn);
 extern char *pg_fe_getauthname(PQExpBuffer errorMessage);
 
-/* Prototypes for functions in fe-auth-scram.c */
-extern void *pg_fe_scram_init(PGconn *conn,
-							  const char *password,
-							  const char *sasl_mechanism);
-extern bool pg_fe_scram_channel_bound(void *opaq);
-extern void pg_fe_scram_free(void *opaq);
-extern void pg_fe_scram_exchange(void *opaq, char *input, int inputlen,
-								 char **output, int *outputlen,
-								 bool *done, bool *success);
+/* Mechanisms in fe-auth-scram.c */
+extern const pg_fe_sasl_mech pg_scram_mech;
 extern char *pg_fe_scram_build_secret(const char *password);
 
 #endif							/* FE_AUTH_H */
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index fc65e490ef..e950b41374 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -516,11 +516,7 @@ pqDropConnection(PGconn *conn, bool flushInput)
 #endif
 	if (conn->sasl_state)
 	{
-		/*
-		 * XXX: if support for more authentication mechanisms is added, this
-		 * needs to call the right 'free' function.
-		 */
-		pg_fe_scram_free(conn->sasl_state);
+		conn->sasl->free(conn->sasl_state);
 		conn->sasl_state = NULL;
 	}
 }
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index 6b7fd2c267..e9f214b61b 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -41,6 +41,7 @@
 #include "getaddrinfo.h"
 #include "libpq/pqcomm.h"
 /* include stuff found in fe only */
+#include "fe-auth-sasl.h"
 #include "pqexpbuffer.h"
 
 #ifdef ENABLE_GSS
@@ -500,6 +501,7 @@ struct pg_conn
 	PGresult   *next_result;	/* next result (used in single-row mode) */
 
 	/* Assorted state for SASL, SSL, GSS, etc */
+	const pg_fe_sasl_mech *sasl;
 	void	   *sasl_state;
 
 	/* SSL structures */
diff --git a/src/tools/pgindent/typedefs.list b/src/tools/pgindent/typedefs.list
index 64c06cf952..3067644cec 100644
--- a/src/tools/pgindent/typedefs.list
+++ b/src/tools/pgindent/typedefs.list
@@ -3318,6 +3318,7 @@ pgParameterStatus
 pg_atomic_flag
 pg_atomic_uint32
 pg_atomic_uint64
+pg_be_sasl_mech
 pg_checksum_context
 pg_checksum_raw_context
 pg_checksum_type
@@ -3333,6 +3334,7 @@ pg_enc
 pg_enc2gettext
 pg_enc2name
 pg_encname
+pg_fe_sasl_mech
 pg_funcptr_t
 pg_gssinfo
 pg_hmac_ctx
-- 
2.32.0

