[BugFix] fix non-deterministic predicate push down problem (#62827)

Signed-off-by: before-Sunrise <unclejyj@gmail.com>
This commit is contained in:
before-Sunrise 2025-09-22 09:49:23 +08:00 committed by GitHub
parent bfb83e7e9c
commit e0ce1d4b42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 16 deletions

View File

@ -838,15 +838,11 @@ public class Utils {
}
public static boolean hasNonDeterministicFunc(ScalarOperator operator) {
for (ScalarOperator child : operator.getChildren()) {
if (child instanceof CallOperator) {
CallOperator call = (CallOperator) child;
String fnName = call.getFnName();
if (FunctionSet.nonDeterministicFunctions.contains(fnName)) {
return true;
}
}
if (hasNonDeterministicFuncImpl(operator)) {
return true;
}
for (ScalarOperator child : operator.getChildren()) {
if (hasNonDeterministicFunc(child)) {
return true;
}
@ -854,6 +850,17 @@ public class Utils {
return false;
}
private static boolean hasNonDeterministicFuncImpl(ScalarOperator operator) {
if (operator instanceof CallOperator) {
CallOperator call = (CallOperator) operator;
String fnName = call.getFnName();
if (FunctionSet.nonDeterministicFunctions.contains(fnName)) {
return true;
}
}
return false;
}
public static void calculateStatistics(OptExpression expr, OptimizerContext context) {
for (OptExpression child : expr.getInputs()) {
calculateStatistics(child, context);

View File

@ -18,12 +18,14 @@ import com.google.common.collect.Lists;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.Utils;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.LambdaFunctionOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
@ -35,9 +37,11 @@ import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule;
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
import com.starrocks.sql.optimizer.rule.RuleType;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
public class PushDownPredicateProjectRule extends TransformationRule {
private static final List<ScalarOperatorRewriteRule> PROJECT_REWRITE_PREDICATE_RULE = Lists.newArrayList(
@ -101,8 +105,44 @@ public class PushDownPredicateProjectRule extends TransformationRule {
}
}
// Check if the filter's predicate contains non-deterministic functions
// If it does, don't push down the predicate below project to avoid incorrect results
List<ScalarOperator> compoundAndPredicates = Utils.extractConjuncts(filter.getPredicate());
Set<ScalarOperator> deterministicPredicates = new HashSet<>();
Set<ScalarOperator> nonDeterministicPredicates = new HashSet<>();
for (var entry : project.getColumnRefMap().entrySet()) {
if (Utils.hasNonDeterministicFunc(entry.getValue())) {
compoundAndPredicates.forEach(scalarOperator -> {
if (scalarOperator.getUsedColumns().contains(entry.getKey())) {
nonDeterministicPredicates.add(scalarOperator);
}
});
}
}
compoundAndPredicates.forEach(predicate -> {
if (!nonDeterministicPredicates.contains(predicate)) {
deterministicPredicates.add(predicate);
}
});
// if all non-deterministic predicate, do not push down!
if (deterministicPredicates.isEmpty()) {
return Lists.newArrayList();
}
ScalarOperator deterministicPredicateTree;
if (nonDeterministicPredicates.isEmpty()) {
deterministicPredicateTree = filter.getPredicate();
} else {
deterministicPredicateTree =
Utils.createCompound(CompoundPredicateOperator.CompoundType.AND, deterministicPredicates);
}
ScalarOperator nonDeterministicPredicateTree =
Utils.createCompound(CompoundPredicateOperator.CompoundType.AND, nonDeterministicPredicates);
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(project.getColumnRefMap());
ScalarOperator newPredicate = rewriter.rewrite(filter.getPredicate());
ScalarOperator newPredicate = rewriter.rewrite(deterministicPredicateTree);
// try rewrite new predicate
// e.g. : select 1 as b, MIN(v1) from t0 having (b + 1) != b;
@ -113,6 +153,15 @@ public class PushDownPredicateProjectRule extends TransformationRule {
newFilter.getInputs().addAll(child.getInputs());
OptExpression newProject = OptExpression.create(project, newFilter);
return Lists.newArrayList(newProject);
OptExpression root;
if (!nonDeterministicPredicates.isEmpty()) {
OptExpression nonDeterministicFilter =
OptExpression.create(new LogicalFilterOperator(nonDeterministicPredicateTree), newProject);
root = nonDeterministicFilter;
} else {
root = newProject;
}
return Lists.newArrayList(root);
}
}

View File

@ -100,17 +100,25 @@ public class ScalarOperatorsReuseRuleTest extends PlanTestBase {
{
String query = "select * from (select rand() as rnd) t where t.rnd < 10 or t.rnd > 20";
String plan = getFragmentPlan(query);
assertContains(plan, " 1:SELECT\n" +
" | predicates: (3: rand < 10.0) OR (3: rand > 20.0)\n" +
" | common sub expr:\n" +
" | <slot 3> : rand()");
assertContains(plan, "2:SELECT\n" +
" | predicates: (2: rand < 10.0) OR (2: rand > 20.0)\n" +
" | \n" +
" 1:Project\n" +
" | <slot 2> : rand()\n" +
" | \n" +
" 0:UNION");
}
{
connectContext.getSessionVariable().setEnablePredicateExprReuse(false);
String query = "select * from (select rand() as rnd) t where t.rnd < 10 or t.rnd > 20";
String plan = getFragmentPlan(query);
assertContains(plan, " 1:SELECT\n" +
" | predicates: (rand() < 10.0) OR (rand() > 20.0)");
assertContains(plan, " 2:SELECT\n" +
" | predicates: (2: rand < 10.0) OR (2: rand > 20.0)\n" +
" | \n" +
" 1:Project\n" +
" | <slot 2> : rand()\n" +
" | \n" +
" 0:UNION");
connectContext.getSessionVariable().setEnablePredicateExprReuse(true);
}
}

View File

@ -141,4 +141,30 @@ public class PredicatePushDownTest extends PlanTestBase {
assertContains(plan, "5: v5 < 3");
}
}
@Test
public void testNonDeterministicFunctionInCTE() throws Exception {
// Test that predicates with non-deterministic functions are not pushed down through Project
// This prevents rand() from being calculated twice in CTE scenarios
String sql = "WITH input AS (SELECT v1, rand() AS rn FROM t0) " +
"SELECT v1, rn, rn < 0.5 FROM input WHERE rn < 0.5";
String plan = getFragmentPlan(sql);
// The plan should not push down the predicate through the project that contains rand()
// Instead, the filter should remain above the project
// The filter should not be pushed down to the scan level
assertContains(plan, "3:Project\n" +
" | <slot 1> : 1: v1\n" +
" | <slot 4> : 4: rand\n" +
" | <slot 5> : 4: rand < 0.5\n" +
" | \n" +
" 2:SELECT\n" +
" | predicates: 4: rand < 0.5\n" +
" | \n" +
" 1:Project\n" +
" | <slot 1> : 1: v1\n" +
" | <slot 4> : rand()\n" +
" | \n" +
" 0:OlapScanNode");
}
}