Refactor type dispatch for predicate (#3341)

This commit is contained in:
mofei 2022-02-15 10:51:09 +08:00 committed by GitHub
parent b84bd52578
commit af0881614e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 125 additions and 203 deletions

View File

@ -2,6 +2,8 @@
#pragma once
#include <limits>
#include "column/binary_column.h"
#include "column/decimalv3_column.h"
#include "column/json_column.h"
@ -9,6 +11,7 @@
#include "column/object_column.h"
#include "column/vectorized_fwd.h"
#include "runtime/primitive_type.h"
#include "util/json.h"
namespace starrocks {
@ -301,5 +304,69 @@ template <>
struct ColumnTraits<TimestampValue> {
using ColumnType = TimestampColumn;
};
template <PrimitiveType ptype, typename = guard::Guard>
struct RunTimeTypeLimits {};
template <PrimitiveType ptype>
struct RunTimeTypeLimits<ptype, ArithmeticPTGuard<ptype>> {
using value_type = RunTimeCppType<ptype>;
static constexpr value_type min_value() { return std::numeric_limits<value_type>::lowest(); }
static constexpr value_type max_value() { return std::numeric_limits<value_type>::max(); }
};
template <PrimitiveType ptype>
struct RunTimeTypeLimits<ptype, BinaryPTGuard<ptype>> {
using value_type = RunTimeCppType<ptype>;
static constexpr value_type min_value() { return Slice(&_min, 0); }
static constexpr value_type max_value() { return Slice(&_max, 1); }
private:
static inline char _min = 0x00;
static inline char _max = 0xff;
};
template <>
struct RunTimeTypeLimits<TYPE_DATE> {
using value_type = RunTimeCppType<TYPE_DATE>;
static value_type min_value() { return DateValue::MIN_DATE_VALUE; }
static value_type max_value() { return DateValue::MAX_DATE_VALUE; }
};
template <>
struct RunTimeTypeLimits<TYPE_DATETIME> {
using value_type = RunTimeCppType<TYPE_DATETIME>;
static value_type min_value() { return TimestampValue::MIN_TIMESTAMP_VALUE; }
static value_type max_value() { return TimestampValue::MAX_TIMESTAMP_VALUE; }
};
template <>
struct RunTimeTypeLimits<TYPE_DECIMALV2> {
using value_type = RunTimeCppType<TYPE_DECIMALV2>;
static value_type min_value() { return DecimalV2Value::get_min_decimal(); }
static value_type max_value() { return DecimalV2Value::get_max_decimal(); }
};
template <PrimitiveType ptype>
struct RunTimeTypeLimits<ptype, DecimalPTGuard<ptype>> {
using value_type = RunTimeCppType<ptype>;
static constexpr value_type min_value() { return get_min_decimal<value_type>(); }
static constexpr value_type max_value() { return get_max_decimal<value_type>(); }
};
template <>
struct RunTimeTypeLimits<TYPE_JSON> {
using value_type = JsonValue;
static value_type min_value() { return JsonValue{vpack::Slice::minKeySlice()}; }
static value_type max_value() { return JsonValue{vpack::Slice::maxKeySlice()}; }
};
} // namespace vectorized
} // namespace starrocks

View File

@ -3,9 +3,13 @@
#include "exec/vectorized/olap_scan_prepare.h"
#include "column/type_traits.h"
#include "exprs/vectorized/in_const_predicate.hpp"
#include "gutil/map_util.h"
#include "runtime/date_value.hpp"
#include "runtime/descriptors.h"
#include "runtime/primitive_type.h"
#include "runtime/primitive_type_infra.h"
#include "storage/vectorized/column_predicate.h"
#include "storage/vectorized/predicate_parser.h"
@ -475,6 +479,37 @@ void OlapScanConjunctsManager::normalize_predicate(const SlotDescriptor& slot,
normalize_join_runtime_filter<SlotType, RangeValueType>(slot, range);
}
struct ColumnRangeBuilder {
template <PrimitiveType ptype>
std::nullptr_t operator()(OlapScanConjunctsManager* cm, const SlotDescriptor* slot,
std::map<std::string, ColumnValueRangeType>* column_value_ranges) {
if constexpr (ptype == TYPE_TIME || ptype == TYPE_NULL || ptype == TYPE_JSON || pt_is_float<ptype>) {
return nullptr;
} else {
// Treat tinyint and boolean as int
constexpr PrimitiveType real_type = ptype == TYPE_TINYINT || ptype == TYPE_BOOLEAN ? TYPE_INT : ptype;
using value_type = typename RunTimeTypeLimits<real_type>::value_type;
using RangeType = ColumnValueRange<value_type>;
const std::string& col_name = slot->col_name();
RangeType full_range(col_name, ptype, RunTimeTypeLimits<ptype>::min_value(),
RunTimeTypeLimits<ptype>::max_value());
if constexpr (pt_is_decimal<real_type>) {
full_range.set_precision(slot->type().precision);
full_range.set_scale(slot->type().scale);
}
ColumnValueRangeType& v = LookupOrInsert(column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<value_type>>(v);
if constexpr (pt_is_decimal<real_type>) {
range.set_precision(slot->type().precision);
range.set_scale(slot->type().scale);
}
cm->normalize_predicate<real_type, value_type>(*slot, &range);
return nullptr;
}
}
};
Status OlapScanConjunctsManager::normalize_conjuncts() {
// Note: _normalized_conjuncts size must be equal to _conjunct_ctxs size,
// but HashJoinNode will push down predicate to OlapScanNode's _conjunct_ctxs,
@ -486,162 +521,8 @@ Status OlapScanConjunctsManager::normalize_conjuncts() {
// TODO(zhuming): if any of the normalized column range is empty, we can know that
// no row will be selected anymore and can return EOF directly.
for (auto& slot : tuple_desc->decoded_slots()) {
const std::string& col_name = slot->col_name();
PrimitiveType type = slot->type().type;
switch (type) {
case TYPE_TINYINT: {
// TYPE_TINYINT use int32_t to present
// because it's easy to be converted to string when building Olap fetch Query
using RangeType = ColumnValueRange<int32_t>;
RangeType full_range(col_name, type, std::numeric_limits<int8_t>::lowest(),
std::numeric_limits<int8_t>::max());
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int32_t>>(v);
normalize_predicate<TYPE_TINYINT, int32_t>(*slot, &range);
break;
}
case TYPE_BOOLEAN: {
// TYPE_BOOLEAN use int32_t to present
// because it's easy to be converted to string when building Olap fetch Query
using RangeType = ColumnValueRange<int32_t>;
RangeType full_range(col_name, type, 0, 1);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int32_t>>(v);
normalize_predicate<TYPE_BOOLEAN, int32_t>(*slot, &range);
break;
}
case TYPE_SMALLINT: {
using RangeType = ColumnValueRange<int16_t>;
RangeType full_range(col_name, type, std::numeric_limits<int16_t>::lowest(),
std::numeric_limits<int16_t>::max());
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int16_t>>(v);
normalize_predicate<TYPE_SMALLINT, int16_t>(*slot, &range);
break;
}
case TYPE_INT: {
using RangeType = ColumnValueRange<int32_t>;
RangeType full_range(col_name, type, std::numeric_limits<int32_t>::lowest(),
std::numeric_limits<int32_t>::max());
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int32_t>>(v);
normalize_predicate<TYPE_INT, int32_t>(*slot, &range);
break;
}
case TYPE_BIGINT: {
using RangeType = ColumnValueRange<int64_t>;
RangeType full_range(col_name, type, std::numeric_limits<int64_t>::lowest(),
std::numeric_limits<int64_t>::max());
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int64_t>>(v);
normalize_predicate<TYPE_BIGINT, int64_t>(*slot, &range);
break;
}
case TYPE_LARGEINT: {
using RangeType = ColumnValueRange<int128_t>;
RangeType full_range(col_name, type, MIN_INT128, MAX_INT128);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<int128_t>>(v);
normalize_predicate<TYPE_LARGEINT, int128_t>(*slot, &range);
break;
}
case TYPE_CHAR:
// for a CHAR column, its `in` predicate will be represented as a
// `InConstPredicate<PrimitiveType::TYPE_VARCHAR>`, so here we mapping CHAR as VARCHAR.
[[fallthrough]];
case TYPE_VARCHAR: {
using RangeType = ColumnValueRange<Slice>;
static char min_char = 0x00;
static char max_char = (char)0xff;
RangeType full_range(col_name, type, Slice(&min_char, 0), Slice(&max_char, 1));
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<Slice>>(v);
normalize_predicate<TYPE_VARCHAR, Slice>(*slot, &range);
break;
}
case TYPE_DATE: {
using RangeType = ColumnValueRange<DateValue>;
RangeType full_range(col_name, type, DateValue::MIN_DATE_VALUE, DateValue::MAX_DATE_VALUE);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<DateValue>>(v);
normalize_predicate<TYPE_DATE, DateValue>(*slot, &range);
break;
}
case TYPE_DATETIME: {
using RangeType = ColumnValueRange<TimestampValue>;
RangeType full_range(col_name, type, TimestampValue::MIN_TIMESTAMP_VALUE,
TimestampValue::MAX_TIMESTAMP_VALUE);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<TimestampValue>>(v);
normalize_predicate<TYPE_DATETIME, TimestampValue>(*slot, &range);
break;
}
case TYPE_DECIMALV2: {
using RangeType = ColumnValueRange<DecimalV2Value>;
RangeType full_range(col_name, type, DecimalV2Value::get_min_decimal(), DecimalV2Value::get_max_decimal());
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<ColumnValueRange<DecimalV2Value>>(v);
normalize_predicate<TYPE_DECIMALV2, DecimalV2Value>(*slot, &range);
break;
}
case TYPE_DECIMAL32: {
using DecimalValueType = int32_t;
using RangeType = ColumnValueRange<DecimalValueType>;
RangeType full_range(col_name, type, get_min_decimal<DecimalValueType>(),
get_max_decimal<DecimalValueType>());
full_range.set_precision(slot->type().precision);
full_range.set_scale(slot->type().scale);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<RangeType>(v);
range.set_precision(slot->type().precision);
range.set_scale(slot->type().scale);
normalize_predicate<TYPE_DECIMAL32, DecimalValueType>(*slot, &range);
break;
}
case TYPE_DECIMAL64: {
using DecimalValueType = int64_t;
using RangeType = ColumnValueRange<DecimalValueType>;
RangeType full_range(col_name, type, get_min_decimal<DecimalValueType>(),
get_max_decimal<DecimalValueType>());
full_range.set_precision(slot->type().precision);
full_range.set_scale(slot->type().scale);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<RangeType>(v);
range.set_precision(slot->type().precision);
range.set_scale(slot->type().scale);
normalize_predicate<TYPE_DECIMAL64, DecimalValueType>(*slot, &range);
break;
}
case TYPE_DECIMAL128: {
using DecimalValueType = int128_t;
using RangeType = ColumnValueRange<DecimalValueType>;
RangeType full_range(col_name, type, get_min_decimal<DecimalValueType>(),
get_max_decimal<DecimalValueType>());
full_range.set_precision(slot->type().precision);
full_range.set_scale(slot->type().scale);
ColumnValueRangeType& v = LookupOrInsert(&column_value_ranges, col_name, full_range);
RangeType& range = boost::get<RangeType>(v);
range.set_precision(slot->type().precision);
range.set_scale(slot->type().scale);
normalize_predicate<TYPE_DECIMAL128, DecimalValueType>(*slot, &range);
break;
}
case INVALID_TYPE:
case TYPE_NULL:
case TYPE_FLOAT:
case TYPE_DOUBLE:
case TYPE_BINARY:
case TYPE_DECIMAL:
case TYPE_STRUCT:
case TYPE_ARRAY:
case TYPE_MAP:
case TYPE_HLL:
case TYPE_TIME:
case TYPE_OBJECT:
case TYPE_PERCENTILE:
case TYPE_JSON:
break;
}
type_dispatch_predicate<std::nullptr_t>(slot->type().type, false, ColumnRangeBuilder(), this, slot,
&column_value_ranges);
}
return Status::OK();
}

View File

@ -48,6 +48,7 @@ public:
bool enable_column_expr_predicate = false);
private:
friend struct ColumnRangeBuilder;
Status normalize_conjuncts();
Status build_olap_filters();
Status build_scan_keys(bool unlimited, int32_t max_scan_key_num);

View File

@ -4,6 +4,7 @@
#include "column/array_column.h"
#include "column/column_hash.h"
#include "column/type_traits.h"
#include "util/raw_container.h"
namespace starrocks::vectorized {
@ -672,43 +673,11 @@ public:
size_t index;
if constexpr (!pt_is_binary<value_type>) {
if constexpr (pt_is_arithmetic<value_type>) {
if constexpr (is_min) {
result = std::numeric_limits<ResultType>::max();
} else {
result = std::numeric_limits<ResultType>::lowest();
}
} else if constexpr (pt_is_decimalv2<value_type>) {
if constexpr (is_min) {
result = DecimalV2Value::get_max_decimal();
} else {
result = DecimalV2Value::get_min_decimal();
}
} else if constexpr (pt_is_decimal<value_type>) {
if constexpr (is_min) {
result = get_max_decimal<ResultType>();
} else {
result = get_min_decimal<ResultType>();
}
} else if constexpr (pt_is_datetime<value_type>) {
if constexpr (is_min) {
result = TimestampValue::MAX_TIMESTAMP_VALUE;
} else {
result = TimestampValue::MIN_TIMESTAMP_VALUE;
}
} else if constexpr (pt_is_date<value_type>) {
if constexpr (is_min) {
result = DateValue::MAX_DATE_VALUE;
} else {
result = DateValue::MIN_DATE_VALUE;
}
if constexpr (is_min) {
result = RunTimeTypeLimits<value_type>::max_value();
} else {
LOG(ERROR) << "unhandled types other than arithmetic/time/decimal/string "
"for min and max";
DCHECK(false) << "other types than arithmetic/time/decimal/string is not "
"support min and max";
result = RunTimeTypeLimits<value_type>::min_value();
}
index = 0;
} else {
int j = 0;

View File

@ -112,7 +112,7 @@ struct BinaryPredicateBuilder {
Expr* VectorizedBinaryPredicateFactory::from_thrift(const TExprNode& node) {
PrimitiveType type = thrift_to_type(node.child_type);
return type_dispatch_predicate(type, BinaryPredicateBuilder(), node);
return type_dispatch_predicate<Expr*>(type, true, BinaryPredicateBuilder(), node);
}
} // namespace starrocks::vectorized

View File

@ -28,7 +28,7 @@ struct FilterBuilder {
};
JoinRuntimeFilter* RuntimeFilterHelper::create_join_runtime_filter(ObjectPool* pool, PrimitiveType type) {
JoinRuntimeFilter* filter = type_dispatch_filter(type, FilterBuilder());
JoinRuntimeFilter* filter = type_dispatch_filter(type, (JoinRuntimeFilter*)nullptr, FilterBuilder());
if (pool != nullptr && filter != nullptr) {
return pool->add(filter);
} else {
@ -84,7 +84,7 @@ JoinRuntimeFilter* RuntimeFilterHelper::create_runtime_bloom_filter(ObjectPool*
struct FilterIniter {
template <PrimitiveType ptype>
void operator()(const ColumnPtr& column, size_t column_offset, JoinRuntimeFilter* expr, bool eq_null) {
auto operator()(const ColumnPtr& column, size_t column_offset, JoinRuntimeFilter* expr, bool eq_null) {
using ColumnType = typename RunTimeTypeTraits<ptype>::ColumnType;
auto* filter = (RuntimeBloomFilter<ptype>*)(expr);
@ -107,12 +107,13 @@ struct FilterIniter {
filter->insert(&data_ptr[j]);
}
}
return nullptr;
}
};
Status RuntimeFilterHelper::fill_runtime_bloom_filter(const ColumnPtr& column, PrimitiveType type,
JoinRuntimeFilter* filter, size_t column_offset, bool eq_null) {
type_dispatch_filter(type, FilterIniter(), column, column_offset, filter, eq_null);
type_dispatch_filter(type, nullptr, FilterIniter(), column, column_offset, filter, eq_null);
return Status::OK();
}

View File

@ -129,23 +129,26 @@ auto type_dispatch_sortable(PrimitiveType ptype, Functor fun, Args... args) {
}
}
template <class Functor, class... Args>
auto type_dispatch_predicate(PrimitiveType ptype, Functor fun, Args... args) {
template <class Ret, class Functor, class... Args>
Ret type_dispatch_predicate(PrimitiveType ptype, bool assert, Functor fun, Args... args) {
switch (ptype) {
APPLY_FOR_ALL_SCALAR_TYPE(_TYPE_DISPATCH_CASE)
default:
CHECK(false) << "Unknown type: " << ptype;
__builtin_unreachable();
if (assert) {
CHECK(false) << "Unknown type: " << ptype;
__builtin_unreachable();
} else {
return Ret{};
}
}
}
template <class Functor, class... Args>
auto type_dispatch_filter(PrimitiveType ptype, Functor fun, Args... args) {
template <class Functor, class Ret, class... Args>
auto type_dispatch_filter(PrimitiveType ptype, Ret default_value, Functor fun, Args... args) {
switch (ptype) {
APPLY_FOR_ALL_SCALAR_TYPE(_TYPE_DISPATCH_CASE)
default:
CHECK(false) << "Unknown type: " << ptype;
__builtin_unreachable();
return default_value;
}
}