[BugFix] Fix UnionToValuesRule bug with unligned constant values (backport #59647) (#63894)

Signed-off-by: Seaven <seaven_7@qq.com>
Co-authored-by: shuming.li <ming.moriarty@gmail.com>
Co-authored-by: Seaven <seaven_7@qq.com>
This commit is contained in:
mergify[bot] 2025-10-11 14:18:40 +08:00 committed by GitHub
parent 36acd25e9f
commit 2aed931195
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 59 additions and 36 deletions

View File

@ -70,7 +70,8 @@ public class LogicalValuesOperator extends LogicalOperator {
@Override
public RowOutputInfo deriveRowOutputInfo(List<OptExpression> inputs) {
return new RowOutputInfo(columnRefSet.stream().collect(Collectors.toMap(Function.identity(), Function.identity())));
return new RowOutputInfo(columnRefSet.stream().distinct()
.collect(Collectors.toMap(Function.identity(), Function.identity())));
}
@Override

View File

@ -55,7 +55,8 @@ public class PhysicalValuesOperator extends PhysicalOperator {
@Override
public RowOutputInfo deriveRowOutputInfo(List<OptExpression> inputs) {
return new RowOutputInfo(columnRefSet.stream().collect(Collectors.toMap(Function.identity(), Function.identity())));
return new RowOutputInfo(columnRefSet.stream().distinct()
.collect(Collectors.toMap(Function.identity(), Function.identity())));
}
@Override

View File

@ -17,6 +17,7 @@ package com.starrocks.sql.optimizer.rule.transformation;
import com.google.common.collect.Lists;
import com.starrocks.sql.optimizer.OptExpression;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.logical.LogicalUnionOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalValuesOperator;
@ -70,7 +71,9 @@ public class UnionToValuesRule extends TransformationRule {
public boolean check(OptExpression input, OptimizerContext context) {
return input.getInputs().stream()
.filter(UnionToValuesRule::isMergeable)
.count() > 1;
.skip(1)
.findFirst()
.isPresent();
}
@Override
@ -80,25 +83,22 @@ public class UnionToValuesRule extends TransformationRule {
List<List<ScalarOperator>> newRows = new ArrayList<>();
List<OptExpression> otherChildren = new ArrayList<>();
List<List<ColumnRefOperator>> newChildOutputs = Lists.newArrayList();
int firstLogicalValuesIndex = -1;
for (int i = 0; i < input.getInputs().size(); i++) {
final int numChildren = input.getInputs().size();
for (int i = 0; i < numChildren; i++) {
OptExpression child = input.getInputs().get(i);
if (isMergeable(child)) {
if (firstLogicalValuesIndex == -1) {
firstLogicalValuesIndex = i;
}
LogicalValuesOperator valuesOp = (LogicalValuesOperator) child.getOp();
List<List<ScalarOperator>> rows = valuesOp.getRows();
if (isConstantUnion(valuesOp)) {
List<ScalarOperator> scalarOperators = unionOp.getChildOutputColumns().get(i).stream()
.map(valuesOp.getProjection().getColumnRefMap()::get).collect(Collectors.toList());
List<ScalarOperator> scalarOperators = unionOp.getChildOutputColumns().get(i)
.stream()
.map(valuesOp.getProjection().getColumnRefMap()::get)
.collect(Collectors.toList());
newRows.add(scalarOperators);
} else if (isConstantValues(valuesOp)) {
} else {
newRows.addAll(rows);
}
} else {
newChildOutputs.add(unionOp.getChildOutputColumns().get(i));
otherChildren.add(child);
@ -106,26 +106,42 @@ public class UnionToValuesRule extends TransformationRule {
}
if (otherChildren.isEmpty()) {
LogicalValuesOperator newValuesOperator =
new LogicalValuesOperator.Builder().setColumnRefSet(unionOp.getOutputColumnRefOp()).setRows(newRows)
.setLimit(unionOp.getLimit())
.setPredicate(unionOp.getPredicate()).setProjection(unionOp.getProjection()).build();
LogicalValuesOperator newValuesOperator = new LogicalValuesOperator.Builder()
.setColumnRefSet(unionOp.getOutputColumnRefOp())
.setRows(newRows)
.setLimit(unionOp.getLimit())
.setPredicate(unionOp.getPredicate())
.setProjection(unionOp.getProjection())
.build();
return List.of(OptExpression.create(newValuesOperator));
} else {
List<OptExpression> inputs = new ArrayList<>(otherChildren);
if (!newRows.isEmpty()) {
LogicalValuesOperator newValuesOperator =
new LogicalValuesOperator.Builder()
.setColumnRefSet(unionOp.getChildOutputColumns().get(firstLogicalValuesIndex))
.setRows(newRows).setPredicate(null).build();
// use new ColumnRefOperator for the new child output columns to avoid conflicts
// eg:
// SELECT 'test1' AS c1, 'test1' AS c2, 'test1' AS c3
// UNION ALL
// SELECT 'test1' AS c1, 'test2' AS c2, 'test3' AS c3
// 1th child's original output only contain one element because of the same name 'test1',
// use new ColumnRefOperator to avoid the conflict.
final ColumnRefFactory columnRefFactory = context.getColumnRefFactory();
final List<ColumnRefOperator> newColRefs = unionOp.getChildOutputColumns().get(0)
.stream()
.map(c -> columnRefFactory.create(c, c.getType(), c.isNullable()))
.collect(Collectors.toUnmodifiableList());
final LogicalValuesOperator newValuesOperator = new LogicalValuesOperator.Builder()
.setColumnRefSet(newColRefs)
.setRows(newRows)
.setPredicate(null)
.build();
inputs.add(OptExpression.create(newValuesOperator));
newChildOutputs.add(newValuesOperator.getColumnRefSet());
}
LogicalUnionOperator newUnionOp =
new LogicalUnionOperator.Builder().withOperator(unionOp)
.setChildOutputColumns(newChildOutputs)
.build();
LogicalUnionOperator newUnionOp = new LogicalUnionOperator.Builder()
.withOperator(unionOp)
.setChildOutputColumns(newChildOutputs)
.build();
OptExpression newUnionExpr = OptExpression.create(newUnionOp, inputs);
return List.of(newUnionExpr);
@ -147,12 +163,12 @@ public class UnionToValuesRule extends TransformationRule {
private static boolean isConstantUnion(LogicalValuesOperator valuesOp) {
if (valuesOp.getProjection() == null ||
!valuesOp.getProjection().getColumnRefMap().values().stream().allMatch(ScalarOperator::isConstant)) {
valuesOp.getProjection().getColumnRefMap().values().stream().anyMatch(expr -> !expr.isConstant())) {
return false;
}
List<List<ScalarOperator>> rows = valuesOp.getRows();
if (!(rows.size() == 1 && rows.get(0).size() == 1)) {
if (rows.size() != 1 || rows.get(0).size() != 1) {
return false;
}

View File

@ -1912,6 +1912,7 @@ public class PlanFragmentBuilder {
}
tupleDescriptor.computeMemLayout();
final int dstSlotCount = tupleDescriptor.getSlots().size();
if (valuesOperator.getRows().isEmpty()) {
EmptySetNode emptyNode = new EmptySetNode(context.getNextNodeId(),
Lists.newArrayList(tupleDescriptor.getId()));
@ -1928,6 +1929,12 @@ public class PlanFragmentBuilder {
List<List<Expr>> consts = new ArrayList<>();
for (List<ScalarOperator> row : valuesOperator.getRows()) {
if (row.size() != dstSlotCount) {
throw new StarRocksPlannerException(
String.format("The number of columns in each row of values %s must be equal to the number of " +
"slots %s", row.size(), dstSlotCount),
INTERNAL_ERROR);
}
List<Expr> exprRow = new ArrayList<>();
for (ScalarOperator field : row) {
exprRow.add(ScalarOperatorToExpr.buildExecExpression(

View File

@ -637,7 +637,7 @@ public class SetTest extends PlanTestBase {
" | [11, DOUBLE, true]\n" +
" | child exprs:\n" +
" | [10: cast, DOUBLE, true]\n" +
" | [3: cast, DOUBLE, true]\n");
" | [12: cast, DOUBLE, true]");
sql = "(select 1 limit 1) UNION ALL select 2;";
plan = getVerboseExplain(sql);
@ -653,25 +653,23 @@ public class SetTest extends PlanTestBase {
" all select 2 union all select * from (values (3)) t";
plan = getVerboseExplain(sql);
assertContains(plan, " 0:UNION\n" +
" | output exprs:\n" +
assertContains(plan, "| output exprs:\n" +
" | [13, VARCHAR(32), true]\n" +
" | child exprs:\n" +
" | [1: k1, VARCHAR, true]\n" +
" | [12: cast, VARCHAR(32), false]\n" +
" | [7: cast, VARCHAR(32), true]\n");
" | [14: k1, VARCHAR(32), true]\n");
sql = "select k1 from db1.tbl6 union all select 1 union" +
" all select 2 union all select * from (values (3)) t";
plan = getVerboseExplain(sql);
assertContains(plan, " 0:UNION\n" +
" | output exprs:\n" +
assertContains(plan, " | output exprs:\n" +
" | [13, VARCHAR(32), true]\n" +
" | child exprs:\n" +
" | [1: k1, VARCHAR, true]\n" +
" | [12: cast, VARCHAR(32), false]\n" +
" | [7: cast, VARCHAR(32), true]\n");
" | [14: k1, VARCHAR(32), true]");
sql = "select 1 union all select 2 union all select * from (values (1)) t;";
plan = getVerboseExplain(sql);

View File

@ -86,4 +86,4 @@ select cast(100 as time);
select cast(1.123 as time);
-- result:
0:00:01
-- !result
-- !result

View File

@ -32,4 +32,4 @@ select cast(10000000 as float), cast(1000000000000000 as double);
select cast(0.00001 as float), cast(0.00001 as double);
select cast(100 as time);
select cast(1.123 as time);
select cast(1.123 as time);