From 679a47a7be220576845c50092a1cccd8303d4535 Mon Sep 17 00:00:00 2001
From: Andres Freund <andres@anarazel.de>
Date: Mon, 7 Dec 2020 13:16:55 -0800
Subject: [PATCH] jit: wip: reference function types instead of re-creating
 them.

Also includes a bugfix that needs to be split out and backpatched.
---
 src/include/jit/llvmjit.h            |  2 +
 src/backend/jit/llvm/llvmjit.c       | 95 +++++++++++++++++-----------
 src/backend/jit/llvm/llvmjit_expr.c  | 35 ++--------
 src/backend/jit/llvm/llvmjit_types.c |  4 ++
 4 files changed, 71 insertions(+), 65 deletions(-)

diff --git a/src/include/jit/llvmjit.h b/src/include/jit/llvmjit.h
index 325409acd5c..1c89075eaff 100644
--- a/src/include/jit/llvmjit.h
+++ b/src/include/jit/llvmjit.h
@@ -92,6 +92,8 @@ extern LLVMModuleRef llvm_mutable_module(LLVMJitContext *context);
 extern char *llvm_expand_funcname(LLVMJitContext *context, const char *basename);
 extern void *llvm_get_function(LLVMJitContext *context, const char *funcname);
 extern void llvm_split_symbol_name(const char *name, char **modname, char **funcname);
+extern LLVMTypeRef llvm_pg_var_type(const char *varname);
+extern LLVMTypeRef llvm_pg_var_func_type(const char *varname);
 extern LLVMValueRef llvm_pg_func(LLVMModuleRef mod, const char *funcname);
 extern void llvm_copy_attributes(LLVMValueRef from, LLVMValueRef to);
 extern LLVMValueRef llvm_function_reference(LLVMJitContext *context,
diff --git a/src/backend/jit/llvm/llvmjit.c b/src/backend/jit/llvm/llvmjit.c
index 40a439326c6..9c4fc75f656 100644
--- a/src/backend/jit/llvm/llvmjit.c
+++ b/src/backend/jit/llvm/llvmjit.c
@@ -367,6 +367,47 @@ llvm_get_function(LLVMJitContext *context, const char *funcname)
 	return NULL;
 }
 
+/*
+ * Return type of a variable in llvmjit_types.c. This is useful to keep types
+ * in sync between plain C and JIT related code.
+ */
+LLVMTypeRef
+llvm_pg_var_type(const char *varname)
+{
+	LLVMValueRef v_srcvar;
+	LLVMTypeRef typ;
+
+	/* this'll return a *pointer* to the global */
+	v_srcvar = LLVMGetNamedGlobal(llvm_types_module, varname);
+	if (!v_srcvar)
+		elog(ERROR, "variable %s not in llvmjit_types.c", varname);
+
+	/* look at the contained type */
+	typ = LLVMTypeOf(v_srcvar);
+	Assert(typ != NULL && LLVMGetTypeKind(typ) == LLVMPointerTypeKind);
+	typ = LLVMGetElementType(typ);
+	Assert(typ != NULL);
+
+	return typ;
+}
+
+/*
+ * Return function type of a variable in llvmjit_types.c. This is useful to
+ * keep function types in sync between C and JITed code.
+ */
+LLVMTypeRef
+llvm_pg_var_func_type(const char *varname)
+{
+	LLVMTypeRef typ = llvm_pg_var_type(varname);
+
+	/* look at the contained type */
+	Assert(LLVMGetTypeKind(typ) == LLVMPointerTypeKind);
+	typ = LLVMGetElementType(typ);
+	Assert(typ != NULL && LLVMGetTypeKind(typ) == LLVMFunctionTypeKind);
+
+	return typ;
+}
+
 /*
  * Return declaration for a function referenced in llvmjit_types.c, adding it
  * to the module if necessary.
@@ -889,26 +930,6 @@ llvm_shutdown(int code, Datum arg)
 #endif							/* LLVM_VERSION_MAJOR > 11 */
 }
 
-/* helper for llvm_create_types, returning a global var's type */
-static LLVMTypeRef
-load_type(LLVMModuleRef mod, const char *name)
-{
-	LLVMValueRef value;
-	LLVMTypeRef typ;
-
-	/* this'll return a *pointer* to the global */
-	value = LLVMGetNamedGlobal(mod, name);
-	if (!value)
-		elog(ERROR, "type %s is unknown", name);
-
-	/* therefore look at the contained type and return that */
-	typ = LLVMTypeOf(value);
-	Assert(typ != NULL);
-	typ = LLVMGetElementType(typ);
-	Assert(typ != NULL);
-	return typ;
-}
-
 /* helper for llvm_create_types, returning a function's return type */
 static LLVMTypeRef
 load_return_type(LLVMModuleRef mod, const char *name)
@@ -970,24 +991,24 @@ llvm_create_types(void)
 	llvm_triple = pstrdup(LLVMGetTarget(llvm_types_module));
 	llvm_layout = pstrdup(LLVMGetDataLayoutStr(llvm_types_module));
 
-	TypeSizeT = load_type(llvm_types_module, "TypeSizeT");
+	TypeSizeT = llvm_pg_var_type("TypeSizeT");
 	TypeParamBool = load_return_type(llvm_types_module, "FunctionReturningBool");
-	TypeStorageBool = load_type(llvm_types_module, "TypeStorageBool");
-	TypePGFunction = load_type(llvm_types_module, "TypePGFunction");
-	StructNullableDatum = load_type(llvm_types_module, "StructNullableDatum");
-	StructExprContext = load_type(llvm_types_module, "StructExprContext");
-	StructExprEvalStep = load_type(llvm_types_module, "StructExprEvalStep");
-	StructExprState = load_type(llvm_types_module, "StructExprState");
-	StructFunctionCallInfoData = load_type(llvm_types_module, "StructFunctionCallInfoData");
-	StructMemoryContextData = load_type(llvm_types_module, "StructMemoryContextData");
-	StructTupleTableSlot = load_type(llvm_types_module, "StructTupleTableSlot");
-	StructHeapTupleTableSlot = load_type(llvm_types_module, "StructHeapTupleTableSlot");
-	StructMinimalTupleTableSlot = load_type(llvm_types_module, "StructMinimalTupleTableSlot");
-	StructHeapTupleData = load_type(llvm_types_module, "StructHeapTupleData");
-	StructTupleDescData = load_type(llvm_types_module, "StructTupleDescData");
-	StructAggState = load_type(llvm_types_module, "StructAggState");
-	StructAggStatePerGroupData = load_type(llvm_types_module, "StructAggStatePerGroupData");
-	StructAggStatePerTransData = load_type(llvm_types_module, "StructAggStatePerTransData");
+	TypeStorageBool = llvm_pg_var_type("TypeStorageBool");
+	TypePGFunction = llvm_pg_var_type("TypePGFunction");
+	StructNullableDatum = llvm_pg_var_type("StructNullableDatum");
+	StructExprContext = llvm_pg_var_type("StructExprContext");
+	StructExprEvalStep = llvm_pg_var_type("StructExprEvalStep");
+	StructExprState = llvm_pg_var_type("StructExprState");
+	StructFunctionCallInfoData = llvm_pg_var_type("StructFunctionCallInfoData");
+	StructMemoryContextData = llvm_pg_var_type("StructMemoryContextData");
+	StructTupleTableSlot = llvm_pg_var_type("StructTupleTableSlot");
+	StructHeapTupleTableSlot = llvm_pg_var_type("StructHeapTupleTableSlot");
+	StructMinimalTupleTableSlot = llvm_pg_var_type("StructMinimalTupleTableSlot");
+	StructHeapTupleData = llvm_pg_var_type("StructHeapTupleData");
+	StructTupleDescData = llvm_pg_var_type("StructTupleDescData");
+	StructAggState = llvm_pg_var_type("StructAggState");
+	StructAggStatePerGroupData = llvm_pg_var_type("StructAggStatePerGroupData");
+	StructAggStatePerTransData = llvm_pg_var_type("StructAggStatePerTransData");
 
 	AttributeTemplate = LLVMGetNamedFunction(llvm_types_module, "AttributeTemplate");
 }
diff --git a/src/backend/jit/llvm/llvmjit_expr.c b/src/backend/jit/llvm/llvmjit_expr.c
index f232397cabf..e0d53c0d0a2 100644
--- a/src/backend/jit/llvm/llvmjit_expr.c
+++ b/src/backend/jit/llvm/llvmjit_expr.c
@@ -84,7 +84,6 @@ llvm_compile_expr(ExprState *state)
 
 	LLVMBuilderRef b;
 	LLVMModuleRef mod;
-	LLVMTypeRef eval_sig;
 	LLVMValueRef eval_fn;
 	LLVMBasicBlockRef entry;
 	LLVMBasicBlockRef *opblocks;
@@ -149,19 +148,9 @@ llvm_compile_expr(ExprState *state)
 
 	funcname = llvm_expand_funcname(context, "evalexpr");
 
-	/* Create the signature and function */
-	{
-		LLVMTypeRef param_types[3];
-
-		param_types[0] = l_ptr(StructExprState);	/* state */
-		param_types[1] = l_ptr(StructExprContext);	/* econtext */
-		param_types[2] = l_ptr(TypeParamBool);	/* isnull */
-
-		eval_sig = LLVMFunctionType(TypeSizeT,
-									param_types, lengthof(param_types),
-									false);
-	}
-	eval_fn = LLVMAddFunction(mod, funcname, eval_sig);
+	/* create function */
+	eval_fn = LLVMAddFunction(mod, funcname,
+							  llvm_pg_var_func_type("TypeExprStateEvalFunc"));
 	LLVMSetLinkage(eval_fn, LLVMExternalLinkage);
 	LLVMSetVisibility(eval_fn, LLVMDefaultVisibility);
 	llvm_copy_attributes(AttributeTemplate, eval_fn);
@@ -265,8 +254,6 @@ llvm_compile_expr(ExprState *state)
 
 					v_tmpvalue = LLVMBuildLoad(b, v_tmpvaluep, "");
 					v_tmpisnull = LLVMBuildLoad(b, v_tmpisnullp, "");
-					v_tmpisnull =
-						LLVMBuildTrunc(b, v_tmpisnull, TypeParamBool, "");
 
 					LLVMBuildStore(b, v_tmpisnull, v_isnullp);
 
@@ -1088,24 +1075,16 @@ llvm_compile_expr(ExprState *state)
 
 			case EEOP_PARAM_CALLBACK:
 				{
-					LLVMTypeRef param_types[3];
-					LLVMValueRef v_params[3];
 					LLVMTypeRef v_functype;
 					LLVMValueRef v_func;
+					LLVMValueRef v_params[3];
 
-					param_types[0] = l_ptr(StructExprState);
-					param_types[1] = l_ptr(TypeSizeT);
-					param_types[2] = l_ptr(StructExprContext);
-
-					v_functype = LLVMFunctionType(LLVMVoidType(),
-												  param_types,
-												  lengthof(param_types),
-												  false);
+					v_functype = llvm_pg_var_func_type("TypeExecEvalSubroutine");
 					v_func = l_ptr_const(op->d.cparam.paramfunc,
-										 l_ptr(v_functype));
+										 LLVMPointerType(v_functype, 0));
 
 					v_params[0] = v_state;
-					v_params[1] = l_ptr_const(op, l_ptr(TypeSizeT));
+					v_params[1] = l_ptr_const(op, l_ptr(StructExprEvalStep));
 					v_params[2] = v_econtext;
 					LLVMBuildCall(b,
 								  v_func,
diff --git a/src/backend/jit/llvm/llvmjit_types.c b/src/backend/jit/llvm/llvmjit_types.c
index 1ed3cafa2f2..2d950463f41 100644
--- a/src/backend/jit/llvm/llvmjit_types.c
+++ b/src/backend/jit/llvm/llvmjit_types.c
@@ -48,6 +48,10 @@
 PGFunction	TypePGFunction;
 size_t		TypeSizeT;
 bool		TypeStorageBool;
+extern ExprStateEvalFunc TypeExprStateEvalFunc;
+ExprStateEvalFunc TypeExprStateEvalFunc;
+extern ExecEvalSubroutine TypeExecEvalSubroutine;
+ExecEvalSubroutine TypeExecEvalSubroutine;
 
 NullableDatum StructNullableDatum;
 AggState	StructAggState;
-- 
2.28.0.651.g306ee63a70

