[Enhancement] Support Constant Folding for some array functions (#63692)

Signed-off-by: satanson <ranpanf@gmail.com>
This commit is contained in:
satanson 2025-10-11 14:22:57 +08:00 committed by GitHub
parent ee66eb3b3f
commit 674145d7c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 480 additions and 29 deletions

View File

@ -1061,6 +1061,11 @@ public class ScalarOperatorFunctions {
/**
* Arithmetic function
*/
@ConstantFunction(name = "add", argTypes = {TINYINT, TINYINT}, returnType = TINYINT, isMonotonic = true)
public static ConstantOperator addTinyInt(ConstantOperator first, ConstantOperator second) {
return ConstantOperator.createTinyInt((byte) Math.addExact(first.getTinyInt(), second.getTinyInt()));
}
@ConstantFunction(name = "add", argTypes = {SMALLINT, SMALLINT}, returnType = SMALLINT, isMonotonic = true)
public static ConstantOperator addSmallInt(ConstantOperator first, ConstantOperator second) {
return ConstantOperator.createSmallInt((short) Math.addExact(first.getSmallint(), second.getSmallint()));

View File

@ -14,9 +14,23 @@
package com.starrocks.sql.optimizer.rewrite.scalar;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.starrocks.catalog.ArrayType;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.catalog.ScalarFunction;
import com.starrocks.catalog.ScalarType;
import com.starrocks.catalog.Type;
import com.starrocks.sql.analyzer.SemanticException;
import com.starrocks.sql.ast.expression.ArithmeticExpr;
import com.starrocks.sql.ast.expression.BinaryType;
import com.starrocks.sql.ast.expression.Expr;
import com.starrocks.sql.optimizer.operator.scalar.ArrayOperator;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
@ -27,12 +41,19 @@ import com.starrocks.sql.optimizer.operator.scalar.LikePredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorEvaluator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule {
private static final Logger LOG = LogManager.getLogger(FoldConstantsRule.class);
@ -47,12 +68,258 @@ public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule {
this.needMonotonicFunc = needMonotonicFunc;
}
private <T extends ScalarOperator> Optional<ScalarOperator> arrayScalarFun(
CallOperator call, BiFunction<List<ScalarOperator>, ScalarOperator, T> cb) {
Preconditions.checkArgument(call.getArguments().size() == 2);
ScalarOperator arg0 = call.getArguments().get(0);
ScalarOperator arg1 = call.getArguments().get(1);
if (arg0.isConstantNull()) {
return Optional.of(ConstantOperator.createNull(call.getType()));
}
if (!(arg0 instanceof ArrayOperator) || !arg1.isConstantRef() ||
!arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef)) {
return Optional.empty();
}
return Optional.of(cb.apply(arg0.getChildren(), arg1));
}
private Optional<ScalarOperator> arrayUnaryFun(CallOperator call,
Function<List<ScalarOperator>, ScalarOperator> cb) {
Preconditions.checkArgument(call.getArguments().size() == 1);
ScalarOperator arg0 = call.getChild(0);
if (arg0 instanceof CastOperator && arg0.getChild(0) instanceof ArrayOperator) {
arg0 = call.getChild(0).getChild(0);
}
if (arg0.isConstantNull()) {
return Optional.of(ConstantOperator.createNull(arg0.getType()));
}
if (!(arg0 instanceof ArrayOperator) || !arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef)) {
return Optional.empty();
}
try {
return Optional.of(cb.apply(arg0.getChildren()));
} catch (IllegalArgumentException ex) {
return Optional.empty();
}
}
private Optional<ScalarOperator> constArrayAppend(CallOperator call) {
return arrayScalarFun(call, (arrayElements, target) ->
new ArrayOperator(call.getType(), call.isNullable(),
Lists.newArrayList(Iterables.concat(arrayElements, List.of(target)))));
}
boolean constantEqual(ScalarOperator lhs, ScalarOperator rhs) {
Preconditions.checkArgument(lhs.isConstantRef() && rhs.isConstantRef());
ConstantOperator constLhs = (ConstantOperator) lhs;
ConstantOperator constRhs = (ConstantOperator) rhs;
if (constLhs.getType().isDecimalV3()) {
return constLhs.getDecimal().compareTo(constRhs.getDecimal()) == 0 &&
constLhs.isNull() == constRhs.isNull();
} else {
return lhs.equals(rhs);
}
}
private Optional<ScalarOperator> constArrayContains(CallOperator call) {
return arrayScalarFun(call, (arrayElements, target) ->
ConstantOperator.createBoolean(arrayElements.stream().anyMatch(elem -> constantEqual(elem, target))));
}
private Optional<ScalarOperator> constArrayRemove(CallOperator call) {
return arrayScalarFun(call, (arrayElements, target) ->
new ArrayOperator(call.getType(), call.isNullable(),
arrayElements.stream().filter((elem) -> !constantEqual(elem, target))
.collect(Collectors.toList())));
}
private Optional<ScalarOperator> constArrayPosition(CallOperator call) {
return arrayScalarFun(call, (arrayElements, target) ->
ConstantOperator.createInt(IntStream.range(0, arrayElements.size())
.boxed()
.filter(i -> constantEqual(arrayElements.get(i), target))
.findAny().map(i -> i + 1).orElse(0)));
}
private Optional<ScalarOperator> constArrayLength(CallOperator call) {
ScalarOperator arg0 = call.getChild(0);
if (arg0.isConstantNull()) {
return Optional.of(ConstantOperator.createNull(call.getType()));
} else if (arg0 instanceof ArrayOperator) {
return Optional.of(ConstantOperator.createInt(arg0.getChildren().size()));
} else if (arg0 instanceof CastOperator && (arg0.getChild(0) instanceof ArrayOperator)) {
// array_length([]) and array_length([NULL, NULL]) would be
// array_length(cast([] as ARRAY<BOOL>)) and array_length(cast([NULL, NULL] as ARRAY<BOOL>))
return Optional.of(ConstantOperator.createInt(arg0.getChild(0).getChildren().size()));
} else {
return Optional.empty();
}
}
private Optional<ScalarOperator> constArraySum(CallOperator call) {
BinaryOperator<ConstantOperator> add = (lhs, rhs) -> {
String opName = ArithmeticExpr.Operator.ADD.getName();
com.starrocks.catalog.Function fn =
Expr.getBuiltinFunction(opName, new Type[] {lhs.getType(), rhs.getType()},
com.starrocks.catalog.Function.CompareMode.IS_SUPERTYPE_OF);
CallOperator addition =
new CallOperator(ArithmeticExpr.Operator.ADD.getName(), lhs.getType(), List.of(lhs, rhs), fn);
ScalarOperator result = addition.accept(this, new ScalarOperatorRewriteContext());
if (result instanceof ConstantOperator) {
return (ConstantOperator) result;
} else {
throw new IllegalArgumentException();
}
};
return arrayUnaryFun(call, (elements) -> elements.stream().filter(elem -> !elem.isConstantNull())
.map(elem -> (ConstantOperator) elem).reduce(add).orElse(ConstantOperator.NULL));
}
private Optional<ScalarOperator> constArrayMinMax(CallOperator call, boolean lessThan) {
Comparator<ConstantOperator> comparator = (lhs, rhs) -> {
BinaryPredicateOperator cmp =
lessThan ? BinaryPredicateOperator.lt(lhs, rhs) : BinaryPredicateOperator.gt(lhs, rhs);
ScalarOperator result = cmp.accept(this, new ScalarOperatorRewriteContext());
if (result.isConstant() && result.getType().isBoolean()) {
return ((ConstantOperator) result).getBoolean() ? -1 : 1;
} else {
throw new IllegalArgumentException();
}
};
return arrayUnaryFun(call, (elements) -> elements.stream().filter(elem -> !elem.isConstantNull())
.map(elem -> (ConstantOperator) elem).min(comparator).orElse(ConstantOperator.NULL));
}
private Optional<ScalarOperator> constArrayMin(CallOperator call) {
return constArrayMinMax(call, true);
}
private Optional<ScalarOperator> constArrayMax(CallOperator call) {
return constArrayMinMax(call, false);
}
private Optional<ScalarOperator> constArrayAvg(CallOperator call) {
Optional<ScalarOperator> optSum = constArraySum(call);
Optional<ScalarOperator> optLength = constArrayLength(call);
if (!optSum.isPresent() || !optLength.isPresent() ||
!optSum.get().isConstant() || !optLength.get().isConstant()) {
return Optional.empty();
}
ConstantOperator sum = (ConstantOperator) optSum.get();
ConstantOperator length = (ConstantOperator) optLength.get();
if (sum.isNull() || length.isNull()) {
return Optional.of(ConstantOperator.createNull(call.getType()));
}
String opName = ArithmeticExpr.Operator.DIVIDE.getName();
com.starrocks.catalog.Function
fn = Expr.getBuiltinFunction(opName, new Type[] {call.getType(), call.getType()},
com.starrocks.catalog.Function.CompareMode.IS_SUPERTYPE_OF);
// for decimal types, divide function should be rectified.
if (call.getType().isDecimalV3()) {
Preconditions.checkArgument(sum.getType().isDecimalV3());
ScalarType sumType = (ScalarType) sum.getType();
sumType = ScalarType.createDecimalV3Type(call.getType().getPrimitiveType(), sumType.getScalarPrecision(),
sumType.getScalarScale());
int precision = PrimitiveType.getMaxPrecisionOfDecimal(call.getType().getPrimitiveType());
ScalarType countType = ScalarType.createDecimalV3Type(call.getType().getPrimitiveType(), precision, 0);
fn = new ScalarFunction(fn.getFunctionName(), new Type[] {sumType, countType}, call.getType(),
fn.hasVarArgs());
}
CallOperator avg =
new CallOperator(opName, call.getType(), List.of(sum, length), fn);
ScalarOperatorRewriter rewriter = new ScalarOperatorRewriter();
ScalarOperator result = rewriter.rewrite(avg, ScalarOperatorRewriter.DEFAULT_REWRITE_RULES);
if (result.isConstant()) {
return Optional.of(result);
} else {
return Optional.empty();
}
}
private Optional<ScalarOperator> returnNullIfExistsNullArg(CallOperator call) {
ImmutableSet<String> functions = new ImmutableSortedSet.Builder<String>(String.CASE_INSENSITIVE_ORDER)
.add("split")
.add("str_to_map")
.add("regexp_extract_all")
.add("regexp_split")
.add("array_length")
.add("array_sum")
.add("array_avg")
.add("array_min")
.add("array_max")
.add("array_distinct")
.add("array_sort")
.add("reverse")
.add("array_join")
.add("array_difference")
.add("array_slice")
.add("array_concat")
.add("arrays_overlap")
.add("array_intersect")
.add("array_cum_sum")
.add("array_contains_all")
.add("array_contains_seq")
.add("all_match")
.add("any_match")
.add("array_generate")
.add("array_repeat")
.add("array_flatten")
.add("array_map")
.add("map_size")
.add("map_keys")
.add("map_values")
.add("map_from_arrays")
.add("distinct_map_keys")
.add("cardinality")
.add("tokenize")
.build();
if (functions.contains(call.getFnName()) && call.getArguments().stream().anyMatch(
ScalarOperator::isConstantNull)) {
return Optional.of(ConstantOperator.createNull(call.getType()));
} else {
return Optional.empty();
}
}
private Optional<ScalarOperator> tryToProcessConstantArrayFunctions(CallOperator call) {
Optional<ScalarOperator> result = returnNullIfExistsNullArg(call);
if (result.isPresent()) {
return result;
}
ImmutableMap<String, Function<CallOperator, Optional<ScalarOperator>>> handlers =
new ImmutableSortedMap.Builder<String, Function<CallOperator, Optional<ScalarOperator>>>(
String.CASE_INSENSITIVE_ORDER)
.put("array_length", this::constArrayLength)
.put("array_sum", this::constArraySum)
.put("array_min", this::constArrayMin)
.put("array_max", this::constArrayMax)
.put("array_avg", this::constArrayAvg)
.put("array_contains", this::constArrayContains)
.put("array_append", this::constArrayAppend)
.put("array_remove", this::constArrayRemove)
.put("array_position", this::constArrayPosition)
.build();
if (handlers.containsKey(call.getFnName())) {
return handlers.get(call.getFnName()).apply(call);
}
return Optional.empty();
}
@Override
public ScalarOperator visitCall(CallOperator call, ScalarOperatorRewriteContext context) {
if (call.isAggregate()) {
return call;
}
Optional<ScalarOperator> mayBeConstant = tryToProcessConstantArrayFunctions(call);
if (mayBeConstant.isPresent()) {
return mayBeConstant.get();
}
if (notAllConstant(call.getChildren())) {
if (call.getFunction() != null && call.getFunction().isMetaFunction()) {
String errMsg = String.format("Meta function %s does not support non-constant arguments",
@ -143,6 +410,25 @@ public class FoldConstantsRule extends BottomUpScalarOperatorRewriteRule {
return ConstantOperator.createNull(operator.getType());
}
ScalarOperator arg0 = operator.getChild(0);
if (arg0 instanceof ArrayOperator
&& arg0.getChildren().stream().allMatch(ScalarOperator::isConstantRef)
&& !arg0.getChildren().isEmpty() && operator.getType().isArrayType()
&& ((ArrayType) operator.getType()).getItemType().isScalarType()) {
if (notAllConstant(arg0.getChildren())) {
return operator;
}
Type arrayElemType = ((ArrayType) operator.getType()).getItemType();
ScalarOperatorRewriteContext subContext = new ScalarOperatorRewriteContext();
List<ScalarOperator> newArguments = arg0.getChildren().stream()
.map(arg -> visitCastOperator(new CastOperator(arrayElemType, arg), subContext))
.collect(Collectors.toList());
if (newArguments.stream().allMatch(arg -> arg != null && arg.isConstantRef())) {
return new ArrayOperator(operator.getType(), operator.isNullable(), newArguments);
}
return operator;
}
if (notAllConstant(operator.getChildren())) {
return operator;
}

View File

@ -64,14 +64,14 @@ public class TrinoFunctionTransformTest extends TrinoTestBase {
sql = "select concat(c1, c2, array[1,2], array[3,4]) from test_array";
assertPlanContains(sql, "array_concat(2: c1, CAST(3: c2 AS ARRAY<VARCHAR>), " +
"CAST([1,2] AS ARRAY<VARCHAR>), " +
"CAST([3,4] AS ARRAY<VARCHAR>)");
"['1','2'], " +
"['3','4']");
sql = "select concat(c2, 2) from test_array";
assertPlanContains(sql, "array_concat(3: c2, CAST([2] AS ARRAY<INT>))");
assertPlanContains(sql, "array_concat(3: c2, [2])");
sql = "select contains(array[1,2,3], 1)";
assertPlanContains(sql, "array_contains([1,2,3], 1)");
assertPlanContains(sql, " <slot 2> : TRUE");
sql = "select slice(array[1,2,3,4], 2, 2)";
assertPlanContains(sql, "array_slice([1,2,3,4], 2, 2)");

View File

@ -367,10 +367,10 @@ public class TrinoQueryTest extends TrinoTestBase {
assertPlanContains(sql, "array_min(2: c1)");
sql = "select array_position(array[1,2,3], 2) from test_array";
assertPlanContains(sql, "array_position([1,2,3], 2)");
assertPlanContains(sql, "<slot 4> : 2");
sql = "select array_remove(array[1,2,3], 2) from test_array";
assertPlanContains(sql, "array_remove([1,2,3], 2)");
assertPlanContains(sql, "<slot 4> : [1,3]");
sql = "select array_sort(c1) from test_array";
assertPlanContains(sql, "array_sort(2: c1)");
@ -1234,7 +1234,7 @@ public class TrinoQueryTest extends TrinoTestBase {
@Test
public void testCastArrayDataType() throws Exception {
String sql = "select cast(ARRAY[1] as array(int))";
assertPlanContains(sql, "CAST([1] AS ARRAY<INT>)");
assertPlanContains(sql, " <slot 2> : [1]");
}
@Test

View File

@ -385,9 +385,11 @@ public class VectorIndexTest extends PlanTestBase {
" | 5 <-> [12: cast, DOUBLE, true] + 1.0\n" +
" | 6 <-> [12: cast, DOUBLE, true] + 2.0\n" +
" | 7 <-> cast([11: approx_cosine_similarity, FLOAT, true] as VARCHAR(65533))\n" +
" | 8 <-> cast(approx_cosine_similarity[(cast([1.1,2.2,3.3,4.4,5.5] as ARRAY<FLOAT>), [3: c2, ARRAY<FLOAT>, true]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true] as DOUBLE) + 2.0\n" +
" | 8 <-> cast(approx_cosine_similarity[([1.1,2.2,3.3,4.4,5.5], [3: c2, ARRAY<FLOAT>, true]); " +
"args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true] as DOUBLE) + 2.0\n" +
" | common expressions:\n" +
" | 11 <-> approx_cosine_similarity[(cast([1.1,2.2,3.3,4.4,5.5] as ARRAY<FLOAT>), [2: c1, ARRAY<FLOAT>, false]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" +
" | 11 <-> approx_cosine_similarity[([1.1,2.2,3.3,4.4,5.5], [2: c1, ARRAY<FLOAT>, false]); " +
"args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" +
" | 12 <-> cast([11: approx_cosine_similarity, FLOAT, true] as DOUBLE)\n" +
" | limit: 10\n" +
" | cardinality: 1\n" +
@ -529,7 +531,8 @@ public class VectorIndexTest extends PlanTestBase {
" 1:Project\n" +
" | output columns:\n" +
" | 2 <-> [2: c1, ARRAY<FLOAT>, false]\n" +
" | 4 <-> approx_l2_distance[(cast([1.1,2.2,3.3,4.4] as ARRAY<FLOAT>), [2: c1, ARRAY<FLOAT>, false]); args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" +
" | 4 <-> approx_l2_distance[([1.1,2.2,3.3,4.4], [2: c1, ARRAY<FLOAT>, false]); " +
"args: INVALID_TYPE,INVALID_TYPE; result: FLOAT; args nullable: true; result nullable: true]\n" +
" | cardinality: 1\n" +
" | \n" +
" 0:OlapScanNode\n" +

View File

@ -1680,7 +1680,7 @@ public class AggregateTest extends PlanTestBase {
" | group by: \n" +
" | \n" +
" 1:Project\n" +
" | <slot 4> : arrays_overlap(3: v3, CAST([1] AS ARRAY<BIGINT>))\n" +
" | <slot 4> : arrays_overlap(3: v3, [1])\n" +
" | \n" +
" 0:OlapScanNode");
}

View File

@ -114,7 +114,7 @@ public class ArrayTypeTest extends PlanTestBase {
sql = "select concat(v1, [1,2,3], s_1) from adec";
plan = getFragmentPlan(sql);
assertContains(plan, "array_concat(CAST([1: v1] AS ARRAY<VARCHAR>), " +
"CAST([1,2,3] AS ARRAY<VARCHAR>), 3: s_1)");
"['1','2','3'], 3: s_1)");
sql = "select concat(1,2, [1,2])";
plan = getFragmentPlan(sql);
@ -122,13 +122,11 @@ public class ArrayTypeTest extends PlanTestBase {
sql = "select concat(1,2, [1,2], 'a', 'b')";
plan = getFragmentPlan(sql);
assertContains(plan, "array_concat(CAST([1] AS ARRAY<VARCHAR>), CAST([2] AS ARRAY<VARCHAR>), " +
"CAST([1,2] AS ARRAY<VARCHAR>), ['a'], ['b'])");
assertContains(plan, "array_concat(['1'], ['2'], ['1','2'], ['a'], ['b'])");
sql = "select concat(1,2, [1,2], 'a', 'b', 1.1)";
plan = getFragmentPlan(sql);
assertContains(plan, "array_concat(CAST([1] AS ARRAY<VARCHAR>), CAST([2] AS ARRAY<VARCHAR>), " +
"CAST([1,2] AS ARRAY<VARCHAR>), ['a'], ['b'], CAST([1.1] AS ARRAY<VARCHAR>)");
assertContains(plan, " array_concat(['1'], ['2'], ['1','2'], ['a'], ['b'], ['1.1'])");
sql = "with t0 as (\n" +
" select c1 from (values([])) as t(c1)\n" +
@ -357,7 +355,7 @@ public class ArrayTypeTest extends PlanTestBase {
String sql = "select array_append([[1,2,3]], [null])";
String plan = getFragmentPlan(sql);
assertContains(plan,
"<slot 2> : array_append([[1,2,3]], CAST([NULL] AS ARRAY<TINYINT>))");
"<slot 2> : array_append([[1,2,3]], [NULL])");
}
{
starRocksAssert.withTable("create table test_literal_array_insert_t0(" +
@ -727,11 +725,10 @@ public class ArrayTypeTest extends PlanTestBase {
sql = "select array_contains([null], null), array_position([null], null)";
plan = getVerboseExplain(sql);
assertContains(plan, " | output columns:\n" +
" | 2 <-> array_contains[([NULL], NULL); " +
"args: INVALID_TYPE,BOOLEAN; result: BOOLEAN; args nullable: true; result nullable: true]\n" +
" | 3 <-> array_position[([NULL], NULL); " +
"args: INVALID_TYPE,BOOLEAN; result: INT; args nullable: true; result nullable: true]");
assertContains(plan, " 1:Project\n" +
" | output columns:\n" +
" | 2 <-> TRUE\n" +
" | 3 <-> 1");
}
@Test

View File

@ -0,0 +1,166 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package com.starrocks.sql.plan;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class ConstArrayFunctionFoldingTest extends PlanTestBase {
@BeforeAll
public static void beforeClass() throws Exception {
PlanTestBase.beforeClass();
}
@Test
public void testConstantArrayFunctions() throws Exception {
String[][] cases = new String[][] {
{"array_contains(['A', 'B'], 'C')", "<slot 2> : FALSE"},
{"array_contains(['A', 'B'], 'B')", "<slot 2> : TRUE"},
{"array_contains(['A', 'B', NULL], NULL)", "<slot 2> : TRUE"},
{"array_contains(['A', 'B', 'C'], NULL)", "<slot 2> : FALSE"},
{"array_contains(NULL, NULL)", "<slot 2> : NULL"}, {"array_remove(['A', 'B'], 'B')", "['A']"},
{"array_remove(['A', 'B', 'B'], 'B')", "['A']"}, {"array_remove(['A', 'B'], 'C')", "['A','B']"},
{"array_remove(['A', 'B'], NULL)", "['A','B']"}, {"array_remove(['A', 'B', NULL], NULL)", "['A','B']"},
{"array_remove(['A', 'B', NULL, NULL], NULL)", "['A','B']"},
{"array_remove(NULL, NULL)", "<slot 2> : NULL"}, {"array_remove(NULL, 'A')", "<slot 2> : NULL"},
{"array_append(['A', 'B'], 'C')", "['A','B','C']"},
{"array_append(['A', 'B'], NULL)", "['A','B',NULL]"}, {"array_append(NULL, NULL)", "<slot 2> : NULL"},
{"array_append(NULL, 'A')", "<slot 2> : NULL"}, {"array_position(['A', 'B'], 'A')", "<slot 2> : 1"},
{"array_position(['A', 'B'], 'C')", "<slot 2> : 0"},
{"array_position(['A', 'B'], NULL)", "<slot 2> : 0"},
{"array_position([NULL,'A', 'B'], NULL)", "<slot 2> : 1"},
{"array_position(NULL, NULL)", "<slot 2> : NULL"}, {"array_position(NULL, 'A')", "<slot 2> : NULL"},
{"array_length(['A', 'B'])", "<slot 2> : 2"}, {"array_length([])", "<slot 2> : 0"},
{"array_length(NULL)", "<slot 2> : NULL"}, {"array_sum([1, 2, 3, 4])", "<slot 2> : 10"},
{"array_sum([NULL,1, 2, 3, 4])", "<slot 2> : 10"}, {"array_sum([])", "<slot 2> : NULL"},
{"array_sum(NULL)", "<slot 2> : NULL"}, {"array_sum([NULL, NULL])", "<slot 2> : NULL"},
{"array_min([1, 2, 3, 4])", "<slot 2> : 1"}, {"array_min([])", "<slot 2> : NULL"},
{"array_min([NULL, 4, 2, 3])", "<slot 2> : 2"}, {"array_min([NULL, 4, 2, 1, NULL])", "<slot 2> : 1"},
{"array_min(NULL)", "<slot 2> : NULL"}, {"array_max([1, 2, 3, 4])", "<slot 2> : 4"},
{"array_max([])", "<slot 2> : NULL"}, {"array_max([NULL, 4, 10, 7])", "<slot 2> : 10"},
{"array_max([NULL, 4, 7, 11, NULL])", "<slot 2> : 11"}, {"array_max(NULL)", "<slot 2> : NULL"},
{"array_avg([1, 2, 3, 4])", "<slot 2> : 2.5"}, {"array_avg([NULL,1, 2, 3, 4])", "<slot 2> : 2"},
{"array_avg([])", "<slot 2> : NULL"}, {"array_avg(NULL)", "<slot 2> : NULL"},
{"array_avg([NULL, NULL])", "<slot 2> : NULL"}
};
String sqlFmt = "select {FUNC}";
for (String[] tc : cases) {
String sql = sqlFmt.replace("{FUNC}", tc[0]);
String expect = tc[1];
System.out.println(sql);
String plan = getFragmentPlan(sql);
assertContains(plan, expect);
}
}
@Test
public void testConstantArrayDecimal() throws Exception {
String[][] cases = new String[][] {
{"array_remove([1.0,2.1,3.2,4.3], 1)", "<slot 2> : [2.1,3.2,4.3]"},
{"array_position([1.0,2.1,3.2,4.3], 1)", "<slot 2> : 1"},
{"array_position([1.0,2.1,3.2,4.3], NULL)", "<slot 2> : 0"},
{"array_avg([1.0, 2.0, 3.0, 4.0])", "<slot 2> : 2.5"},
{"array_contains([1.0,2.1,3.2,4.3], 1)", "<slot 2> : TRUE"},
{"array_avg([1.1, 2.22, 3.333, 4.4444, NULL])", "<slot 2> : 2.21948"},
{"array_sum([1.0, 2.0, 3.0, 4.0])", "<slot 2> : 10.0"},
{"array_sum([1.1, 2.22, 3.333, 4.4444, NULL])", "<slot 2> : 11.0974"},
};
String sqlFmt = "select {FUNC}";
for (String[] tc : cases) {
String sql = sqlFmt.replace("{FUNC}", tc[0]);
String expect = tc[1];
System.out.println(sql);
String plan = getFragmentPlan(sql);
assertContains(plan, expect);
}
}
@Test
public void testConstArrayFunctionsReturnNullExistsNullArgument() throws Exception {
String[] queryList = new String[] {"SELECT 'split' as func_name, split(NULL, ',') as result;",
"SELECT 'split' as func_name, split('a,b', NULL) as result;",
"SELECT 'str_to_map' as func_name, str_to_map(NULL, ',', ':') as result;",
"SELECT 'str_to_map' as func_name, str_to_map('a:1,b:2', NULL, ':') as result;",
"SELECT 'str_to_map' as func_name, str_to_map('a:1,b:2', ',', NULL) as result;",
"SELECT 'regexp_extract_all' as func_name, regexp_extract_all(NULL, '(\\d+)', 1) as result;",
"SELECT 'regexp_extract_all' as func_name, regexp_extract_all('abc123', NULL, 1) as result;",
"SELECT 'regexp_extract_all' as func_name, regexp_extract_all('abc123', '(\\d+)', NULL) as result;",
"SELECT 'regexp_split' as func_name, regexp_split(NULL, ',', 1) as result;",
"SELECT 'regexp_split' as func_name, regexp_split('a,b', NULL, 1) as result;",
"SELECT 'regexp_split' as func_name, regexp_split('a,b', ',', NULL) as result;",
"SELECT 'array_length' as func_name, array_length(NULL) as result;",
"SELECT 'array_sum' as func_name, array_sum(NULL) as result;",
"SELECT 'array_avg' as func_name, array_avg(NULL) as result;",
"SELECT 'array_min' as func_name, array_min(NULL) as result;",
"SELECT 'array_max' as func_name, array_max(NULL) as result;",
"SELECT 'array_distinct' as func_name, array_distinct(NULL) as result;",
"SELECT 'array_sort' as func_name, array_sort(NULL) as result;",
"SELECT 'reverse' as func_name, reverse(NULL) as result;",
"SELECT 'array_join' as func_name, array_join(NULL, ',') as result;",
"SELECT 'array_join' as func_name, array_join(['a', 'b'], NULL) as result;",
"SELECT 'array_difference' as func_name, array_difference(NULL) as result;",
"SELECT 'array_slice' as func_name, array_slice(NULL, 1, 2) as result;",
"SELECT 'array_slice' as func_name, array_slice([1, 2, 3], NULL, 2) as result;",
"SELECT 'array_slice' as func_name, array_slice([1, 2, 3], 1, NULL) as result;",
"SELECT 'array_concat' as func_name, array_concat(NULL, [3, 4]) as result;",
"SELECT 'array_concat' as func_name, array_concat([1, 2], NULL) as result;",
"SELECT 'arrays_overlap' as func_name, arrays_overlap(NULL, [1, 2]) as result;",
"SELECT 'arrays_overlap' as func_name, arrays_overlap([1, 2], NULL) as result;",
"SELECT 'array_intersect' as func_name, array_intersect(NULL, [3, 4]) as result;",
"SELECT 'array_intersect' as func_name, array_intersect([1, 2], NULL) as result;",
"SELECT 'array_cum_sum' as func_name, array_cum_sum(NULL) as result;",
"SELECT 'array_contains_all' as func_name, array_contains_all(NULL, [1, 2]) as result;",
"SELECT 'array_contains_all' as func_name, array_contains_all([1, 2], NULL) as result;",
"SELECT 'array_contains_seq' as func_name, array_contains_seq(NULL, [1, 2]) as result;",
"SELECT 'array_contains_seq' as func_name, array_contains_seq([1, 2], NULL) as result;",
"SELECT 'all_match' as func_name, all_match(NULL) as result;",
"SELECT 'any_match' as func_name, any_match(NULL) as result;",
"SELECT 'array_generate' as func_name, array_generate(NULL, 1, 1) as result;",
"SELECT 'array_generate' as func_name, array_generate(1, NULL, 1) as result;",
"SELECT 'array_generate' as func_name, array_generate(1, 1, NULL) as result;",
"SELECT 'array_repeat' as func_name, array_repeat(NULL, 3) as result;",
"SELECT 'array_repeat' as func_name, array_repeat(1, NULL) as result;",
"SELECT 'array_flatten' as func_name, array_flatten(NULL) as result;",
"SELECT 'map_size' as func_name, map_size(NULL) as result;",
"SELECT 'map_keys' as func_name, map_keys(NULL) as result;",
"SELECT 'map_values' as func_name, map_values(NULL) as result;",
"SELECT 'map_from_arrays' as func_name, map_from_arrays(NULL, [1, 2]) as result;",
"SELECT 'map_from_arrays' as func_name, map_from_arrays(['a', 'b'], NULL) as result;",
"SELECT 'distinct_map_keys' as func_name, distinct_map_keys(NULL) as result;",
"SELECT 'cardinality' as func_name, cardinality(NULL) as result;",
"SELECT 'tokenize' as func_name, tokenize(NULL, ' ') as result;",
"SELECT 'tokenize' as func_name, tokenize('hello world', NULL) as result;"};
for (String q : queryList) {
System.out.println(q);
String plan = getFragmentPlan(q);
assertContains(plan, "<slot 3> : NULL");
}
}
@Test
public void testLeftJoin() throws Exception {
String sql = "select ta.v1, tb.v3 " + "from t0 ta left join t0 tb on ta.v1 = tb.v2 " +
"where array_contains([1,2,3], tb.v3)";
String plan = getFragmentPlan(sql);
assertCContains(plan,
" 3:HASH JOIN\n" + " | join op: INNER JOIN (BROADCAST)\n" + " | colocate: false, reason: \n" +
" | equal join conjunct: 5: v2 = 1: v1");
assertCContains(plan, " 0:OlapScanNode\n" + " TABLE: t0\n" + " PREAGGREGATION: ON\n" +
" PREDICATES: array_contains([1,2,3], 6: v3)");
}
}

View File

@ -677,8 +677,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase {
assertContains(plan, " 0:OlapScanNode\n"
+ " table: pc0, rollup: pc0\n"
+ " preAggregation: on\n"
+ " Predicates: array_length[([]); args: INVALID_TYPE; result: INT; args nullable: true; result "
+ "nullable: true] IS NOT NULL\n"
+ " partitionsRatio=0/1, tabletsRatio=0/0\n"
+ " tabletList=\n"
+ " actualRows=0, avgRowSize=1.0\n"
@ -692,8 +690,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase {
assertContains(plan, " 0:OlapScanNode\n"
+ " table: sc0, rollup: sc0\n"
+ " preAggregation: on\n"
+ " Predicates: array_length[([1,2,3]); args: INVALID_TYPE; result: INT; args nullable: true; result "
+ "nullable: true] IS NOT NULL\n"
+ " partitionsRatio=0/1, tabletsRatio=0/0\n"
+ " tabletList=\n"
+ " actualRows=0, avgRowSize=3.0\n"
@ -711,8 +707,6 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase {
assertContains(plan, " 0:OlapScanNode\n" +
" table: pc0, rollup: pc0\n" +
" preAggregation: on\n" +
" Predicates: array_length[([]); args: INVALID_TYPE; result: INT; args nullable: true; result nullable:"
+ " true] IS NOT NULL\n" +
" partitionsRatio=0/1, tabletsRatio=0/0\n" +
" tabletList=\n" +
" actualRows=0, avgRowSize=3.0\n" +
@ -734,7 +728,7 @@ public class PruneComplexSubfieldTest extends PlanTestNoneDBBase {
String sql = "select [1, 2, 3] is null from pc0 t1 right join sc0 t2 on t1.v1 = t2.v1;";
String plan = getFragmentPlan(sql);
assertContains(plan, "5:Project\n" +
" | <slot 15> : array_length([1,2,3]) IS NULL");
" | <slot 15> : FALSE");
sql = "select [1, 2, 3][1] is null from pc0 t1 right join sc0 t2 on t1.v1 = t2.v1;";
plan = getFragmentPlan(sql);