[Enhancement] support expr reuse in outer join where predicates (backport #62139) (#62625)

Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com>
Co-authored-by: eyes_on_me <nopainnofame@sina.com>
This commit is contained in:
mergify[bot] 2025-09-23 11:13:51 +08:00 committed by GitHub
parent 0aa7cfb99e
commit 449a3ab2a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 643 additions and 62 deletions

View File

@ -88,6 +88,13 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
_build_runtime_filters.emplace_back(rf_desc);
}
}
if (tnode.nestloop_join_node.__isset.common_slot_map) {
for (const auto& [key, val] : tnode.nestloop_join_node.common_slot_map) {
ExprContext* context;
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
_common_expr_ctxs.insert({key, context});
}
}
return Status::OK();
}
@ -608,10 +615,10 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> CrossJoinNode::_decompos
OpFactories left_ops = _children[0]->decompose_to_pipeline(context);
// communication with CrossJoinRight through shared_data.
auto left_factory =
std::make_shared<ProbeFactory>(context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(),
child(1)->row_desc(), _sql_join_conjuncts, std::move(_join_conjuncts),
std::move(_conjunct_ctxs), std::move(cross_join_context), _join_op);
auto left_factory = std::make_shared<ProbeFactory>(
context->next_operator_id(), id(), _row_descriptor, child(0)->row_desc(), child(1)->row_desc(),
_sql_join_conjuncts, std::move(_join_conjuncts), std::move(_conjunct_ctxs), std::move(_common_expr_ctxs),
std::move(cross_join_context), _join_op);
// Initialize OperatorFactory's fields involving runtime filters.
this->init_runtime_filter_for_operator(left_factory.get(), context, rc_rf_probe_collector);
if (!context->is_colocate_group()) {

View File

@ -128,6 +128,8 @@ private:
std::vector<RuntimeFilterBuildDescriptor*> _build_runtime_filters;
bool _interpolate_passthrough = false;
std::map<SlotId, ExprContext*> _common_expr_ctxs;
};
} // namespace starrocks

View File

@ -126,6 +126,14 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
_build_equivalence_partition_expr_ctxs = _build_expr_ctxs;
}
if (tnode.__isset.hash_join_node && tnode.hash_join_node.__isset.common_slot_map) {
for (const auto& [key, val] : tnode.hash_join_node.common_slot_map) {
ExprContext* context;
RETURN_IF_ERROR(Expr::create_expr_tree(_pool, val, &context, state, true));
_common_expr_ctxs.insert({key, context});
}
}
RETURN_IF_ERROR(Expr::create_expr_trees(_pool, tnode.hash_join_node.other_join_conjuncts,
&_other_join_conjunct_ctxs, state));
@ -484,7 +492,7 @@ pipeline::OpFactories HashJoinNode::_decompose_to_pipeline(pipeline::PipelineBui
_other_join_conjunct_ctxs, _conjunct_ctxs, child(1)->row_desc(), child(0)->row_desc(),
child(1)->type(), child(0)->type(), child(1)->conjunct_ctxs().empty(), _build_runtime_filters,
_output_slots, _output_slots, context->degree_of_parallelism(), _distribution_mode,
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join);
_enable_late_materialization, _enable_partition_hash_join, _is_skew_join, _common_expr_ctxs);
auto hash_joiner_factory = std::make_shared<starrocks::pipeline::HashJoinerFactory>(param);
// Create a shared RefCountedRuntimeFilterCollector

View File

@ -140,6 +140,8 @@ private:
bool _probe_eos = false; // probe table scan finished;
size_t _runtime_join_filter_pushdown_limit = 1024000;
std::map<SlotId, ExprContext*> _common_expr_ctxs;
RuntimeProfile::Counter* _build_timer = nullptr;
RuntimeProfile::Counter* _build_ht_timer = nullptr;
RuntimeProfile::Counter* _copy_right_table_chunk_timer = nullptr;

View File

@ -33,6 +33,7 @@
#include "pipeline/hashjoin/hash_joiner_fwd.h"
#include "runtime/current_thread.h"
#include "simd/simd.h"
#include "storage/chunk_helper.h"
#include "util/runtime_profile.h"
namespace starrocks {
@ -73,6 +74,7 @@ HashJoiner::HashJoiner(const HashJoinerParam& param)
_probe_expr_ctxs(param._probe_expr_ctxs),
_other_join_conjunct_ctxs(param._other_join_conjunct_ctxs),
_conjunct_ctxs(param._conjunct_ctxs),
_common_expr_ctxs(param._common_expr_ctxs),
_build_row_descriptor(param._build_row_descriptor),
_probe_row_descriptor(param._probe_row_descriptor),
_build_node_type(param._build_node_type),
@ -158,6 +160,11 @@ void HashJoiner::_init_hash_table_param(HashTableParam* param, RuntimeState* sta
param->column_view_concat_rows_limit = state->column_view_concat_rows_limit();
param->column_view_concat_bytes_limit = state->column_view_concat_bytes_limit();
std::set<SlotId> predicate_slots;
for (const auto& [slot_id, ctx] : _common_expr_ctxs) {
std::vector<SlotId> expr_slots;
ctx->root()->get_slot_ids(&expr_slots);
predicate_slots.insert(expr_slots.begin(), expr_slots.end());
}
for (ExprContext* expr_context : _conjunct_ctxs) {
std::vector<SlotId> expr_slots;
expr_context->root()->get_slot_ids(&expr_slots);
@ -388,6 +395,9 @@ Status HashJoiner::_calc_filter_for_other_conjunct(ChunkPtr* chunk, Filter& filt
hit_all = false;
filter.assign((*chunk)->num_rows(), 1);
CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
for (auto* ctx : _other_join_conjunct_ctxs) {
ASSIGN_OR_RETURN(ColumnPtr column, ctx->evaluate((*chunk).get()))
size_t true_count = ColumnHelper::count_true_with_notnull(column);
@ -516,6 +526,8 @@ Status HashJoiner::_process_other_conjunct(ChunkPtr* chunk, JoinHashTable& hash_
Status HashJoiner::_process_where_conjunct(ChunkPtr* chunk) {
SCOPED_TIMER(probe_metrics().where_conjunct_evaluate_timer);
CommonExprEvalScopeGuard guard(*chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
return ExecNode::eval_conjuncts(_conjunct_ctxs, (*chunk).get());
}

View File

@ -72,7 +72,8 @@ struct HashJoinerParam {
bool build_conjunct_ctxs_is_empty, std::list<RuntimeFilterBuildDescriptor*> build_runtime_filters,
std::set<SlotId> build_output_slots, std::set<SlotId> probe_output_slots, size_t max_dop,
const TJoinDistributionMode::type distribution_mode, bool enable_late_materialization,
bool enable_partition_hash_join, bool is_skew_join)
bool enable_partition_hash_join, bool is_skew_join,
const std::map<SlotId, ExprContext*>& common_expr_ctxs)
: _pool(pool),
_hash_join_node(hash_join_node),
_is_null_safes(std::move(is_null_safes)),
@ -92,7 +93,8 @@ struct HashJoinerParam {
_distribution_mode(distribution_mode),
_enable_late_materialization(enable_late_materialization),
_enable_partition_hash_join(enable_partition_hash_join),
_is_skew_join(is_skew_join) {}
_is_skew_join(is_skew_join),
_common_expr_ctxs(common_expr_ctxs) {}
HashJoinerParam(HashJoinerParam&&) = default;
HashJoinerParam(HashJoinerParam&) = default;
@ -120,6 +122,7 @@ struct HashJoinerParam {
const bool _enable_late_materialization;
const bool _enable_partition_hash_join;
const bool _is_skew_join;
const std::map<SlotId, ExprContext*> _common_expr_ctxs;
};
inline bool could_short_circuit(TJoinOp::type join_type) {
@ -439,6 +442,7 @@ private:
const std::vector<ExprContext*>& _other_join_conjunct_ctxs;
// Conjuncts in Join followed by a filter predicate, usually in Where and Having.
const std::vector<ExprContext*>& _conjunct_ctxs;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
const RowDescriptor& _build_row_descriptor;
const RowDescriptor& _probe_row_descriptor;
const TPlanNodeType::type _build_node_type;

View File

@ -19,16 +19,19 @@ namespace starrocks::pipeline {
Status HashJoinerFactory::prepare(RuntimeState* state) {
RETURN_IF_ERROR(Expr::prepare(_param._build_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._probe_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._other_join_conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_param._conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._build_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._probe_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._other_join_conjunct_ctxs, state));
RETURN_IF_ERROR(Expr::open(_param._conjunct_ctxs, state));
return Status::OK();
}
void HashJoinerFactory::close(RuntimeState* state) {
Expr::close(_param._common_expr_ctxs, state);
Expr::close(_param._conjunct_ctxs, state);
Expr::close(_param._other_join_conjunct_ctxs, state);
Expr::close(_param._probe_expr_ctxs, state);

View File

@ -22,6 +22,7 @@
#include "runtime/current_thread.h"
#include "runtime/descriptors.h"
#include "simd/simd.h"
#include "storage/chunk_helper.h"
namespace starrocks::pipeline {
@ -30,6 +31,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
const std::string& sql_join_conjuncts,
const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context)
: OperatorWithDependency(factory, id, "nestloop_join_probe", plan_node_id, false, driver_sequence),
@ -39,6 +41,7 @@ NLJoinProbeOperator::NLJoinProbeOperator(OperatorFactory* factory, int32_t id, i
_sql_join_conjuncts(sql_join_conjuncts),
_join_conjuncts(join_conjuncts),
_conjunct_ctxs(conjunct_ctxs),
_common_expr_ctxs(common_expr_ctxs),
_cross_join_context(cross_join_context) {}
Status NLJoinProbeOperator::prepare(RuntimeState* state) {
@ -309,6 +312,9 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk
// for null-aware left anti join, join_conjunct[0] is on-predicate
// others are other-conjuncts
// process on conjuncts
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
{
ASSIGN_OR_RETURN(ColumnPtr column, _join_conjuncts[0]->evaluate(chunk.get()));
size_t num_false = ColumnHelper::count_false_with_notnull(column);
@ -354,6 +360,8 @@ Status NLJoinProbeOperator::_eval_nullaware_anti_conjuncts(const ChunkPtr& chunk
Status NLJoinProbeOperator::_probe_for_inner_join(const ChunkPtr& chunk) {
if (!_join_conjuncts.empty() && chunk && !chunk->is_empty()) {
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), nullptr, true));
}
return Status::OK();
@ -374,7 +382,10 @@ Status NLJoinProbeOperator::_probe_for_other_join(const ChunkPtr& chunk) {
if (_is_null_aware_left_anti_join()) {
RETURN_IF_ERROR(_eval_nullaware_anti_conjuncts(chunk, &filter));
} else {
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts_and_in_filters(_join_conjuncts, chunk.get(), &filter, apply_filter));
chunk->check_or_die();
}
DCHECK(!!filter);
// The filter has not been assigned if no rows matched
@ -652,8 +663,11 @@ Status NLJoinProbeOperator::_permute_right_join(size_t chunk_size) {
}
}
permute_rows += chunk->num_rows();
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
{
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
}
RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
match_flag_index += cur_chunk_size;
}
@ -703,7 +717,11 @@ StatusOr<ChunkPtr> NLJoinProbeOperator::_pull_chunk_for_other_join(size_t chunk_
ASSIGN_OR_RETURN(ChunkPtr chunk, _permute_chunk_for_other_join(chunk_size));
DCHECK(chunk);
RETURN_IF_ERROR(_probe_for_other_join(chunk));
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
{
CommonExprEvalScopeGuard guard(chunk, _common_expr_ctxs);
RETURN_IF_ERROR(guard.evaluate());
RETURN_IF_ERROR(eval_conjuncts(_conjunct_ctxs, chunk.get(), nullptr));
}
RETURN_IF_ERROR(_output_accumulator.push(std::move(chunk)));
if (ChunkPtr res = _output_accumulator.pull()) {
@ -800,9 +818,9 @@ void NLJoinProbeOperatorFactory::_init_row_desc() {
}
OperatorPtr NLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
return std::make_shared<NLJoinProbeOperator>(this, _id, _plan_node_id, driver_sequence, _join_op,
_sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs, _col_types,
_probe_column_count, _cross_join_context);
return std::make_shared<NLJoinProbeOperator>(
this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs,
_common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context);
}
Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
@ -812,6 +830,9 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
_cross_join_context->ref();
_init_row_desc();
RETURN_IF_ERROR(Expr::prepare(_common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::open(_common_expr_ctxs, state));
RETURN_IF_ERROR(Expr::prepare(_join_conjuncts, state));
RETURN_IF_ERROR(Expr::open(_join_conjuncts, state));
RETURN_IF_ERROR(Expr::prepare(_conjunct_ctxs, state));
@ -821,6 +842,7 @@ Status NLJoinProbeOperatorFactory::prepare(RuntimeState* state) {
}
void NLJoinProbeOperatorFactory::close(RuntimeState* state) {
Expr::close(_common_expr_ctxs, state);
Expr::close(_join_conjuncts, state);
Expr::close(_conjunct_ctxs, state);

View File

@ -37,6 +37,7 @@ public:
NLJoinProbeOperator(OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence,
TJoinOp::type join_op, const std::string& sql_join_conjuncts,
const std::vector<ExprContext*>& join_conjuncts, const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context);
@ -115,6 +116,7 @@ private:
const std::vector<ExprContext*>& _join_conjuncts;
const std::vector<ExprContext*>& _conjunct_ctxs;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
const std::shared_ptr<NLJoinContext>& _cross_join_context;
bool _input_finished = false;
@ -147,6 +149,7 @@ public:
const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc,
std::string sql_join_conjuncts, std::vector<ExprContext*>&& join_conjuncts,
std::vector<ExprContext*>&& conjunct_ctxs,
std::map<SlotId, ExprContext*>&& common_expr_ctxs,
std::shared_ptr<NLJoinContext>&& cross_join_context, TJoinOp::type join_op)
: OperatorWithDependencyFactory(id, "cross_join_left", plan_node_id),
_join_op(join_op),
@ -155,6 +158,7 @@ public:
_sql_join_conjuncts(std::move(sql_join_conjuncts)),
_join_conjuncts(std::move(join_conjuncts)),
_conjunct_ctxs(std::move(conjunct_ctxs)),
_common_expr_ctxs(std::move(common_expr_ctxs)),
_cross_join_context(std::move(cross_join_context)) {}
~NLJoinProbeOperatorFactory() override = default;
@ -178,6 +182,7 @@ private:
std::string _sql_join_conjuncts;
std::vector<ExprContext*> _join_conjuncts;
std::vector<ExprContext*> _conjunct_ctxs;
std::map<SlotId, ExprContext*> _common_expr_ctxs;
std::shared_ptr<NLJoinContext> _cross_join_context;
};

View File

@ -27,12 +27,14 @@ namespace starrocks::pipeline {
NLJoinProber::NLJoinProber(TJoinOp::type join_op, const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count)
: _join_op(join_op),
_col_types(col_types),
_probe_column_count(probe_column_count),
_join_conjuncts(join_conjuncts),
_conjunct_ctxs(conjunct_ctxs) {}
_conjunct_ctxs(conjunct_ctxs),
_common_expr_ctxs(common_expr_ctxs) {}
Status NLJoinProber::prepare(RuntimeState* state, RuntimeProfile* profile) {
_permute_rows_counter = ADD_COUNTER(profile, "PermuteRows", TUnit::UNIT);
@ -115,10 +117,11 @@ void NLJoinProber::_permute_probe_row(Chunk* dst, const ChunkPtr& build_chunk) {
SpillableNLJoinProbeOperator::SpillableNLJoinProbeOperator(
OperatorFactory* factory, int32_t id, int32_t plan_node_id, int32_t driver_sequence, TJoinOp::type join_op,
const std::string& sql_join_conjuncts, const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs, const std::vector<SlotDescriptor*>& col_types,
size_t probe_column_count, const std::shared_ptr<NLJoinContext>& cross_join_context)
const std::vector<ExprContext*>& conjunct_ctxs, const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context)
: OperatorWithDependency(factory, id, "spillable_nestloop_join_probe", plan_node_id, false, driver_sequence),
_prober(join_op, join_conjuncts, conjunct_ctxs, col_types, probe_column_count),
_prober(join_op, join_conjuncts, conjunct_ctxs, common_expr_ctxs, col_types, probe_column_count),
_cross_join_context(cross_join_context) {}
Status SpillableNLJoinProbeOperator::prepare(RuntimeState* state) {
@ -244,9 +247,9 @@ void SpillableNLJoinProbeOperatorFactory::_init_row_desc() {
}
OperatorPtr SpillableNLJoinProbeOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
return std::make_shared<SpillableNLJoinProbeOperator>(this, _id, _plan_node_id, driver_sequence, _join_op,
_sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs,
_col_types, _probe_column_count, _cross_join_context);
return std::make_shared<SpillableNLJoinProbeOperator>(
this, _id, _plan_node_id, driver_sequence, _join_op, _sql_join_conjuncts, _join_conjuncts, _conjunct_ctxs,
_common_expr_ctxs, _col_types, _probe_column_count, _cross_join_context);
}
Status SpillableNLJoinProbeOperatorFactory::prepare(RuntimeState* state) {

View File

@ -30,8 +30,8 @@ namespace starrocks::pipeline {
class NLJoinProber {
public:
NLJoinProber(TJoinOp::type join_op, const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs, const std::vector<SlotDescriptor*>& col_types,
size_t probe_column_count);
const std::vector<ExprContext*>& conjunct_ctxs, const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count);
~NLJoinProber() = default;
@ -80,6 +80,7 @@ private:
const std::vector<ExprContext*>& _join_conjuncts;
const std::vector<ExprContext*>& _conjunct_ctxs;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
//
ChunkPtr _probe_chunk = nullptr;
@ -98,6 +99,7 @@ public:
TJoinOp::type join_op, const std::string& sql_join_conjuncts,
const std::vector<ExprContext*>& join_conjuncts,
const std::vector<ExprContext*>& conjunct_ctxs,
const std::map<SlotId, ExprContext*>& common_expr_ctxs,
const std::vector<SlotDescriptor*>& col_types, size_t probe_column_count,
const std::shared_ptr<NLJoinContext>& cross_join_context);
@ -153,6 +155,7 @@ public:
const RowDescriptor& left_row_desc, const RowDescriptor& right_row_desc,
std::string sql_join_conjuncts, std::vector<ExprContext*>&& join_conjuncts,
std::vector<ExprContext*>&& conjunct_ctxs,
std::map<SlotId, ExprContext*>&& common_expr_ctxs,
std::shared_ptr<NLJoinContext>&& cross_join_context, TJoinOp::type join_op)
: OperatorWithDependencyFactory(id, "spillable_nl_join_left", plan_node_id),
_join_op(join_op),
@ -161,6 +164,7 @@ public:
_sql_join_conjuncts(std::move(sql_join_conjuncts)),
_join_conjuncts(std::move(join_conjuncts)),
_conjunct_ctxs(std::move(conjunct_ctxs)),
_common_expr_ctxs(std::move(common_expr_ctxs)),
_cross_join_context(std::move(cross_join_context)) {}
~SpillableNLJoinProbeOperatorFactory() override = default;
@ -184,6 +188,7 @@ private:
std::string _sql_join_conjuncts;
std::vector<ExprContext*> _join_conjuncts;
std::vector<ExprContext*> _conjunct_ctxs;
std::map<SlotId, ExprContext*> _common_expr_ctxs;
std::shared_ptr<NLJoinContext> _cross_join_context;
};

View File

@ -513,6 +513,13 @@ Status Expr::prepare(const std::vector<ExprContext*>& ctxs, RuntimeState* state)
return Status::OK();
}
Status Expr::prepare(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state) {
for (const auto& [_, ctx] : ctxs) {
RETURN_IF_ERROR(ctx->prepare(state));
}
return Status::OK();
}
Status Expr::prepare(RuntimeState* state, ExprContext* context) {
FAIL_POINT_TRIGGER_RETURN_ERROR(randome_error);
DCHECK(_type.type != TYPE_UNKNOWN);
@ -529,6 +536,13 @@ Status Expr::open(const std::vector<ExprContext*>& ctxs, RuntimeState* state) {
return Status::OK();
}
Status Expr::open(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state) {
for (const auto& [_, ctx] : ctxs) {
RETURN_IF_ERROR(ctx->open(state));
}
return Status::OK();
}
Status Expr::open(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) {
FAIL_POINT_TRIGGER_RETURN_ERROR(random_error);
DCHECK(_type.type != TYPE_UNKNOWN);
@ -546,6 +560,14 @@ void Expr::close(const std::vector<ExprContext*>& ctxs, RuntimeState* state) {
}
}
void Expr::close(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state) {
for (const auto& [_, ctx] : ctxs) {
if (ctx != nullptr) {
ctx->close(state);
}
}
}
void Expr::close(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) {
for (auto& i : _children) {
i->close(state, context, scope);

View File

@ -179,9 +179,11 @@ public:
/// Convenience function for preparing multiple expr trees.
static Status prepare(const std::vector<ExprContext*>& ctxs, RuntimeState* state);
static Status prepare(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state);
/// Convenience function for opening multiple expr trees.
static Status open(const std::vector<ExprContext*>& ctxs, RuntimeState* state);
static Status open(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state);
/// Clones each ExprContext for multiple expr trees. 'new_ctxs' must be non-NULL.
/// Idempotent: if '*new_ctxs' is empty, a clone of each context in 'ctxs' will be added
@ -192,6 +194,7 @@ public:
/// Convenience function for closing multiple expr trees.
static void close(const std::vector<ExprContext*>& ctxs, RuntimeState* state);
static void close(const std::map<SlotId, ExprContext*>& ctxs, RuntimeState* state);
/// Convenience functions for closing a list of ScalarExpr.
static void close(const std::vector<Expr*>& exprs);

View File

@ -26,6 +26,7 @@
#include "column/struct_column.h"
#include "column/type_traits.h"
#include "column/vectorized_fwd.h"
#include "exprs/expr_context.h"
#include "gutil/strings/fastmem.h"
#include "runtime/current_thread.h"
#include "runtime/descriptors.h"
@ -1037,4 +1038,22 @@ void SegmentedChunk::check_or_die() {
}
}
CommonExprEvalScopeGuard::CommonExprEvalScopeGuard(const ChunkPtr& chunk,
const std::map<SlotId, ExprContext*>& common_expr_ctxs)
: _chunk(chunk), _common_expr_ctxs(common_expr_ctxs) {}
CommonExprEvalScopeGuard::~CommonExprEvalScopeGuard() {
for (const auto& [slot_id, _] : _common_expr_ctxs) {
_chunk->remove_column_by_slot_id(slot_id);
}
}
Status CommonExprEvalScopeGuard::evaluate() {
for (const auto& [slot_id, ctx] : _common_expr_ctxs) {
ASSIGN_OR_RETURN(auto column, ctx->evaluate(_chunk.get()));
_chunk->append_column(std::move(column), slot_id);
}
return Status::OK();
}
} // namespace starrocks

View File

@ -221,4 +221,29 @@ private:
const size_t _segment_size;
};
class ExprContext;
/**
* RAII guard for evaluating common expressions on a chunk.
*
* This class provides automatic scope management for evaluating common expressions
* that are temporarily used during expression computation. Common expressions are
* computed once and reused across multiple expressions to avoid redundant computation,
* but they are only needed during the computation phase and should be cleaned up
* from the chunk after computation completes.
*
* The destructor automatically removes the common expressions from the chunk
* to prevent memory leaks and ensure proper cleanup.
*/
class CommonExprEvalScopeGuard {
public:
CommonExprEvalScopeGuard(const ChunkPtr& chunk, const std::map<SlotId, ExprContext*>& common_expr_ctxs);
~CommonExprEvalScopeGuard();
Status evaluate();
private:
const ChunkPtr& _chunk;
const std::map<SlotId, ExprContext*>& _common_expr_ctxs;
};
} // namespace starrocks

View File

@ -192,6 +192,9 @@ public class HashJoinNode extends JoinNode {
if (isSkewJoin) {
msg.hash_join_node.setIs_skew_join(isSkewJoin);
}
if (commonSlotMap != null) {
commonSlotMap.forEach((key, value) -> msg.hash_join_node.putToCommon_slot_map(key.asInt(), value.treeToThrift()));
}
}
@Override

View File

@ -58,7 +58,9 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
@ -95,6 +97,7 @@ public abstract class JoinNode extends PlanNode implements RuntimeFilterBuildNod
// The partitionByExprs which need to check the probe side for partition join.
protected List<Expr> probePartitionByExprs;
protected boolean canLocalShuffle = false;
protected Map<SlotId, Expr> commonSlotMap;
public List<RuntimeFilterDescription> getBuildRuntimeFilters() {
return buildRuntimeFilters;
@ -492,6 +495,10 @@ public abstract class JoinNode extends PlanNode implements RuntimeFilterBuildNod
this.ukfkProperty = ukfkProperty;
}
public void setCommonSlotMap(Map<SlotId, Expr> commonSlotMap) {
this.commonSlotMap = commonSlotMap;
}
@Override
protected String getNodeExplainString(String detailPrefix, TExplainLevel detailLevel) {
String distrModeStr =
@ -527,6 +534,14 @@ public abstract class JoinNode extends PlanNode implements RuntimeFilterBuildNod
.append("\n");
}
if (commonSlotMap != null && !commonSlotMap.isEmpty()) {
output.append(detailPrefix + " common sub expr:" + "\n");
for (Map.Entry<SlotId, Expr> entry : commonSlotMap.entrySet()) {
output.append(detailPrefix + " <slot " + entry.getKey().toString() + "> : "
+ getExplainString(Arrays.asList(entry.getValue())) + "\n");
}
}
if (detailLevel == TExplainLevel.VERBOSE) {
if (!buildRuntimeFilters.isEmpty()) {

View File

@ -136,6 +136,10 @@ public class NestLoopJoinNode extends JoinNode implements RuntimeFilterBuildNode
msg.nestloop_join_node.setBuild_runtime_filters(
RuntimeFilterDescription.toThriftRuntimeFilterDescriptions(buildRuntimeFilters));
}
if (commonSlotMap != null) {
commonSlotMap.forEach((key, value) ->
msg.nestloop_join_node.putToCommon_slot_map(key.asInt(), value.treeToThrift()));
}
}
@Override

View File

@ -78,6 +78,11 @@ public class DictMappingOperator extends ScalarOperator {
public void setChild(int index, ScalarOperator child) {
}
@Override
public boolean isConstant() {
return false;
}
@Override
public String toString() {
String stringOperator = stringProvideOperator == null ? "" : ", " + stringProvideOperator;

View File

@ -20,7 +20,7 @@ import com.starrocks.sql.optimizer.OptExpressionVisitor;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.Projection;
import com.starrocks.sql.optimizer.operator.physical.PhysicalFilterOperator;
import com.starrocks.sql.optimizer.operator.physical.PhysicalOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rule.tree.TreeRewriteRule;
@ -49,10 +49,10 @@ public class ScalarOperatorsReuseRule implements TreeRewriteRule {
if (shouldRewritePredicate(opt, context)) {
Projection result = rewritePredicate(opt, context);
if (!result.getCommonSubOperatorMap().isEmpty()) {
PhysicalFilterOperator filter = (PhysicalFilterOperator) opt.getOp();
PhysicalOperator op = (PhysicalOperator) opt.getOp();
ScalarOperator newPredicate = result.getColumnRefMap().values().iterator().next();
filter.setPredicate(newPredicate);
filter.setPredicateCommonOperators(result.getCommonSubOperatorMap());
op.setPredicate(newPredicate);
op.setPredicateCommonOperators(result.getCommonSubOperatorMap());
}
}
@ -89,8 +89,9 @@ public class ScalarOperatorsReuseRule implements TreeRewriteRule {
|| input.getOp().getPredicate() == null) {
return false;
}
// for now, only support rewrite predicates in PhysicalFilterOperator
if (input.getOp().getOpType() == OperatorType.PHYSICAL_FILTER) {
if (input.getOp().getOpType() == OperatorType.PHYSICAL_FILTER ||
input.getOp().getOpType() == OperatorType.PHYSICAL_HASH_JOIN ||
input.getOp().getOpType() == OperatorType.PHYSICAL_NESTLOOP_JOIN) {
return true;
}
return false;

View File

@ -162,6 +162,9 @@ public class InputDependenciesChecker implements PlanValidator.Checker {
}
if (joinOperator.getPredicate() != null) {
usedCols.union(joinOperator.getPredicate().getUsedColumns());
if (joinOperator.getPredicateCommonOperators() != null) {
inputCols.union(joinOperator.getPredicateCommonOperators().keySet());
}
}
}
checkInputCols(inputCols, usedCols, optExpression);

View File

@ -2732,6 +2732,7 @@ public class PlanFragmentBuilder {
rightFragment.getPlanRoot().forceCollectExecStats();
this.currentExecGroup = leftExecGroup;
Map<SlotId, Expr> commonSubExprMap = buildCommonSubExprMap(node.getPredicateCommonOperators(), context);
List<Expr> conjuncts = extractConjuncts(node.getPredicate(), context);
List<Expr> joinOnConjuncts = extractConjuncts(node.getOnPredicate(), context);
List<Expr> probePartitionByExprs = Lists.newArrayList();
@ -2749,6 +2750,7 @@ public class PlanFragmentBuilder {
NestLoopJoinNode joinNode = new NestLoopJoinNode(context.getNextNodeId(),
leftFragment.getPlanRoot(), rightFragment.getPlanRoot(),
null, node.getJoinType(), Lists.newArrayList(), joinOnConjuncts);
joinNode.setCommonSlotMap(commonSubExprMap);
joinNode.setLimit(node.getLimit());
joinNode.computeStatistics(optExpr.getStatistics());
@ -2863,6 +2865,7 @@ public class PlanFragmentBuilder {
List<Expr> eqJoinConjuncts = joinExpr.eqJoinConjuncts;
List<Expr> otherJoinConjuncts = joinExpr.otherJoin;
List<Expr> conjuncts = joinExpr.conjuncts;
Map<SlotId, Expr> commonSlotMap = joinExpr.commonSubOperatorMap;
setNullableForJoin(joinOperator, leftFragment, rightFragment, context);
@ -2880,6 +2883,7 @@ public class PlanFragmentBuilder {
joinNode.setUkfkProperty(joinProperty);
}
}
joinNode.setCommonSlotMap(commonSlotMap);
// set skew join, this is used by runtime filter
PhysicalHashJoinOperator physicalHashJoinOperator = (PhysicalHashJoinOperator) node;
boolean isSkewJoin = physicalHashJoinOperator.getSkewColumn() != null;
@ -3414,23 +3418,8 @@ public class PlanFragmentBuilder {
TupleDescriptor tupleDescriptor = context.getDescTbl().createTupleDescriptor();
Map<SlotId, Expr> commonSubOperatorMap = Maps.newHashMap();
if (filter.getPredicateCommonOperators() != null) {
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : filter.getPredicateCommonOperators().entrySet()) {
Expr expr = ScalarOperatorToExpr.buildExecExpression(entry.getValue(),
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr(),
filter.getPredicateCommonOperators()));
Map<SlotId, Expr> commonSubOperatorMap = buildCommonSubExprMap(filter.getPredicateCommonOperators(), context);
commonSubOperatorMap.put(new SlotId(entry.getKey().getId()), expr);
SlotDescriptor slotDescriptor =
context.getDescTbl().addSlotDescriptor(tupleDescriptor, new SlotId(entry.getKey().getId()));
slotDescriptor.setIsNullable(expr.isNullable());
slotDescriptor.setIsMaterialized(false);
slotDescriptor.setType(expr.getType());
context.getColRefToExpr().put(entry.getKey(), new SlotRef(entry.getKey().toString(), slotDescriptor));
}
}
List<Expr> predicates = Utils.extractConjuncts(filter.getPredicate()).stream()
.map(d -> ScalarOperatorToExpr.buildExecExpression(d,
@ -3648,12 +3637,38 @@ public class PlanFragmentBuilder {
public final List<Expr> eqJoinConjuncts;
public final List<Expr> otherJoin;
public final List<Expr> conjuncts;
public final Map<SlotId, Expr> commonSubOperatorMap;
public JoinExprInfo(List<Expr> eqJoinConjuncts, List<Expr> otherJoin, List<Expr> conjuncts) {
public JoinExprInfo(List<Expr> eqJoinConjuncts, List<Expr> otherJoin, List<Expr> conjuncts,
Map<SlotId, Expr> commonSubOperatorMap) {
this.eqJoinConjuncts = eqJoinConjuncts;
this.otherJoin = otherJoin;
this.conjuncts = conjuncts;
this.commonSubOperatorMap = commonSubOperatorMap;
}
}
private Map<SlotId, Expr> buildCommonSubExprMap(
Map<ColumnRefOperator, ScalarOperator> commonSubOperators, ExecPlan context) {
Map<SlotId, Expr> commonSubExprMap = Maps.newHashMap();
if (commonSubOperators != null && !commonSubOperators.isEmpty()) {
TupleDescriptor tupleDescriptor = context.getDescTbl().createTupleDescriptor();
for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : commonSubOperators.entrySet()) {
Expr expr = ScalarOperatorToExpr.buildExecExpression(entry.getValue(),
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr(), commonSubOperators));
commonSubExprMap.put(new SlotId(entry.getKey().getId()), expr);
SlotDescriptor slotDescriptor =
context.getDescTbl().addSlotDescriptor(tupleDescriptor, new SlotId(entry.getKey().getId()));
slotDescriptor.setIsNullable(expr.isNullable());
slotDescriptor.setIsMaterialized(false);
slotDescriptor.setType(expr.getType());
context.getColRefToExpr().put(entry.getKey(), new SlotRef(entry.getKey().toString(), slotDescriptor));
}
}
return commonSubExprMap;
}
private JoinExprInfo buildJoinExpr(OptExpression optExpr, ExecPlan context) {
@ -3699,13 +3714,18 @@ public class PlanFragmentBuilder {
List<Expr> otherJoinConjuncts = otherJoin.stream().map(e -> ScalarOperatorToExpr.buildExecExpression(e,
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr())))
.collect(Collectors.toList());
Map<SlotId, Expr> commonSubExprMap = Maps.newHashMap();
if (optExpr.getOp() instanceof PhysicalJoinOperator) {
PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) optExpr.getOp();
commonSubExprMap = buildCommonSubExprMap(joinOperator.getPredicateCommonOperators(), context);
}
List<ScalarOperator> predicates = Utils.extractConjuncts(predicate);
List<Expr> conjuncts = predicates.stream().map(e -> ScalarOperatorToExpr.buildExecExpression(e,
new ScalarOperatorToExpr.FormatterContext(context.getColRefToExpr())))
.collect(Collectors.toList());
return new JoinExprInfo(eqJoinConjuncts, otherJoinConjuncts, conjuncts);
return new JoinExprInfo(eqJoinConjuncts, otherJoinConjuncts, conjuncts, commonSubExprMap);
}
// TODO(murphy) consider state distribution

View File

@ -16,6 +16,8 @@ package com.starrocks.planner;
import com.starrocks.common.FeConstants;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.plan.PlanTestBase;
import com.starrocks.sql.plan.PlanTestNoneDBBase;
import com.starrocks.statistic.StatsConstants;
import com.starrocks.utframe.StarRocksAssert;
import com.starrocks.utframe.UtFrameUtils;
@ -111,9 +113,15 @@ public class PushDownSubfieldHashJoinTest {
" | colocate: false, reason: \n" +
" | equal join conjunct: 3: fk = 1: fk\n" +
" | other predicates: CAST(array_sum(array_map(<slot 8> -> <slot 8> != 'A', " +
"if(array_length(array_filter(['A','B'], CAST([0,CAST((2: col_int = 1) AND " +
"(4: id IS NOT NULL) AS TINYINT)] AS ARRAY<BOOLEAN>))) = 0, ['C'], " +
"array_filter(['A','B'], CAST([0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)] " +
"AS ARRAY<BOOLEAN>))))) AS BOOLEAN)"));
"if(array_length(26: array_filter) = 0, ['C'], 26: array_filter))) AS BOOLEAN)\n" +
" | common sub expr:\n" +
" | <slot 20> : 2: col_int = 1\n" +
" | <slot 21> : 4: id IS NOT NULL\n" +
" | <slot 22> : (20: expr) AND (21: expr)\n" +
" | <slot 23> : CAST(22: expr AS TINYINT)\n" +
" | <slot 24> : [0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)]\n" +
" | <slot 25> : CAST([0,CAST((2: col_int = 1) AND (4: id IS NOT NULL) AS TINYINT)] AS ARRAY<BOOLEAN>)\n" +
" | <slot 26> : array_filter(['A','B'], 25: cast)"));
}
}

View File

@ -0,0 +1,130 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.starrocks.sql.optimizer;
import com.starrocks.sql.plan.PlanTestBase;
import org.junit.jupiter.api.Test;
public class JoinPredicateExprReuseTest extends PlanTestBase {
@Test
public void testHashJoin() throws Exception {
{
String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " +
"abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" +
" | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : abs(7: add)");
}
{
String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " +
"bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" +
" | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " +
"8: bit_shift_left IN (10, 20)\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : 7: add BITSHIFTLEFT 1");
}
{
String sql = "select * from t0 right join t1 on t0.v1 = t1.v4 where " +
"abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" +
" | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : abs(7: add)");
}
{
String sql = "select * from t0 right join t1 on t0.v1 = t1.v4 where " +
"bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" +
" | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " +
"8: bit_shift_left IN (10, 20)\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : 7: add BITSHIFTLEFT 1");
}
}
@Test
public void testNestLoopJoin() throws Exception {
{
String sql = "select * from t0 left join t1 on t0.v1 > t1.v4 where " +
"abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5";
String plan = getFragmentPlan(sql);
assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" +
" | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : abs(7: add)");
}
{
String sql = "select * from t0 left join t1 on t0.v1 > t1.v4 where " +
"bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20";
String plan = getFragmentPlan(sql);
assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" +
" | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " +
"8: bit_shift_left IN (10, 20)\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : 7: add BITSHIFTLEFT 1");
}
{
String sql = "select * from t0 right join t1 on t0.v1 > t1.v4 and t0.v2 = t1.v5 where " +
"abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) and abs(t0.v1 + t1.v4) > 5";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 2: v2 = 5: v5\n" +
" | other join predicates: 1: v1 > 4: v4\n" +
" | other predicates: 8: abs = abs(2: v2 + 5: v5), 8: abs > 5\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : abs(7: add)");
}
{
String sql = "select * from t0 right join t1 on t0.v1 > t1.v4 where " +
"bit_shift_left(t0.v1 + t1.v4, 1) = 10 or bit_shift_left(t0.v1 + t1.v4, 1) = 20";
String plan = getFragmentPlan(sql);
assertContains(plan, " | other join predicates: 1: v1 > 4: v4\n" +
" | other predicates: (8: bit_shift_left = 10) OR (8: bit_shift_left = 20), " +
"8: bit_shift_left IN (10, 20)\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : 7: add BITSHIFTLEFT 1");
}
{
String sql = "select * from t0 left join t1 on t0.v1 = t1.v4 where " +
"abs(t0.v1 + t1.v4) > 5 and abs(t0.v1 + t1.v4) < 10";
String plan = getFragmentPlan(sql);
assertContains(plan, " | equal join conjunct: 1: v1 = 4: v4\n" +
" | other predicates: 8: abs > 5, 8: abs < 10\n" +
" | common sub expr:\n" +
" | <slot 7> : 1: v1 + 4: v4\n" +
" | <slot 8> : abs(7: add)");
}
}
}

View File

@ -3326,12 +3326,11 @@ public class JoinTest extends PlanTestBase {
" | colocate: false, reason: \n" +
" | equal join conjunct: 1: v1 = 5: v4\n" +
" | equal join conjunct: 4: count = 8: count\n" +
" | other predicates: ((1: v1 != 1) AND (if(8: count = 1, 'a', 'b') = 'b')) OR ((1: v1 = 1) AND " +
"(if(8: count = 1, 'a', 'b') = 'b')), if(8: count = 1, 'a', 'b') = 'b'\n" +
" | \n" +
" |----5:EXCHANGE\n" +
" | \n" +
" 2:EXCHANGE");
" | other predicates: ((1: v1 != 1) AND (31: expr)) OR ((1: v1 = 1) AND (31: expr)), 31: expr\n" +
" | common sub expr:\n" +
" | <slot 29> : 8: count = 1\n" +
" | <slot 30> : if(29: expr, 'a', 'b')\n" +
" | <slot 31> : 30: if = 'b'");
}
}

View File

@ -810,10 +810,11 @@ public class LowCardinalityTest extends PlanTestBase {
// test input two string column
sql = "select if(S_ADDRESS='kks', S_COMMENT, S_COMMENT) from supplier";
plan = getVerboseExplain(sql);
Assertions.assertTrue(plan.contains(
"9 <-> if[(DictDecode(10: S_ADDRESS, [<place-holder> = 'kks']), DictDecode(11: S_COMMENT, [<place-holder>]), " +
"DictDecode(11: S_COMMENT, [<place-holder>])); args: BOOLEAN,VARCHAR,VARCHAR; " +
"result: VARCHAR; args nullable: true; result nullable: true]"));
assertContains(plan, " | 9 <-> if[(DictDecode(10: S_ADDRESS, [<place-holder> = 'kks']), " +
"[12: expr, VARCHAR(101), true], [12: expr, VARCHAR(101), true]); " +
"args: BOOLEAN,VARCHAR,VARCHAR; result: VARCHAR; args nullable: true; result nullable: true]\n" +
" | common expressions:\n" +
" | 12 <-> DictDecode(11: S_COMMENT, [<place-holder>])");
assertNotContains(plan, "DecodeNode");
// common expression reuse 3
@ -824,7 +825,10 @@ public class LowCardinalityTest extends PlanTestBase {
// support(support(unsupport(Column), unsupport(Column)))
sql = "select REVERSE(SUBSTR(LEFT(REVERSE(S_ADDRESS),INSTR(REVERSE(S_ADDRESS),'/')-1),5)) FROM supplier";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 9> : reverse(substr(left(DictDecode(10: S_ADDRESS, [reverse(<place-holder>)])");
assertContains(plan, " | <slot 9> : " +
"reverse(substr(left(11: expr, CAST(CAST(instr(11: expr, '/') AS BIGINT) - 1 AS INT)), 5))\n" +
" | common expressions:\n" +
" | <slot 11> : DictDecode(10: S_ADDRESS, [reverse(<place-holder>)])");
}
@Test

View File

@ -922,7 +922,10 @@ public class LowCardinalityTest2 extends PlanTestBase {
// support(support(unsupport(Column), unsupport(Column)))
sql = "select REVERSE(SUBSTR(LEFT(REVERSE(S_ADDRESS),INSTR(REVERSE(S_ADDRESS),'/')-1),5)) FROM supplier";
plan = getFragmentPlan(sql);
assertContains(plan, "<slot 9> : reverse(substr(left(DictDecode(10: S_ADDRESS, [reverse(<place-holder>)])");
assertContains(plan, " | <slot 9> : reverse(substr(left(11: expr, " +
"CAST(CAST(instr(11: expr, '/') AS BIGINT) - 1 AS INT)), 5))\n" +
" | common expressions:\n" +
" | <slot 11> : DictDecode(10: S_ADDRESS, [reverse(<place-holder>)])");
}
@Test

View File

@ -766,6 +766,8 @@ struct THashJoinNode {
56: optional bool late_materialization = false
57: optional bool enable_partition_hash_join = false
58: optional bool is_skew_join = false
59: optional map<Types.TSlotId, Exprs.TExpr> common_slot_map
}
struct TMergeJoinNode {
@ -805,6 +807,7 @@ struct TNestLoopJoinNode {
3: optional list<Exprs.TExpr> join_conjuncts
4: optional string sql_join_conjuncts
5: optional bool interpolate_passthrough = false
6: optional map<Types.TSlotId, Exprs.TExpr> common_slot_map
}
enum TAggregationOp {

View File

@ -0,0 +1,135 @@
-- name: test_outer_join_predicate_expr_reuse
CREATE TABLE t0 (
v1 INT,
v2 INT,
v3 VARCHAR(20)
) DUPLICATE KEY(v1)
DISTRIBUTED BY HASH(v1) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
-- result:
-- !result
CREATE TABLE t1 (
v4 INT,
v5 INT,
v6 VARCHAR(20)
) DUPLICATE KEY(v4)
DISTRIBUTED BY HASH(v4) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
-- result:
-- !result
INSERT INTO t0 VALUES
(1, 10, 'a'), (2, 20, 'b'), (3, 30, 'c'), (4, 40, 'd'), (5, 50, 'e'),
(6, 60, 'f'), (7, 70, 'g'), (8, 80, 'h'), (9, 90, 'i'), (10, 100, 'j'),
(11, 110, 'a'), (12, 120, 'b'), (13, 130, 'c'), (14, 140, 'd'), (15, 150, 'e'),
(16, 160, 'f'), (17, 170, 'g'), (18, 180, 'h'), (19, 190, 'i'), (20, 200, 'j'),
(21, 210, 'a'), (22, 220, 'b'), (23, 230, 'c'), (24, 240, 'd'), (25, 250, 'e'),
(26, 260, 'f'), (27, 270, 'g'), (28, 280, 'h'), (29, 290, 'i'), (30, 300, 'j'),
(31, 310, 'a'), (32, 320, 'b'), (33, 330, 'c'), (34, 340, 'd'), (35, 350, 'e'),
(36, 360, 'f'), (37, 370, 'g'), (38, 380, 'h'), (39, 390, 'i'), (40, 400, 'j'),
(41, 410, 'a'), (42, 420, 'b'), (43, 430, 'c'), (44, 440, 'd'), (45, 450, 'e'),
(46, 460, 'f'), (47, 470, 'g'), (48, 480, 'h'), (49, 490, 'i'), (50, 500, 'j'),
(51, 510, 'a'), (52, 520, 'b'), (53, 530, 'c'), (54, 540, 'd'), (55, 550, 'e'),
(56, 560, 'f'), (57, 570, 'g'), (58, 580, 'h'), (59, 590, 'i'), (60, 600, 'j'),
(61, 610, 'a'), (62, 620, 'b'), (63, 630, 'c'), (64, 640, 'd'), (65, 650, 'e'),
(66, 660, 'f'), (67, 670, 'g'), (68, 680, 'h'), (69, 690, 'i'), (70, 700, 'j'),
(71, 710, 'a'), (72, 720, 'b'), (73, 730, 'c'), (74, 740, 'd'), (75, 750, 'e'),
(76, 760, 'f'), (77, 770, 'g'), (78, 780, 'h'), (79, 790, 'i'), (80, 800, 'j'),
(81, 810, 'a'), (82, 820, 'b'), (83, 830, 'c'), (84, 840, 'd'), (85, 850, 'e'),
(86, 860, 'f'), (87, 870, 'g'), (88, 880, 'h'), (89, 890, 'i'), (90, 900, 'j'),
(91, 910, 'a'), (92, 920, 'b'), (93, 930, 'c'), (94, 940, 'd'), (95, 950, 'e'),
(96, 960, 'f'), (97, 970, 'g'), (98, 980, 'h'), (99, 990, 'i'), (100, 1000, 'j');
-- result:
-- !result
INSERT INTO t1 VALUES
(1, 15, 'x'), (2, 25, 'y'), (3, 35, 'z'), (4, 45, 'w'), (5, 55, 'v'),
(6, 65, 'u'), (7, 75, 't'), (8, 85, 's'), (9, 95, 'r'), (10, 105, 'q'),
(11, 115, 'x'), (12, 125, 'y'), (13, 135, 'z'), (14, 145, 'w'), (15, 155, 'v'),
(16, 165, 'u'), (17, 175, 't'), (18, 185, 's'), (19, 195, 'r'), (20, 205, 'q'),
(21, 215, 'x'), (22, 225, 'y'), (23, 235, 'z'), (24, 245, 'w'), (25, 255, 'v'),
(26, 265, 'u'), (27, 275, 't'), (28, 285, 's'), (29, 295, 'r'), (30, 305, 'q'),
(31, 315, 'x'), (32, 325, 'y'), (33, 335, 'z'), (34, 345, 'w'), (35, 355, 'v'),
(36, 365, 'u'), (37, 375, 't'), (38, 385, 's'), (39, 395, 'r'), (40, 405, 'q'),
(41, 415, 'x'), (42, 425, 'y'), (43, 435, 'z'), (44, 445, 'w'), (45, 455, 'v'),
(46, 465, 'u'), (47, 475, 't'), (48, 485, 's'), (49, 495, 'r'), (50, 505, 'q'),
(51, 515, 'x'), (52, 525, 'y'), (53, 535, 'z'), (54, 545, 'w'), (55, 555, 'v'),
(56, 565, 'u'), (57, 575, 't'), (58, 585, 's'), (59, 595, 'r'), (60, 605, 'q'),
(61, 615, 'x'), (62, 625, 'y'), (63, 635, 'z'), (64, 645, 'w'), (65, 655, 'v'),
(66, 665, 'u'), (67, 675, 't'), (68, 685, 's'), (69, 695, 'r'), (70, 705, 'q'),
(71, 715, 'x'), (72, 725, 'y'), (73, 735, 'z'), (74, 745, 'w'), (75, 755, 'v'),
(76, 765, 'u'), (77, 775, 't'), (78, 785, 's'), (79, 795, 'r'), (80, 805, 'q'),
(81, 815, 'x'), (82, 825, 'y'), (83, 835, 'z'), (84, 845, 'w'), (85, 855, 'v'),
(86, 865, 'u'), (87, 875, 't'), (88, 885, 's'), (89, 895, 'r'), (90, 905, 'q'),
(91, 915, 'x'), (92, 925, 'y'), (93, 935, 'z'), (94, 945, 'w'), (95, 955, 'v'),
(96, 965, 'u'), (97, 975, 't'), (98, 985, 's'), (99, 995, 'r'), (100, 1005, 'q');
-- result:
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
-- result:
1
-- !result
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
-- result:
1
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
-- result:
6
-- !result
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 AND t0.v2 = t1.v5
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10;
-- result:
2
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5)
AND bit_shift_left(t0.v1 + t1.v4, 1) > 10
AND bit_shift_left(t0.v1 + t1.v4, 1) < 20;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4
WHERE (abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10)
OR (abs(t0.v1 + t1.v4) > 15 AND abs(t0.v1 + t1.v4) < 20);
-- result:
44
-- !result
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(bit_shift_left(t0.v1 + t1.v4, 1)) = abs(bit_shift_left(t0.v2 + t1.v5, 1))
AND abs(bit_shift_left(t0.v1 + t1.v4, 1)) > 10;
-- result:
0
-- !result
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE (t0.v1 + t1.v4) * 2 = (t0.v2 + t1.v5) * 2
AND (t0.v1 + t1.v4) * 2 > 10
AND (t0.v1 + t1.v4) * 2 < 100;
-- result:
0
-- !result

View File

@ -0,0 +1,106 @@
-- name: test_outer_join_predicate_expr_reuse
CREATE TABLE t0 (
v1 INT,
v2 INT,
v3 VARCHAR(20)
) DUPLICATE KEY(v1)
DISTRIBUTED BY HASH(v1) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
CREATE TABLE t1 (
v4 INT,
v5 INT,
v6 VARCHAR(20)
) DUPLICATE KEY(v4)
DISTRIBUTED BY HASH(v4) BUCKETS 3
PROPERTIES (
"replication_num" = "1"
);
INSERT INTO t0 VALUES
(1, 10, 'a'), (2, 20, 'b'), (3, 30, 'c'), (4, 40, 'd'), (5, 50, 'e'),
(6, 60, 'f'), (7, 70, 'g'), (8, 80, 'h'), (9, 90, 'i'), (10, 100, 'j'),
(11, 110, 'a'), (12, 120, 'b'), (13, 130, 'c'), (14, 140, 'd'), (15, 150, 'e'),
(16, 160, 'f'), (17, 170, 'g'), (18, 180, 'h'), (19, 190, 'i'), (20, 200, 'j'),
(21, 210, 'a'), (22, 220, 'b'), (23, 230, 'c'), (24, 240, 'd'), (25, 250, 'e'),
(26, 260, 'f'), (27, 270, 'g'), (28, 280, 'h'), (29, 290, 'i'), (30, 300, 'j'),
(31, 310, 'a'), (32, 320, 'b'), (33, 330, 'c'), (34, 340, 'd'), (35, 350, 'e'),
(36, 360, 'f'), (37, 370, 'g'), (38, 380, 'h'), (39, 390, 'i'), (40, 400, 'j'),
(41, 410, 'a'), (42, 420, 'b'), (43, 430, 'c'), (44, 440, 'd'), (45, 450, 'e'),
(46, 460, 'f'), (47, 470, 'g'), (48, 480, 'h'), (49, 490, 'i'), (50, 500, 'j'),
(51, 510, 'a'), (52, 520, 'b'), (53, 530, 'c'), (54, 540, 'd'), (55, 550, 'e'),
(56, 560, 'f'), (57, 570, 'g'), (58, 580, 'h'), (59, 590, 'i'), (60, 600, 'j'),
(61, 610, 'a'), (62, 620, 'b'), (63, 630, 'c'), (64, 640, 'd'), (65, 650, 'e'),
(66, 660, 'f'), (67, 670, 'g'), (68, 680, 'h'), (69, 690, 'i'), (70, 700, 'j'),
(71, 710, 'a'), (72, 720, 'b'), (73, 730, 'c'), (74, 740, 'd'), (75, 750, 'e'),
(76, 760, 'f'), (77, 770, 'g'), (78, 780, 'h'), (79, 790, 'i'), (80, 800, 'j'),
(81, 810, 'a'), (82, 820, 'b'), (83, 830, 'c'), (84, 840, 'd'), (85, 850, 'e'),
(86, 860, 'f'), (87, 870, 'g'), (88, 880, 'h'), (89, 890, 'i'), (90, 900, 'j'),
(91, 910, 'a'), (92, 920, 'b'), (93, 930, 'c'), (94, 940, 'd'), (95, 950, 'e'),
(96, 960, 'f'), (97, 970, 'g'), (98, 980, 'h'), (99, 990, 'i'), (100, 1000, 'j');
INSERT INTO t1 VALUES
(1, 15, 'x'), (2, 25, 'y'), (3, 35, 'z'), (4, 45, 'w'), (5, 55, 'v'),
(6, 65, 'u'), (7, 75, 't'), (8, 85, 's'), (9, 95, 'r'), (10, 105, 'q'),
(11, 115, 'x'), (12, 125, 'y'), (13, 135, 'z'), (14, 145, 'w'), (15, 155, 'v'),
(16, 165, 'u'), (17, 175, 't'), (18, 185, 's'), (19, 195, 'r'), (20, 205, 'q'),
(21, 215, 'x'), (22, 225, 'y'), (23, 235, 'z'), (24, 245, 'w'), (25, 255, 'v'),
(26, 265, 'u'), (27, 275, 't'), (28, 285, 's'), (29, 295, 'r'), (30, 305, 'q'),
(31, 315, 'x'), (32, 325, 'y'), (33, 335, 'z'), (34, 345, 'w'), (35, 355, 'v'),
(36, 365, 'u'), (37, 375, 't'), (38, 385, 's'), (39, 395, 'r'), (40, 405, 'q'),
(41, 415, 'x'), (42, 425, 'y'), (43, 435, 'z'), (44, 445, 'w'), (45, 455, 'v'),
(46, 465, 'u'), (47, 475, 't'), (48, 485, 's'), (49, 495, 'r'), (50, 505, 'q'),
(51, 515, 'x'), (52, 525, 'y'), (53, 535, 'z'), (54, 545, 'w'), (55, 555, 'v'),
(56, 565, 'u'), (57, 575, 't'), (58, 585, 's'), (59, 595, 'r'), (60, 605, 'q'),
(61, 615, 'x'), (62, 625, 'y'), (63, 635, 'z'), (64, 645, 'w'), (65, 655, 'v'),
(66, 665, 'u'), (67, 675, 't'), (68, 685, 's'), (69, 695, 'r'), (70, 705, 'q'),
(71, 715, 'x'), (72, 725, 'y'), (73, 735, 'z'), (74, 745, 'w'), (75, 755, 'v'),
(76, 765, 'u'), (77, 775, 't'), (78, 785, 's'), (79, 795, 'r'), (80, 805, 'q'),
(81, 815, 'x'), (82, 825, 'y'), (83, 835, 'z'), (84, 845, 'w'), (85, 855, 'v'),
(86, 865, 'u'), (87, 875, 't'), (88, 885, 's'), (89, 895, 'r'), (90, 905, 'q'),
(91, 915, 'x'), (92, 925, 'y'), (93, 935, 'z'), (94, 945, 'w'), (95, 955, 'v'),
(96, 965, 'u'), (97, 975, 't'), (98, 985, 's'), (99, 995, 'r'), (100, 1005, 'q');
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 > t1.v4
WHERE bit_shift_left(t0.v1 + t1.v4, 1) = 10 OR bit_shift_left(t0.v1 + t1.v4, 1) = 20;
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4 AND t0.v2 = t1.v5
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5) AND abs(t0.v1 + t1.v4) > 5;
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10;
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(t0.v1 + t1.v4) = abs(t0.v2 + t1.v5)
AND bit_shift_left(t0.v1 + t1.v4, 1) > 10
AND bit_shift_left(t0.v1 + t1.v4, 1) < 20;
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 > t1.v4
WHERE (abs(t0.v1 + t1.v4) > 5 AND abs(t0.v1 + t1.v4) < 10)
OR (abs(t0.v1 + t1.v4) > 15 AND abs(t0.v1 + t1.v4) < 20);
SELECT COUNT(*) FROM t0 LEFT JOIN t1 ON t0.v1 = t1.v4
WHERE abs(bit_shift_left(t0.v1 + t1.v4, 1)) = abs(bit_shift_left(t0.v2 + t1.v5, 1))
AND abs(bit_shift_left(t0.v1 + t1.v4, 1)) > 10;
SELECT COUNT(*) FROM t0 RIGHT JOIN t1 ON t0.v1 = t1.v4
WHERE (t0.v1 + t1.v4) * 2 = (t0.v2 + t1.v5) * 2
AND (t0.v1 + t1.v4) * 2 > 10
AND (t0.v1 + t1.v4) * 2 < 100;