[BugFix] Do not simplify case-when with complex functions to avoid yielding very tedious result on scan node because of lack of CSE extraction (backport #62505) (#62519)
Signed-off-by: satanson <ranpanf@gmail.com> Co-authored-by: satanson <ranpanf@gmail.com>
This commit is contained in:
parent
2fe7265367
commit
6029763c83
|
|
@ -26,6 +26,8 @@ import com.starrocks.server.GlobalStateMgr;
|
|||
import com.starrocks.sql.optimizer.Utils;
|
||||
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static com.starrocks.catalog.Function.CompareMode.IS_IDENTICAL;
|
||||
import static com.starrocks.catalog.Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF;
|
||||
import static com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter.DEFAULT_TYPE_CAST_RULE;
|
||||
|
|
@ -90,4 +92,9 @@ public class ScalarOperatorUtil {
|
|||
.map(compOp -> compOp.isNot() && isSimpleLike(compOp.getChild(0)))
|
||||
.orElse(false);
|
||||
}
|
||||
|
||||
public static Stream<ScalarOperator> getStream(ScalarOperator operator) {
|
||||
return Stream.concat(Stream.of(operator),
|
||||
operator.getChildren().stream().flatMap(ScalarOperatorUtil::getStream));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,6 +66,12 @@ public class ScalarOperatorRewriter {
|
|||
SimplifiedCaseWhenRule.INSTANCE,
|
||||
PruneTediousPredicateRule.INSTANCE
|
||||
);
|
||||
|
||||
private static final List<ScalarOperatorRewriteRule> CASE_WHEN_PREDICATE_ON_SCAN_RULE = Lists.newArrayList(
|
||||
SimplifiedCaseWhenRule.SKIP_COMPLEX_FUNCTIONS_INSTANCE,
|
||||
PruneTediousPredicateRule.INSTANCE
|
||||
);
|
||||
|
||||
public static final List<ScalarOperatorRewriteRule> DEFAULT_REWRITE_SCAN_PREDICATE_RULES = Lists.newArrayList(
|
||||
// required
|
||||
new ImplicitCastRule(),
|
||||
|
|
@ -159,9 +165,12 @@ public class ScalarOperatorRewriter {
|
|||
return op;
|
||||
}
|
||||
|
||||
public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates) {
|
||||
// simplify case-when
|
||||
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_RULE);
|
||||
public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates, boolean isOnScan) {
|
||||
if (isOnScan) {
|
||||
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_ON_SCAN_RULE);
|
||||
} else {
|
||||
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_RULE);
|
||||
}
|
||||
}
|
||||
|
||||
public static ScalarOperator replaceScalarOperatorByColumnRef(ScalarOperator operator,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
|
|||
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.IsNullPredicateOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorVisitor;
|
||||
|
||||
import java.util.Collection;
|
||||
|
|
@ -39,6 +40,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator.CompoundType;
|
||||
|
|
@ -81,7 +83,6 @@ public class InvertedCaseWhen {
|
|||
return elseBranch;
|
||||
}
|
||||
|
||||
|
||||
public static ScalarOperator in(boolean isNotIn, ScalarOperator lhs, List<ScalarOperator> values) {
|
||||
Preconditions.checkArgument(!values.isEmpty());
|
||||
List<ScalarOperator> args = Lists.newArrayList(lhs);
|
||||
|
|
@ -116,9 +117,21 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
}
|
||||
|
||||
private static class InvertCaseWhenVisitor extends ScalarOperatorVisitor<Optional<InvertedCaseWhen>, Void> {
|
||||
private static class Context {
|
||||
private final boolean skipComplexFunctions;
|
||||
|
||||
public Context(boolean skipComplexFunctions) {
|
||||
this.skipComplexFunctions = skipComplexFunctions;
|
||||
}
|
||||
|
||||
public boolean isSkipComplexFunctions() {
|
||||
return skipComplexFunctions;
|
||||
}
|
||||
}
|
||||
|
||||
private static class InvertCaseWhenVisitor extends ScalarOperatorVisitor<Optional<InvertedCaseWhen>, Context> {
|
||||
@Override
|
||||
public Optional<InvertedCaseWhen> visit(ScalarOperator scalarOperator, Void context) {
|
||||
public Optional<InvertedCaseWhen> visit(ScalarOperator scalarOperator, Context context) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +224,22 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<InvertedCaseWhen> visitCaseWhenOperator(CaseWhenOperator operator, Void context) {
|
||||
public Optional<InvertedCaseWhen> visitCaseWhenOperator(CaseWhenOperator operator, Context context) {
|
||||
if (context.isSkipComplexFunctions()) {
|
||||
Predicate<ScalarOperator> isComplexFunction = op -> (op instanceof CallOperator)
|
||||
&& op.getChildren().stream()
|
||||
.map(ScalarOperator::getType)
|
||||
.anyMatch(t -> t.isComplexType() || t.isJsonType());
|
||||
|
||||
boolean existsComplexFunctions = operator.getAllConditionClause()
|
||||
.stream()
|
||||
.anyMatch(when -> ScalarOperatorUtil.getStream(when).anyMatch(isComplexFunction));
|
||||
|
||||
if (existsComplexFunctions) {
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
|
||||
if (!operator.getAllValuesClause().stream().allMatch(ScalarOperator::isConstantRef)) {
|
||||
return visit(operator, context);
|
||||
}
|
||||
|
|
@ -224,7 +252,7 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<InvertedCaseWhen> visitCall(CallOperator call, Void context) {
|
||||
public Optional<InvertedCaseWhen> visitCall(CallOperator call, Context context) {
|
||||
String fnName = call.getFnName();
|
||||
if (fnName.equals(FunctionSet.IF)) {
|
||||
ScalarOperator cond = call.getChild(0);
|
||||
|
|
@ -252,12 +280,10 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
private static final InvertCaseWhenVisitor INVERT_CASE_WHEN_VISITOR = new InvertCaseWhenVisitor();
|
||||
|
||||
public static Optional<InvertedCaseWhen> from(ScalarOperator op) {
|
||||
return op.accept(INVERT_CASE_WHEN_VISITOR, null);
|
||||
public static Optional<InvertedCaseWhen> from(ScalarOperator op, Context context) {
|
||||
return op.accept(INVERT_CASE_WHEN_VISITOR, context);
|
||||
}
|
||||
|
||||
private static ScalarOperator buildIfThen(ScalarOperator p, ConstantOperator first, ConstantOperator second) {
|
||||
|
|
@ -267,15 +293,15 @@ public class InvertedCaseWhen {
|
|||
Lists.newArrayList(p, first, second), ifFunc);
|
||||
}
|
||||
|
||||
private static class SimplifyVisitor extends ScalarOperatorVisitor<Optional<ScalarOperator>, Void> {
|
||||
private static class SimplifyVisitor extends ScalarOperatorVisitor<Optional<ScalarOperator>, Context> {
|
||||
|
||||
@Override
|
||||
public Optional<ScalarOperator> visit(ScalarOperator scalarOperator, Void context) {
|
||||
public Optional<ScalarOperator> visit(ScalarOperator scalarOperator, Context context) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ScalarOperator> visitInPredicate(InPredicateOperator predicate, Void context) {
|
||||
public Optional<ScalarOperator> visitInPredicate(InPredicateOperator predicate, Context context) {
|
||||
Set<ScalarOperator> inSet = predicate.getChildren().stream().skip(1).collect(Collectors.toSet());
|
||||
// col in (1, 2, 3) is not equal with col in (1, 2, 3, null). For col = 4, the first return false
|
||||
// while the second return null.
|
||||
|
|
@ -283,7 +309,7 @@ public class InvertedCaseWhen {
|
|||
if (!inSet.stream().allMatch(e -> e.isConstantRef() && !e.isNullable())) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Optional<InvertedCaseWhen> maybeInvertedCaseWhen = from(predicate.getChild(0));
|
||||
Optional<InvertedCaseWhen> maybeInvertedCaseWhen = from(predicate.getChild(0), context);
|
||||
if (!maybeInvertedCaseWhen.isPresent()) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
|
@ -363,8 +389,6 @@ public class InvertedCaseWhen {
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// for case when q1 then c1 when q2 then c2 ... else pn end, when it converted into InvertedCaseWhen
|
||||
// thenToWhen records (c1->p1, c2->p2, ..., pn)(lisp style), pi satisfies:
|
||||
// 1. p1 = q1;
|
||||
|
|
@ -412,7 +436,7 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<ScalarOperator> visitBinaryPredicate(BinaryPredicateOperator predicate, Void context) {
|
||||
public Optional<ScalarOperator> visitBinaryPredicate(BinaryPredicateOperator predicate, Context context) {
|
||||
BinaryType binaryType = predicate.getBinaryType();
|
||||
if (!binaryType.isEqual() && !binaryType.isNotEqual()) {
|
||||
return Optional.empty();
|
||||
|
|
@ -423,8 +447,8 @@ public class InvertedCaseWhen {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<ScalarOperator> visitIsNullPredicate(IsNullPredicateOperator predicate, Void context) {
|
||||
Optional<InvertedCaseWhen> maybeInvertedCaseWhen = from(predicate.getChild(0));
|
||||
public Optional<ScalarOperator> visitIsNullPredicate(IsNullPredicateOperator predicate, Context context) {
|
||||
Optional<InvertedCaseWhen> maybeInvertedCaseWhen = from(predicate.getChild(0), context);
|
||||
if (!maybeInvertedCaseWhen.isPresent()) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
|
@ -460,7 +484,7 @@ public class InvertedCaseWhen {
|
|||
|
||||
private static final SimplifyVisitor SIMPLIFY_VISITOR = new SimplifyVisitor();
|
||||
|
||||
public static ScalarOperator simplify(ScalarOperator op) {
|
||||
return op.accept(SIMPLIFY_VISITOR, null).orElse(op);
|
||||
public static ScalarOperator simplify(ScalarOperator op, boolean skipComplexFunctions) {
|
||||
return op.accept(SIMPLIFY_VISITOR, new Context(skipComplexFunctions)).orElse(op);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,13 +18,17 @@ import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
|
|||
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriteContext;
|
||||
|
||||
public class SimplifiedCaseWhenRule extends BottomUpScalarOperatorRewriteRule {
|
||||
private SimplifiedCaseWhenRule() {
|
||||
private final boolean skipComplexFunctions;
|
||||
|
||||
private SimplifiedCaseWhenRule(boolean skipComplexFunctions) {
|
||||
this.skipComplexFunctions = skipComplexFunctions;
|
||||
}
|
||||
|
||||
public static final SimplifiedCaseWhenRule INSTANCE = new SimplifiedCaseWhenRule();
|
||||
public static final SimplifiedCaseWhenRule INSTANCE = new SimplifiedCaseWhenRule(false);
|
||||
public static final SimplifiedCaseWhenRule SKIP_COMPLEX_FUNCTIONS_INSTANCE = new SimplifiedCaseWhenRule(true);
|
||||
|
||||
@Override
|
||||
public ScalarOperator apply(ScalarOperator root, ScalarOperatorRewriteContext context) {
|
||||
return InvertedCaseWhen.simplify(root);
|
||||
return InvertedCaseWhen.simplify(root, skipComplexFunctions);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ public class PushDownPredicateScanRule extends TransformationRule {
|
|||
ScalarOperatorRewriter scalarOperatorRewriter = new ScalarOperatorRewriter();
|
||||
ScalarOperator predicates = Utils.compoundAnd(lfo.getPredicate(), logicalScanOperator.getPredicate());
|
||||
|
||||
predicates = ScalarOperatorRewriter.simplifyCaseWhen(predicates);
|
||||
predicates = ScalarOperatorRewriter.simplifyCaseWhen(predicates, true);
|
||||
|
||||
ScalarRangePredicateExtractor rangeExtractor = new ScalarRangePredicateExtractor();
|
||||
predicates = rangeExtractor.rewriteOnlyColumn(Utils.compoundAnd(Utils.extractConjuncts(predicates)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import com.starrocks.sql.optimizer.operator.Operator;
|
|||
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
|
||||
import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator;
|
||||
import com.starrocks.sql.optimizer.operator.logical.LogicalRepeatOperator;
|
||||
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
|
||||
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
|
||||
import com.starrocks.sql.optimizer.task.TaskContext;
|
||||
|
|
@ -70,7 +71,8 @@ public class SimplifyCaseWhenPredicateRule implements TreeRewriteRule {
|
|||
if (predicate == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate);
|
||||
boolean isScan = optExpression.getOp() instanceof LogicalScanOperator;
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, isScan);
|
||||
if (newPredicate == predicate) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
|
@ -88,12 +90,12 @@ public class SimplifyCaseWhenPredicateRule implements TreeRewriteRule {
|
|||
}
|
||||
Optional<ScalarOperator> optNewOnPredicate =
|
||||
Optional.ofNullable(joinOperator.getOnPredicate()).map(predicate -> {
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate);
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, false);
|
||||
return newPredicate == predicate ? null : newPredicate;
|
||||
});
|
||||
Optional<ScalarOperator> optNewPredicate =
|
||||
Optional.ofNullable(joinOperator.getPredicate()).map(predicate -> {
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate);
|
||||
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, false);
|
||||
return newPredicate == predicate ? null : newPredicate;
|
||||
});
|
||||
Operator newOperator = LogicalJoinOperator.builder().withOperator(joinOperator)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import org.junit.jupiter.api.Test;
|
|||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.wildfly.common.Assert;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
|
@ -33,7 +34,6 @@ import java.util.stream.Stream;
|
|||
class SelectStmtWithCaseWhenTest {
|
||||
private static StarRocksAssert starRocksAssert;
|
||||
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp()
|
||||
throws Exception {
|
||||
|
|
@ -43,7 +43,7 @@ class SelectStmtWithCaseWhenTest {
|
|||
" `order_date` date NOT NULL COMMENT \"\",\n" +
|
||||
" `income` decimal(7, 0) NOT NULL COMMENT \"\",\n" +
|
||||
" `ship_mode` int NOT NULL COMMENT \"\",\n" +
|
||||
" `ship_code` int" +
|
||||
" `ship_code` int\n" +
|
||||
") ENGINE=OLAP \n" +
|
||||
"DUPLICATE KEY(`region`, `order_date`)\n" +
|
||||
"COMMENT \"OLAP\"\n" +
|
||||
|
|
@ -59,10 +59,18 @@ class SelectStmtWithCaseWhenTest {
|
|||
"\"replicated_storage\" = \"false\",\n" +
|
||||
"\"compression\" = \"LZ4\"\n" +
|
||||
");";
|
||||
String createTbl2StmtStr = " CREATE TABLE `t1` (\n" +
|
||||
" `id` varchar(128) NOT NULL COMMENT \"\",\n" +
|
||||
" `col_arr` array<varchar(100)> " +
|
||||
") ENGINE=OLAP \n" +
|
||||
"PROPERTIES (\n" +
|
||||
"\"replication_num\" = \"1\"\n" +
|
||||
");";
|
||||
|
||||
starRocksAssert = new StarRocksAssert();
|
||||
starRocksAssert.withDatabase("test").useDatabase("test");
|
||||
starRocksAssert.withTable(createTblStmtStr);
|
||||
starRocksAssert.withTable(createTbl2StmtStr);
|
||||
FeConstants.enablePruneEmptyOutputScan = false;
|
||||
FeConstants.setLengthForVarchar = false;
|
||||
}
|
||||
|
|
@ -527,7 +535,8 @@ class SelectStmtWithCaseWhenTest {
|
|||
{"<> 'A'", "[4: ship_mode, INT, false] < 90"},
|
||||
{"<> 'B'", "(4: ship_mode < 80) OR (4: ship_mode >= 90)"},
|
||||
{"<> 'C'", "(4: ship_mode < 70) OR ((4: ship_mode >= 90) OR (4: ship_mode >= 80))"},
|
||||
{"<> 'D'", "(4: ship_mode < 60) OR (((4: ship_mode >= 90) OR (4: ship_mode >= 80)) OR (4: ship_mode >= 70))"},
|
||||
{"<> 'D'",
|
||||
"(4: ship_mode < 60) OR (((4: ship_mode >= 90) OR (4: ship_mode >= 80)) OR (4: ship_mode >= 70))"},
|
||||
{"<> 'E'", "[4: ship_mode, INT, false] >= 60"},
|
||||
|
||||
{"in ('A','B')",
|
||||
|
|
@ -580,11 +589,9 @@ class SelectStmtWithCaseWhenTest {
|
|||
argumentsList.add(Arguments.of(sql, Arrays.asList(tc).subList(1, tc.length)));
|
||||
}
|
||||
|
||||
|
||||
return argumentsList.stream();
|
||||
}
|
||||
|
||||
|
||||
private static Stream<Arguments> caseWhenWithNullableCol() {
|
||||
String sqlTemp = "select * from test.t0 where (case \n" +
|
||||
" when ship_code >= 90 then 'A'\n" +
|
||||
|
|
@ -648,7 +655,6 @@ class SelectStmtWithCaseWhenTest {
|
|||
argumentsList.add(Arguments.of(sql, Arrays.asList(tc).subList(1, tc.length)));
|
||||
}
|
||||
|
||||
|
||||
return argumentsList.stream();
|
||||
}
|
||||
|
||||
|
|
@ -680,7 +686,8 @@ class SelectStmtWithCaseWhenTest {
|
|||
"if(1: region = 'USA', 1, 0) IN (2, 3, NULL)"},
|
||||
{"select * from test.t0 where (if(region='USA', 1, 0) in (2,3, null)) is null",
|
||||
"if(1: region = 'USA', 1, 0) IN (2, 3, NULL) IS NULL"},
|
||||
{"select * from test.t0 where if(region='USA', 1, 0) not in (0)", "[1: region, VARCHAR, false] = 'USA'"},
|
||||
{"select * from test.t0 where if(region='USA', 1, 0) not in (0)",
|
||||
"[1: region, VARCHAR, false] = 'USA'"},
|
||||
|
||||
{"select * from test.t0 where if(region='USA', 1, 0) not in (0,1)", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where if(region='USA', 1, 0) not in (2,3)", " 0:OlapScanNode\n" +
|
||||
|
|
@ -704,13 +711,14 @@ class SelectStmtWithCaseWhenTest {
|
|||
|
||||
{"select * from test.t0 where if(ship_code is null or ship_code > 2, 2, 1) != 2",
|
||||
"if[((5: ship_code IS NULL) OR (5: ship_code > 2), 2, 1)"},
|
||||
{"select * from test.t0 where if(ship_code is null or ship_code > 2, 1, 0) is NOT NULL", " 0:OlapScanNode\n" +
|
||||
" table: t0, rollup: t0\n" +
|
||||
" preAggregation: on\n" +
|
||||
" partitionsRatio=0/3, tabletsRatio=0/0\n" +
|
||||
" tabletList=\n" +
|
||||
" actualRows=0, avgRowSize=5.0\n" +
|
||||
" cardinality: 1\n"},
|
||||
{"select * from test.t0 where if(ship_code is null or ship_code > 2, 1, 0) is NOT NULL",
|
||||
" 0:OlapScanNode\n" +
|
||||
" table: t0, rollup: t0\n" +
|
||||
" preAggregation: on\n" +
|
||||
" partitionsRatio=0/3, tabletsRatio=0/0\n" +
|
||||
" tabletList=\n" +
|
||||
" actualRows=0, avgRowSize=5.0\n" +
|
||||
" cardinality: 1\n"},
|
||||
{"with tmp as (select ship_mode, if(ship_code > 4, 1, 0) as layer0, " +
|
||||
"if (ship_code >= 1 and ship_code <= 4, 1, 0) as layer1," +
|
||||
"if(ship_code is null or ship_code < 1, 1, 0) as layer2 from t0) " +
|
||||
|
|
@ -719,39 +727,46 @@ class SelectStmtWithCaseWhenTest {
|
|||
"if[((5: ship_code >= 1) AND (5: ship_code <= 4), 1, 0)"
|
||||
},
|
||||
|
||||
{"select * from test.t0 where nullif('China', region) = 'China'", "[1: region, VARCHAR, false] != 'China'"},
|
||||
{"select * from test.t0 where nullif('China', region) = 'China'",
|
||||
"[1: region, VARCHAR, false] != 'China'"},
|
||||
{"select * from test.t0 where nullif('China', region) <> 'China'", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where nullif('China', region) is NULL", "[1: region, VARCHAR, false] = 'China'"},
|
||||
{"select * from test.t0 where nullif('China', region) is NULL",
|
||||
"[1: region, VARCHAR, false] = 'China'"},
|
||||
{"select * from test.t0 where (nullif('China', region) is NULL) is NULL",
|
||||
"0:EMPTYSET"},
|
||||
{"select * from test.t0 where (nullif('China', region) is NULL) is NOT NULL",
|
||||
"0:OlapScanNode\n" +
|
||||
" table: t0, rollup: t0\n" +
|
||||
" preAggregation: on"},
|
||||
{"select * from test.t0 where nullif('China', region) is NOT NULL", "[1: region, VARCHAR, false] != 'China'"},
|
||||
{"select * from test.t0 where nullif('China', region) is NOT NULL",
|
||||
"[1: region, VARCHAR, false] != 'China'"},
|
||||
{"select * from test.t0 where (nullif('China', region) is NOT NULL) is NULL",
|
||||
"1: region != 'China' IS NULL"},
|
||||
{"select * from test.t0 where (nullif('China', region) is NOT NULL) is NOT NULL",
|
||||
"1: region != 'China' IS NOT NULL"},
|
||||
|
||||
{"select * from test.t0 where nullif('China', region) = 'USA'", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where nullif('China', region) <> 'USA'", "[1: region, VARCHAR, false] != 'China'"},
|
||||
{"select * from test.t0 where nullif('China', region) <> 'USA'",
|
||||
"[1: region, VARCHAR, false] != 'China'"},
|
||||
|
||||
{"select * from test.t0 where nullif(1, ship_code) = 1", "(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) = 1",
|
||||
"(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) <> 1", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) is NULL", "[5: ship_code, INT, true] = 1"},
|
||||
{"select * from test.t0 where (nullif(1, ship_code) is NULL) is NULL", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where (nullif(1, ship_code) is NULL) is NOT NULL",
|
||||
"0:OlapScanNode\n" +
|
||||
" table: t0, rollup: t0\n" +
|
||||
" preAggregation: on"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) is NOT NULL", "(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
" table: t0, rollup: t0\n" +
|
||||
" preAggregation: on"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) is NOT NULL",
|
||||
"(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where (nullif(1, ship_code) is NOT NULL) is NULL",
|
||||
"(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where (nullif(1, ship_code) is NOT NULL) is NOT NULL",
|
||||
"(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) = 2", "0:EMPTYSET"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) <> 2", "(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
{"select * from test.t0 where nullif(1, ship_code) <> 2",
|
||||
"(5: ship_code != 1) OR (5: ship_code IS NULL)"},
|
||||
};
|
||||
|
||||
List<Arguments> argumentsList = Lists.newArrayList();
|
||||
|
|
@ -824,7 +839,6 @@ class SelectStmtWithCaseWhenTest {
|
|||
"0:EMPTYSET"
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
List<Arguments> argumentsList = Lists.newArrayList();
|
||||
for (String[] tc : testCases) {
|
||||
|
|
@ -843,4 +857,26 @@ class SelectStmtWithCaseWhenTest {
|
|||
joiner.add(plan);
|
||||
Assertions.assertTrue(patterns.stream().anyMatch(plan::contains), joiner.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNotSimplifyCaseWhenSkipComplexFunctions() throws Exception {
|
||||
String sql = "WITH cte01 AS (\n" +
|
||||
" SELECT\n" +
|
||||
" id,\n" +
|
||||
" (\n" +
|
||||
" CASE\n" +
|
||||
" WHEN (ARRAY_LENGTH(col_arr) < 2) THEN \"bucket1\"\n" +
|
||||
" WHEN ((ARRAY_LENGTH(col_arr) >= 2) AND (ARRAY_LENGTH(col_arr) < 4)) THEN \"bucket2\"\n" +
|
||||
" ELSE NULL\n" +
|
||||
" END\n" +
|
||||
" ) AS len_bucket\n" +
|
||||
" FROM\n" +
|
||||
" t1\n" +
|
||||
")\n" +
|
||||
"SELECT id, len_bucket\n" +
|
||||
"FROM cte01\n" +
|
||||
"WHERE len_bucket IS NOT NULL;";
|
||||
String plan = UtFrameUtils.getVerboseFragmentPlan(starRocksAssert.getCtx(), sql);
|
||||
Assert.assertTrue(plan.contains("Predicates: CASE WHEN array_length"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue