[BugFix] Fix invalid ProjectOperator above table-pruning frontier CTEConsumperOperator (backport #62914) (#62936)

This commit is contained in:
mergify[bot] 2025-09-10 19:52:19 +08:00 committed by GitHub
parent 9fba692e37
commit 35ae03962e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 1 deletions

View File

@ -32,6 +32,7 @@ import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.OperatorBuilderFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalCTEConsumeOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
@ -50,6 +51,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -1060,7 +1062,18 @@ public class CPJoinGardener extends OptExpressionVisitor<Boolean, Void> {
frontier = grafter.graft(frontier).orElse(frontier);
// step4: (Top-down) remove the unused ColumnRefOperators.
colRefMap.replaceAll((k, v) -> pruner.columnRefRewriter.rewrite(v));
// if Frontier is LogicalCTEConsumerOperator, getRowOutputInfo().getColumnRefMap() returns
// LogicalCTEConsumerOperator.cteOutputColumnRefMap, the input ColumnRefOperators in this map
// are output ColumnRefOperators of LogicalCTEProducerOperator, thus these ColumnRefOperators
// are dangling, so when we create a LogicalProjectOperator above LogicalCTEConsumerOperator,we
// just need map output ColumnRefOperators to themselves.
//LogicalCTEConsumeOperator cteConsumeOperator = frontier.getOp().cast();
if (frontier.getOp() instanceof LogicalCTEConsumeOperator) {
colRefMap = colRefMap.keySet().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
} else {
colRefMap.replaceAll((k, v) -> pruner.columnRefRewriter.rewrite(v));
}
LogicalProjectOperator projectOperator = new LogicalProjectOperator(colRefMap);
frontier = OptExpression.create(projectOperator, frontier);
Cleaner cleaner = new Cleaner(columnRefFactory);

View File

@ -25,6 +25,8 @@ import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.starrocks.sql.optimizer.statistics.CachedStatisticStorageTest.DEFAULT_CREATE_TABLE_TEMPLATE;
@ -402,4 +404,47 @@ public class TablePruningCTETest extends TablePruningTestBase {
" a.l_orderkey = d.l_orderkey and a.l_partkey = d.l_partkey and a.l_suppkey = d.l_suppkey;";
checkHashJoinCountWithBothRBOAndCBO(sql, 3);
}
@Test
public void testCteConsumerAsPruneFrontier() throws Exception {
String subquery = getSqlList("sql/tpch_pk_tables/", "lineitem_flat_subquery").get(0);
String viewSql = "CREATE VIEW lineitem_flat_view AS " + subquery;
starRocksAssert.withView(viewSql);
String cte = getSqlList("sql/tpch_pk_tables/", "lineitem_flat_cte").get(0);
String sql = String.format("with lineitem_flat as (select * from lineitem_flat_view),\n" +
"cteA as (\n" +
"select l_orderkey,l_partkey,l_suppkey,sum(l_quantity) as sum_qty\n" +
"from lineitem_flat \n" +
"group by l_orderkey,l_partkey,l_suppkey\n" +
"),\n" +
"cteB as(\n" +
"select l_orderkey,l_partkey,avg(l_quantity) as avg_qty\n" +
"from lineitem_flat \n" +
"group by l_orderkey,l_partkey\n" +
"),\n" +
"cteC as(\n" +
"select \n" +
" l_orderkey,\n" +
" count(distinct l_partkey) as uniq_partkey, \n" +
" count(distinct l_suppkey) as uniq_suppkey, \n" +
" count(distinct l_quantity) as uniq_qty\n" +
"from lineitem_flat \n" +
"group by l_orderkey\n" +
")\n" +
"\n" +
"select /*+SET_VAR(enable_cbo_table_prune=true,enable_rbo_table_prune=true)*/\n" +
"t.l_orderkey, t.l_partkey, t.l_suppkey, sum_qty, avg_qty, uniq_partkey, uniq_suppkey, uniq_qty\n" +
"from lineitem_flat t\n" +
" join cteA a on t.l_orderkey = a.l_orderkey AND t.l_partkey = a.l_partkey " +
" AND t.l_suppkey = a.l_suppkey\n" +
" join cteB b on t.l_orderkey = b.l_orderkey AND t.l_partkey = b.l_partkey\n" +
" join cteC c on t.l_orderkey = c.l_orderkey\n", cte);
String plan = UtFrameUtils.getVerboseFragmentPlan(starRocksAssert.getCtx(), sql);
List<String> tables = Stream.of(plan.split("\n"))
.filter(ln -> ln.matches("^\\s*table:\\s*\\S+,.*$"))
.collect(Collectors.toList());
Assertions.assertTrue(tables.size() > 0, plan);
Assertions.assertTrue(tables.stream().allMatch(ln -> ln.contains("lineitem")), plan);
}
}