diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml index 88145c5..b939340 100644 --- a/doc/src/sgml/func.sgml +++ b/doc/src/sgml/func.sgml @@ -12750,6 +12750,29 @@ NULL baz(3 rows) + weighted_average + + + weighted_avg + + weighted_avg(value expression, weight expression) + + + smallint, int, + bigint, real, double + precision, numeric, or interval + + + numeric for any integer-type argument, + double precision for a floating-point argument, + otherwise the same as the argument data type + + the average (arithmetic mean) of all input values, weighted by the input weights + + + + + bit_and bit_and(expression) @@ -13430,6 +13453,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y DESC) AS tab; + weighted standard deviation + population + + + weighted_stddev_pop + + weighted_stddev_pop(value expression, weight expression) + + + smallint, int, + bigint, real, double + precision, or numeric + + + double precision for floating-point arguments, + otherwise numeric + + weighted population standard deviation of the input values + + + + + standard deviation sample @@ -13454,6 +13500,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y DESC) AS tab; + weighted standard deviation + sample + + + weighted_stddev_samp + + weighted_stddev_samp(value expression, weight expression) + + + smallint, int, + bigint, real, double + precision, or numeric + + + double precision for floating-point arguments, + otherwise numeric + + weighted sample standard deviation of the input values + + + + + variance variance(expression) diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c index c7c0b58..cd7b10a 100644 --- a/src/backend/utils/adt/float.c +++ b/src/backend/utils/adt/float.c @@ -2405,6 +2405,7 @@ setseed(PG_FUNCTION_ARGS) * float8_accum - accumulate for AVG(), variance aggregates, etc. * float4_accum - same, but input data is float4 * float8_avg - produce final result for float AVG() + * float8_weighted_avg - produce final result for float WEIGHTED_AVG() * float8_var_samp - produce final result for float VAR_SAMP() * float8_var_pop - produce final result for float VAR_POP() * float8_stddev_samp - produce final result for float STDDEV_SAMP() @@ -3205,6 +3206,164 @@ float8_regr_intercept(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(numeratorXXY / numeratorX); } +/* + * =================== + * WEIGHTED AGGREGATES + * =================== + * + * The transition datatype for these aggregates is a 5-element array + * of float8, holding the values N, sum(W), sum(W*X), and sum(W*X*X) + * in that order. + * + * First, an accumulator function for those we can't pirate from the + * other accumulators. This accumulator function takes out some of + * the rounding error inherent in the general one. + * https://en.wikipedia.org/wiki/Standard_deviation#Rapid_calculation_methods + * + * It consists of a five-element array which includes: + * + * N, the number of non-zero-weighted values seen thus far, + * W, the running sum of weights, + * WX, the running dot product of weights and values, + * A, an intermediate value used in the calculation, and + * Q, another intermediate value. + * + */ + +Datum +float8_weighted_accum(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 newvalX = PG_GETARG_FLOAT8(1); + float8 newvalW = PG_GETARG_FLOAT8(2); + float8 *transvalues; + float8 N, /* common */ + W, /* common */ + WX, /* Used in avg */ + A, /* Used in stddev_* */ + Q; /* Used in stddev_* */ + + transvalues = check_float8_array(transarray, "float8_weighted_accum", 5); + + if (newvalW == 0.0) /* Discard zero weights */ + PG_RETURN_NULL(); + + if (newvalW < 0.0) /* Negative weights are an error. */ + ereport(ERROR, + (errmsg("negative weights are not allowed"))); + + N = transvalues[0]; + W = transvalues[1]; + WX = transvalues[2]; + A = transvalues[3]; + Q = transvalues[4]; + + N += 1.0; + CHECKFLOATVAL(N, isinf(transvalues[0]), true); + W += newvalW; + CHECKFLOATVAL(W, isinf(transvalues[1]) || isinf(newvalW), true); + WX += newvalW * newvalX; + CHECKFLOATVAL(WX, isinf(transvalues[1]) || isinf(newvalW), true); + A += newvalW * ( newvalX - transvalues[3] ) / W; + CHECKFLOATVAL(A, isinf(newvalW) || isinf(transvalues[3]) || isinf(1.0/W), true); + Q += newvalW * (newvalX - transvalues[3]) * (newvalX - A); + CHECKFLOATVAL(A, isinf(newvalX - transvalues[4]) || isinf(newvalX - A) || isinf(1.0/W), true); + + if (AggCheckCallContext(fcinfo, NULL)) /* Update in place is safe in Agg context */ + { + transvalues[0] = N; + transvalues[1] = W; + transvalues[2] = WX; + transvalues[3] = A; + transvalues[4] = Q; + + PG_RETURN_ARRAYTYPE_P(transarray); + } + else /* You do not need to call this directly. */ + ereport(ERROR, + (errmsg("float8_weighted_accum called outside agg context"))); +} + +/* + * This is the final function for the weighted mean. It uses the + * 5-element accumulator common to weighted aggregates. + * + * N, the number of elements with non-zero weights, + * sumW, the sum of the weights, and + * sumWX, the dot product of elements and weights. + * + * While it might be possible to optimize this further by making a + * more compact accumulator, the performance gain is likely marginal. + * + */ +Datum +float8_weighted_avg(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + sumWX, + sumW; + + transvalues = check_float8_array(transarray, "float8_weighted_avg", 5); + N = transvalues[0]; + sumW = transvalues[1]; + sumWX = transvalues[2]; + + if (N < 1.0) + PG_RETURN_NULL(); + + CHECKFLOATVAL(N, isinf(1.0/sumW) || isinf(sumWX), true); + + PG_RETURN_FLOAT8(sumWX/sumW); +} + +Datum +float8_weighted_stddev_samp(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + W, + /* Skip A. Not used in the calculation */ + Q; + + transvalues = check_float8_array(transarray, "float8_weighted_stddev_samp", 5); + N = transvalues[0]; + W = transvalues[1]; + Q = transvalues[4]; + + if (N < 2.0) /* Must have at least two samples to get a stddev */ + PG_RETURN_NULL(); + + PG_RETURN_FLOAT8( + sqrt( + N * Q / + ( (N-1) * W ) + ) + ); +} + +Datum +float8_weighted_stddev_pop(PG_FUNCTION_ARGS) +{ + ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0); + float8 *transvalues; + float8 N, + W, + /* Skip A. Not used in the calculation */ + Q; + + transvalues = check_float8_array(transarray, "float8_weighted_stddev_pop", 5); + N = transvalues[0]; + W = transvalues[1]; + Q = transvalues[4]; + + if (N < 2.0) /* Must have at least two samples to get a stddev */ + PG_RETURN_NULL(); + + PG_RETURN_FLOAT8( sqrt( Q / W ) ); +} /* * ==================================== diff --git a/src/bin/pg_dump/pg_dump_sort.c b/src/bin/pg_dump/pg_dump_sort.c index dc35a93..36de6b6 100644 --- a/src/bin/pg_dump/pg_dump_sort.c +++ b/src/bin/pg_dump/pg_dump_sort.c @@ -848,14 +848,9 @@ repairTypeFuncLoop(DumpableObject *typeobj, DumpableObject *funcobj) if (typeInfo->shellType) { addObjectDependency(funcobj, typeInfo->shellType->dobj.dumpId); - /* - * Mark shell type (always including the definition, as we need - * the shell type defined to identify the function fully) as to be - * dumped if any such function is - */ + /* Mark shell type as to be dumped if any such function is */ if (funcobj->dump) - typeInfo->shellType->dobj.dump = funcobj->dump | - DUMP_COMPONENT_DEFINITION; + typeInfo->shellType->dobj.dump = true; } } diff --git a/src/include/catalog/pg_aggregate.h b/src/include/catalog/pg_aggregate.h index e16aa48..8e2a3a3 100644 --- a/src/include/catalog/pg_aggregate.h +++ b/src/include/catalog/pg_aggregate.h @@ -145,6 +145,7 @@ DATA(insert ( 2103 n 0 numeric_avg_accum numeric_avg numeric_avg_combine numer DATA(insert ( 2104 n 0 float4_accum float8_avg float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2105 n 0 float8_accum float8_avg float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2106 n 0 interval_accum interval_avg interval_combine - - interval_accum interval_accum_inv interval_avg f f 0 1187 0 0 1187 0 "{0 second,0 second}" "{0 second,0 second}" )); +DATA(insert ( 3998 n 0 float8_weighted_accum float8_weighted_avg - - - - - - f f 0 1022 0 40 0 0 "{0,0,0,0,0}" _null_)); /* sum */ DATA(insert ( 2107 n 0 int8_avg_accum numeric_poly_sum int8_avg_combine int8_avg_serialize int8_avg_deserialize int8_avg_accum int8_avg_accum_inv numeric_poly_sum f f 0 2281 17 48 2281 48 _null_ _null_ )); @@ -237,6 +238,7 @@ DATA(insert ( 2726 n 0 int2_accum numeric_poly_stddev_pop numeric_poly_combine DATA(insert ( 2727 n 0 float4_accum float8_stddev_pop float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2728 n 0 float8_accum float8_stddev_pop float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2729 n 0 numeric_accum numeric_stddev_pop numeric_combine numeric_serialize numeric_deserialize numeric_accum numeric_accum_inv numeric_stddev_pop f f 0 2281 17 128 2281 128 _null_ _null_ )); +DATA(insert ( 4032 n 0 float8_weighted_accum float8_weighted_stddev_pop - - - - - - f f 0 1022 0 40 0 0 "{0,0,0,0,0}" _null_)); /* stddev_samp */ DATA(insert ( 2712 n 0 int8_accum numeric_stddev_samp numeric_combine numeric_serialize numeric_deserialize int8_accum int8_accum_inv numeric_stddev_samp f f 0 2281 17 128 2281 128 _null_ _null_ )); @@ -245,6 +247,7 @@ DATA(insert ( 2714 n 0 int2_accum numeric_poly_stddev_samp numeric_poly_combine DATA(insert ( 2715 n 0 float4_accum float8_stddev_samp float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2716 n 0 float8_accum float8_stddev_samp float8_combine - - - - - f f 0 1022 0 0 0 0 "{0,0,0}" _null_ )); DATA(insert ( 2717 n 0 numeric_accum numeric_stddev_samp numeric_combine numeric_serialize numeric_deserialize numeric_accum numeric_accum_inv numeric_stddev_samp f f 0 2281 17 128 2281 128 _null_ _null_ )); +DATA(insert ( 4101 n 0 float8_weighted_accum float8_weighted_stddev_samp - - - - - - f f 0 1022 0 40 0 0 "{0,0,0,0,0}" _null_)); /* stddev: historical Postgres syntax for stddev_samp */ DATA(insert ( 2154 n 0 int8_accum numeric_stddev_samp numeric_combine numeric_serialize numeric_deserialize int8_accum int8_accum_inv numeric_stddev_samp f f 0 2281 17 128 2281 128 _null_ _null_ )); diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h index bb539d4..7dff0d3 100644 --- a/src/include/catalog/pg_proc.h +++ b/src/include/catalog/pg_proc.h @@ -2433,6 +2433,12 @@ DESCR("join selectivity of case-insensitive regex non-match"); /* Aggregate-related functions */ DATA(insert OID = 1830 ( float8_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_avg _null_ _null_ _null_ )); DESCR("aggregate final function"); +DATA(insert OID = 3997 ( float8_weighted_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_avg _null_ _null_ _null_ )); +DESCR("aggregate final function"); +DATA(insert OID = 4099 ( float8_weighted_stddev_pop PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_stddev_pop _null_ _null_ _null_ )); +DESCR("aggregate final function"); +DATA(insert OID = 4100 ( float8_weighted_stddev_samp PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_weighted_stddev_samp _null_ _null_ _null_ )); +DESCR("aggregate final function"); DATA(insert OID = 2512 ( float8_var_pop PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_pop _null_ _null_ _null_ )); DESCR("aggregate final function"); DATA(insert OID = 1831 ( float8_var_samp PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_samp _null_ _null_ _null_ )); @@ -2544,6 +2550,8 @@ DATA(insert OID = 2805 ( int8inc_float8_float8 PGNSP PGUID 12 1 0 0 0 f f f f DESCR("aggregate transition function"); DATA(insert OID = 2806 ( float8_regr_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_ _null_ float8_regr_accum _null_ _null_ _null_ )); DESCR("aggregate transition function"); +DATA(insert OID = 3999 ( float8_weighted_accum PGNSP PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_ _null_ float8_weighted_accum _null_ _null_ _null_ )); +DESCR("aggregate transition function"); DATA(insert OID = 3342 ( float8_regr_combine PGNSP PGUID 12 1 0 0 0 f f f f t f i s 2 0 1022 "1022 1022" _null_ _null_ _null_ _null_ _null_ float8_regr_combine _null_ _null_ _null_ )); DESCR("aggregate combine function"); DATA(insert OID = 2807 ( float8_regr_sxx PGNSP PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_regr_sxx _null_ _null_ _null_ )); @@ -3204,6 +3212,8 @@ DATA(insert OID = 2104 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701 DESCR("the average (arithmetic mean) as float8 of all float4 values"); DATA(insert OID = 2105 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701 "701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("the average (arithmetic mean) as float8 of all float8 values"); +DATA(insert OID = 3998 ( weighted_avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("the weighted average (arithmetic mean) as float8 of all float8 values"); DATA(insert OID = 2106 ( avg PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1186 "1186" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("the average (arithmetic mean) as interval of all interval values"); @@ -3364,6 +3374,8 @@ DATA(insert OID = 2728 ( stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 DESCR("population standard deviation of float8 input values"); DATA(insert OID = 2729 ( stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("population standard deviation of numeric input values"); +DATA(insert OID = 4032 ( weighted_stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("population weighted standard deviation of float8 input values"); DATA(insert OID = 2712 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("sample standard deviation of bigint input values"); @@ -3377,6 +3389,8 @@ DATA(insert OID = 2716 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 DESCR("sample standard deviation of float8 input values"); DATA(insert OID = 2717 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("sample standard deviation of numeric input values"); +DATA(insert OID = 4101 ( weighted_stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); +DESCR("sample weighted standard deviation of float8 input values"); DATA(insert OID = 2154 ( stddev PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_ _null_ _null_ )); DESCR("historical alias for stddev_samp"); diff --git a/src/include/utils/builtins.h b/src/include/utils/builtins.h index 01976a1..1c1d0d4 100644 --- a/src/include/utils/builtins.h +++ b/src/include/utils/builtins.h @@ -429,8 +429,12 @@ extern Datum drandom(PG_FUNCTION_ARGS); extern Datum setseed(PG_FUNCTION_ARGS); extern Datum float8_combine(PG_FUNCTION_ARGS); extern Datum float8_accum(PG_FUNCTION_ARGS); +extern Datum float8_weighted_accum(PG_FUNCTION_ARGS); extern Datum float4_accum(PG_FUNCTION_ARGS); extern Datum float8_avg(PG_FUNCTION_ARGS); +extern Datum float8_weighted_avg(PG_FUNCTION_ARGS); +extern Datum float8_weighted_stddev_pop(PG_FUNCTION_ARGS); +extern Datum float8_weighted_stddev_samp(PG_FUNCTION_ARGS); extern Datum float8_var_pop(PG_FUNCTION_ARGS); extern Datum float8_var_samp(PG_FUNCTION_ARGS); extern Datum float8_stddev_pop(PG_FUNCTION_ARGS); diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index 3ff6691..c12ea3b 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -247,6 +247,18 @@ SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; 653.62895538751 | 871.505273850014 (1 row) +SELECT weighted_avg(a, b) FROM aggtest; + weighted_avg +------------------ + 55.5553072763149 +(1 row) + +SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest; + weighted_stddev_pop | weighted_stddev_samp +---------------------+---------------------- + 24.3364627240769 | 28.1013266097382 +(1 row) + SELECT corr(b, a) FROM aggtest; corr ------------------- diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index 80ef14c..6f236a1 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -60,6 +60,8 @@ SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest; SELECT regr_r2(b, a) FROM aggtest; SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest; SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest; +SELECT weighted_avg(a, b) FROM aggtest; +SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest; SELECT corr(b, a) FROM aggtest; SELECT count(four) AS cnt_1000 FROM onek;