[BugFix] fix non-deterministic predicate push down problem (#62827)
Signed-off-by: before-Sunrise <unclejyj@gmail.com>
This commit is contained in:
parent
bfb83e7e9c
commit
e0ce1d4b42
|
|
@ -838,15 +838,11 @@ public class Utils {
|
|||
}
|
||||
|
||||
public static boolean hasNonDeterministicFunc(ScalarOperator operator) {
|
||||
for (ScalarOperator child : operator.getChildren()) {
|
||||
if (child instanceof CallOperator) {
|
||||
CallOperator call = (CallOperator) child;
|
||||
String fnName = call.getFnName();
|
||||
if (FunctionSet.nonDeterministicFunctions.contains(fnName)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (hasNonDeterministicFuncImpl(operator)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (ScalarOperator child : operator.getChildren()) {
|
||||
if (hasNonDeterministicFunc(child)) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -854,6 +850,17 @@ public class Utils {
|
|||
return false;
|
||||
}
|
||||
|
||||
private static boolean hasNonDeterministicFuncImpl(ScalarOperator operator) {
|
||||
if (operator instanceof CallOperator) {
|
||||
CallOperator call = (CallOperator) operator;
|
||||
String fnName = call.getFnName();
|
||||
if (FunctionSet.nonDeterministicFunctions.contains(fnName)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static void calculateStatistics(OptExpression expr, OptimizerContext context) {
|
||||
for (OptExpression child : expr.getInputs()) {
|
||||
calculateStatistics(child, context);
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ import com.google.common.collect.Lists;
|
|||
import com.starrocks.catalog.FunctionSet;
|
||||
import com.starrocks.sql.optimizer.OptExpression;
|
||||
import com.starrocks.sql.optimizer.OptimizerContext;
|
||||
import com.starrocks.sql.optimizer.Utils;
|
||||
import com.starrocks.sql.optimizer.operator.OperatorType;
|
||||
import com.starrocks.sql.optimizer.operator.logical.LogicalFilterOperator;
|
||||
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
|
||||
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.LambdaFunctionOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
|
||||
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
|
||||
|
|
@ -35,9 +37,11 @@ import com.starrocks.sql.optimizer.rewrite.scalar.ScalarOperatorRewriteRule;
|
|||
import com.starrocks.sql.optimizer.rewrite.scalar.SimplifiedPredicateRule;
|
||||
import com.starrocks.sql.optimizer.rule.RuleType;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
public class PushDownPredicateProjectRule extends TransformationRule {
|
||||
private static final List<ScalarOperatorRewriteRule> PROJECT_REWRITE_PREDICATE_RULE = Lists.newArrayList(
|
||||
|
|
@ -101,8 +105,44 @@ public class PushDownPredicateProjectRule extends TransformationRule {
|
|||
}
|
||||
}
|
||||
|
||||
// Check if the filter's predicate contains non-deterministic functions
|
||||
// If it does, don't push down the predicate below project to avoid incorrect results
|
||||
List<ScalarOperator> compoundAndPredicates = Utils.extractConjuncts(filter.getPredicate());
|
||||
Set<ScalarOperator> deterministicPredicates = new HashSet<>();
|
||||
Set<ScalarOperator> nonDeterministicPredicates = new HashSet<>();
|
||||
for (var entry : project.getColumnRefMap().entrySet()) {
|
||||
if (Utils.hasNonDeterministicFunc(entry.getValue())) {
|
||||
compoundAndPredicates.forEach(scalarOperator -> {
|
||||
if (scalarOperator.getUsedColumns().contains(entry.getKey())) {
|
||||
nonDeterministicPredicates.add(scalarOperator);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
compoundAndPredicates.forEach(predicate -> {
|
||||
if (!nonDeterministicPredicates.contains(predicate)) {
|
||||
deterministicPredicates.add(predicate);
|
||||
}
|
||||
});
|
||||
|
||||
// if all non-deterministic predicate, do not push down!
|
||||
if (deterministicPredicates.isEmpty()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
ScalarOperator deterministicPredicateTree;
|
||||
if (nonDeterministicPredicates.isEmpty()) {
|
||||
deterministicPredicateTree = filter.getPredicate();
|
||||
} else {
|
||||
deterministicPredicateTree =
|
||||
Utils.createCompound(CompoundPredicateOperator.CompoundType.AND, deterministicPredicates);
|
||||
}
|
||||
|
||||
ScalarOperator nonDeterministicPredicateTree =
|
||||
Utils.createCompound(CompoundPredicateOperator.CompoundType.AND, nonDeterministicPredicates);
|
||||
|
||||
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(project.getColumnRefMap());
|
||||
ScalarOperator newPredicate = rewriter.rewrite(filter.getPredicate());
|
||||
ScalarOperator newPredicate = rewriter.rewrite(deterministicPredicateTree);
|
||||
|
||||
// try rewrite new predicate
|
||||
// e.g. : select 1 as b, MIN(v1) from t0 having (b + 1) != b;
|
||||
|
|
@ -113,6 +153,15 @@ public class PushDownPredicateProjectRule extends TransformationRule {
|
|||
newFilter.getInputs().addAll(child.getInputs());
|
||||
|
||||
OptExpression newProject = OptExpression.create(project, newFilter);
|
||||
return Lists.newArrayList(newProject);
|
||||
OptExpression root;
|
||||
if (!nonDeterministicPredicates.isEmpty()) {
|
||||
OptExpression nonDeterministicFilter =
|
||||
OptExpression.create(new LogicalFilterOperator(nonDeterministicPredicateTree), newProject);
|
||||
root = nonDeterministicFilter;
|
||||
} else {
|
||||
root = newProject;
|
||||
}
|
||||
|
||||
return Lists.newArrayList(root);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,17 +100,25 @@ public class ScalarOperatorsReuseRuleTest extends PlanTestBase {
|
|||
{
|
||||
String query = "select * from (select rand() as rnd) t where t.rnd < 10 or t.rnd > 20";
|
||||
String plan = getFragmentPlan(query);
|
||||
assertContains(plan, " 1:SELECT\n" +
|
||||
" | predicates: (3: rand < 10.0) OR (3: rand > 20.0)\n" +
|
||||
" | common sub expr:\n" +
|
||||
" | <slot 3> : rand()");
|
||||
assertContains(plan, "2:SELECT\n" +
|
||||
" | predicates: (2: rand < 10.0) OR (2: rand > 20.0)\n" +
|
||||
" | \n" +
|
||||
" 1:Project\n" +
|
||||
" | <slot 2> : rand()\n" +
|
||||
" | \n" +
|
||||
" 0:UNION");
|
||||
}
|
||||
{
|
||||
connectContext.getSessionVariable().setEnablePredicateExprReuse(false);
|
||||
String query = "select * from (select rand() as rnd) t where t.rnd < 10 or t.rnd > 20";
|
||||
String plan = getFragmentPlan(query);
|
||||
assertContains(plan, " 1:SELECT\n" +
|
||||
" | predicates: (rand() < 10.0) OR (rand() > 20.0)");
|
||||
assertContains(plan, " 2:SELECT\n" +
|
||||
" | predicates: (2: rand < 10.0) OR (2: rand > 20.0)\n" +
|
||||
" | \n" +
|
||||
" 1:Project\n" +
|
||||
" | <slot 2> : rand()\n" +
|
||||
" | \n" +
|
||||
" 0:UNION");
|
||||
connectContext.getSessionVariable().setEnablePredicateExprReuse(true);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -141,4 +141,30 @@ public class PredicatePushDownTest extends PlanTestBase {
|
|||
assertContains(plan, "5: v5 < 3");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNonDeterministicFunctionInCTE() throws Exception {
|
||||
// Test that predicates with non-deterministic functions are not pushed down through Project
|
||||
// This prevents rand() from being calculated twice in CTE scenarios
|
||||
String sql = "WITH input AS (SELECT v1, rand() AS rn FROM t0) " +
|
||||
"SELECT v1, rn, rn < 0.5 FROM input WHERE rn < 0.5";
|
||||
String plan = getFragmentPlan(sql);
|
||||
|
||||
// The plan should not push down the predicate through the project that contains rand()
|
||||
// Instead, the filter should remain above the project
|
||||
// The filter should not be pushed down to the scan level
|
||||
assertContains(plan, "3:Project\n" +
|
||||
" | <slot 1> : 1: v1\n" +
|
||||
" | <slot 4> : 4: rand\n" +
|
||||
" | <slot 5> : 4: rand < 0.5\n" +
|
||||
" | \n" +
|
||||
" 2:SELECT\n" +
|
||||
" | predicates: 4: rand < 0.5\n" +
|
||||
" | \n" +
|
||||
" 1:Project\n" +
|
||||
" | <slot 1> : 1: v1\n" +
|
||||
" | <slot 4> : rand()\n" +
|
||||
" | \n" +
|
||||
" 0:OlapScanNode");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue