[BugFix] Fix group by compressed key cause wrong result on decimal (backport #62022) (#62147)

Signed-off-by: stdpain <drfeng08@gmail.com>
Co-authored-by: stdpain <34912776+stdpain@users.noreply.github.com>
This commit is contained in:
mergify[bot] 2025-08-20 12:17:12 +00:00 committed by GitHub
parent 8d11089dcb
commit d90d3bc5b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 195 additions and 10 deletions

View File

@ -229,7 +229,7 @@ public:
}
template <typename T>
Status do_visit(const DecimalV3Column<T>& column) {
Status do_visit(DecimalV3Column<T>* column) {
bit_decompress<DecimalV3Column<T>, T>(column);
return Status::OK();
}

View File

@ -16,21 +16,18 @@
#include <algorithm>
#include <memory>
#include <optional>
#include <type_traits>
#include <utility>
#include "column/chunk.h"
#include "column/column_helper.h"
#include "column/vectorized_fwd.h"
#include "common/config.h"
#include "common/logging.h"
#include "common/status.h"
#include "exec/agg_runtime_filter_builder.h"
#include "exec/aggregate/agg_hash_variant.h"
#include "exec/aggregate/agg_profile.h"
#include "exec/exec_node.h"
#include "exec/limited_pipeline_chunk_buffer.h"
#include "exec/pipeline/operator.h"
#include "exprs/agg/agg_state_if.h"
#include "exprs/agg/agg_state_merge.h"
@ -1362,6 +1359,7 @@ bool could_apply_bitcompress_opt(
const std::vector<std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>>>& ranges,
std::vector<std::any>& base, std::vector<int>& used_bytes, size_t* max_size, bool* has_null) {
size_t accumulated = 0;
size_t accumulated_fixed_length_bits = 0;
for (size_t i = 0; i < group_by_types.size(); i++) {
size_t size = 0;
// 1 bytes for null flag.
@ -1376,6 +1374,7 @@ bool could_apply_bitcompress_opt(
size_t fixed_base_size = get_size_of_fixed_length_type(ltype);
if (fixed_base_size == 0) return false;
accumulated_fixed_length_bits += fixed_base_size * 8;
if (!ranges[i].has_value()) {
return false;
@ -1389,8 +1388,28 @@ bool could_apply_bitcompress_opt(
accumulated += size;
used_bytes[i] = accumulated;
}
*max_size = accumulated;
return true;
auto get_level = [](size_t used_bits) {
if (used_bits <= sizeof(uint8_t) * 8)
return 1;
else if (used_bits <= sizeof(uint16_t) * 8)
return 2;
else if (used_bits <= sizeof(uint32_t) * 8)
return 3;
else if (used_bits <= sizeof(uint64_t) * 8)
return 4;
else if (used_bits <= sizeof(int128_t) * 8)
return 5;
else
return 6;
};
// If they are at the same level, grouping by compressed key will not optimize performance, so we disable it.
// eg: For example, two int32 values both have a threshold of 0-2^32, so they need to use group by int64.
// In this case, there will be no optimization effect. We disable this situation.
if (get_level(accumulated_fixed_length_bits) > get_level(accumulated)) {
*max_size = accumulated;
return true;
}
return false;
}
bool is_group_columns_fixed_size(std::vector<ColumnType>& group_by_types, size_t* max_size, bool* has_null) {
@ -1489,10 +1508,7 @@ typename HashVariantType::Type Aggregator::_try_to_apply_compressed_key_opt(type
if (could_apply_bitcompress_opt(_group_by_types, _ranges, bases, used_bits, &new_max_bit_size,
&has_null_column)) {
if (new_max_bit_size <= 8 && _group_by_types.size() == 1) {
type = _aggr_phase == AggrPhase1 ? HashVariantType::Type::phase1_slice_cx1
: HashVariantType::Type::phase2_slice_cx1;
} else if (_group_by_types.size() > 1) {
if (_group_by_types.size() > 0) {
if (new_max_bit_size <= 8) {
type = _aggr_phase == AggrPhase1 ? HashVariantType::Type::phase1_slice_cx1
: HashVariantType::Type::phase2_slice_cx1;

View File

@ -117,6 +117,7 @@ set(EXEC_FILES
./exprs/agg/data_sketch/ds_theta_test.cpp
./exprs/agg/json_each_test.cpp
./exprs/agg/aggregate_test.cpp
./exprs/agg/agg_compressed_key_test.cpp
./exprs/arithmetic_expr_test.cpp
./exprs/arithmetic_operation_test.cpp
./exprs/array_element_expr_test.cpp

View File

@ -0,0 +1,117 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "column/column_helper.h"
#include "common/object_pool.h"
#include "exec/aggregator.h"
#include "exprs/literal.h"
#include "runtime/types.h"
#include "types/logical_type.h"
namespace starrocks {
bool could_apply_bitcompress_opt(
const std::vector<ColumnType>& group_by_types,
const std::vector<std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>>>& ranges,
std::vector<std::any>& base, std::vector<int>& used_bytes, size_t* max_size, bool* has_null);
TEST(AggCompressedKey, could_bound) {
// group by 1 columns
{
ObjectPool pool;
std::vector<ColumnType> groupby;
std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>> range;
std::vector<std::any> bases;
std::vector<int> used_bytes;
size_t max_size;
bool has_null;
bases.resize(1);
used_bytes.resize(1);
auto type1 = TypeDescriptor(TYPE_INT);
groupby.emplace_back(type1, false);
std::vector<std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>>> ranges;
auto* min = pool.add(new VectorizedLiteral(ColumnHelper::create_const_column<TYPE_INT>(0, 1), type1));
auto* max = pool.add(new VectorizedLiteral(ColumnHelper::create_const_column<TYPE_INT>(100, 1), type1));
range = {min, max};
ranges.emplace_back(range);
bool res = could_apply_bitcompress_opt(groupby, ranges, bases, used_bytes, &max_size, &has_null);
EXPECT_EQ(max_size, 7);
ASSERT_EQ(res, true);
}
// group by 2 columns
{
ObjectPool pool;
std::vector<ColumnType> groupby;
std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>> range;
std::vector<std::any> bases;
std::vector<int> used_bytes;
size_t max_size;
bool has_null;
bases.resize(2);
used_bytes.resize(2);
auto type1 = TypeDescriptor(TYPE_INT);
groupby.emplace_back(type1, false);
groupby.emplace_back(type1, true);
std::vector<std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>>> ranges;
auto* min = pool.add(new VectorizedLiteral(ColumnHelper::create_const_column<TYPE_INT>(0, 1), type1));
auto* max = pool.add(new VectorizedLiteral(ColumnHelper::create_const_column<TYPE_INT>(100, 1), type1));
range = {min, max};
ranges.emplace_back(range);
ranges.emplace_back(range);
bool res = could_apply_bitcompress_opt(groupby, ranges, bases, used_bytes, &max_size, &has_null);
EXPECT_EQ(max_size, 15);
ASSERT_EQ(res, true);
}
// group by decimal columns
{
ObjectPool pool;
std::vector<ColumnType> groupby;
std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>> range;
std::vector<std::any> bases;
std::vector<int> used_bytes;
size_t max_size;
bool has_null;
bases.resize(2);
used_bytes.resize(2);
auto type1 = TypeDescriptor::create_decimalv3_type(TYPE_DECIMAL128, 8, 4);
groupby.emplace_back(type1, false);
groupby.emplace_back(type1, true);
std::vector<std::optional<std::pair<VectorizedLiteral*, VectorizedLiteral*>>> ranges;
auto* min = pool.add(
new VectorizedLiteral(ColumnHelper::create_const_decimal_column<TYPE_DECIMAL128>(0, 8, 4, 1), type1));
auto* max = pool.add(
new VectorizedLiteral(ColumnHelper::create_const_decimal_column<TYPE_DECIMAL128>(100, 8, 4, 1), type1));
range = {min, max};
ranges.emplace_back(range);
ranges.emplace_back(range);
bool res = could_apply_bitcompress_opt(groupby, ranges, bases, used_bytes, &max_size, &has_null);
EXPECT_EQ(max_size, 15);
ASSERT_EQ(res, true);
}
}
} // namespace starrocks

View File

@ -348,6 +348,46 @@ select c4, sum(c1) from all_decimal group by 1 order by 1, 2 limit 1;
-- result:
0.00000 0.00
-- !result
select c1, c2, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0.00 0.00 0.00
-- !result
select c1, c3, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0.00 0E-9 0.00
-- !result
select c1, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0.00 0.00000 0.00
-- !result
select c2, c3, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0.00 0E-9 0.00
-- !result
select c2, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0.00 0.00000 0.00
-- !result
select c3, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
-- result:
0E-9 0.00000 0.00
-- !result
select c1, c2, c3, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
-- result:
0.00 0.00 0E-9 0.00
-- !result
select c1, c2, c4, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
-- result:
0.00 0.00 0.00000 0.00
-- !result
select c2, c3, c4, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
-- result:
0.00 0E-9 0.00000 0.00
-- !result
select c1, c2, c3, c4, sum(c1) from all_decimal group by 1,2,3,4 order by 1,2,3,4,5 limit 1;
-- result:
0.00 0.00 0E-9 0.00000 0.00
-- !result
create table all_numbers_t0 (
c1 tinyint,
c2 smallint,

View File

@ -127,6 +127,17 @@ select c2, sum(c1) from all_decimal group by 1 order by 1, 2 limit 1;
select c3, sum(c1) from all_decimal group by 1 order by 1, 2 limit 1;
select c4, sum(c1) from all_decimal group by 1 order by 1, 2 limit 1;
select c1, c2, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c1, c3, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c1, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c2, c3, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c2, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c3, c4, sum(c1) from all_decimal group by 1,2 order by 1,2,3 limit 1;
select c1, c2, c3, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
select c1, c2, c4, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
select c2, c3, c4, sum(c1) from all_decimal group by 1,2,3 order by 1,2,3,4 limit 1;
select c1, c2, c3, c4, sum(c1) from all_decimal group by 1,2,3,4 order by 1,2,3,4,5 limit 1;
-- int overflow
create table all_numbers_t0 (
c1 tinyint,