#include "postgres.h"
#include "fmgr.h"
#include "libpq-fe.h"
#include "miscadmin.h"
#include "pgstat.h"

#include "access/xact.h"
#include "access/xlog.h"
#include "access/xlogdefs.h"
#include "lib/stringinfo.h"
#include "libpq/pqformat.h"
#include "storage/latch.h"
#include "storage/proc.h"
#include "tcop/pquery.h"
#include "tcop/utility.h"
#include "utils/builtins.h"
#include "utils/memutils.h"
#include "utils/snapmgr.h"

#define DDL_SLOT_NAME "pg_ddl_get_changes"

PG_FUNCTION_INFO_V1(pg_ddl_decode_get_changes);
PG_FUNCTION_INFO_V1(pg_ddl_decode_create_slot);

typedef struct PGLFlushPosition
{
	dlist_node node;
	XLogRecPtr local_end;
	XLogRecPtr remote_end;
} PGLFlushPosition;

dlist_head lsn_mapping = DLIST_STATIC_INIT(lsn_mapping);

static PGconn *pglogical_connect_replica(const char *connstring,
										 const char *connname);
static void pglogical_start_replication(PGconn *streamConn,
										const char *slot_name);

static void apply_work(PGconn *streamConn);

Datum
pg_ddl_decode_get_changes(PG_FUNCTION_ARGS)
{
	char		   *provider_dsn = text_to_cstring(PG_GETARG_TEXT_PP(0));
	PGconn		   *streamConn;

	streamConn = pglogical_connect_replica(provider_dsn, DDL_SLOT_NAME);

	pglogical_start_replication(streamConn, DDL_SLOT_NAME);
	apply_work(streamConn);

	PQfinish(streamConn);

	PG_RETURN_BOOL(true);
}

Datum
pg_ddl_decode_create_slot(PG_FUNCTION_ARGS)
{
	char		   *provider_dsn = text_to_cstring(PG_GETARG_TEXT_PP(0));
	PGconn		   *streamConn;
	PGresult	   *res;
	StringInfoData	query;

	streamConn = pglogical_connect_replica(provider_dsn, DDL_SLOT_NAME);

	initStringInfo(&query);
	appendStringInfo(&query, "CREATE_REPLICATION_SLOT \"%s\" LOGICAL %s",
					 DDL_SLOT_NAME, "pg_ddl_decode");

	res = PQexec(streamConn, query.data);
	if (PQresultStatus(res) != PGRES_TUPLES_OK)
	{
		elog(ERROR, "could not send replication command \"%s\": status %s: %s\n",
			 query.data,
			 PQresStatus(PQresultStatus(res)), PQresultErrorMessage(res));
	}

	PQclear(res);
	PQfinish(streamConn);

	PG_RETURN_BOOL(true);
}

/*
 * Make replication connection, ERROR on failure.
 */
static PGconn *
pglogical_connect_replica(const char *connstring, const char *connname)
{
	PGconn		   *conn;
	const char	   *keys[4];
	const char	   *vals[4];

	/*
	 * We use the expand_dbname parameter to process the connection string
	 * (or URI), and pass some extra options.
	 */
	keys[0] = "dbname";
	vals[0] = connstring;
	keys[1] = "replication";
	vals[1] = "database";
	keys[2] = "application_name";
	vals[2] = connname;
	keys[3] = NULL;
	vals[3] = NULL;

	conn = PQconnectdbParams(keys, vals, /* expand_dbname = */ true);
	if (PQstatus(conn) != CONNECTION_OK)
	{
		ereport(ERROR,
				(errmsg("could not connect to the postgresql server in replication mode: %s",
						PQerrorMessage(conn)),
				 errdetail("dsn was: %s", connstring)));
	}

	return conn;
}

static void
pglogical_start_replication(PGconn *streamConn, const char *slot_name)
{
	StringInfoData	command;
	PGresult	   *res;
	char		   *sqlstate;

	initStringInfo(&command);
	appendStringInfo(&command, "START_REPLICATION SLOT \"%s\" LOGICAL 0/0",
					 slot_name);

	res = PQexec(streamConn, command.data);
	sqlstate = PQresultErrorField(res, PG_DIAG_SQLSTATE);
	if (PQresultStatus(res) != PGRES_COPY_BOTH)
		elog(ERROR, "could not send replication command \"%s\": %s\n, sqlstate: %s",
			 command.data, PQresultErrorMessage(res), sqlstate);
	PQclear(res);
}

/*
 * Read transaction BEGIN from the stream.
 */
static void
pglogical_read_begin(StringInfo in, XLogRecPtr *remote_lsn,
					 TimestampTz *committime, TransactionId *remote_xid)
{
	/* read flags */
	uint8	flags = pq_getmsgbyte(in);
	Assert(flags == 0);

	/* read fields */
	*remote_lsn = pq_getmsgint64(in);
	Assert(*remote_lsn != InvalidXLogRecPtr);
	*committime = pq_getmsgint64(in);
	*remote_xid = pq_getmsgint(in, 4);
}

static bool
ensure_transaction(void)
{
	if (IsTransactionState())
	{
		if (CurrentMemoryContext != MessageContext)
			MemoryContextSwitchTo(MessageContext);
		return false;
	}

	StartTransactionCommand();
	MemoryContextSwitchTo(MessageContext);
	return true;
}

static void
handle_begin(StringInfo s)
{
	XLogRecPtr		commit_lsn;
	TimestampTz		commit_time;
	TransactionId	remote_xid;

	pglogical_read_begin(s, &commit_lsn, &commit_time, &remote_xid);

	pgstat_report_activity(STATE_RUNNING, false);
}

/*
 * Read transaction COMMIT from the stream.
 */
static void
pglogical_read_commit(StringInfo in, XLogRecPtr *commit_lsn,
					   XLogRecPtr *end_lsn, TimestampTz *committime)
{
	/* read flags */
	uint8	flags = pq_getmsgbyte(in);
	Assert(flags == 0);

	/* read fields */
	*commit_lsn = pq_getmsgint64(in);
	*end_lsn = pq_getmsgint64(in);
	*committime = pq_getmsgint64(in);
}

/*
 * Handle COMMIT message.
 */
static void
handle_commit(StringInfo s)
{
	XLogRecPtr		commit_lsn;
	XLogRecPtr		end_lsn;
	TimestampTz		commit_time;

	pglogical_read_commit(s, &commit_lsn, &end_lsn, &commit_time);

	if (IsTransactionState())
	{
		PGLFlushPosition *flushpos;

		CommitTransactionCommand();
		MemoryContextSwitchTo(TopMemoryContext);

		/* Track commit lsn  */
		flushpos = (PGLFlushPosition *) palloc(sizeof(PGLFlushPosition));
		flushpos->local_end = XactLastCommitEnd;
		flushpos->remote_end = end_lsn;

		dlist_push_tail(&lsn_mapping, &flushpos->node);
		MemoryContextSwitchTo(MessageContext);
	}

	pgstat_report_activity(STATE_IDLE, NULL);
}

/*
 * Read DDL command from stream.
 *
 * Returns DDL query.
 */
static char *
pglogical_read_ddl(StringInfo in)
{
	uint8		flags;
	Size		sz;

	/* read the flags */
	flags = pq_getmsgbyte(in);
	Assert(flags == 0);

	/* read the message lengh */
	sz = pq_getmsgint(in, 4);

	/* read the message */
	return (char *) pq_getmsgbytes(in, sz);
}

/*
 * Add context to the errors produced by pglogical_execute_sql_command().
 */
static void
execute_sql_command_error_cb(void *arg)
{
	errcontext("during execution of queued SQL statement: %s", (char *) arg);
}

/*
 * Execute an SQL command. This can be multiple multiple queries.
 */
static void
pglogical_execute_sql_command(char *cmdstr, char *role, bool isTopLevel)
{
	List	   *commands;
	ListCell   *command_i;
	MemoryContext oldcontext;
	ErrorContextCallback errcallback;

	oldcontext = MemoryContextSwitchTo(MessageContext);

	errcallback.callback = execute_sql_command_error_cb;
	errcallback.arg = cmdstr;
	errcallback.previous = error_context_stack;
	error_context_stack = &errcallback;

	commands = pg_parse_query(cmdstr);

	MemoryContextSwitchTo(oldcontext);

	/*
	 * Do a limited amount of safety checking against CONCURRENTLY commands
	 * executed in situations where they aren't allowed. The sender side should
	 * provide protection, but better be safe than sorry.
	 */
	isTopLevel = isTopLevel && (list_length(commands) == 1);

	foreach(command_i, commands)
	{
		List	   *plantree_list;
		List	   *querytree_list;
		Node	   *command = (Node *) lfirst(command_i);
		const char *commandTag;
		Portal		portal;
		DestReceiver *receiver;

		/* temporarily push snapshot for parse analysis/planning */
		PushActiveSnapshot(GetTransactionSnapshot());

		oldcontext = MemoryContextSwitchTo(MessageContext);

		/*
		 * Set the current role to the user that executed the command on the
		 * origin server.  NB: there is no need to reset this afterwards, as
		 * the value will be gone with our transaction.
		 */
		SetConfigOption("role", role, PGC_INTERNAL, PGC_S_OVERRIDE);

		commandTag = CreateCommandTag(command);

		querytree_list = pg_analyze_and_rewrite(
			command, cmdstr, NULL, 0);

		plantree_list = pg_plan_queries(
			querytree_list, 0, NULL);

		PopActiveSnapshot();

		portal = CreatePortal("pg_ddl_decode", true, true);
		PortalDefineQuery(portal, NULL,
						  cmdstr, commandTag,
						  plantree_list, NULL);
		PortalStart(portal, NULL, 0, InvalidSnapshot);

		receiver = CreateDestReceiver(DestNone);

		(void) PortalRun(portal, FETCH_ALL,
						 isTopLevel,
						 receiver, receiver,
						 NULL);
		(*receiver->rDestroy) (receiver);

		PortalDrop(portal, false);

		CommandCounterIncrement();

		MemoryContextSwitchTo(oldcontext);
	}

	/* protect against stack resets during CONCURRENTLY processing */
	if (error_context_stack == &errcallback)
		error_context_stack = errcallback.previous;
}

static void
handle_ddl(StringInfo s)
{
	char		   *query = pglogical_read_ddl(s);
	bool			started_tx = ensure_transaction();

	/* Execute the query locally. */
	pglogical_execute_sql_command(query, GetUserNameFromId(GetUserId(), false),
								  started_tx);

	CommandCounterIncrement();
}

static void
replication_handler(StringInfo s)
{
	char action = pq_getmsgbyte(s);

	switch (action)
	{
		/* BEGIN */
		case 'B':
			handle_begin(s);
			break;
		/* COMMIT */
		case 'C':
			handle_commit(s);
			break;
		/* DDL MESSAGE */
		case 'M':
			handle_ddl(s);
			break;
		default:
			elog(ERROR, "unknown action of type %c", action);
	}
}

/*
 * Figure out which write/flush positions to report to the walsender process.
 *
 * We can't simply report back the last LSN the walsender sent us because the
 * local transaction might not yet be flushed to disk locally. Instead we
 * build a list that associates local with remote LSNs for every commit. When
 * reporting back the flush position to the sender we iterate that list and
 * check which entries on it are already locally flushed. Those we can report
 * as having been flushed.
 *
 * Returns true if there's no outstanding transactions that need to be
 * flushed.
 */
static bool
get_flush_position(XLogRecPtr *write, XLogRecPtr *flush)
{
	dlist_mutable_iter iter;
	XLogRecPtr	local_flush = GetFlushRecPtr();

	*write = InvalidXLogRecPtr;
	*flush = InvalidXLogRecPtr;

	dlist_foreach_modify(iter, &lsn_mapping)
	{
		PGLFlushPosition *pos =
			dlist_container(PGLFlushPosition, node, iter.cur);

		*write = pos->remote_end;

		if (pos->local_end <= local_flush)
		{
			*flush = pos->remote_end;
			dlist_delete(iter.cur);
			pfree(pos);
		}
		else
		{
			/*
			 * Don't want to uselessly iterate over the rest of the list which
			 * could potentially be long. Instead get the last element and
			 * grab the write position from there.
			 */
			pos = dlist_tail_element(PGLFlushPosition, node,
									 &lsn_mapping);
			*write = pos->remote_end;
			return false;
		}
	}

	return dlist_is_empty(&lsn_mapping);
}

/*
 * Send a Standby Status Update message to server.
 *
 * 'recvpos' is the latest LSN we've received data to, force is set if we need
 * to send a response to avoid timeouts.
 */
static bool
send_feedback(PGconn *conn, XLogRecPtr recvpos, int64 now, bool force)
{
	static StringInfo	reply_message = NULL;

	static XLogRecPtr last_recvpos = InvalidXLogRecPtr;
	static XLogRecPtr last_writepos = InvalidXLogRecPtr;
	static XLogRecPtr last_flushpos = InvalidXLogRecPtr;

	XLogRecPtr writepos;
	XLogRecPtr flushpos;

	/* It's legal to not pass a recvpos */
	if (recvpos < last_recvpos)
		recvpos = last_recvpos;

	if (get_flush_position(&writepos, &flushpos))
	{
		/*
		 * No outstanding transactions to flush, we can report the latest
		 * received position. This is important for synchronous replication.
		 */
		flushpos = writepos = recvpos;
	}

	if (writepos < last_writepos)
		writepos = last_writepos;

	if (flushpos < last_flushpos)
		flushpos = last_flushpos;

	/* if we've already reported everything we're good */
	if (!force &&
		writepos == last_writepos &&
		flushpos == last_flushpos)
		return true;

	if (!reply_message)
	{
		MemoryContext	oldcontext = MemoryContextSwitchTo(TopMemoryContext);
		reply_message = makeStringInfo();
		MemoryContextSwitchTo(oldcontext);
	}
	else
		resetStringInfo(reply_message);

	pq_sendbyte(reply_message, 'r');
	pq_sendint64(reply_message, recvpos);		/* write */
	pq_sendint64(reply_message, flushpos);		/* flush */
	pq_sendint64(reply_message, writepos);		/* apply */
	pq_sendint64(reply_message, now);			/* sendTime */
	pq_sendbyte(reply_message, false);			/* replyRequested */

	elog(DEBUG2, "sending feedback (force %d) to recv %X/%X, write %X/%X, flush %X/%X",
		 force,
		 (uint32) (recvpos >> 32), (uint32) recvpos,
		 (uint32) (writepos >> 32), (uint32) writepos,
		 (uint32) (flushpos >> 32), (uint32) flushpos
		);

	if (PQputCopyData(conn, reply_message->data, reply_message->len) <= 0 ||
		PQflush(conn))
	{
		ereport(ERROR,
				(errcode(ERRCODE_CONNECTION_FAILURE),
				 errmsg("could not send feedback packet: %s",
						PQerrorMessage(conn))));
		return false;
	}

	if (recvpos > last_recvpos)
		last_recvpos = recvpos;
	if (writepos > last_writepos)
		last_writepos = writepos;
	if (flushpos > last_flushpos)
		last_flushpos = flushpos;

	return true;
}

static void
apply_work(PGconn *streamConn)
{
	char	   *copybuf = NULL;
	XLogRecPtr	last_received = InvalidXLogRecPtr;
	bool		wait_data = true;

	while (wait_data)
	{
		int			r;

		if (PQstatus(streamConn) == CONNECTION_BAD)
		{
			elog(ERROR, "connection to other side has died");
		}

		PQconsumeInput(streamConn);

		for (;;)
		{
			if (copybuf != NULL)
			{
				PQfreemem(copybuf);
				copybuf = NULL;
			}
			r = PQgetCopyData(streamConn, &copybuf, 1);
			if (r == -1)
				/* End of copy stream */
				break;
			else if (r == -2)
			{
				elog(ERROR, "could not read COPY data: %s",
					 PQerrorMessage(streamConn));
			}
			else if (r < 0)
				elog(ERROR, "invalid COPY status %d", r);
			else if (r == 0)
			{
				/* need to wait for new data */
				break;
			}
			else
			{
				int c;
				StringInfoData s;

				initStringInfo(&s);
				s.data = copybuf;
				s.len = r;
				s.maxlen = -1;

				c = pq_getmsgbyte(&s);

				if (c == 'w')
				{
					XLogRecPtr	start_lsn;
					XLogRecPtr	end_lsn;

					start_lsn = pq_getmsgint64(&s);
					end_lsn = pq_getmsgint64(&s);
					pq_getmsgint64(&s); /* sendTime */

					if (last_received < start_lsn)
						last_received = start_lsn;

					if (last_received < end_lsn)
						last_received = end_lsn;

					replication_handler(&s);
				}
				else if (c == 'k')
				{
					XLogRecPtr endpos;
					bool reply_requested;

					endpos = pq_getmsgint64(&s);
					/* timestamp = */ pq_getmsgint64(&s);
					reply_requested = pq_getmsgbyte(&s);

					send_feedback(streamConn, endpos,
								  GetCurrentTimestamp(),
								  reply_requested);
				}
				/* other message types are purposefully ignored */
			}
		}
		send_feedback(streamConn, last_received, GetCurrentTimestamp(),
					  false);
	}
}
