[BugFix] Need local shuffle for the merged local agg (#14599)

This commit is contained in:
zihe.liu 2022-12-05 21:06:16 +08:00 committed by GitHub
parent 5e63dbe70d
commit 1caaca622f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 52 deletions

View File

@ -228,14 +228,25 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> AggregateBlockingNode::d
return context->maybe_interpolate_local_shuffle_exchange(runtime_state(), ops, group_by_expr_ctxs);
};
if (agg_node.need_finalize && !sorted_streaming_aggregate) {
// If finalize aggregate with group by clause, then it can be parallelized
if (has_group_by_keys) {
if (could_local_shuffle) {
if (!sorted_streaming_aggregate) {
// 1. Finalize aggregation:
// - Without group by clause, it cannot be parallelized and need local passthough.
// - With group by clause, it can be parallelized and need local shuffle when could_local_shuffle is true.
// 2. Non-finalize aggregation:
// - Without group by clause, it can be parallelized and needn't local shuffle.
// - With group by clause, it can be parallelized and need local shuffle when could_local_shuffle is true.
if (agg_node.need_finalize) {
if (!has_group_by_keys) {
ops_with_sink = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), ops_with_sink);
} else if (could_local_shuffle) {
ops_with_sink = try_interpolate_local_shuffle(ops_with_sink);
}
} else {
ops_with_sink = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), ops_with_sink);
if (!has_group_by_keys) {
// Do nothing.
} else if (could_local_shuffle) {
ops_with_sink = try_interpolate_local_shuffle(ops_with_sink);
}
}
}

View File

@ -124,9 +124,6 @@ public class PlanFragment extends TreeNode<PlanFragment> {
protected int pipelineDop = 1;
protected boolean dopEstimated = false;
// Enable shared_scan for this fragment: OlapScanOperator could share the output data to avoid data skew
protected boolean enableSharedScan = true;
// Whether to assign scan ranges to each driver sequence of pipeline,
// for the normal backend assignment (not colocate, bucket, and replicated join).
protected boolean assignScanRangesPerDriverSeq = false;
@ -239,14 +236,6 @@ public class PlanFragment extends TreeNode<PlanFragment> {
this.forceSetTableSinkDop = true;
}
public void setEnableSharedScan(boolean enable) {
this.enableSharedScan = enable;
}
public boolean isEnableSharedScan() {
return enableSharedScan;
}
public boolean isAssignScanRangesPerDriverSeq() {
return assignScanRangesPerDriverSeq;
}

View File

@ -1418,8 +1418,7 @@ public class CoordinatorPreprocessor {
if (isEnablePipelineEngine) {
commonParams.setIs_pipeline(true);
commonParams.getQuery_options().setBatch_size(SessionVariable.PIPELINE_BATCH_SIZE);
commonParams.setEnable_shared_scan(
sessionVariable.isEnableSharedScan() && fragment.isEnableSharedScan());
commonParams.setEnable_shared_scan(sessionVariable.isEnableSharedScan());
commonParams.params.setEnable_exchange_pass_through(sessionVariable.isEnableExchangePassThrough());
commonParams.params.setEnable_exchange_perf(sessionVariable.isEnableExchangePerf());

View File

@ -87,6 +87,13 @@ public class PhysicalHashAggregateOperator extends PhysicalOperator {
return type.isGlobal() && !isSplit;
}
/**
* Whether it is the first phase in three/four-phase agg whose second phase is pruned.
*/
public boolean isMergedLocalAgg() {
return type.isLocal() && !useStreamingPreAgg;
}
public List<ColumnRefOperator> getPartitionByColumns() {
return partitionByColumns;
}

View File

@ -1530,10 +1530,9 @@ public class PlanFragmentBuilder {
aggregationNode.setHasNullableGenerateChild();
aggregationNode.computeStatistics(optExpr.getStatistics());
if ((node.isOnePhaseAgg() || node.getType().isDistinct())) {
if (node.isOnePhaseAgg() || node.isMergedLocalAgg()) {
// For ScanNode->LocalShuffle->AggNode, we needn't assign scan ranges per driver sequence.
inputFragment.setAssignScanRangesPerDriverSeq(!withLocalShuffle);
inputFragment.setEnableSharedScan(withLocalShuffle);
aggregationNode.setWithLocalShuffle(withLocalShuffle);
}
@ -2671,10 +2670,6 @@ public class PlanFragmentBuilder {
leftFragment.mergeQueryGlobalDicts(rightFragment.getQueryGlobalDicts());
if (distributionMode.equals(HashJoinNode.DistributionMode.COLOCATE)) {
leftFragment.setEnableSharedScan(false);
}
return leftFragment;
} else if (distributionMode.equals(JoinNode.DistributionMode.SHUFFLE_HASH_BUCKET)) {
setJoinPushDown(joinNode);

View File

@ -1456,25 +1456,6 @@ public class AggregateTest extends PlanTestBase {
" | <slot 21> : 3: t1c");
assertContains(plan, "4:AGGREGATE (update serialize)\n" +
" | output: multi_distinct_count(NULL)");
int prevAggStage = connectContext.getSessionVariable().getNewPlannerAggStage();
try {
connectContext.getSessionVariable().setNewPlanerAggStage(3);
sql = "select count(distinct t1b) from test_all_type";
ExecPlan execPlan = UtFrameUtils.getPlanAndFragment(connectContext, sql).second;
assertContains(execPlan.getFragments().get(1).getExplainString(TExplainLevel.NORMAL),
" 4:AGGREGATE (update serialize)\n" +
" | output: count(2: t1b)\n" +
" | group by: \n" +
" | \n" +
" 3:AGGREGATE (merge serialize)\n" +
" | group by: 2: t1b\n" +
" | ");
Assert.assertFalse(execPlan.getFragments().get(1).isEnableSharedScan());
} finally {
connectContext.getSessionVariable().setNewPlanerAggStage(prevAggStage);
}
}
@Test

View File

@ -5,6 +5,7 @@ import com.starrocks.catalog.OlapTable;
import com.starrocks.common.FeConstants;
import com.starrocks.common.Pair;
import com.starrocks.planner.AggregationNode;
import com.starrocks.planner.PlanFragment;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.optimizer.statistics.ColumnStatistic;
import com.starrocks.thrift.TExplainLevel;
@ -1014,9 +1015,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
public void testPruneAggNode() throws Exception {
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(3);
String sql = "select count(distinct C_NAME) from customer group by C_CUSTKEY;";
String plan = getFragmentPlan(sql);
assertContains(plan, "2:AGGREGATE (update finalize)\n" +
ExecPlan plan = getExecPlan(sql);
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
assertContains(plan.getExplainString(TExplainLevel.NORMAL), "2:AGGREGATE (update finalize)\n" +
" | output: count(2: C_NAME)\n" +
" | group by: 1: C_CUSTKEY\n" +
" | \n" +
@ -1025,8 +1026,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(4);
sql = "select count(distinct C_CUSTKEY) from customer;";
plan = getFragmentPlan(sql);
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
plan = getExecPlan(sql);
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update serialize)\n" +
" | output: count(1: C_CUSTKEY)\n" +
" | group by: \n" +
" | \n" +
@ -1036,8 +1038,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(0);
sql = "select count(distinct C_CUSTKEY, C_NAME) from customer;";
plan = getFragmentPlan(sql);
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
plan = getExecPlan(sql);
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update serialize)\n" +
" | output: count(if(1: C_CUSTKEY IS NULL, NULL, 2: C_NAME))\n" +
" | group by: \n" +
" | \n" +
@ -1045,8 +1048,9 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
" | group by: 1: C_CUSTKEY, 2: C_NAME");
sql = "select count(distinct C_CUSTKEY, C_NAME) from customer group by C_CUSTKEY;";
plan = getFragmentPlan(sql);
assertContains(plan, " 2:AGGREGATE (update finalize)\n" +
plan = getExecPlan(sql);
Assert.assertTrue(plan.getFragments().get(1).isAssignScanRangesPerDriverSeq());
assertContains(plan.getExplainString(TExplainLevel.NORMAL), " 2:AGGREGATE (update finalize)\n" +
" | output: count(if(1: C_CUSTKEY IS NULL, NULL, 2: C_NAME))\n" +
" | group by: 1: C_CUSTKEY\n" +
" | \n" +
@ -1054,6 +1058,70 @@ public class DistributedEnvPlanWithCostTest extends DistributedEnvPlanTestBase {
" | group by: 1: C_CUSTKEY, 2: C_NAME");
}
@Test
public void testAggNodeAndBucketDistribution() throws Exception {
// For the local one-phase aggregation, enable AssignScanRangesPerDriverSeq and disable SharedScan.
String sql = "select count(1) from customer group by C_CUSTKEY";
ExecPlan plan = getExecPlan(sql);
PlanFragment fragment = plan.getFragments().get(1);
Assert.assertTrue(fragment.isAssignScanRangesPerDriverSeq());
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update finalize)\n" +
" | output: count(1)\n" +
" | group by: 1: C_CUSTKEY\n" +
" | \n" +
" 0:OlapScanNode");
// For the none-local one-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(1);
sql = "select count(1) from customer group by C_NAME";
plan = getExecPlan(sql);
fragment = plan.getFragments().get(2);
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " STREAM DATA SINK\n" +
" EXCHANGE ID: 01\n" +
" HASH_PARTITIONED: 2: C_NAME\n" +
"\n" +
" 0:OlapScanNode");
// For the two-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(2);
sql = "select count(1) from customer group by C_NAME";
plan = getExecPlan(sql);
fragment = plan.getFragments().get(2);
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: count(1)\n" +
" | group by: 2: C_NAME\n" +
" | \n" +
" 0:OlapScanNode");
// For the none-pruned four-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
ConnectContext.get().getSessionVariable().setNewPlanerAggStage(0);
sql = "select count(distinct C_ADDRESS) from customer group by C_NAME";
plan = getExecPlan(sql);
fragment = plan.getFragments().get(2);
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | group by: 2: C_NAME, 3: C_ADDRESS\n" +
" | \n" +
" 0:OlapScanNode");
// For the none-pruned three-phase aggregation, disable AssignScanRangesPerDriverSeq and enable SharedScan.
sql = "select count(distinct C_ADDRESS) from customer";
plan = getExecPlan(sql);
System.out.println(plan.getExplainString(TExplainLevel.NORMAL));
fragment = plan.getFragments().get(2);
Assert.assertFalse(fragment.isAssignScanRangesPerDriverSeq());
assertContains(fragment.getExplainString(TExplainLevel.NORMAL), " 1:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | group by: 3: C_ADDRESS\n" +
" | \n" +
" 0:OlapScanNode");
}
@Test
public void testCastDistributionPrune() throws Exception {
String sql =