1: a94a11d56c < -: ---------- Add support for custom authentication methods 2: 973f622fea < -: ---------- Add sample extension to test custom auth provider hooks 3: b4a0ab5a4e < -: ---------- Add tests for test_auth_provider extension 4: 49715c3c98 < -: ---------- Add support for "map" and custom auth options 5: 2b2e8d3050 ! 1: c3698bbc3d common/jsonapi: support FRONTEND clients @@ src/common/jsonapi.c: parse_object(JsonLexContext *lex, JsonSemAction *sem) #endif @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) - char *const end = lex->input + lex->input_length; - int hi_surrogate = -1; + return code; \ + } while (0) - if (lex->strval != NULL) - resetStringInfo(lex->strval); @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) Assert(lex->input_length > 0); s = lex->token_start; @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) - return JSON_UNICODE_ESCAPE_FORMAT; - } + else + FAIL_AT_CHAR_END(JSON_UNICODE_ESCAPE_FORMAT); } - if (lex->strval != NULL) + if (lex->parse_strval) @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) + appendPQExpBufferChar(lex->strval, (char) ch); } else - return JSON_UNICODE_HIGH_ESCAPE; + FAIL_AT_CHAR_END(JSON_UNICODE_HIGH_ESCAPE); #endif /* FRONTEND */ } } @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) + else if (lex->parse_strval) { if (hi_surrogate != -1) - return JSON_UNICODE_LOW_SURROGATE; + FAIL_AT_CHAR_END(JSON_UNICODE_LOW_SURROGATE); @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) case '"': case '\\': @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) + appendStrValChar(lex->strval, '\t'); break; default: - /* Not a valid string escape, so signal error. */ + @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) /* @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) /* * s will be incremented at the top of the loop, so set it to just @@ src/common/jsonapi.c: json_lex_string(JsonLexContext *lex) - if (hi_surrogate != -1) return JSON_UNICODE_LOW_SURROGATE; + } +#ifdef FRONTEND + if (lex->parse_strval && PQExpBufferBroken(lex->strval)) @@ src/common/jsonapi.c: report_parse_error(JsonParseContext ctx, JsonLexContext *l -} - /* - * Construct a detail message for a JSON error. + * Construct an (already translated) detail message for a JSON error. * - * Note that the error message generated by this routine may not be - * palloc'd, making it unsafe for frontend code as there is no way to -- * know if this can be safery pfree'd or not. +- * know if this can be safely pfree'd or not. + * The returned allocation is either static or owned by the JsonLexContext and + * should not be freed. */ @@ src/common/jsonapi.c: report_parse_error(JsonParseContext ctx, JsonLexContext *l return _("\\u0000 cannot be converted to text."); case JSON_UNICODE_ESCAPE_FORMAT: @@ src/common/jsonapi.c: json_errdetail(JsonParseErrorType error, JsonLexContext *lex) - return _("Unicode low surrogate must follow a high surrogate."); + /* note: this case is only reachable in frontend not backend */ + return _("Unicode escape values cannot be used for code point values above 007F when the encoding is not UTF8."); + case JSON_UNICODE_UNTRANSLATABLE: +- /* note: this case is only reachable in backend not frontend */ ++ /* ++ * note: this case is only reachable in backend not frontend. ++ * #ifdef it away so the frontend doesn't try to link against ++ * backend functionality. ++ */ ++#ifndef FRONTEND + return psprintf(_("Unicode escape value could not be translated to the server's encoding %s."), + GetDatabaseEncodingName()); ++#else ++ Assert(false); ++ break; ++#endif + case JSON_UNICODE_HIGH_SURROGATE: + return _("Unicode high surrogate must not follow a high surrogate."); + case JSON_UNICODE_LOW_SURROGATE: +@@ src/common/jsonapi.c: json_errdetail(JsonParseErrorType error, JsonLexContext *lex) + break; } - /* @@ src/include/common/jsonapi.h -#include "lib/stringinfo.h" - - typedef enum + typedef enum JsonTokenType { JSON_TOKEN_INVALID, -@@ src/include/common/jsonapi.h: typedef enum +@@ src/include/common/jsonapi.h: typedef enum JsonParseErrorType JSON_EXPECTED_OBJECT_NEXT, JSON_EXPECTED_STRING, JSON_INVALID_TOKEN, @@ src/include/common/jsonapi.h: typedef enum JSON_UNICODE_CODE_POINT_ZERO, JSON_UNICODE_ESCAPE_FORMAT, JSON_UNICODE_HIGH_ESCAPE, -@@ src/include/common/jsonapi.h: typedef enum - JSON_UNICODE_LOW_SURROGATE +@@ src/include/common/jsonapi.h: typedef enum JsonParseErrorType + JSON_SEM_ACTION_FAILED /* error should already be reported */ } JsonParseErrorType; +/* @@ src/include/common/jsonapi.h: typedef struct JsonLexContext + StrValType *errormsg; } JsonLexContext; - typedef void (*json_struct_action) (void *state); + typedef JsonParseErrorType (*json_struct_action) (void *state); @@ src/include/common/jsonapi.h: extern PGDLLIMPORT JsonSemAction nullSemAction; */ extern JsonParseErrorType json_count_array_elements(JsonLexContext *lex, 6: c18d7da6cc ! 2: 0cd726fd55 libpq: add OAUTHBEARER SASL mechanism @@ Commit message - handle cases where the client has been set up with an issuer and scope, but the Postgres server wants to use something different - improve error debuggability during the OAuth handshake + - migrate JSON parsing to the new JSON_SEM_ACTION_FAILED API convention - ...and more. ## configure ## @@ configure: fi + as_fn_error $? "library 'iddawc' is required for OAuth support" "$LINENO" 5 +fi + ++ # Check for an older spelling of i_get_openid_config ++ for ac_func in i_load_openid_config ++do : ++ ac_fn_c_check_func "$LINENO" "i_load_openid_config" "ac_cv_func_i_load_openid_config" ++if test "x$ac_cv_func_i_load_openid_config" = xyes; then : ++ cat >>confdefs.h <<_ACEOF ++#define HAVE_I_LOAD_OPENID_CONFIG 1 ++_ACEOF ++ ++fi ++done ++ +fi + # for contrib/sepgsql @@ configure.ac: fi +if test "$with_oauth" = yes ; then + AC_CHECK_LIB(iddawc, i_init_session, [], [AC_MSG_ERROR([library 'iddawc' is required for OAuth support])]) ++ # Check for an older spelling of i_get_openid_config ++ AC_CHECK_FUNCS([i_load_openid_config]) +fi + # for contrib/sepgsql @@ configure.ac: elif test "$with_uuid" = ossp ; then AC_CHECK_HEADERS(crtdefs.h) fi + ## meson.build ## +@@ meson.build: endif + + + ++############################################################### ++# Library: oauth ++############################################################### ++ ++oauthopt = get_option('oauth') ++if not oauthopt.disabled() ++ oauth = dependency('libiddawc', required: oauthopt) ++ ++ if oauth.found() ++ cdata.set('USE_OAUTH', 1) ++ # Check for an older spelling of i_get_openid_config ++ if cc.has_function('i_load_openid_config', ++ dependencies: oauth, args: test_c_args) ++ cdata.set('HAVE_I_LOAD_OPENID_CONFIG', 1) ++ endif ++ endif ++else ++ oauth = not_found_dep ++endif ++ ++ + ############################################################### + # Library: Tcl (for pltcl) + # +@@ meson.build: libpq_deps += [ + gssapi, + ldap_r, + libintl, ++ oauth, + ssl, + ] + +@@ meson.build: if meson.version().version_compare('>=0.57') + 'llvm': llvm, + 'lz4': lz4, + 'nls': libintl, ++ 'oauth': oauth, + 'openssl': ssl, + 'pam': pam, + 'plperl': perl_dep, + + ## meson_options.txt ## +@@ meson_options.txt: option('lz4', type : 'feature', value: 'auto', + option('nls', type: 'feature', value: 'auto', + description: 'native language support') + ++option('oauth', type : 'feature', value: 'auto', ++ description: 'OAuth 2.0 support') ++ + option('pam', type : 'feature', value: 'auto', + description: 'build with PAM support') + + ## src/Makefile.global.in ## @@ src/Makefile.global.in: with_ldap = @with_ldap@ with_libxml = @with_libxml@ @@ src/Makefile.global.in: with_ldap = @with_ldap@ with_uuid = @with_uuid@ with_zlib = @with_zlib@ + ## src/common/meson.build ## +@@ src/common/meson.build: common_sources_frontend_static += files( + # For the server build of pgcommon, depend on lwlocknames_h, because at least + # cryptohash_openssl.c, hmac_openssl.c depend on it. That's arguably a + # layering violation, but ... ++# ++# XXX Frontend builds need libpq's pqexpbuffer.h, so adjust the include paths ++# appropriately. This seems completely broken. + pgcommon = {} + pgcommon_variants = { + '_srv': internal_lib_args + { ++ 'include_directories': include_directories('.'), + 'sources': common_sources + [lwlocknames_h], + 'dependencies': [backend_common_code], + }, + '': default_lib_args + { ++ 'include_directories': include_directories('../interfaces/libpq', '.'), + 'sources': common_sources_frontend_static, + 'dependencies': [frontend_common_code], + }, + '_shlib': default_lib_args + { + 'pic': true, ++ 'include_directories': include_directories('../interfaces/libpq', '.'), + 'sources': common_sources_frontend_shlib, + 'dependencies': [frontend_common_code], + }, +@@ src/common/meson.build: foreach name, opts : pgcommon_variants + c_args = opts.get('c_args', []) + common_cflags[cflagname] + cflag_libs += static_library('libpgcommon@0@_@1@'.format(name, cflagname), + c_pch: pch_c_h, +- include_directories: include_directories('.'), + kwargs: opts + { + 'sources': sources, + 'c_args': c_args, +@@ src/common/meson.build: foreach name, opts : pgcommon_variants + lib = static_library('libpgcommon@0@'.format(name), + link_with: cflag_libs, + c_pch: pch_c_h, +- include_directories: include_directories('.'), + kwargs: opts + { + 'dependencies': opts['dependencies'] + [ssl], + } + ## src/include/common/oauth-common.h (new) ## @@ +/*------------------------------------------------------------------------- @@ src/include/common/oauth-common.h (new) +#endif /* OAUTH_COMMON_H */ ## src/include/pg_config.h.in ## +@@ + /* Define to 1 if __builtin_constant_p(x) implies "i"(x) acceptance. */ + #undef HAVE_I_CONSTRAINT__BUILTIN_CONSTANT_P + ++/* Define to 1 if you have the `i_load_openid_config' function. */ ++#undef HAVE_I_LOAD_OPENID_CONFIG ++ + /* Define to 1 if you have the `kqueue' function. */ + #undef HAVE_KQUEUE + @@ /* Define to 1 if you have the `crypto' library (-lcrypto). */ #undef HAVE_LIBCRYPTO @@ src/interfaces/libpq/fe-auth-oauth.c (new) +#include "fe-auth.h" +#include "mb/pg_wchar.h" + ++#ifdef HAVE_I_LOAD_OPENID_CONFIG ++/* Older versions of iddawc used 'load' instead of 'get' for some APIs. */ ++#define i_get_openid_config i_load_openid_config ++#endif ++ +/* The exported OAuth callback mechanism. */ +static void *oauth_init(PGconn *conn, const char *password, + const char *sasl_mechanism); @@ src/interfaces/libpq/fe-auth-oauth.c (new) + (ctx)->errmsg = (ctx)->errbuf.data; \ + } while (0) + -+static void ++static JsonParseErrorType +oauth_json_object_start(void *state) +{ + struct json_ctx *ctx = state; + + if (oauth_json_has_error(ctx)) -+ return; /* short-circuit */ ++ return JSON_SUCCESS; /* short-circuit */ + + if (ctx->target_field) + { @@ src/interfaces/libpq/fe-auth-oauth.c (new) + } + + ++ctx->nested; ++ return JSON_SUCCESS; /* TODO: switch all of these to JSON_SEM_ACTION_FAILED */ +} + -+static void ++static JsonParseErrorType +oauth_json_object_end(void *state) +{ + struct json_ctx *ctx = state; + + if (oauth_json_has_error(ctx)) -+ return; /* short-circuit */ ++ return JSON_SUCCESS; /* short-circuit */ + + --ctx->nested; ++ return JSON_SUCCESS; +} + -+static void ++static JsonParseErrorType +oauth_json_object_field_start(void *state, char *name, bool isnull) +{ + struct json_ctx *ctx = state; @@ src/interfaces/libpq/fe-auth-oauth.c (new) + { + /* short-circuit */ + free(name); -+ return; ++ return JSON_SUCCESS; + } + + if (ctx->nested == 1) @@ src/interfaces/libpq/fe-auth-oauth.c (new) + } + + free(name); ++ return JSON_SUCCESS; +} + -+static void ++static JsonParseErrorType +oauth_json_array_start(void *state) +{ + struct json_ctx *ctx = state; + + if (oauth_json_has_error(ctx)) -+ return; /* short-circuit */ ++ return JSON_SUCCESS; /* short-circuit */ + + if (!ctx->nested) + { @@ src/interfaces/libpq/fe-auth-oauth.c (new) + libpq_gettext("field \"%s\" must be a string"), + ctx->target_field_name); + } ++ ++ return JSON_SUCCESS; +} + -+static void ++static JsonParseErrorType +oauth_json_scalar(void *state, char *token, JsonTokenType type) +{ + struct json_ctx *ctx = state; @@ src/interfaces/libpq/fe-auth-oauth.c (new) + { + /* short-circuit */ + free(token); -+ return; ++ return JSON_SUCCESS; + } + + if (!ctx->nested) @@ src/interfaces/libpq/fe-auth-oauth.c (new) + ctx->target_field = NULL; + ctx->target_field_name = NULL; + -+ return; /* don't free the token we're using */ ++ return JSON_SUCCESS; /* don't free the token we're using */ + } + + oauth_json_set_error(ctx, @@ src/interfaces/libpq/fe-auth-oauth.c (new) + } + + free(token); ++ return JSON_SUCCESS; +} + +static bool @@ src/interfaces/libpq/fe-auth.c: pg_SASL_continue(PGconn *conn, int payloadlen, b &done, &success); ## src/interfaces/libpq/fe-auth.h ## -@@ src/interfaces/libpq/fe-auth.h: extern const pg_fe_sasl_mech pg_scram_mech; - extern char *pg_fe_scram_build_secret(const char *password, +@@ src/interfaces/libpq/fe-auth.h: extern char *pg_fe_scram_build_secret(const char *password, + int iterations, const char **errstr); +/* Mechanisms in fe-auth-oauth.c */ @@ src/interfaces/libpq/fe-auth.h: extern const pg_fe_sasl_mech pg_scram_mech; ## src/interfaces/libpq/fe-connect.c ## @@ src/interfaces/libpq/fe-connect.c: static const internalPQconninfoOption PQconninfoOptions[] = { - "Target-Session-Attrs", "", 15, /* sizeof("prefer-standby") = 15 */ - offsetof(struct pg_conn, target_session_attrs)}, + "Load-Balance-Hosts", "", 8, /* sizeof("disable") = 8 */ + offsetof(struct pg_conn, load_balance_hosts)}, + /* OAuth v2 */ + {"oauth_issuer", NULL, NULL, NULL, @@ src/interfaces/libpq/fe-connect.c: keep_going: /* We will come back to here /* @@ src/interfaces/libpq/fe-connect.c: freePGconn(PGconn *conn) - free(conn->outBuffer); free(conn->rowBuf); free(conn->target_session_attrs); + free(conn->load_balance_hosts); + free(conn->oauth_issuer); + free(conn->oauth_discovery_uri); + free(conn->oauth_client_id); @@ src/interfaces/libpq/fe-connect.c: freePGconn(PGconn *conn) ## src/interfaces/libpq/libpq-int.h ## @@ src/interfaces/libpq/libpq-int.h: struct pg_conn - char *ssl_max_protocol_version; /* maximum TLS protocol version */ - char *target_session_attrs; /* desired session properties */ + char *require_auth; /* name of the expected auth method */ + char *load_balance_hosts; /* load balance over hosts */ + /* OAuth v2 */ + char *oauth_issuer; /* token issuer URL */ @@ src/interfaces/libpq/libpq-int.h: struct pg_conn /* Optional file to write trace info to */ FILE *Pfdebug; int traceFlags; + + ## src/interfaces/libpq/meson.build ## +@@ src/interfaces/libpq/meson.build: if gssapi.found() + ) + endif + ++if oauth.found() ++ libpq_sources += files('fe-auth-oauth.c') ++endif ++ + export_file = custom_target('libpq.exports', + kwargs: gen_export_kwargs, + ) + + ## src/makefiles/meson.build ## +@@ src/makefiles/meson.build: pgxs_deps = { + 'llvm': llvm, + 'lz4': lz4, + 'nls': libintl, ++ 'oauth': oauth, + 'pam': pam, + 'perl': perl_dep, + 'python': python3_dep, 7: f6a81f50f2 ! 3: 77889eb986 backend: add OAUTHBEARER SASL mechanism @@ src/backend/libpq/auth.c #include "libpq/pqformat.h" #include "libpq/sasl.h" #include "libpq/scram.h" +@@ + */ + static void auth_failed(Port *port, int status, const char *logdetail); + static char *recv_password_packet(Port *port); +-static void set_authn_id(Port *port, const char *id); + + + /*---------------------------------------------------------------- +@@ src/backend/libpq/auth.c: static int CheckRADIUSAuth(Port *port); + static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd); + + +-/* +- * Maximum accepted size of GSS and SSPI authentication tokens. +- * We also use this as a limit on ordinary password packet lengths. +- * +- * Kerberos tickets are usually quite small, but the TGTs issued by Windows +- * domain controllers include an authorization field known as the Privilege +- * Attribute Certificate (PAC), which contains the user's Windows permissions +- * (group memberships etc.). The PAC is copied into all tickets obtained on +- * the basis of this TGT (even those issued by Unix realms which the Windows +- * realm trusts), and can be several kB in size. The maximum token size +- * accepted by Windows systems is determined by the MaxAuthToken Windows +- * registry setting. Microsoft recommends that it is not set higher than +- * 65535 bytes, so that seems like a reasonable limit for us as well. +- */ +-#define PG_MAX_AUTH_TOKEN_LENGTH 65535 +- + /*---------------------------------------------------------------- + * Global authentication functions + *---------------------------------------------------------------- @@ src/backend/libpq/auth.c: auth_failed(Port *port, int status, const char *logdetail) case uaRADIUS: errstr = gettext_noop("RADIUS authentication failed for user \"%s\""); @@ src/backend/libpq/auth.c: auth_failed(Port *port, int status, const char *logdet + case uaOAuth: + errstr = gettext_noop("OAuth bearer authentication failed for user \"%s\""); + break; - case uaCustom: - { - CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider); + default: + errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method"); + break; +@@ src/backend/libpq/auth.c: auth_failed(Port *port, int status, const char *logdetail) + * lifetime of MyClientConnectionInfo, so it is safe to pass a string that is + * managed by an external library. + */ +-static void ++void + set_authn_id(Port *port, const char *id) + { + Assert(id); @@ src/backend/libpq/auth.c: ClientAuthentication(Port *port) case uaTrust: status = STATUS_OK; @@ src/backend/libpq/auth.c: ClientAuthentication(Port *port) + case uaOAuth: + status = CheckSASLAuth(&pg_be_oauth_mech, port, NULL, NULL); + break; - case uaCustom: - { - CustomAuthProvider *provider = get_provider_by_name(port->hba->custom_provider); + } + + if ((status == STATUS_OK && port->hba->clientcert == clientCertFull) ## src/backend/libpq/hba.c ## @@ src/backend/libpq/hba.c: static const char *const UserAuthName[] = + "ldap", "cert", "radius", - "custom", - "peer" + "peer", + "oauth", }; - + /* @@ src/backend/libpq/hba.c: parse_hba_line(TokenizedAuthLine *tok_line, int elevel) #endif else if (strcmp(token->string, "radius") == 0) parsedline->auth_method = uaRADIUS; + else if (strcmp(token->string, "oauth") == 0) + parsedline->auth_method = uaOAuth; - else if (strcmp(token->string, "custom") == 0) - parsedline->auth_method = uaCustom; else + { + ereport(elevel, @@ src/backend/libpq/hba.c: parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, + hbaline->auth_method != uaPeer && hbaline->auth_method != uaGSS && hbaline->auth_method != uaSSPI && - hbaline->auth_method != uaCert && -+ hbaline->auth_method != uaOAuth && - hbaline->auth_method != uaCustom) -- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert and custom")); -+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, oauth, and custom")); +- hbaline->auth_method != uaCert) +- INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert")); ++ hbaline->auth_method != uaCert && ++ hbaline->auth_method != uaOAuth) ++ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, cert, and oauth")); hbaline->usermap = pstrdup(val); } else if (strcmp(name, "clientcert") == 0) @@ src/backend/libpq/hba.c: parse_hba_auth_opt(char *name, char *val, HbaLine *hbal + else + hbaline->oauth_skip_usermap = false; + } - else if (strcmp(name, "provider") == 0) + else { - REQUIRE_AUTH_OPTION(uaCustom, "provider", "custom"); + ereport(elevel, + + ## src/backend/libpq/meson.build ## +@@ + # Copyright (c) 2022-2023, PostgreSQL Global Development Group + + backend_sources += files( ++ 'auth-oauth.c', + 'auth-sasl.c', + 'auth-scram.c', + 'auth.c', ## src/backend/utils/misc/guc_tables.c ## @@ @@ src/backend/utils/misc/guc_tables.c #include "libpq/auth.h" #include "libpq/libpq.h" +#include "libpq/oauth.h" + #include "libpq/scram.h" + #include "nodes/queryjumble.h" #include "optimizer/cost.h" - #include "optimizer/geqo.h" - #include "optimizer/optimizer.h" @@ src/backend/utils/misc/guc_tables.c: struct config_string ConfigureNamesString[] = - check_backtrace_functions, assign_backtrace_functions, NULL + check_io_direct, assign_io_direct, NULL }, + { + {"oauth_validator_command", PGC_SIGHUP, CONN_AUTH_AUTH, + gettext_noop("Command to validate OAuth v2 bearer tokens."), + NULL, -+ GUC_SUPERUSER_ONLY ++ GUC_SUPERUSER_ONLY | GUC_NOT_IN_SAMPLE + }, + &oauth_validator_command, + "", @@ src/backend/utils/misc/guc_tables.c: struct config_string ConfigureNamesString[] { {NULL, 0, 0, NULL, NULL}, NULL, NULL, NULL, NULL, NULL + ## src/include/libpq/auth.h ## +@@ + + #include "libpq/libpq-be.h" + ++/* ++ * Maximum accepted size of GSS and SSPI authentication tokens. ++ * We also use this as a limit on ordinary password packet lengths. ++ * ++ * Kerberos tickets are usually quite small, but the TGTs issued by Windows ++ * domain controllers include an authorization field known as the Privilege ++ * Attribute Certificate (PAC), which contains the user's Windows permissions ++ * (group memberships etc.). The PAC is copied into all tickets obtained on ++ * the basis of this TGT (even those issued by Unix realms which the Windows ++ * realm trusts), and can be several kB in size. The maximum token size ++ * accepted by Windows systems is determined by the MaxAuthToken Windows ++ * registry setting. Microsoft recommends that it is not set higher than ++ * 65535 bytes, so that seems like a reasonable limit for us as well. ++ */ ++#define PG_MAX_AUTH_TOKEN_LENGTH 65535 ++ + extern PGDLLIMPORT char *pg_krb_server_keyfile; + extern PGDLLIMPORT bool pg_krb_caseins_users; + extern PGDLLIMPORT bool pg_gss_accept_deleg; +@@ src/include/libpq/auth.h: extern PGDLLIMPORT char *pg_krb_realm; + extern void ClientAuthentication(Port *port); + extern void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, + int extralen); ++extern void set_authn_id(Port *port, const char *id); + + /* Hook for plugins to get control in ClientAuthentication() */ + typedef void (*ClientAuthentication_hook_type) (Port *, int); + ## src/include/libpq/hba.h ## @@ src/include/libpq/hba.h: typedef enum UserAuth + uaLDAP, uaCert, uaRADIUS, - uaCustom, - uaPeer -#define USER_AUTH_LAST uaPeer /* Must be last value of this enum */ + uaPeer, @@ src/include/libpq/hba.h: typedef struct HbaLine + char *oauth_issuer; + char *oauth_scope; + bool oauth_skip_usermap; - char *custom_provider; - List *custom_auth_options; } HbaLine; + + typedef struct IdentLine ## src/include/libpq/oauth.h (new) ## @@ 8: e71df89b8f < -: ---------- Add a very simple authn_id extension 9: 73adeb3645 ! 4: 573a2ca3bc Add pytest suite for OAuth @@ src/test/python/README (new) +A test suite for exercising both the libpq client and the server backend at the +protocol level, based on pytest and Construct. + ++WARNING! This suite takes superuser-level control of the cluster under test, ++writing to the server config, creating and destroying databases, etc. It also ++spins up various ephemeral TCP services. This is not safe for production servers ++and therefore must be explicitly opted into by setting PG_TEST_EXTRA=python in ++the environment. ++ +The test suite currently assumes that the standard PG* environment variables +point to the database under test and are sufficient to log in a superuser on +that system. In other words, a bare `psql` needs to Just Work before the test @@ src/test/python/README (new) + +The first run of + -+ make installcheck ++ make installcheck PG_TEST_EXTRA=python + +will install a local virtual environment and all needed dependencies. During +development, if libpq changes incompatibly, you can issue @@ src/test/python/README (new) +The Makefile is there for convenience, but you don't have to use it. Activate +the virtualenv to be able to use pytest directly: + ++ $ export PG_TEST_EXTRA=python + $ source venv/bin/activate + $ py.test -k oauth + ... @@ src/test/python/client/test_oauth.py (new) + with pytest.raises(psycopg2.OperationalError, match=expected_error): + client.check_completed() + ## src/test/python/conftest.py (new) ## +@@ ++# ++# Copyright 2023 Timescale, Inc. ++# SPDX-License-Identifier: PostgreSQL ++# ++ ++import os ++ ++import pytest ++ ++ ++@pytest.fixture(scope="session", autouse=True) ++def _check_PG_TEST_EXTRA(request): ++ """ ++ Automatically skips the whole suite if PG_TEST_EXTRA doesn't contain ++ 'python'. pytestmark doesn't seem to work in a top-level conftest.py, so ++ I've made this an autoused fixture instead. ++ ++ TODO: there are tests here that are probably safe, but until I do a full ++ analysis on which are and which are not, I've made the entire thing opt-in. ++ """ ++ extra_tests = os.getenv("PG_TEST_EXTRA", "").split() ++ if "python" not in extra_tests: ++ pytest.skip("Potentially unsafe test 'python' not enabled in PG_TEST_EXTRA") + ## src/test/python/pq3.py (new) ## @@ +# @@ src/test/python/pq3.py (new) + +# Pq3 + ++ +# Adapted from construct.core.EnumIntegerString +class EnumNamedByte: + def __init__(self, val, name): @@ src/test/python/pytest.ini (new) ## src/test/python/requirements.txt (new) ## @@ +black -+cryptography~=3.4.6 ++# cryptography 39.x removes a lot of platform support, beware ++cryptography~=38.0.4 +construct~=2.10.61 +isort~=5.6 ++# TODO: update to psycopg[c] 3.1 +psycopg2~=2.8.6 -+pytest~=6.1 -+pytest-asyncio~=0.14.0 ++pytest~=7.3 ++pytest-asyncio~=0.21.0 ## src/test/python/server/__init__.py (new) ## @@ src/test/python/server/conftest.py (new) + +import pq3 + ++BLOCKING_TIMEOUT = 2 # the number of seconds to wait for blocking calls ++ + +@pytest.fixture +def connect(): @@ src/test/python/server/conftest.py (new) + addr = (pq3.pghost(), pq3.pgport()) + + try: -+ sock = socket.create_connection(addr, timeout=2) ++ sock = socket.create_connection(addr, timeout=BLOCKING_TIMEOUT) + except ConnectionError as e: + pytest.skip(f"unable to connect to {addr}: {e}") + @@ src/test/python/server/test_oauth.py (new) @@ +# +# Copyright 2021 VMware, Inc. ++# Portions Copyright 2023 Timescale, Inc. +# SPDX-License-Identifier: PostgreSQL +# + @@ src/test/python/server/test_oauth.py (new) + +import pq3 + ++from .conftest import BLOCKING_TIMEOUT ++ +MAX_SASL_MESSAGE_LENGTH = 65535 + +INVALID_AUTHORIZATION_ERRCODE = b"28000" @@ src/test/python/server/test_oauth.py (new) + +SHARED_MEM_NAME = "oauth-pytest" +MAX_TOKEN_SIZE = 4096 -+MAX_UINT16 = 2 ** 16 - 1 ++MAX_UINT16 = 2**16 - 1 + + +def skip_if_no_postgres(): @@ src/test/python/server/test_oauth.py (new) + addr = (pq3.pghost(), pq3.pgport()) + + try: -+ with socket.create_connection(addr, timeout=2): ++ with socket.create_connection(addr, timeout=BLOCKING_TIMEOUT): + pass + except ConnectionError as e: + pytest.skip(f"unable to connect to {addr}: {e}") @@ src/test/python/server/test_oauth.py (new) + return connect() + + -+@pytest.fixture(scope="module", autouse=True) -+def authn_id_extension(oauth_ctx): -+ """ -+ Performs a `CREATE EXTENSION authn_id` in the test database. This fixture is -+ autoused, so tests don't need to rely on it. -+ """ -+ conn = psycopg2.connect(database=oauth_ctx.dbname) -+ conn.autocommit = True -+ -+ with contextlib.closing(conn): -+ c = conn.cursor() -+ c.execute("CREATE EXTENSION authn_id;") -+ -+ +@pytest.fixture(scope="session") +def shared_mem(): + """ @@ src/test/python/server/test_oauth.py (new) + expect_handshake_success(conn) + + # Make sure that the server has not set an authenticated ID. -+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();") ++ pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + + row = resp.payload @@ src/test/python/server/test_oauth.py (new) + expect_handshake_success(conn) + + # Check the reported authn_id. -+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();") ++ pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + ++ expected = authn_id ++ if expected is not None: ++ expected = b"oauth:" + expected ++ + row = resp.payload -+ assert row.columns == [authn_id] ++ assert row.columns == [expected] + + +class ExpectedError(object): @@ src/test/python/server/test_oauth.py (new) + expect_handshake_success(conn) + + # Check the user identity. -+ pq3.send(conn, pq3.types.Query, query=b"SELECT authn_id();") ++ pq3.send(conn, pq3.types.Query, query=b"SELECT system_user;") + resp = receive_until(conn, pq3.types.DataRow) + + row = resp.payload -+ expected = oauth_ctx.user.encode("utf-8") ++ expected = b"oauth:" + oauth_ctx.user.encode("utf-8") + assert row.columns == [expected] + + @@ src/test/python/server/validate_bearer.py (new) +import sys +from multiprocessing import shared_memory + -+MAX_UINT16 = 2 ** 16 - 1 ++MAX_UINT16 = 2**16 - 1 + + +def remove_shm_from_resource_tracker(): 10: ab32128f34 < -: ---------- contrib/oauth: switch to pluggable auth API -: ---------- > 5: 4490d029b5 squash! Add pytest suite for OAuth