[BugFix] fix lambda common expr slot id conflicts in array_map (backport #62414) (#62428)

Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com>
Co-authored-by: eyes_on_me <nopainnofame@sina.com>
This commit is contained in:
mergify[bot] 2025-08-28 06:44:24 +00:00 committed by GitHub
parent 6ada019be0
commit df8b4f31f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 178 additions and 7 deletions

View File

@ -115,8 +115,9 @@ StatusOr<ColumnPtr> ArrayMapExpr::evaluate_lambda_expr(ExprContext* context, Chu
// 2. check captured columns' size
for (auto slot_id : capture_slot_ids) {
DCHECK(slot_id > 0);
auto captured_column = chunk->is_slot_exist(slot_id) ? chunk->get_column_by_slot_id(slot_id)
: tmp_chunk->get_column_by_slot_id(slot_id);
auto captured_column = tmp_chunk->is_slot_exist(slot_id) ? tmp_chunk->get_column_by_slot_id(slot_id)
: chunk->get_column_by_slot_id(slot_id);
if (UNLIKELY(captured_column->size() < input_elements[0]->size())) {
return Status::InternalError(fmt::format("The size of the captured column {} is less than array's size.",
captured_column->get_name()));
@ -176,8 +177,8 @@ StatusOr<ColumnPtr> ArrayMapExpr::evaluate_lambda_expr(ExprContext* context, Chu
// 4. prepare capture columns
for (auto slot_id : capture_slot_ids) {
auto captured_column = chunk->is_slot_exist(slot_id) ? chunk->get_column_by_slot_id(slot_id)
: tmp_chunk->get_column_by_slot_id(slot_id);
auto captured_column = tmp_chunk->is_slot_exist(slot_id) ? tmp_chunk->get_column_by_slot_id(slot_id)
: chunk->get_column_by_slot_id(slot_id);
if constexpr (independent_lambda_expr) {
cur_chunk->append_column(captured_column, slot_id);
} else {

View File

@ -68,6 +68,9 @@
namespace starrocks {
// for ut only
RuntimeState::RuntimeState() : _obj_pool(new ObjectPool()) {}
// for ut only
RuntimeState::RuntimeState(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
const TQueryGlobals& query_globals, ExecEnv* exec_env)

View File

@ -93,7 +93,7 @@ constexpr int64_t kRpcHttpMinSize = ((1L << 31) - (1L << 10));
class RuntimeState {
public:
// for ut only
RuntimeState() = default;
RuntimeState();
// for ut only
RuntimeState(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
const TQueryGlobals& query_globals, ExecEnv* exec_env);

View File

@ -19,6 +19,7 @@
#include <memory>
#include "butil/time.h"
#include "column/binary_column.h"
#include "column/column_helper.h"
#include "column/fixed_length_column.h"
#include "exprs/arithmetic_expr.h"
@ -216,8 +217,8 @@ std::vector<Expr*> VectorizedLambdaFunctionExprTest::create_lambda_expr(ObjectPo
ColumnRef* col6 = pool->add(new ColumnRef(slot_ref));
slot_ref.slot_ref.slot_id = 1;
ColumnRef* col7 = pool->add(new ColumnRef(slot_ref));
add_expr->_children.push_back(col6);
add_expr->_children.push_back(col7);
add_expr->add_child(col6);
add_expr->add_child(col7);
lambda_func->add_child(add_expr);
lambda_func->add_child(col5);
lambda_funcs.push_back(lambda_func);
@ -488,4 +489,170 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_const_array) {
}
}
TEST_F(VectorizedLambdaFunctionExprTest, test_lambda_common_expr_slot_conflict) {
auto cur_chunk = std::make_shared<Chunk>();
std::vector<int> vec_a = {1, 1, 1};
cur_chunk->append_column(build_int_column(vec_a), 1);
// Create a string column for length function
auto string_col = BinaryColumn::create();
string_col->append("abc");
string_col->append("def");
string_col->append("ghi");
cur_chunk->append_column(string_col, 2); // slot_id = 2 for captured column z
auto fake_col = BinaryColumn::create();
fake_col->append("abc");
fake_col->append("def");
fake_col->append("ghi");
// fake column, which will conflict with the slot id of the common expr extracted by the lambda function
cur_chunk->append_column(fake_col, 100002);
// Create two array columns for testing
TypeDescriptor type_arr_int;
type_arr_int.type = LogicalType::TYPE_ARRAY;
type_arr_int.children.emplace_back();
type_arr_int.children.back().type = LogicalType::TYPE_INT;
// col1: [1,2], [3,4], [5,6]
auto array1 = ColumnHelper::create_column(type_arr_int, false);
array1->append_datum(DatumArray{Datum((int32_t)1), Datum((int32_t)2)});
array1->append_datum(DatumArray{Datum((int32_t)3), Datum((int32_t)4)});
array1->append_datum(DatumArray{Datum((int32_t)5), Datum((int32_t)6)});
auto* col1_expr = new_fake_const_expr(std::move(array1), type_arr_int);
// col2: [10,20], [30,40], [50,60]
auto array2 = ColumnHelper::create_column(type_arr_int, false);
array2->append_datum(DatumArray{Datum((int32_t)10), Datum((int32_t)20)});
array2->append_datum(DatumArray{Datum((int32_t)30), Datum((int32_t)40)});
array2->append_datum(DatumArray{Datum((int32_t)50), Datum((int32_t)60)});
auto* col2_expr = new_fake_const_expr(std::move(array2), type_arr_int);
// Create lambda function: (x,y) -> x + y + length(z)
TExprNode tlambda_func;
tlambda_func.opcode = TExprOpcode::ADD;
tlambda_func.child_type = TPrimitiveType::INT;
tlambda_func.node_type = TExprNodeType::LAMBDA_FUNCTION_EXPR;
tlambda_func.num_children = 3; // lambda_expr + 2 arguments
tlambda_func.__isset.opcode = true;
tlambda_func.__isset.child_type = true;
tlambda_func.type = gen_type_desc(TPrimitiveType::INT);
LambdaFunction* lambda_func = _objpool.add(new LambdaFunction(tlambda_func));
// Create lambda arguments: x and y
TExprNode slot_ref_x, slot_ref_y;
slot_ref_x.node_type = TExprNodeType::SLOT_REF;
slot_ref_x.type = gen_type_desc(TPrimitiveType::INT);
slot_ref_x.num_children = 0;
slot_ref_x.__isset.slot_ref = true;
slot_ref_x.slot_ref.slot_id = 100000; // x's slot_id
slot_ref_x.slot_ref.tuple_id = 0;
slot_ref_x.__set_is_nullable(true);
slot_ref_y.node_type = TExprNodeType::SLOT_REF;
slot_ref_y.type = gen_type_desc(TPrimitiveType::INT);
slot_ref_y.num_children = 0;
slot_ref_y.__isset.slot_ref = true;
slot_ref_y.slot_ref.slot_id = 100001; // y's slot_id
slot_ref_y.slot_ref.tuple_id = 0;
slot_ref_y.__set_is_nullable(true);
ColumnRef* col_x = _objpool.add(new ColumnRef(slot_ref_x));
ColumnRef* col_y = _objpool.add(new ColumnRef(slot_ref_y));
// Create captured column reference: z (slot_id = 2)
TExprNode slot_ref_z;
slot_ref_z.node_type = TExprNodeType::SLOT_REF;
slot_ref_z.type = gen_type_desc(TPrimitiveType::VARCHAR);
slot_ref_z.num_children = 0;
slot_ref_z.__isset.slot_ref = true;
slot_ref_z.slot_ref.slot_id = 2; // z's slot_id (captured column)
slot_ref_z.slot_ref.tuple_id = 0;
slot_ref_z.__set_is_nullable(true);
ColumnRef* col_z = _objpool.add(new ColumnRef(slot_ref_z));
// Create length function call: length(z)
TExprNode length_node;
length_node.node_type = TExprNodeType::FUNCTION_CALL;
length_node.type = gen_type_desc(TPrimitiveType::INT);
length_node.num_children = 1;
length_node.__isset.fn = true;
length_node.fn.name.function_name = "length";
length_node.fn.fid = 30120;
length_node.fn.__isset.fid = true;
length_node.fn.binary_type = TFunctionBinaryType::BUILTIN;
auto* length_expr = _objpool.add(new VectorizedFunctionCallExpr(length_node));
length_expr->add_child(col_z);
// Create arithmetic expression: x + y + length(z)
TExprNode add_node1;
add_node1.opcode = TExprOpcode::ADD;
add_node1.child_type = TPrimitiveType::INT;
add_node1.node_type = TExprNodeType::BINARY_PRED;
add_node1.num_children = 2;
add_node1.__isset.opcode = true;
add_node1.__isset.child_type = true;
add_node1.type = gen_type_desc(TPrimitiveType::INT);
auto* add_expr1 = _objpool.add(VectorizedArithmeticExprFactory::from_thrift(add_node1));
add_expr1->add_child(col_x);
add_expr1->add_child(col_y);
TExprNode add_node2;
add_node2.opcode = TExprOpcode::ADD;
add_node2.child_type = TPrimitiveType::INT;
add_node2.node_type = TExprNodeType::BINARY_PRED;
add_node2.num_children = 2;
add_node2.__isset.opcode = true;
add_node2.__isset.child_type = true;
add_node2.type = gen_type_desc(TPrimitiveType::INT);
auto* add_expr2 = _objpool.add(VectorizedArithmeticExprFactory::from_thrift(add_node2));
add_expr2->add_child(add_expr1);
add_expr2->add_child(length_expr);
// Build lambda function: lambda_expr, arg_x, arg_y
lambda_func->add_child(add_expr2); // lambda expression
lambda_func->add_child(col_x); // argument x
lambda_func->add_child(col_y); // argument y
// Create ArrayMapExpr: array_map((x,y)->x + y + length(z), col1, col2)
ArrayMapExpr array_map_expr(array_type(TYPE_INT));
array_map_expr.clear_children();
array_map_expr.add_child(lambda_func);
array_map_expr.add_child(col1_expr);
array_map_expr.add_child(col2_expr);
ExprContext exprContext(&array_map_expr);
std::vector<ExprContext*> expr_ctxs = {&exprContext};
ASSERT_OK(Expr::prepare(expr_ctxs, &_runtime_state));
ASSERT_OK(Expr::open(expr_ctxs, &_runtime_state));
// Check LambdaFunction::prepare()
std::vector<SlotId> ids, arguments;
lambda_func->get_captured_slot_ids(&ids);
lambda_func->get_lambda_arguments_ids(&arguments);
ASSERT_TRUE(arguments.size() == 2 && arguments[0] == 100000 && arguments[1] == 100001);
ColumnPtr result = array_map_expr.evaluate(&exprContext, cur_chunk.get());
// Verify results
// For each row, the lambda function (x,y)->x + y + length(z) is applied to each element
// Row 0: x=1, y=10, length("abc")=3, result should be 1+10+3=14
// Row 0: x=2, y=20, length("abc")=3, result should be 2+20+3=25
// Row 1: x=3, y=30, length("def")=3, result should be 3+30+3=36
// Row 1: x=4, y=40, length("def")=3, result should be 4+40+3=47
// Row 2: x=5, y=50, length("ghi")=3, result should be 5+50+3=58
// Row 2: x=6, y=60, length("ghi")=3, result should be 6+60+3=69
ASSERT_EQ(3, result->size());
ASSERT_EQ(14, result->get(0).get_array()[0].get_int32());
ASSERT_EQ(25, result->get(0).get_array()[1].get_int32());
ASSERT_EQ(36, result->get(1).get_array()[0].get_int32());
ASSERT_EQ(47, result->get(1).get_array()[1].get_int32());
ASSERT_EQ(58, result->get(2).get_array()[0].get_int32());
ASSERT_EQ(69, result->get(2).get_array()[1].get_int32());
Expr::close(expr_ctxs, &_runtime_state);
}
} // namespace starrocks