Signed-off-by: silverbullet233 <3675229+silverbullet233@users.noreply.github.com> Co-authored-by: eyes_on_me <nopainnofame@sina.com>
This commit is contained in:
parent
6ada019be0
commit
df8b4f31f3
|
|
@ -115,8 +115,9 @@ StatusOr<ColumnPtr> ArrayMapExpr::evaluate_lambda_expr(ExprContext* context, Chu
|
|||
// 2. check captured columns' size
|
||||
for (auto slot_id : capture_slot_ids) {
|
||||
DCHECK(slot_id > 0);
|
||||
auto captured_column = chunk->is_slot_exist(slot_id) ? chunk->get_column_by_slot_id(slot_id)
|
||||
: tmp_chunk->get_column_by_slot_id(slot_id);
|
||||
auto captured_column = tmp_chunk->is_slot_exist(slot_id) ? tmp_chunk->get_column_by_slot_id(slot_id)
|
||||
: chunk->get_column_by_slot_id(slot_id);
|
||||
|
||||
if (UNLIKELY(captured_column->size() < input_elements[0]->size())) {
|
||||
return Status::InternalError(fmt::format("The size of the captured column {} is less than array's size.",
|
||||
captured_column->get_name()));
|
||||
|
|
@ -176,8 +177,8 @@ StatusOr<ColumnPtr> ArrayMapExpr::evaluate_lambda_expr(ExprContext* context, Chu
|
|||
|
||||
// 4. prepare capture columns
|
||||
for (auto slot_id : capture_slot_ids) {
|
||||
auto captured_column = chunk->is_slot_exist(slot_id) ? chunk->get_column_by_slot_id(slot_id)
|
||||
: tmp_chunk->get_column_by_slot_id(slot_id);
|
||||
auto captured_column = tmp_chunk->is_slot_exist(slot_id) ? tmp_chunk->get_column_by_slot_id(slot_id)
|
||||
: chunk->get_column_by_slot_id(slot_id);
|
||||
if constexpr (independent_lambda_expr) {
|
||||
cur_chunk->append_column(captured_column, slot_id);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -68,6 +68,9 @@
|
|||
|
||||
namespace starrocks {
|
||||
|
||||
// for ut only
|
||||
RuntimeState::RuntimeState() : _obj_pool(new ObjectPool()) {}
|
||||
|
||||
// for ut only
|
||||
RuntimeState::RuntimeState(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
|
||||
const TQueryGlobals& query_globals, ExecEnv* exec_env)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ constexpr int64_t kRpcHttpMinSize = ((1L << 31) - (1L << 10));
|
|||
class RuntimeState {
|
||||
public:
|
||||
// for ut only
|
||||
RuntimeState() = default;
|
||||
RuntimeState();
|
||||
// for ut only
|
||||
RuntimeState(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
|
||||
const TQueryGlobals& query_globals, ExecEnv* exec_env);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "butil/time.h"
|
||||
#include "column/binary_column.h"
|
||||
#include "column/column_helper.h"
|
||||
#include "column/fixed_length_column.h"
|
||||
#include "exprs/arithmetic_expr.h"
|
||||
|
|
@ -216,8 +217,8 @@ std::vector<Expr*> VectorizedLambdaFunctionExprTest::create_lambda_expr(ObjectPo
|
|||
ColumnRef* col6 = pool->add(new ColumnRef(slot_ref));
|
||||
slot_ref.slot_ref.slot_id = 1;
|
||||
ColumnRef* col7 = pool->add(new ColumnRef(slot_ref));
|
||||
add_expr->_children.push_back(col6);
|
||||
add_expr->_children.push_back(col7);
|
||||
add_expr->add_child(col6);
|
||||
add_expr->add_child(col7);
|
||||
lambda_func->add_child(add_expr);
|
||||
lambda_func->add_child(col5);
|
||||
lambda_funcs.push_back(lambda_func);
|
||||
|
|
@ -488,4 +489,170 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_const_array) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(VectorizedLambdaFunctionExprTest, test_lambda_common_expr_slot_conflict) {
|
||||
auto cur_chunk = std::make_shared<Chunk>();
|
||||
std::vector<int> vec_a = {1, 1, 1};
|
||||
cur_chunk->append_column(build_int_column(vec_a), 1);
|
||||
|
||||
// Create a string column for length function
|
||||
auto string_col = BinaryColumn::create();
|
||||
string_col->append("abc");
|
||||
string_col->append("def");
|
||||
string_col->append("ghi");
|
||||
cur_chunk->append_column(string_col, 2); // slot_id = 2 for captured column z
|
||||
|
||||
auto fake_col = BinaryColumn::create();
|
||||
fake_col->append("abc");
|
||||
fake_col->append("def");
|
||||
fake_col->append("ghi");
|
||||
// fake column, which will conflict with the slot id of the common expr extracted by the lambda function
|
||||
cur_chunk->append_column(fake_col, 100002);
|
||||
|
||||
// Create two array columns for testing
|
||||
TypeDescriptor type_arr_int;
|
||||
type_arr_int.type = LogicalType::TYPE_ARRAY;
|
||||
type_arr_int.children.emplace_back();
|
||||
type_arr_int.children.back().type = LogicalType::TYPE_INT;
|
||||
|
||||
// col1: [1,2], [3,4], [5,6]
|
||||
auto array1 = ColumnHelper::create_column(type_arr_int, false);
|
||||
array1->append_datum(DatumArray{Datum((int32_t)1), Datum((int32_t)2)});
|
||||
array1->append_datum(DatumArray{Datum((int32_t)3), Datum((int32_t)4)});
|
||||
array1->append_datum(DatumArray{Datum((int32_t)5), Datum((int32_t)6)});
|
||||
auto* col1_expr = new_fake_const_expr(std::move(array1), type_arr_int);
|
||||
|
||||
// col2: [10,20], [30,40], [50,60]
|
||||
auto array2 = ColumnHelper::create_column(type_arr_int, false);
|
||||
array2->append_datum(DatumArray{Datum((int32_t)10), Datum((int32_t)20)});
|
||||
array2->append_datum(DatumArray{Datum((int32_t)30), Datum((int32_t)40)});
|
||||
array2->append_datum(DatumArray{Datum((int32_t)50), Datum((int32_t)60)});
|
||||
auto* col2_expr = new_fake_const_expr(std::move(array2), type_arr_int);
|
||||
|
||||
// Create lambda function: (x,y) -> x + y + length(z)
|
||||
TExprNode tlambda_func;
|
||||
tlambda_func.opcode = TExprOpcode::ADD;
|
||||
tlambda_func.child_type = TPrimitiveType::INT;
|
||||
tlambda_func.node_type = TExprNodeType::LAMBDA_FUNCTION_EXPR;
|
||||
tlambda_func.num_children = 3; // lambda_expr + 2 arguments
|
||||
tlambda_func.__isset.opcode = true;
|
||||
tlambda_func.__isset.child_type = true;
|
||||
tlambda_func.type = gen_type_desc(TPrimitiveType::INT);
|
||||
LambdaFunction* lambda_func = _objpool.add(new LambdaFunction(tlambda_func));
|
||||
|
||||
// Create lambda arguments: x and y
|
||||
TExprNode slot_ref_x, slot_ref_y;
|
||||
slot_ref_x.node_type = TExprNodeType::SLOT_REF;
|
||||
slot_ref_x.type = gen_type_desc(TPrimitiveType::INT);
|
||||
slot_ref_x.num_children = 0;
|
||||
slot_ref_x.__isset.slot_ref = true;
|
||||
slot_ref_x.slot_ref.slot_id = 100000; // x's slot_id
|
||||
slot_ref_x.slot_ref.tuple_id = 0;
|
||||
slot_ref_x.__set_is_nullable(true);
|
||||
|
||||
slot_ref_y.node_type = TExprNodeType::SLOT_REF;
|
||||
slot_ref_y.type = gen_type_desc(TPrimitiveType::INT);
|
||||
slot_ref_y.num_children = 0;
|
||||
slot_ref_y.__isset.slot_ref = true;
|
||||
slot_ref_y.slot_ref.slot_id = 100001; // y's slot_id
|
||||
slot_ref_y.slot_ref.tuple_id = 0;
|
||||
slot_ref_y.__set_is_nullable(true);
|
||||
|
||||
ColumnRef* col_x = _objpool.add(new ColumnRef(slot_ref_x));
|
||||
ColumnRef* col_y = _objpool.add(new ColumnRef(slot_ref_y));
|
||||
|
||||
// Create captured column reference: z (slot_id = 2)
|
||||
TExprNode slot_ref_z;
|
||||
slot_ref_z.node_type = TExprNodeType::SLOT_REF;
|
||||
slot_ref_z.type = gen_type_desc(TPrimitiveType::VARCHAR);
|
||||
slot_ref_z.num_children = 0;
|
||||
slot_ref_z.__isset.slot_ref = true;
|
||||
slot_ref_z.slot_ref.slot_id = 2; // z's slot_id (captured column)
|
||||
slot_ref_z.slot_ref.tuple_id = 0;
|
||||
slot_ref_z.__set_is_nullable(true);
|
||||
ColumnRef* col_z = _objpool.add(new ColumnRef(slot_ref_z));
|
||||
|
||||
// Create length function call: length(z)
|
||||
TExprNode length_node;
|
||||
length_node.node_type = TExprNodeType::FUNCTION_CALL;
|
||||
length_node.type = gen_type_desc(TPrimitiveType::INT);
|
||||
length_node.num_children = 1;
|
||||
length_node.__isset.fn = true;
|
||||
length_node.fn.name.function_name = "length";
|
||||
length_node.fn.fid = 30120;
|
||||
length_node.fn.__isset.fid = true;
|
||||
length_node.fn.binary_type = TFunctionBinaryType::BUILTIN;
|
||||
auto* length_expr = _objpool.add(new VectorizedFunctionCallExpr(length_node));
|
||||
length_expr->add_child(col_z);
|
||||
|
||||
// Create arithmetic expression: x + y + length(z)
|
||||
TExprNode add_node1;
|
||||
add_node1.opcode = TExprOpcode::ADD;
|
||||
add_node1.child_type = TPrimitiveType::INT;
|
||||
add_node1.node_type = TExprNodeType::BINARY_PRED;
|
||||
add_node1.num_children = 2;
|
||||
add_node1.__isset.opcode = true;
|
||||
add_node1.__isset.child_type = true;
|
||||
add_node1.type = gen_type_desc(TPrimitiveType::INT);
|
||||
auto* add_expr1 = _objpool.add(VectorizedArithmeticExprFactory::from_thrift(add_node1));
|
||||
add_expr1->add_child(col_x);
|
||||
add_expr1->add_child(col_y);
|
||||
|
||||
TExprNode add_node2;
|
||||
add_node2.opcode = TExprOpcode::ADD;
|
||||
add_node2.child_type = TPrimitiveType::INT;
|
||||
add_node2.node_type = TExprNodeType::BINARY_PRED;
|
||||
add_node2.num_children = 2;
|
||||
add_node2.__isset.opcode = true;
|
||||
add_node2.__isset.child_type = true;
|
||||
add_node2.type = gen_type_desc(TPrimitiveType::INT);
|
||||
auto* add_expr2 = _objpool.add(VectorizedArithmeticExprFactory::from_thrift(add_node2));
|
||||
add_expr2->add_child(add_expr1);
|
||||
add_expr2->add_child(length_expr);
|
||||
|
||||
// Build lambda function: lambda_expr, arg_x, arg_y
|
||||
lambda_func->add_child(add_expr2); // lambda expression
|
||||
lambda_func->add_child(col_x); // argument x
|
||||
lambda_func->add_child(col_y); // argument y
|
||||
|
||||
// Create ArrayMapExpr: array_map((x,y)->x + y + length(z), col1, col2)
|
||||
ArrayMapExpr array_map_expr(array_type(TYPE_INT));
|
||||
array_map_expr.clear_children();
|
||||
array_map_expr.add_child(lambda_func);
|
||||
array_map_expr.add_child(col1_expr);
|
||||
array_map_expr.add_child(col2_expr);
|
||||
|
||||
ExprContext exprContext(&array_map_expr);
|
||||
std::vector<ExprContext*> expr_ctxs = {&exprContext};
|
||||
ASSERT_OK(Expr::prepare(expr_ctxs, &_runtime_state));
|
||||
ASSERT_OK(Expr::open(expr_ctxs, &_runtime_state));
|
||||
|
||||
// Check LambdaFunction::prepare()
|
||||
std::vector<SlotId> ids, arguments;
|
||||
lambda_func->get_captured_slot_ids(&ids);
|
||||
lambda_func->get_lambda_arguments_ids(&arguments);
|
||||
|
||||
ASSERT_TRUE(arguments.size() == 2 && arguments[0] == 100000 && arguments[1] == 100001);
|
||||
|
||||
ColumnPtr result = array_map_expr.evaluate(&exprContext, cur_chunk.get());
|
||||
|
||||
// Verify results
|
||||
// For each row, the lambda function (x,y)->x + y + length(z) is applied to each element
|
||||
// Row 0: x=1, y=10, length("abc")=3, result should be 1+10+3=14
|
||||
// Row 0: x=2, y=20, length("abc")=3, result should be 2+20+3=25
|
||||
// Row 1: x=3, y=30, length("def")=3, result should be 3+30+3=36
|
||||
// Row 1: x=4, y=40, length("def")=3, result should be 4+40+3=47
|
||||
// Row 2: x=5, y=50, length("ghi")=3, result should be 5+50+3=58
|
||||
// Row 2: x=6, y=60, length("ghi")=3, result should be 6+60+3=69
|
||||
|
||||
ASSERT_EQ(3, result->size());
|
||||
ASSERT_EQ(14, result->get(0).get_array()[0].get_int32());
|
||||
ASSERT_EQ(25, result->get(0).get_array()[1].get_int32());
|
||||
ASSERT_EQ(36, result->get(1).get_array()[0].get_int32());
|
||||
ASSERT_EQ(47, result->get(1).get_array()[1].get_int32());
|
||||
ASSERT_EQ(58, result->get(2).get_array()[0].get_int32());
|
||||
ASSERT_EQ(69, result->get(2).get_array()[1].get_int32());
|
||||
|
||||
Expr::close(expr_ctxs, &_runtime_state);
|
||||
}
|
||||
|
||||
} // namespace starrocks
|
||||
|
|
|
|||
Loading…
Reference in New Issue