[BugFix] Disable simplifying case-when with complex funtions to avoid time-consuming and tedious predicates (#63732)

Signed-off-by: satanson <ranpanf@gmail.com>
This commit is contained in:
satanson 2025-09-30 18:13:37 +08:00 committed by GitHub
parent e5bf03cda5
commit 1a92068146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 12 deletions

View File

@ -67,10 +67,11 @@ public class ScalarOperatorRewriter {
PruneTediousPredicateRule.INSTANCE
);
private static final List<ScalarOperatorRewriteRule> CASE_WHEN_PREDICATE_ON_SCAN_RULE = Lists.newArrayList(
SimplifiedCaseWhenRule.SKIP_COMPLEX_FUNCTIONS_INSTANCE,
PruneTediousPredicateRule.INSTANCE
);
private static final List<ScalarOperatorRewriteRule> CASE_WHEN_PREDICATE_SKIP_COMPLEX_FUNCTIONS =
Lists.newArrayList(
SimplifiedCaseWhenRule.SKIP_COMPLEX_FUNCTIONS_INSTANCE,
PruneTediousPredicateRule.INSTANCE
);
public static final List<ScalarOperatorRewriteRule> DEFAULT_REWRITE_SCAN_PREDICATE_RULES = Lists.newArrayList(
// required
@ -165,9 +166,9 @@ public class ScalarOperatorRewriter {
return op;
}
public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates, boolean isOnScan) {
if (isOnScan) {
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_ON_SCAN_RULE);
public static ScalarOperator simplifyCaseWhen(ScalarOperator predicates, boolean skipComplexFunctions) {
if (skipComplexFunctions) {
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_SKIP_COMPLEX_FUNCTIONS);
} else {
return new ScalarOperatorRewriter().rewrite(predicates, CASE_WHEN_PREDICATE_RULE);
}

View File

@ -20,7 +20,6 @@ import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalRepeatOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.task.TaskContext;
@ -71,8 +70,7 @@ public class SimplifyCaseWhenPredicateRule implements TreeRewriteRule {
if (predicate == null) {
return Optional.empty();
}
boolean isScan = optExpression.getOp() instanceof LogicalScanOperator;
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, isScan);
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, true);
if (newPredicate == predicate) {
return Optional.empty();
}
@ -90,12 +88,12 @@ public class SimplifyCaseWhenPredicateRule implements TreeRewriteRule {
}
Optional<ScalarOperator> optNewOnPredicate =
Optional.ofNullable(joinOperator.getOnPredicate()).map(predicate -> {
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, false);
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, true);
return newPredicate == predicate ? null : newPredicate;
});
Optional<ScalarOperator> optNewPredicate =
Optional.ofNullable(joinOperator.getPredicate()).map(predicate -> {
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, false);
ScalarOperator newPredicate = ScalarOperatorRewriter.simplifyCaseWhen(predicate, true);
return newPredicate == predicate ? null : newPredicate;
});
Operator newOperator = LogicalJoinOperator.builder().withOperator(joinOperator)

View File

@ -650,4 +650,50 @@ class SelectStmtWithCaseWhenTest {
" | 5 <-> array_length[([2: col_arr, ARRAY<VARCHAR(100)>, true]); " +
"args: INVALID_TYPE; result: INT; args nullable: true; result nullable: true]"));
}
@Test
public void testNotSimplifyCaseWhenSkipComplexFunctionsOnHashJoin() throws Exception {
String sql = "with cte1 AS (\n" +
"select id, col_arr\n" +
"from t1\n" +
"),\n" +
"cte2 AS (\n" +
"select ta.id as id, array_concat(ta.col_arr, tb.col_arr) as col_arr\n" +
"from cte1 ta inner join cte1 tb on ta.id = tb.id+1\n" +
"),\n" +
"cte3 AS (\n" +
" SELECT\n" +
" id,\n" +
" (\n" +
" CASE\n" +
" WHEN (ARRAY_LENGTH(col_arr) < 2) THEN \"bucket1\"\n" +
" WHEN ((ARRAY_LENGTH(col_arr) >= 2) AND (ARRAY_LENGTH(col_arr) < 4)) THEN \"bucket2\"\n" +
" ELSE NULL\n" +
" END\n" +
" ) AS len_bucket\n" +
" FROM\n" +
" cte2\n" +
")\n" +
"SELECT id, len_bucket\n" +
"FROM cte3\n" +
"WHERE len_bucket IS NOT NULL;";
String plan = UtFrameUtils.getVerboseFragmentPlan(starRocksAssert.getCtx(), sql);
Assert.assertTrue(plan.contains(" 6:HASH JOIN\n" +
" | join op: INNER JOIN (PARTITIONED)\n" +
" | equal join conjunct: [9: cast, DOUBLE, true] = [10: add, DOUBLE, true]\n" +
" | other join predicates: CASE WHEN " +
"array_length[(array_concat[([4: col_arr, ARRAY<VARCHAR(100)>, true], " +
"[6: col_arr, ARRAY<VARCHAR(100)>, true]); args: INVALID_TYPE; result:" +
" ARRAY<VARCHAR>; args nullable: true; result nullable: true]); args:" +
" INVALID_TYPE; result: INT; args nullable: true; result nullable: true] < 2 " +
"THEN 'bucket1' WHEN (array_length[(array_concat[([4: col_arr, ARRAY<VARCHAR(100)>, true]," +
" [6: col_arr, ARRAY<VARCHAR(100)>, true]); args: INVALID_TYPE; result: ARRAY<VARCHAR>; " +
"args nullable: true; result nullable: true]); args: INVALID_TYPE; result: INT; " +
"args nullable: true; result nullable: true] >= 2) AND " +
"(array_length[(array_concat[([4: col_arr, ARRAY<VARCHAR(100)>, true], " +
"[6: col_arr, ARRAY<VARCHAR(100)>, true]); args: INVALID_TYPE; result: " +
"ARRAY<VARCHAR>; args nullable: true; result nullable: true]); args: INVALID_TYPE; " +
"result: INT; args nullable: true; result nullable: true] < 4) THEN 'bucket2' " +
"ELSE NULL END IS NOT NULL\n"));
}
}