From 674145d7c877b754f3f589ecdbb3e5bef4356f25 Mon Sep 17 00:00:00 2001 From: satanson Date: Sat, 11 Oct 2025 14:22:57 +0800 Subject: [PATCH] [Enhancement] Support Constant Folding for some array functions (#63692) Signed-off-by: satanson --- .../rewrite/ScalarOperatorFunctions.java | 5 + .../rewrite/scalar/FoldConstantsRule.java | 286 ++++++++++++++++++ .../trino/TrinoFunctionTransformTest.java | 8 +- .../parser/trino/TrinoQueryTest.java | 6 +- .../starrocks/planner/VectorIndexTest.java | 9 +- .../com/starrocks/sql/plan/AggregateTest.java | 2 +- .../com/starrocks/sql/plan/ArrayTypeTest.java | 19 +- .../plan/ConstArrayFunctionFoldingTest.java | 166 ++++++++++ .../sql/plan/PruneComplexSubfieldTest.java | 8 +- 9 files changed, 480 insertions(+), 29 deletions(-) create mode 100644 fe/fe-core/src/test/java/com/starrocks/sql/plan/ConstArrayFunctionFoldingTest.java diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorFunctions.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorFunctions.java index 52e498fb787..55b38f1b622 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorFunctions.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/ScalarOperatorFunctions.java @@ -1061,6 +1061,11 @@ public class ScalarOperatorFunctions { /** * Arithmetic function */ + + @ConstantFunction(name = "add", argTypes = {TINYINT, TINYINT}, returnType = TINYINT, isMonotonic = true) + public static ConstantOperator addTinyInt(ConstantOperator first, ConstantOperator second) { + return ConstantOperator.createTinyInt((byte) Math.addExact(first.getTinyInt(), second.getTinyInt())); + } @ConstantFunction(name = "add", argTypes = {SMALLINT, SMALLINT}, returnType = SMALLINT, isMonotonic = true) public static ConstantOperator addSmallInt(ConstantOperator first, ConstantOperator second) { return ConstantOperator.createSmallInt((short) Math.addExact(first.getSmallint(), second.getSmallint())); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/scalar/FoldConstantsRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/scalar/FoldConstantsRule.java index 7daeab90009..a37741fafa4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/scalar/FoldConstantsRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rewrite/scalar/FoldConstantsRule.java @@ -14,9 +14,23 @@ package com.starrocks.sql.optimizer.rewrite.scalar; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSortedMap; +import com.google.common.collect.ImmutableSortedSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.starrocks.catalog.ArrayType; +import com.starrocks.catalog.PrimitiveType; +import com.starrocks.catalog.ScalarFunction; +import com.starrocks.catalog.ScalarType; import com.starrocks.catalog.Type; import com.starrocks.sql.analyzer.SemanticException; +import com.starrocks.sql.ast.expression.ArithmeticExpr; import com.starrocks.sql.ast.expression.BinaryType; +import com.starrocks.sql.ast.expression.Expr; +import com.starrocks.sql.optimizer.operator.scalar.ArrayOperator; import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.CallOperator; import com.starrocks.sql.optimizer.operator.scalar.CastOperator; @@ -27,12 +41,19 @@ import com.starrocks.sql.optimizer.operator.scalar.LikePredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.optimizer.rewrite.ScalarOperatorEvaluator; import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext; +import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import java.util.Comparator; import java.util.List; import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule { private static final Logger LOG = LogManager.getLogger(FoldConstantsRule.class); @@ -47,12 +68,258 @@ public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule { this.needMonotonicFunc = needMonotonicFunc; } + private Optional arrayScalarFun( + CallOperator call, BiFunction, ScalarOperator, T> cb) { + Preconditions.checkArgument(call.getArguments().size() == 2); + ScalarOperator arg0 = call.getArguments().get(0); + ScalarOperator arg1 = call.getArguments().get(1); + if (arg0.isConstantNull()) { + return Optional.of(ConstantOperator.createNull(call.getType())); + } + if (!(arg0 instanceof ArrayOperator) || !arg1.isConstantRef() || + !arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef)) { + return Optional.empty(); + } + return Optional.of(cb.apply(arg0.getChildren(), arg1)); + } + + private Optional arrayUnaryFun(CallOperator call, + Function, ScalarOperator> cb) { + Preconditions.checkArgument(call.getArguments().size() == 1); + ScalarOperator arg0 = call.getChild(0); + if (arg0 instanceof CastOperator && arg0.getChild(0) instanceof ArrayOperator) { + arg0 = call.getChild(0).getChild(0); + } + if (arg0.isConstantNull()) { + return Optional.of(ConstantOperator.createNull(arg0.getType())); + } + if (!(arg0 instanceof ArrayOperator) || !arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef)) { + return Optional.empty(); + } + try { + return Optional.of(cb.apply(arg0.getChildren())); + } catch (IllegalArgumentException ex) { + return Optional.empty(); + } + } + + private Optional constArrayAppend(CallOperator call) { + return arrayScalarFun(call, (arrayElements, target) -> + new ArrayOperator(call.getType(), call.isNullable(), + Lists.newArrayList(Iterables.concat(arrayElements, List.of(target))))); + } + + boolean constantEqual(ScalarOperator lhs, ScalarOperator rhs) { + Preconditions.checkArgument(lhs.isConstantRef() && rhs.isConstantRef()); + ConstantOperator constLhs = (ConstantOperator) lhs; + ConstantOperator constRhs = (ConstantOperator) rhs; + if (constLhs.getType().isDecimalV3()) { + return constLhs.getDecimal().compareTo(constRhs.getDecimal()) == 0 && + constLhs.isNull() == constRhs.isNull(); + } else { + return lhs.equals(rhs); + } + } + + private Optional constArrayContains(CallOperator call) { + + return arrayScalarFun(call, (arrayElements, target) -> + ConstantOperator.createBoolean(arrayElements.stream().anyMatch(elem -> constantEqual(elem, target)))); + } + + private Optional constArrayRemove(CallOperator call) { + return arrayScalarFun(call, (arrayElements, target) -> + new ArrayOperator(call.getType(), call.isNullable(), + arrayElements.stream().filter((elem) -> !constantEqual(elem, target)) + .collect(Collectors.toList()))); + } + + private Optional constArrayPosition(CallOperator call) { + return arrayScalarFun(call, (arrayElements, target) -> + ConstantOperator.createInt(IntStream.range(0, arrayElements.size()) + .boxed() + .filter(i -> constantEqual(arrayElements.get(i), target)) + .findAny().map(i -> i + 1).orElse(0))); + } + + private Optional constArrayLength(CallOperator call) { + ScalarOperator arg0 = call.getChild(0); + if (arg0.isConstantNull()) { + return Optional.of(ConstantOperator.createNull(call.getType())); + } else if (arg0 instanceof ArrayOperator) { + return Optional.of(ConstantOperator.createInt(arg0.getChildren().size())); + } else if (arg0 instanceof CastOperator && (arg0.getChild(0) instanceof ArrayOperator)) { + // array_length([]) and array_length([NULL, NULL]) would be + // array_length(cast([] as ARRAY)) and array_length(cast([NULL, NULL] as ARRAY)) + return Optional.of(ConstantOperator.createInt(arg0.getChild(0).getChildren().size())); + } else { + return Optional.empty(); + } + } + + private Optional constArraySum(CallOperator call) { + BinaryOperator add = (lhs, rhs) -> { + String opName = ArithmeticExpr.Operator.ADD.getName(); + com.starrocks.catalog.Function fn = + Expr.getBuiltinFunction(opName, new Type[] {lhs.getType(), rhs.getType()}, + com.starrocks.catalog.Function.CompareMode.IS_SUPERTYPE_OF); + CallOperator addition = + new CallOperator(ArithmeticExpr.Operator.ADD.getName(), lhs.getType(), List.of(lhs, rhs), fn); + ScalarOperator result = addition.accept(this, new ScalarOperatorRewriteContext()); + if (result instanceof ConstantOperator) { + return (ConstantOperator) result; + } else { + throw new IllegalArgumentException(); + } + }; + return arrayUnaryFun(call, (elements) -> elements.stream().filter(elem -> !elem.isConstantNull()) + .map(elem -> (ConstantOperator) elem).reduce(add).orElse(ConstantOperator.NULL)); + } + + private Optional constArrayMinMax(CallOperator call, boolean lessThan) { + Comparator comparator = (lhs, rhs) -> { + BinaryPredicateOperator cmp = + lessThan ? BinaryPredicateOperator.lt(lhs, rhs) : BinaryPredicateOperator.gt(lhs, rhs); + ScalarOperator result = cmp.accept(this, new ScalarOperatorRewriteContext()); + if (result.isConstant() && result.getType().isBoolean()) { + return ((ConstantOperator) result).getBoolean() ? -1 : 1; + } else { + throw new IllegalArgumentException(); + } + }; + return arrayUnaryFun(call, (elements) -> elements.stream().filter(elem -> !elem.isConstantNull()) + .map(elem -> (ConstantOperator) elem).min(comparator).orElse(ConstantOperator.NULL)); + } + + private Optional constArrayMin(CallOperator call) { + return constArrayMinMax(call, true); + } + + private Optional constArrayMax(CallOperator call) { + return constArrayMinMax(call, false); + } + + private Optional constArrayAvg(CallOperator call) { + Optional optSum = constArraySum(call); + Optional optLength = constArrayLength(call); + if (!optSum.isPresent() || !optLength.isPresent() || + !optSum.get().isConstant() || !optLength.get().isConstant()) { + return Optional.empty(); + } + ConstantOperator sum = (ConstantOperator) optSum.get(); + ConstantOperator length = (ConstantOperator) optLength.get(); + if (sum.isNull() || length.isNull()) { + return Optional.of(ConstantOperator.createNull(call.getType())); + } + + String opName = ArithmeticExpr.Operator.DIVIDE.getName(); + com.starrocks.catalog.Function + fn = Expr.getBuiltinFunction(opName, new Type[] {call.getType(), call.getType()}, + com.starrocks.catalog.Function.CompareMode.IS_SUPERTYPE_OF); + + // for decimal types, divide function should be rectified. + if (call.getType().isDecimalV3()) { + Preconditions.checkArgument(sum.getType().isDecimalV3()); + ScalarType sumType = (ScalarType) sum.getType(); + sumType = ScalarType.createDecimalV3Type(call.getType().getPrimitiveType(), sumType.getScalarPrecision(), + sumType.getScalarScale()); + int precision = PrimitiveType.getMaxPrecisionOfDecimal(call.getType().getPrimitiveType()); + ScalarType countType = ScalarType.createDecimalV3Type(call.getType().getPrimitiveType(), precision, 0); + fn = new ScalarFunction(fn.getFunctionName(), new Type[] {sumType, countType}, call.getType(), + fn.hasVarArgs()); + } + + CallOperator avg = + new CallOperator(opName, call.getType(), List.of(sum, length), fn); + ScalarOperatorRewriter rewriter = new ScalarOperatorRewriter(); + ScalarOperator result = rewriter.rewrite(avg, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES); + if (result.isConstant()) { + return Optional.of(result); + } else { + return Optional.empty(); + } + } + + private Optional returnNullIfExistsNullArg(CallOperator call) { + ImmutableSet functions = new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER) + .add("split") + .add("str_to_map") + .add("regexp_extract_all") + .add("regexp_split") + .add("array_length") + .add("array_sum") + .add("array_avg") + .add("array_min") + .add("array_max") + .add("array_distinct") + .add("array_sort") + .add("reverse") + .add("array_join") + .add("array_difference") + .add("array_slice") + .add("array_concat") + .add("arrays_overlap") + .add("array_intersect") + .add("array_cum_sum") + .add("array_contains_all") + .add("array_contains_seq") + .add("all_match") + .add("any_match") + .add("array_generate") + .add("array_repeat") + .add("array_flatten") + .add("array_map") + .add("map_size") + .add("map_keys") + .add("map_values") + .add("map_from_arrays") + .add("distinct_map_keys") + .add("cardinality") + .add("tokenize") + .build(); + if (functions.contains(call.getFnName()) && call.getArguments().stream().anyMatch( + ScalarOperator::isConstantNull)) { + return Optional.of(ConstantOperator.createNull(call.getType())); + } else { + return Optional.empty(); + } + } + + private Optional tryToProcessConstantArrayFunctions(CallOperator call) { + Optional result = returnNullIfExistsNullArg(call); + if (result.isPresent()) { + return result; + } + ImmutableMap>> handlers = + new ImmutableSortedMap.Builder>>( + String.CASE_INSENSITIVE_ORDER) + .put("array_length", this::constArrayLength) + .put("array_sum", this::constArraySum) + .put("array_min", this::constArrayMin) + .put("array_max", this::constArrayMax) + .put("array_avg", this::constArrayAvg) + .put("array_contains", this::constArrayContains) + .put("array_append", this::constArrayAppend) + .put("array_remove", this::constArrayRemove) + .put("array_position", this::constArrayPosition) + .build(); + if (handlers.containsKey(call.getFnName())) { + return handlers.get(call.getFnName()).apply(call); + } + return Optional.empty(); + } + @Override public ScalarOperator visitCall(CallOperator call, ScalarOperatorRewriteContext context) { if (call.isAggregate()) { return call; } + Optional mayBeConstant = tryToProcessConstantArrayFunctions(call); + if (mayBeConstant.isPresent()) { + return mayBeConstant.get(); + } + if (notAllConstant(call.getChildren())) { if (call.getFunction() != null && call.getFunction().isMetaFunction()) { String errMsg = String.format("Meta function %s does not support non-constant arguments", @@ -143,6 +410,25 @@ public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule { return ConstantOperator.createNull(operator.getType()); } + ScalarOperator arg0 = operator.getChild(0); + if (arg0 instanceof ArrayOperator + && arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef) + && !arg0.getChildren().isEmpty() && operator.getType().isArrayType() + && ((ArrayType) operator.getType()).getItemType().isScalarType()) { + if (notAllConstant(arg0.getChildren())) { + return operator; + } + Type arrayElemType = ((ArrayType) operator.getType()).getItemType(); + ScalarOperatorRewriteContext subContext = new ScalarOperatorRewriteContext(); + List newArguments = arg0.getChildren().stream() + .map(arg -> visitCastOperator(new CastOperator(arrayElemType, arg), subContext)) + .collect(Collectors.toList()); + if (newArguments.stream().allMatch(arg -> arg != null && arg.isConstantRef())) { + return new ArrayOperator(operator.getType(), operator.isNullable(), newArguments); + } + return operator; + } + if (notAllConstant(operator.getChildren())) { return operator; } diff --git a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java index f83d2956e76..72b4e1f8cb6 100644 --- a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java @@ -64,14 +64,14 @@ public class TrinoFunctionTransformTest extends TrinoTestBase { sql = "select concat(c1, c2, array[1,2], array[3,4]) from test_array"; assertPlanContains(sql, "array_concat(2: c1, CAST(3: c2 AS ARRAY), " + - "CAST([1,2] AS ARRAY), " + - "CAST([3,4] AS ARRAY)"); + "['1','2'], " + + "['3','4']"); sql = "select concat(c2, 2) from test_array"; - assertPlanContains(sql, "array_concat(3: c2, CAST([2] AS ARRAY))"); + assertPlanContains(sql, "array_concat(3: c2, [2])"); sql = "select contains(array[1,2,3], 1)"; - assertPlanContains(sql, "array_contains([1,2,3], 1)"); + assertPlanContains(sql, " : TRUE"); sql = "select slice(array[1,2,3,4], 2, 2)"; assertPlanContains(sql, "array_slice([1,2,3,4], 2, 2)"); diff --git a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoQueryTest.java b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoQueryTest.java index 0c60419240a..a99f71604dd 100644 --- a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoQueryTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoQueryTest.java @@ -367,10 +367,10 @@ public class TrinoQueryTest extends TrinoTestBase { assertPlanContains(sql, "array_min(2: c1)"); sql = "select array_position(array[1,2,3], 2) from test_array"; - assertPlanContains(sql, "array_position([1,2,3], 2)"); + assertPlanContains(sql, " : 2"); sql = "select array_remove(array[1,2,3], 2) from test_array"; - assertPlanContains(sql, "array_remove([1,2,3], 2)"); + assertPlanContains(sql, " : [1,3]"); sql = "select array_sort(c1) from test_array"; assertPlanContains(sql, "array_sort(2: c1)"); @@ -1234,7 +1234,7 @@ public class TrinoQueryTest extends TrinoTestBase { @Test public void testCastArrayDataType() throws Exception { String sql = "select cast(ARRAY[1] as array(int))"; - assertPlanContains(sql, "CAST([1] AS ARRAY)"); + assertPlanContains(sql, " : [1]"); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java index d2b829ccb21..e7f3f27be2c 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java @@ -385,9 +385,11 @@ public class VectorIndexTest extends PlanTestBase { " | 5 <-> [12: cast, DOUBLE, true] + 1.0\n" + " | 6 <-> [12: cast, DOUBLE, true] + 2.0\n" + " | 7 <-> cast([11: approx_cosine_similarity, FLOAT, true] as VARCHAR(65533))\n" + - " | 8 <-> cast(approx_cosine_similarity[(cast([1.1,2.2,3.3,4.4,5.5] as ARRAY), [3: c2, ARRAY, true]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true] as DOUBLE) + 2.0\n" + + " | 8 <-> cast(approx_cosine_similarity[([1.1,2.2,3.3,4.4,5.5], [3: c2, ARRAY, true]); " + + "args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true] as DOUBLE) + 2.0\n" + " | common expressions:\n" + - " | 11 <-> approx_cosine_similarity[(cast([1.1,2.2,3.3,4.4,5.5] as ARRAY), [2: c1, ARRAY, false]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" + + " | 11 <-> approx_cosine_similarity[([1.1,2.2,3.3,4.4,5.5], [2: c1, ARRAY, false]); " + + "args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" + " | 12 <-> cast([11: approx_cosine_similarity, FLOAT, true] as DOUBLE)\n" + " | limit: 10\n" + " | cardinality: 1\n" + @@ -529,7 +531,8 @@ public class VectorIndexTest extends PlanTestBase { " 1:Project\n" + " | output columns:\n" + " | 2 <-> [2: c1, ARRAY, false]\n" + - " | 4 <-> approx_l2_distance[(cast([1.1,2.2,3.3,4.4] as ARRAY), [2: c1, ARRAY, false]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" + + " | 4 <-> approx_l2_distance[([1.1,2.2,3.3,4.4], [2: c1, ARRAY, false]); " + + "args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" + " | cardinality: 1\n" + " | \n" + " 0:OlapScanNode\n" + diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java index 3c4465073f0..fa040d1f225 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java @@ -1680,7 +1680,7 @@ public class AggregateTest extends PlanTestBase { " | group by: \n" + " | \n" + " 1:Project\n" + - " | : arrays_overlap(3: v3, CAST([1] AS ARRAY))\n" + + " | : arrays_overlap(3: v3, [1])\n" + " | \n" + " 0:OlapScanNode"); } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/ArrayTypeTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ArrayTypeTest.java index 52f3d4182db..c5eada30c88 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/ArrayTypeTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ArrayTypeTest.java @@ -114,7 +114,7 @@ public class ArrayTypeTest extends PlanTestBase { sql = "select concat(v1, [1,2,3], s_1) from adec"; plan = getFragmentPlan(sql); assertContains(plan, "array_concat(CAST([1: v1] AS ARRAY), " + - "CAST([1,2,3] AS ARRAY), 3: s_1)"); + "['1','2','3'], 3: s_1)"); sql = "select concat(1,2, [1,2])"; plan = getFragmentPlan(sql); @@ -122,13 +122,11 @@ public class ArrayTypeTest extends PlanTestBase { sql = "select concat(1,2, [1,2], 'a', 'b')"; plan = getFragmentPlan(sql); - assertContains(plan, "array_concat(CAST([1] AS ARRAY), CAST([2] AS ARRAY), " + - "CAST([1,2] AS ARRAY), ['a'], ['b'])"); + assertContains(plan, "array_concat(['1'], ['2'], ['1','2'], ['a'], ['b'])"); sql = "select concat(1,2, [1,2], 'a', 'b', 1.1)"; plan = getFragmentPlan(sql); - assertContains(plan, "array_concat(CAST([1] AS ARRAY), CAST([2] AS ARRAY), " + - "CAST([1,2] AS ARRAY), ['a'], ['b'], CAST([1.1] AS ARRAY)"); + assertContains(plan, " array_concat(['1'], ['2'], ['1','2'], ['a'], ['b'], ['1.1'])"); sql = "with t0 as (\n" + " select c1 from (values([])) as t(c1)\n" + @@ -357,7 +355,7 @@ public class ArrayTypeTest extends PlanTestBase { String sql = "select array_append([[1,2,3]], [null])"; String plan = getFragmentPlan(sql); assertContains(plan, - " : array_append([[1,2,3]], CAST([NULL] AS ARRAY))"); + " : array_append([[1,2,3]], [NULL])"); } { starRocksAssert.withTable("create table test_literal_array_insert_t0(" + @@ -727,11 +725,10 @@ public class ArrayTypeTest extends PlanTestBase { sql = "select array_contains([null], null), array_position([null], null)"; plan = getVerboseExplain(sql); - assertContains(plan, " | output columns:\n" + - " | 2 <-> array_contains[([NULL], NULL); " + - "args: INVALID_TYPE,BOOLEAN; result: BOOLEAN; args nullable: true; result nullable: true]\n" + - " | 3 <-> array_position[([NULL], NULL); " + - "args: INVALID_TYPE,BOOLEAN; result: INT; args nullable: true; result nullable: true]"); + assertContains(plan, " 1:Project\n" + + " | output columns:\n" + + " | 2 <-> TRUE\n" + + " | 3 <-> 1"); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/ConstArrayFunctionFoldingTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ConstArrayFunctionFoldingTest.java new file mode 100644 index 00000000000..94116fc5ad2 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/ConstArrayFunctionFoldingTest.java @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package com.starrocks.sql.plan; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +public class ConstArrayFunctionFoldingTest extends PlanTestBase { + + @BeforeAll + public static void beforeClass() throws Exception { + PlanTestBase.beforeClass(); + } + + @Test + public void testConstantArrayFunctions() throws Exception { + String[][] cases = new String[][] { + {"array_contains(['A', 'B'], 'C')", " : FALSE"}, + {"array_contains(['A', 'B'], 'B')", " : TRUE"}, + {"array_contains(['A', 'B', NULL], NULL)", " : TRUE"}, + {"array_contains(['A', 'B', 'C'], NULL)", " : FALSE"}, + {"array_contains(NULL, NULL)", " : NULL"}, {"array_remove(['A', 'B'], 'B')", "['A']"}, + {"array_remove(['A', 'B', 'B'], 'B')", "['A']"}, {"array_remove(['A', 'B'], 'C')", "['A','B']"}, + {"array_remove(['A', 'B'], NULL)", "['A','B']"}, {"array_remove(['A', 'B', NULL], NULL)", "['A','B']"}, + {"array_remove(['A', 'B', NULL, NULL], NULL)", "['A','B']"}, + {"array_remove(NULL, NULL)", " : NULL"}, {"array_remove(NULL, 'A')", " : NULL"}, + {"array_append(['A', 'B'], 'C')", "['A','B','C']"}, + {"array_append(['A', 'B'], NULL)", "['A','B',NULL]"}, {"array_append(NULL, NULL)", " : NULL"}, + {"array_append(NULL, 'A')", " : NULL"}, {"array_position(['A', 'B'], 'A')", " : 1"}, + {"array_position(['A', 'B'], 'C')", " : 0"}, + {"array_position(['A', 'B'], NULL)", " : 0"}, + {"array_position([NULL,'A', 'B'], NULL)", " : 1"}, + {"array_position(NULL, NULL)", " : NULL"}, {"array_position(NULL, 'A')", " : NULL"}, + {"array_length(['A', 'B'])", " : 2"}, {"array_length([])", " : 0"}, + {"array_length(NULL)", " : NULL"}, {"array_sum([1, 2, 3, 4])", " : 10"}, + {"array_sum([NULL,1, 2, 3, 4])", " : 10"}, {"array_sum([])", " : NULL"}, + {"array_sum(NULL)", " : NULL"}, {"array_sum([NULL, NULL])", " : NULL"}, + {"array_min([1, 2, 3, 4])", " : 1"}, {"array_min([])", " : NULL"}, + {"array_min([NULL, 4, 2, 3])", " : 2"}, {"array_min([NULL, 4, 2, 1, NULL])", " : 1"}, + {"array_min(NULL)", " : NULL"}, {"array_max([1, 2, 3, 4])", " : 4"}, + {"array_max([])", " : NULL"}, {"array_max([NULL, 4, 10, 7])", " : 10"}, + {"array_max([NULL, 4, 7, 11, NULL])", " : 11"}, {"array_max(NULL)", " : NULL"}, + {"array_avg([1, 2, 3, 4])", " : 2.5"}, {"array_avg([NULL,1, 2, 3, 4])", " : 2"}, + {"array_avg([])", " : NULL"}, {"array_avg(NULL)", " : NULL"}, + {"array_avg([NULL, NULL])", " : NULL"} + }; + String sqlFmt = "select {FUNC}"; + for (String[] tc : cases) { + String sql = sqlFmt.replace("{FUNC}", tc[0]); + String expect = tc[1]; + System.out.println(sql); + String plan = getFragmentPlan(sql); + assertContains(plan, expect); + } + } + + @Test + public void testConstantArrayDecimal() throws Exception { + String[][] cases = new String[][] { + {"array_remove([1.0,2.1,3.2,4.3], 1)", " : [2.1,3.2,4.3]"}, + {"array_position([1.0,2.1,3.2,4.3], 1)", " : 1"}, + {"array_position([1.0,2.1,3.2,4.3], NULL)", " : 0"}, + {"array_avg([1.0, 2.0, 3.0, 4.0])", " : 2.5"}, + {"array_contains([1.0,2.1,3.2,4.3], 1)", " : TRUE"}, + {"array_avg([1.1, 2.22, 3.333, 4.4444, NULL])", " : 2.21948"}, + {"array_sum([1.0, 2.0, 3.0, 4.0])", " : 10.0"}, + {"array_sum([1.1, 2.22, 3.333, 4.4444, NULL])", " : 11.0974"}, + }; + String sqlFmt = "select {FUNC}"; + for (String[] tc : cases) { + String sql = sqlFmt.replace("{FUNC}", tc[0]); + String expect = tc[1]; + System.out.println(sql); + String plan = getFragmentPlan(sql); + assertContains(plan, expect); + } + } + @Test + public void testConstArrayFunctionsReturnNullExistsNullArgument() throws Exception { + String[] queryList = new String[] {"SELECT 'split' as func_name, split(NULL, ',') as result;", + "SELECT 'split' as func_name, split('a,b', NULL) as result;", + "SELECT 'str_to_map' as func_name, str_to_map(NULL, ',', ':') as result;", + "SELECT 'str_to_map' as func_name, str_to_map('a:1,b:2', NULL, ':') as result;", + "SELECT 'str_to_map' as func_name, str_to_map('a:1,b:2', ',', NULL) as result;", + "SELECT 'regexp_extract_all' as func_name, regexp_extract_all(NULL, '(\\d+)', 1) as result;", + "SELECT 'regexp_extract_all' as func_name, regexp_extract_all('abc123', NULL, 1) as result;", + "SELECT 'regexp_extract_all' as func_name, regexp_extract_all('abc123', '(\\d+)', NULL) as result;", + "SELECT 'regexp_split' as func_name, regexp_split(NULL, ',', 1) as result;", + "SELECT 'regexp_split' as func_name, regexp_split('a,b', NULL, 1) as result;", + "SELECT 'regexp_split' as func_name, regexp_split('a,b', ',', NULL) as result;", + "SELECT 'array_length' as func_name, array_length(NULL) as result;", + "SELECT 'array_sum' as func_name, array_sum(NULL) as result;", + "SELECT 'array_avg' as func_name, array_avg(NULL) as result;", + "SELECT 'array_min' as func_name, array_min(NULL) as result;", + "SELECT 'array_max' as func_name, array_max(NULL) as result;", + "SELECT 'array_distinct' as func_name, array_distinct(NULL) as result;", + "SELECT 'array_sort' as func_name, array_sort(NULL) as result;", + "SELECT 'reverse' as func_name, reverse(NULL) as result;", + "SELECT 'array_join' as func_name, array_join(NULL, ',') as result;", + "SELECT 'array_join' as func_name, array_join(['a', 'b'], NULL) as result;", + "SELECT 'array_difference' as func_name, array_difference(NULL) as result;", + "SELECT 'array_slice' as func_name, array_slice(NULL, 1, 2) as result;", + "SELECT 'array_slice' as func_name, array_slice([1, 2, 3], NULL, 2) as result;", + "SELECT 'array_slice' as func_name, array_slice([1, 2, 3], 1, NULL) as result;", + "SELECT 'array_concat' as func_name, array_concat(NULL, [3, 4]) as result;", + "SELECT 'array_concat' as func_name, array_concat([1, 2], NULL) as result;", + "SELECT 'arrays_overlap' as func_name, arrays_overlap(NULL, [1, 2]) as result;", + "SELECT 'arrays_overlap' as func_name, arrays_overlap([1, 2], NULL) as result;", + "SELECT 'array_intersect' as func_name, array_intersect(NULL, [3, 4]) as result;", + "SELECT 'array_intersect' as func_name, array_intersect([1, 2], NULL) as result;", + "SELECT 'array_cum_sum' as func_name, array_cum_sum(NULL) as result;", + "SELECT 'array_contains_all' as func_name, array_contains_all(NULL, [1, 2]) as result;", + "SELECT 'array_contains_all' as func_name, array_contains_all([1, 2], NULL) as result;", + "SELECT 'array_contains_seq' as func_name, array_contains_seq(NULL, [1, 2]) as result;", + "SELECT 'array_contains_seq' as func_name, array_contains_seq([1, 2], NULL) as result;", + "SELECT 'all_match' as func_name, all_match(NULL) as result;", + "SELECT 'any_match' as func_name, any_match(NULL) as result;", + "SELECT 'array_generate' as func_name, array_generate(NULL, 1, 1) as result;", + "SELECT 'array_generate' as func_name, array_generate(1, NULL, 1) as result;", + "SELECT 'array_generate' as func_name, array_generate(1, 1, NULL) as result;", + "SELECT 'array_repeat' as func_name, array_repeat(NULL, 3) as result;", + "SELECT 'array_repeat' as func_name, array_repeat(1, NULL) as result;", + "SELECT 'array_flatten' as func_name, array_flatten(NULL) as result;", + "SELECT 'map_size' as func_name, map_size(NULL) as result;", + "SELECT 'map_keys' as func_name, map_keys(NULL) as result;", + "SELECT 'map_values' as func_name, map_values(NULL) as result;", + "SELECT 'map_from_arrays' as func_name, map_from_arrays(NULL, [1, 2]) as result;", + "SELECT 'map_from_arrays' as func_name, map_from_arrays(['a', 'b'], NULL) as result;", + "SELECT 'distinct_map_keys' as func_name, distinct_map_keys(NULL) as result;", + "SELECT 'cardinality' as func_name, cardinality(NULL) as result;", + "SELECT 'tokenize' as func_name, tokenize(NULL, ' ') as result;", + "SELECT 'tokenize' as func_name, tokenize('hello world', NULL) as result;"}; + for (String q : queryList) { + System.out.println(q); + String plan = getFragmentPlan(q); + assertContains(plan, " : NULL"); + } + } + + @Test + public void testLeftJoin() throws Exception { + String sql = "select ta.v1, tb.v3 " + "from t0 ta left join t0 tb on ta.v1 = tb.v2 " + + "where array_contains([1,2,3], tb.v3)"; + String plan = getFragmentPlan(sql); + assertCContains(plan, + " 3:HASH JOIN\n" + " | join op: INNER JOIN (BROADCAST)\n" + " | colocate: false, reason: \n" + + " | equal join conjunct: 5: v2 = 1: v1"); + assertCContains(plan, " 0:OlapScanNode\n" + " TABLE: t0\n" + " PREAGGREGATION: ON\n" + + " PREDICATES: array_contains([1,2,3], 6: v3)"); + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneComplexSubfieldTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneComplexSubfieldTest.java index 00ea96db19a..a8418d621c4 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneComplexSubfieldTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneComplexSubfieldTest.java @@ -677,8 +677,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase { assertContains(plan, " 0:OlapScanNode\n" + " table: pc0, rollup: pc0\n" + " preAggregation: on\n" - + " Predicates: array_length[([]); args: INVALID_TYPE; result: INT; args nullable: true; result " - + "nullable: true] IS NOT NULL\n" + " partitionsRatio=0/1, tabletsRatio=0/0\n" + " tabletList=\n" + " actualRows=0, avgRowSize=1.0\n" @@ -692,8 +690,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase { assertContains(plan, " 0:OlapScanNode\n" + " table: sc0, rollup: sc0\n" + " preAggregation: on\n" - + " Predicates: array_length[([1,2,3]); args: INVALID_TYPE; result: INT; args nullable: true; result " - + "nullable: true] IS NOT NULL\n" + " partitionsRatio=0/1, tabletsRatio=0/0\n" + " tabletList=\n" + " actualRows=0, avgRowSize=3.0\n" @@ -711,8 +707,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase { assertContains(plan, " 0:OlapScanNode\n" + " table: pc0, rollup: pc0\n" + " preAggregation: on\n" + - " Predicates: array_length[([]); args: INVALID_TYPE; result: INT; args nullable: true; result nullable:" - + " true] IS NOT NULL\n" + " partitionsRatio=0/1, tabletsRatio=0/0\n" + " tabletList=\n" + " actualRows=0, avgRowSize=3.0\n" + @@ -734,7 +728,7 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase { String sql = "select [1, 2, 3] is null from pc0 t1 right join sc0 t2 on t1.v1 = t2.v1;"; String plan = getFragmentPlan(sql); assertContains(plan, "5:Project\n" + - " | : array_length([1,2,3]) IS NULL"); + " | : FALSE"); sql = "select [1, 2, 3][1] is null from pc0 t1 right join sc0 t2 on t1.v1 = t2.v1;"; plan = getFragmentPlan(sql);