[BugFix] Need local shuffle for the merged local agg (#14599)
This commit is contained in:
parent
5e63dbe70d
commit
1caaca622f
|
|
@ -228,14 +228,25 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> AggregateBlockingNode::d
|
|||
return context->maybe_interpolate_local_shuffle_exchange(runtime_state(), ops, group_by_expr_ctxs);
|
||||
};
|
||||
|
||||
if (agg_node.need_finalize && !sorted_streaming_aggregate) {
|
||||
// If finalize aggregate with group by clause, then it can be parallelized
|
||||
if (has_group_by_keys) {
|
||||
if (could_local_shuffle) {
|
||||
if (!sorted_streaming_aggregate) {
|
||||
// 1. Finalize aggregation:
|
||||
// - Without group by clause, it cannot be parallelized and need local passthough.
|
||||
// - With group by clause, it can be parallelized and need local shuffle when could_local_shuffle is true.
|
||||
// 2. Non-finalize aggregation:
|
||||
// - Without group by clause, it can be parallelized and needn't local shuffle.
|
||||
// - With group by clause, it can be parallelized and need local shuffle when could_local_shuffle is true.
|
||||
if (agg_node.need_finalize) {
|
||||
if (!has_group_by_keys) {
|
||||
ops_with_sink = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), ops_with_sink);
|
||||
} else if (could_local_shuffle) {
|
||||
ops_with_sink = try_interpolate_local_shuffle(ops_with_sink);
|
||||
}
|
||||
} else {
|
||||
ops_with_sink = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), ops_with_sink);
|
||||
if (!has_group_by_keys) {
|
||||
// Do nothing.
|
||||
} else if (could_local_shuffle) {
|
||||
ops_with_sink = try_interpolate_local_shuffle(ops_with_sink);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -124,9 +124,6 @@ public class PlanFragment extends TreeNode<PlanFragment> {
|
|||
protected int pipelineDop = 1;
|
||||
protected boolean dopEstimated = false;
|
||||
|
||||
// Enable shared_scan for this fragment: OlapScanOperator could share the output data to avoid data skew
|
||||
protected boolean enableSharedScan = true;
|
||||
|
||||
// Whether to assign scan ranges to each driver sequence of pipeline,
|
||||
// for the normal backend assignment (not colocate, bucket, and replicated join).
|
||||
protected boolean assignScanRangesPerDriverSeq = false;
|
||||
|
|
@ -239,14 +236,6 @@ public class PlanFragment extends TreeNode<PlanFragment> {
|
|||
this.forceSetTableSinkDop = true;
|
||||
}
|
||||
|
||||
public void setEnableSharedScan(boolean enable) {
|
||||
this.enableSharedScan = enable;
|
||||
}
|
||||
|
||||
public boolean isEnableSharedScan() {
|
||||
return enableSharedScan;
|
||||
}
|
||||
|
||||
public boolean isAssignScanRangesPerDriverSeq() {
|
||||
return assignScanRangesPerDriverSeq;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1418,8 +1418,7 @@ public class CoordinatorPreprocessor {
|
|||
if (isEnablePipelineEngine) {
|
||||
commonParams.setIs_pipeline(true);
|
||||
commonParams.getQuery_options().setBatch_size(SessionVariable.PIPELINE_BATCH_SIZE);
|
||||
commonParams.setEnable_shared_scan(
|
||||
sessionVariable.isEnableSharedScan() && fragment.isEnableSharedScan());
|
||||
commonParams.setEnable_shared_scan(sessionVariable.isEnableSharedScan());
|
||||
commonParams.params.setEnable_exchange_pass_through(sessionVariable.isEnableExchangePassThrough());
|
||||
commonParams.params.setEnable_exchange_perf(sessionVariable.isEnableExchangePerf());
|
||||
|
||||
|
|
|
|||
|
|
@ -87,6 +87,13 @@ public class PhysicalHashAggregateOperator extends PhysicalOperator {
|
|||
return type.isGlobal() && !isSplit;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether it is the first phase in three/four-phase agg whose second phase is pruned.
|
||||
*/
|
||||
public boolean isMergedLocalAgg() {
|
||||
return type.isLocal() && !useStreamingPreAgg;
|
||||
}
|
||||
|
||||
public List<ColumnRefOperator> getPartitionByColumns() {
|
||||
return partitionByColumns;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1530,10 +1530,9 @@ public class PlanFragmentBuilder {
|
|||
aggregationNode.setHasNullableGenerateChild();
|
||||
aggregationNode.computeStatistics(optExpr.getStatistics());
|
||||
|
||||
if ((node.isOnePhaseAgg() || node.getType().isDistinct())) {
|
||||
if (node.isOnePhaseAgg() || node.isMergedLocalAgg()) {
|
||||
// For ScanNode->LocalShuffle->AggNode, we needn't assign scan ranges per driver sequence.
|
||||
inputFragment.setAssignScanRangesPerDriverSeq(!withLocalShuffle);
|
||||
inputFragment.setEnableSharedScan(withLocalShuffle);
|
||||
aggregationNode.setWithLocalShuffle(withLocalShuffle);
|
||||
}
|
||||
|
||||
|
|
@ -2671,10 +2670,6 @@ public class PlanFragmentBuilder {
|
|||
|
||||
leftFragment.mergeQueryGlobalDicts(rightFragment.getQueryGlobalDicts());
|
||||
|
||||
if (distributionMode.equals(HashJoinNode.DistributionMode.COLOCATE)) {
|
||||
leftFragment.setEnableSharedScan(false);
|
||||
}
|
||||
|
||||
return leftFragment;
|
||||
} else if (distributionMode.equals(JoinNode.DistributionMode.SHUFFLE_HASH_BUCKET)) {
|
||||
setJoinPushDown(joinNode);
|
||||
|
|
|
|||
|
|
@ -1456,25 +1456,6 @@ public class AggregateTest extends PlanTestBase {
|
|||
" | <slot 21> : 3: t1c");
|
||||
assertContains(plan, "4:AGGREGATE (update serialize)\n" +
|
||||
" | output: multi_distinct_count(NULL)");
|
||||
|
||||
int prevAggStage = connectContext.getSessionVariable().getNewPlannerAggStage();
|
||||
try {
|
||||
connectContext.getSessionVariable().setNewPlanerAggStage(3);
|
||||
sql = "select count(distinct t1b) from test_all_type";
|
||||
|
||||
ExecPlan execPlan = UtFrameUtils.getPlanAndFragment(connectContext, sql).second;
|
||||
assertContains(execPlan.getFragments().get(1).getExplainString(TExplainLevel.NORMAL),
|
||||
" 4:AGGREGATE (update serialize)\n" +
|
||||
" | output: count(2: t1b)\n" +
|
||||
" | group by: \n" +
|
||||
" | \n" +
|
||||
" 3:AGGREGATE (merge serialize)\n" +
|
||||
" | group by: 2: t1b\n" +
|
||||
" | ");
|
||||
Assert.assertFalse(execPlan.getFragments().get(1).isEnableSharedScan());
|
||||
} finally {
|
||||
connectContext.getSessionVariable().setNewPlanerAggStage(prevAggStage);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import com.starrocks.catalog.OlapTable;
|
|||
import com.starrocks.common.FeConstants;
|
||||
import com.starrocks.common.Pair;
|
||||
import com.starrocks.planner.AggregationNode;
|
||||
import com.starrocks.planner.PlanFragment;
|
||||
import com.starrocks.qe.ConnectContext;
|
||||
import com.starrocks.sql.optimizer.statistics.ColumnStatistic;
|
||||
import com.starrocks.thrift.TExplainLevel;
|
||||
|
|
@ -1014,9 +1015,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
|
|||
public void testPruneAggNode() throws Exception {
|
||||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(3);
|
||||
String sql = "select count(distinct C_NAME) from customer group by C_CUSTKEY;";
|
||||
String plan = getFragmentPlan(sql);
|
||||
|
||||
assertContains(plan, "2:AGGREGATE (update finalize)\n" +
|
||||
ExecPlan plan = getExecPlan(sql);
|
||||
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
|
||||
assertContains(plan.getExplainString(TExplainLevel.NORMAL), "2:AGGREGATE (update finalize)\n" +
|
||||
" | output: count(2: C_NAME)\n" +
|
||||
" | group by: 1: C_CUSTKEY\n" +
|
||||
" | \n" +
|
||||
|
|
@ -1025,8 +1026,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
|
|||
|
||||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(4);
|
||||
sql = "select count(distinct C_CUSTKEY) from customer;";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
|
||||
plan = getExecPlan(sql);
|
||||
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
|
||||
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update serialize)\n" +
|
||||
" | output: count(1: C_CUSTKEY)\n" +
|
||||
" | group by: \n" +
|
||||
" | \n" +
|
||||
|
|
@ -1036,8 +1038,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
|
|||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(0);
|
||||
|
||||
sql = "select count(distinct C_CUSTKEY, C_NAME) from customer;";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
|
||||
plan = getExecPlan(sql);
|
||||
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
|
||||
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update serialize)\n" +
|
||||
" | output: count(if(1: C_CUSTKEY IS NULL, NULL, 2: C_NAME))\n" +
|
||||
" | group by: \n" +
|
||||
" | \n" +
|
||||
|
|
@ -1045,8 +1048,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
|
|||
" | group by: 1: C_CUSTKEY, 2: C_NAME");
|
||||
|
||||
sql = "select count(distinct C_CUSTKEY, C_NAME) from customer group by C_CUSTKEY;";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, " 2:AGGREGATE (update finalize)\n" +
|
||||
plan = getExecPlan(sql);
|
||||
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
|
||||
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update finalize)\n" +
|
||||
" | output: count(if(1: C_CUSTKEY IS NULL, NULL, 2: C_NAME))\n" +
|
||||
" | group by: 1: C_CUSTKEY\n" +
|
||||
" | \n" +
|
||||
|
|
@ -1054,6 +1058,70 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
|
|||
" | group by: 1: C_CUSTKEY, 2: C_NAME");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testAggNodeAndBucketDistribution() throws Exception {
|
||||
// For the local one-phase aggregation, enable AssignScanRangesPerDriverSeq and disable SharedScan.
|
||||
String sql = "select count(1) from customer group by C_CUSTKEY";
|
||||
ExecPlan plan = getExecPlan(sql);
|
||||
PlanFragment fragment = plan.getFragments().get(1);
|
||||
Assert.assertTrue(fragment.isAssignScanRangesPerDriverSeq());
|
||||
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update finalize)\n" +
|
||||
" | output: count(1)\n" +
|
||||
" | group by: 1: C_CUSTKEY\n" +
|
||||
" | \n" +
|
||||
" 0:OlapScanNode");
|
||||
|
||||
// For the none-local one-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
|
||||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(1);
|
||||
sql = "select count(1) from customer group by C_NAME";
|
||||
plan = getExecPlan(sql);
|
||||
fragment = plan.getFragments().get(2);
|
||||
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
|
||||
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " STREAM DATA SINK\n" +
|
||||
" EXCHANGE ID: 01\n" +
|
||||
" HASH_PARTITIONED: 2: C_NAME\n" +
|
||||
"\n" +
|
||||
" 0:OlapScanNode");
|
||||
|
||||
// For the two-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
|
||||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(2);
|
||||
sql = "select count(1) from customer group by C_NAME";
|
||||
plan = getExecPlan(sql);
|
||||
fragment = plan.getFragments().get(2);
|
||||
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
|
||||
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
|
||||
" | STREAMING\n" +
|
||||
" | output: count(1)\n" +
|
||||
" | group by: 2: C_NAME\n" +
|
||||
" | \n" +
|
||||
" 0:OlapScanNode");
|
||||
|
||||
// For the none-pruned four-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
|
||||
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(0);
|
||||
sql = "select count(distinct C_ADDRESS) from customer group by C_NAME";
|
||||
plan = getExecPlan(sql);
|
||||
fragment = plan.getFragments().get(2);
|
||||
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
|
||||
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
|
||||
" | STREAMING\n" +
|
||||
" | group by: 2: C_NAME, 3: C_ADDRESS\n" +
|
||||
" | \n" +
|
||||
" 0:OlapScanNode");
|
||||
|
||||
// For the none-pruned three-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
|
||||
sql = "select count(distinct C_ADDRESS) from customer";
|
||||
plan = getExecPlan(sql);
|
||||
System.out.println(plan.getExplainString(TExplainLevel.NORMAL));
|
||||
fragment = plan.getFragments().get(2);
|
||||
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
|
||||
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
|
||||
" | STREAMING\n" +
|
||||
" | group by: 3: C_ADDRESS\n" +
|
||||
" | \n" +
|
||||
" 0:OlapScanNode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCastDistributionPrune() throws Exception {
|
||||
String sql =
|
||||
|
|
|
|||
Loading…
Reference in New Issue