Signed-off-by: Murphy <mofei@starrocks.com> Co-authored-by: Murphy <96611012+murphyatwork@users.noreply.github.com> Co-authored-by: Murphy <mofei@starrocks.com>
This commit is contained in:
parent
c8e85a33ae
commit
25adbac4e0
|
|
@ -1598,7 +1598,7 @@ StatusOr<ColumnPtr> TimeFunctions::hour_from_unixtime(FunctionContext* context,
|
|||
auto ctz = context->state()->timezone_obj();
|
||||
auto size = columns[0]->size();
|
||||
ColumnViewer<TYPE_BIGINT> data_column(columns[0]);
|
||||
ColumnBuilder<TYPE_INT> result(size);
|
||||
ColumnBuilder<TYPE_TINYINT> result(size);
|
||||
for (int row = 0; row < size; ++row) {
|
||||
if (data_column.is_null(row)) {
|
||||
result.append_null();
|
||||
|
|
@ -3892,4 +3892,4 @@ StatusOr<ColumnPtr> TimeFunctions::time_format(FunctionContext* context, const s
|
|||
|
||||
} // namespace starrocks
|
||||
|
||||
#include "gen_cpp/opcode/TimeFunctions.inc"
|
||||
#include "gen_cpp/opcode/TimeFunctions.inc"
|
||||
|
|
|
|||
|
|
@ -4541,7 +4541,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
|
|||
columns.emplace_back(tc);
|
||||
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
|
||||
|
||||
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
|
||||
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
|
||||
for (size_t i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) {
|
||||
EXPECT_EQ(expected[i], hours->get_data()[i]) << "Failed for basic positive at index " << i;
|
||||
}
|
||||
|
|
@ -4574,7 +4574,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
|
|||
columns.emplace_back(tc);
|
||||
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
|
||||
|
||||
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
|
||||
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
|
||||
for (size_t i = 0; i < sizeof(expected_negative) / sizeof(expected_negative[0]); ++i) {
|
||||
EXPECT_EQ(expected_negative[i], hours->get_data()[i])
|
||||
<< "Failed for timezone offset at index " << i << " with value " << tc->get_data()[i];
|
||||
|
|
@ -4609,7 +4609,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
|
|||
columns.emplace_back(tc);
|
||||
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
|
||||
|
||||
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
|
||||
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
|
||||
for (size_t i = 0; i < sizeof(expected_mixed) / sizeof(expected_mixed[0]); ++i) {
|
||||
EXPECT_EQ(expected_mixed[i], hours->get_data()[i])
|
||||
<< "Failed for mixed timezone offset at index " << i << " with value " << tc->get_data()[i];
|
||||
|
|
@ -4646,14 +4646,14 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
|
|||
ASSERT_EQ(8, nullable_result->size());
|
||||
|
||||
// Check that results are in correct order
|
||||
EXPECT_EQ(3, nullable_result->get(0).get_int32()); // 0 -> hour 3
|
||||
EXPECT_TRUE(nullable_result->is_null(1)); // null
|
||||
EXPECT_EQ(4, nullable_result->get(2).get_int32()); // 3600 -> hour 4
|
||||
EXPECT_TRUE(nullable_result->is_null(3)); // null
|
||||
EXPECT_EQ(5, nullable_result->get(4).get_int32()); // 7200 -> hour 5
|
||||
EXPECT_EQ(2, nullable_result->get(5).get_int32()); // 82800 -> hour 2 (next day)
|
||||
EXPECT_TRUE(nullable_result->is_null(6)); // null
|
||||
EXPECT_EQ(6, nullable_result->get(7).get_int32()); // 10800 -> hour 6
|
||||
EXPECT_EQ(3, nullable_result->get(0).get_int8()); // 0 -> hour 3
|
||||
EXPECT_TRUE(nullable_result->is_null(1)); // null
|
||||
EXPECT_EQ(4, nullable_result->get(2).get_int8()); // 3600 -> hour 4
|
||||
EXPECT_TRUE(nullable_result->is_null(3)); // null
|
||||
EXPECT_EQ(5, nullable_result->get(4).get_int8()); // 7200 -> hour 5
|
||||
EXPECT_EQ(2, nullable_result->get(5).get_int8()); // 82800 -> hour 2 (next day)
|
||||
EXPECT_TRUE(nullable_result->is_null(6)); // null
|
||||
EXPECT_EQ(6, nullable_result->get(7).get_int8()); // 10800 -> hour 6
|
||||
}
|
||||
|
||||
// Test 5: Edge cases for hour wrapping with timezone offset and negative input (should return null)
|
||||
|
|
@ -4696,7 +4696,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
|
|||
// Check non-negative cases
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
EXPECT_FALSE(nullable_result->is_null(i)) << "Unexpected null at index " << i;
|
||||
EXPECT_EQ(expected_edge[i], nullable_result->get(i).get_int32())
|
||||
EXPECT_EQ(expected_edge[i], nullable_result->get(i).get_int8())
|
||||
<< "Failed for edge case with timezone offset at index " << i << " with value "
|
||||
<< tc->get(i).get_int64();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
|
||||
|
|
@ -576,6 +577,19 @@ public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
|
|||
ConstantOperator.createVarchar(mergePath)), child.getFunction());
|
||||
}
|
||||
|
||||
private static ScalarOperator lookupChild(ScalarOperator call, Predicate<ScalarOperator> predicate) {
|
||||
if (predicate.test(call)) {
|
||||
return call;
|
||||
}
|
||||
for (ScalarOperator child : call.getChildren()) {
|
||||
ScalarOperator res = lookupChild(child, predicate);
|
||||
if (res != null) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// Simplify hour(from_unixtime(ts)) to hour_from_unixtime(ts)
|
||||
// Also simplify hour(to_datetime(ts)) and hour(to_datetime(ts, 0)) to hour_from_unixtime(ts)
|
||||
private static ScalarOperator simplifiedHourFromUnixTime(CallOperator call) {
|
||||
|
|
@ -583,31 +597,28 @@ public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
|
|||
return call;
|
||||
}
|
||||
|
||||
ScalarOperator child = call.getChild(0);
|
||||
if (!(child instanceof CallOperator)) {
|
||||
return call;
|
||||
}
|
||||
|
||||
CallOperator childCall = (CallOperator) child;
|
||||
String childFnName = childCall.getFnName();
|
||||
|
||||
// Case 1: hour(from_unixtime(ts)) -> hour_from_unixtime(ts)
|
||||
if (FunctionSet.FROM_UNIXTIME.equalsIgnoreCase(childFnName)) {
|
||||
ScalarOperator fromUnixTime = lookupChild(call,
|
||||
x -> x instanceof CallOperator &&
|
||||
((CallOperator) x).getFnName().equalsIgnoreCase(FunctionSet.FROM_UNIXTIME));
|
||||
if (fromUnixTime != null) {
|
||||
// Keep original behavior: only succeeds when argument list matches hour_from_unixtime signature
|
||||
Type[] argTypes = childCall.getChildren().stream().map(ScalarOperator::getType).toArray(Type[]::new);
|
||||
Type[] argTypes = fromUnixTime.getChildren().stream().map(ScalarOperator::getType).toArray(Type[]::new);
|
||||
Function fn =
|
||||
Expr.getBuiltinFunction(FunctionSet.HOUR_FROM_UNIXTIME, argTypes, Function.CompareMode.IS_IDENTICAL);
|
||||
|
||||
Expr.getBuiltinFunction(FunctionSet.HOUR_FROM_UNIXTIME, argTypes,
|
||||
Function.CompareMode.IS_IDENTICAL);
|
||||
if (fn == null) {
|
||||
return call;
|
||||
}
|
||||
|
||||
return new CallOperator(FunctionSet.HOUR_FROM_UNIXTIME, call.getType(), childCall.getChildren(), fn);
|
||||
return new CallOperator(FunctionSet.HOUR_FROM_UNIXTIME, call.getType(), fromUnixTime.getChildren(), fn);
|
||||
}
|
||||
|
||||
// Case 2: hour(to_datetime(ts)) or hour(to_datetime(ts, 0)) -> hour_from_unixtime(ts)
|
||||
if (FunctionSet.TO_DATETIME.equalsIgnoreCase(childFnName)) {
|
||||
List<ScalarOperator> args = childCall.getChildren();
|
||||
ScalarOperator toDatetime = lookupChild(call,
|
||||
x -> x instanceof CallOperator &&
|
||||
((CallOperator) x).getFnName().equalsIgnoreCase(FunctionSet.TO_DATETIME));
|
||||
if (toDatetime != null) {
|
||||
List<ScalarOperator> args = toDatetime.getChildren();
|
||||
ScalarOperator tsArg;
|
||||
ScalarOperator unixtimeArgForHour;
|
||||
|
||||
|
|
|
|||
|
|
@ -17,21 +17,21 @@ package com.starrocks.sql.optimizer.rewrite.scalar;
|
|||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.starrocks.analysis.BinaryType;
|
||||
import com.starrocks.catalog.FunctionSet;
|
||||
import com.starrocks.catalog.Type;
|
||||
import com.starrocks.sql.optimizer.operator.OperatorType;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.LikePredicateOperator;
|
||||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
|
||||
import com.starrocks.sql.plan.PlanTestBase;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
public class SimplifiedPredicateRuleTest {
|
||||
public class SimplifiedPredicateRuleTest extends PlanTestBase {
|
||||
private static final ConstantOperator OI_NULL = ConstantOperator.createNull(Type.INT);
|
||||
private static final ConstantOperator OI_100 = ConstantOperator.createInt(100);
|
||||
private static final ConstantOperator OI_200 = ConstantOperator.createInt(200);
|
||||
|
|
@ -42,6 +42,20 @@ public class SimplifiedPredicateRuleTest {
|
|||
|
||||
private SimplifiedPredicateRule rule = new SimplifiedPredicateRule();
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll() throws Exception {
|
||||
starRocksAssert.withTable("CREATE TABLE IF NOT EXISTS `test_timestamp` (\n" +
|
||||
" `id` bigint NULL COMMENT \"\",\n" +
|
||||
" `ts` bigint NULL COMMENT \"unix timestamp\"\n" +
|
||||
") ENGINE=OLAP\n" +
|
||||
"DUPLICATE KEY(`id`)\n" +
|
||||
"DISTRIBUTED BY HASH(`id`) BUCKETS 3\n" +
|
||||
"PROPERTIES (\n" +
|
||||
"\"replication_num\" = \"1\",\n" +
|
||||
"\"in_memory\" = \"false\"\n" +
|
||||
");");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void applyCaseWhen() {
|
||||
CaseWhenOperator cwo1 = new CaseWhenOperator(Type.INT, new ColumnRefOperator(1, Type.INT, "id", true), null,
|
||||
|
|
@ -105,113 +119,32 @@ public class SimplifiedPredicateRuleTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void applyHourFromUnixTime() {
|
||||
// Test hour(from_unixtime(ts)) -> hour_from_unixtime(ts)
|
||||
ColumnRefOperator tsColumn = new ColumnRefOperator(1, Type.BIGINT, "ts", true);
|
||||
public void applyHourFromUnixTime() throws Exception {
|
||||
starRocksAssert.query("SELECT hour(from_unixtime(ts)) FROM test_timestamp")
|
||||
.explainContains("hour_from_unixtime");
|
||||
|
||||
// Create from_unixtime(ts) call
|
||||
CallOperator fromUnixTimeCall = new CallOperator(FunctionSet.FROM_UNIXTIME, Type.VARCHAR,
|
||||
Lists.newArrayList(tsColumn), null);
|
||||
starRocksAssert.query("SELECT hour(ts) FROM test_timestamp")
|
||||
.explainWithout("hour_from_unixtime");
|
||||
|
||||
// Create hour(from_unixtime(ts)) call
|
||||
CallOperator hourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(fromUnixTimeCall), null);
|
||||
|
||||
ScalarOperator result = rule.apply(hourCall, null);
|
||||
|
||||
// Verify the result is hour_from_unixtime(ts)
|
||||
assertEquals(OperatorType.CALL, result.getOpType());
|
||||
CallOperator resultCall = (CallOperator) result;
|
||||
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall.getFnName());
|
||||
assertEquals(1, resultCall.getChildren().size());
|
||||
assertEquals(tsColumn, resultCall.getChild(0));
|
||||
|
||||
// Test that hour(ts) is not optimized (not from_unixtime)
|
||||
CallOperator simpleHourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(tsColumn), null);
|
||||
ScalarOperator simpleResult = rule.apply(simpleHourCall, null);
|
||||
assertEquals(simpleHourCall, simpleResult);
|
||||
|
||||
// Test that hour(from_unixtime(ts, format)) is not optimized (multiple arguments)
|
||||
CallOperator fromUnixTimeCall2 = new CallOperator(FunctionSet.FROM_UNIXTIME, Type.VARCHAR,
|
||||
Lists.newArrayList(tsColumn, ConstantOperator.createVarchar("format")), null);
|
||||
CallOperator hourCall2 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(fromUnixTimeCall2), null);
|
||||
ScalarOperator result2 = rule.apply(hourCall2, null);
|
||||
assertEquals(hourCall2, result2);
|
||||
starRocksAssert.query("SELECT hour(from_unixtime(ts, '%Y-%m-%d %H:%i:%s')) FROM test_timestamp")
|
||||
.explainWithout("hour_from_unixtime");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void applyHourToDatetimeRewrite() {
|
||||
// hour(to_datetime(ts)) -> hour_from_unixtime(ts)
|
||||
ColumnRefOperator tsColumn = new ColumnRefOperator(2, Type.BIGINT, "ts2", true);
|
||||
public void applyHourToDatetimeRewrite() throws Exception {
|
||||
starRocksAssert.query("SELECT hour(to_datetime(ts)) FROM test_timestamp")
|
||||
.explainContains("hour_from_unixtime");
|
||||
|
||||
CallOperator toDatetimeCall = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
|
||||
Lists.newArrayList(tsColumn), null);
|
||||
CallOperator hourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(toDatetimeCall), null);
|
||||
starRocksAssert.query("SELECT hour(to_datetime(ts, 0)) FROM test_timestamp")
|
||||
.explainContains("hour_from_unixtime");
|
||||
|
||||
ScalarOperator result = rule.apply(hourCall, null);
|
||||
assertEquals(OperatorType.CALL, result.getOpType());
|
||||
CallOperator resultCall = (CallOperator) result;
|
||||
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall.getFnName());
|
||||
assertEquals(1, resultCall.getChildren().size());
|
||||
assertEquals(tsColumn, resultCall.getChild(0));
|
||||
starRocksAssert.query("SELECT hour(to_datetime(ts, 3)) FROM test_timestamp")
|
||||
.explainContains("hour_from_unixtime", "/ 1000");
|
||||
|
||||
// hour(to_datetime(ts, 0)) -> hour_from_unixtime(ts)
|
||||
CallOperator toDatetimeCallScale0 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
|
||||
Lists.newArrayList(tsColumn, ConstantOperator.createInt(0)), null);
|
||||
CallOperator hourCall2 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(toDatetimeCallScale0), null);
|
||||
starRocksAssert.query("SELECT hour(to_datetime(ts, 6)) FROM test_timestamp")
|
||||
.explainContains("hour_from_unixtime", "/ 1000000");
|
||||
|
||||
ScalarOperator result2 = rule.apply(hourCall2, null);
|
||||
assertEquals(OperatorType.CALL, result2.getOpType());
|
||||
CallOperator resultCall2 = (CallOperator) result2;
|
||||
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall2.getFnName());
|
||||
assertEquals(1, resultCall2.getChildren().size());
|
||||
assertEquals(tsColumn, resultCall2.getChild(0));
|
||||
|
||||
// hour(to_datetime(ts, 3)) -> hour_from_unixtime(ts/1000)
|
||||
CallOperator toDatetimeCallScale3 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
|
||||
Lists.newArrayList(tsColumn, ConstantOperator.createInt(3)), null);
|
||||
CallOperator hourCall3 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(toDatetimeCallScale3), null);
|
||||
ScalarOperator result3 = rule.apply(hourCall3, null);
|
||||
assertEquals(OperatorType.CALL, result3.getOpType());
|
||||
CallOperator resultCall3 = (CallOperator) result3;
|
||||
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall3.getFnName());
|
||||
assertEquals(1, resultCall3.getChildren().size());
|
||||
// Expect a divide(ts, 1000) as the argument
|
||||
ScalarOperator arg3 = resultCall3.getChild(0);
|
||||
assertEquals(OperatorType.CALL, arg3.getOpType());
|
||||
CallOperator div3 = (CallOperator) arg3;
|
||||
assertEquals(FunctionSet.DIVIDE, div3.getFnName());
|
||||
assertEquals(tsColumn, div3.getChild(0));
|
||||
assertEquals(ConstantOperator.createInt(1000), div3.getChild(1));
|
||||
|
||||
// hour(to_datetime(ts, 6)) -> hour_from_unixtime(ts/1000000)
|
||||
CallOperator toDatetimeCallScale6 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
|
||||
Lists.newArrayList(tsColumn, ConstantOperator.createInt(6)), null);
|
||||
CallOperator hourCall6 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(toDatetimeCallScale6), null);
|
||||
ScalarOperator result6 = rule.apply(hourCall6, null);
|
||||
assertEquals(OperatorType.CALL, result6.getOpType());
|
||||
CallOperator resultCall6 = (CallOperator) result6;
|
||||
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall6.getFnName());
|
||||
assertEquals(1, resultCall6.getChildren().size());
|
||||
ScalarOperator arg6 = resultCall6.getChild(0);
|
||||
assertEquals(OperatorType.CALL, arg6.getOpType());
|
||||
CallOperator div6 = (CallOperator) arg6;
|
||||
assertEquals(FunctionSet.DIVIDE, div6.getFnName());
|
||||
assertEquals(tsColumn, div6.getChild(0));
|
||||
assertEquals(ConstantOperator.createInt(1_000_000), div6.getChild(1));
|
||||
|
||||
// Unsupported scale like 4 should not be rewritten
|
||||
CallOperator toDatetimeCallScale4 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
|
||||
Lists.newArrayList(tsColumn, ConstantOperator.createInt(4)), null);
|
||||
CallOperator hourCall4 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
|
||||
Lists.newArrayList(toDatetimeCallScale4), null);
|
||||
ScalarOperator result4 = rule.apply(hourCall4, null);
|
||||
assertEquals(hourCall4, result4);
|
||||
starRocksAssert.query("SELECT hour(to_datetime(ts, 4)) FROM test_timestamp")
|
||||
.explainWithout("hour_from_unixtime");
|
||||
}
|
||||
}
|
||||
|
|
@ -605,7 +605,7 @@ vectorized_functions = [
|
|||
# TODO: 50380 year_from_unixtime
|
||||
# TODO: 50381 month_from_unixtime
|
||||
# TODO: 50382 day_from_unixtime
|
||||
[50383, 'hour_from_unixtime', True, False, 'INT', ['BIGINT'], 'TimeFunctions::hour_from_unixtime'],
|
||||
[50383, 'hour_from_unixtime', True, False, 'TINYINT', ['BIGINT'], 'TimeFunctions::hour_from_unixtime'],
|
||||
# TODO: 50384 minute_from_unixtime
|
||||
# TODO: 50385 second_from_unixtime
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue