[BugFix] fix hour_from_unixtime rule (#63006)

This commit is contained in:
Murphy 2025-09-12 13:57:33 +08:00 committed by GitHub
parent 96c4a26560
commit 3688bc4bd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 77 additions and 133 deletions

View File

@ -1605,7 +1605,7 @@ StatusOr<ColumnPtr> TimeFunctions::hour_from_unixtime(FunctionContext* context,
auto ctz = context->state()->timezone_obj();
auto size = columns[0]->size();
ColumnViewer<TYPE_BIGINT> data_column(columns[0]);
ColumnBuilder<TYPE_INT> result(size);
ColumnBuilder<TYPE_TINYINT> result(size);
for (int row = 0; row < size; ++row) {
if (data_column.is_null(row)) {
result.append_null();
@ -3899,4 +3899,4 @@ StatusOr<ColumnPtr> TimeFunctions::time_format(FunctionContext* context, const s
} // namespace starrocks
#include "gen_cpp/opcode/TimeFunctions.inc"
#include "gen_cpp/opcode/TimeFunctions.inc"

View File

@ -4560,7 +4560,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
columns.emplace_back(tc);
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
for (size_t i = 0; i < sizeof(expected) / sizeof(expected[0]); ++i) {
EXPECT_EQ(expected[i], hours->get_data()[i]) << "Failed for basic positive at index " << i;
}
@ -4593,7 +4593,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
columns.emplace_back(tc);
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
for (size_t i = 0; i < sizeof(expected_negative) / sizeof(expected_negative[0]); ++i) {
EXPECT_EQ(expected_negative[i], hours->get_data()[i])
<< "Failed for timezone offset at index " << i << " with value " << tc->get_data()[i];
@ -4628,7 +4628,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
columns.emplace_back(tc);
ColumnPtr result = TimeFunctions::hour_from_unixtime(_utils->get_fn_ctx(), columns).value();
auto hours = ColumnHelper::cast_to<TYPE_INT>(result);
auto hours = ColumnHelper::cast_to<TYPE_TINYINT>(result);
for (size_t i = 0; i < sizeof(expected_mixed) / sizeof(expected_mixed[0]); ++i) {
EXPECT_EQ(expected_mixed[i], hours->get_data()[i])
<< "Failed for mixed timezone offset at index " << i << " with value " << tc->get_data()[i];
@ -4665,14 +4665,14 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
ASSERT_EQ(8, nullable_result->size());
// Check that results are in correct order
EXPECT_EQ(3, nullable_result->get(0).get_int32()); // 0 -> hour 3
EXPECT_TRUE(nullable_result->is_null(1)); // null
EXPECT_EQ(4, nullable_result->get(2).get_int32()); // 3600 -> hour 4
EXPECT_TRUE(nullable_result->is_null(3)); // null
EXPECT_EQ(5, nullable_result->get(4).get_int32()); // 7200 -> hour 5
EXPECT_EQ(2, nullable_result->get(5).get_int32()); // 82800 -> hour 2 (next day)
EXPECT_TRUE(nullable_result->is_null(6)); // null
EXPECT_EQ(6, nullable_result->get(7).get_int32()); // 10800 -> hour 6
EXPECT_EQ(3, nullable_result->get(0).get_int8()); // 0 -> hour 3
EXPECT_TRUE(nullable_result->is_null(1)); // null
EXPECT_EQ(4, nullable_result->get(2).get_int8()); // 3600 -> hour 4
EXPECT_TRUE(nullable_result->is_null(3)); // null
EXPECT_EQ(5, nullable_result->get(4).get_int8()); // 7200 -> hour 5
EXPECT_EQ(2, nullable_result->get(5).get_int8()); // 82800 -> hour 2 (next day)
EXPECT_TRUE(nullable_result->is_null(6)); // null
EXPECT_EQ(6, nullable_result->get(7).get_int8()); // 10800 -> hour 6
}
// Test 5: Edge cases for hour wrapping with timezone offset and negative input (should return null)
@ -4715,7 +4715,7 @@ TEST_F(TimeFunctionsTest, hourFromUnixTime) {
// Check non-negative cases
for (size_t i = 0; i < 4; ++i) {
EXPECT_FALSE(nullable_result->is_null(i)) << "Unexpected null at index " << i;
EXPECT_EQ(expected_edge[i], nullable_result->get(i).get_int32())
EXPECT_EQ(expected_edge[i], nullable_result->get(i).get_int8())
<< "Failed for edge case with timezone offset at index " << i << " with value "
<< tc->get(i).get_int64();
}

View File

@ -44,6 +44,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
@ -576,6 +577,19 @@ public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
ConstantOperator.createVarchar(mergePath)), child.getFunction());
}
private static ScalarOperator lookupChild(ScalarOperator call, Predicate<ScalarOperator> predicate) {
if (predicate.test(call)) {
return call;
}
for (ScalarOperator child : call.getChildren()) {
ScalarOperator res = lookupChild(child, predicate);
if (res != null) {
return res;
}
}
return null;
}
// Simplify hour(from_unixtime(ts)) to hour_from_unixtime(ts)
// Also simplify hour(to_datetime(ts)) and hour(to_datetime(ts, 0)) to hour_from_unixtime(ts)
private static ScalarOperator simplifiedHourFromUnixTime(CallOperator call) {
@ -583,31 +597,28 @@ public class SimplifiedPredicateRule extends BottomUpScalarOperatorRewriteRule {
return call;
}
ScalarOperator child = call.getChild(0);
if (!(child instanceof CallOperator)) {
return call;
}
CallOperator childCall = (CallOperator) child;
String childFnName = childCall.getFnName();
// Case 1: hour(from_unixtime(ts)) -> hour_from_unixtime(ts)
if (FunctionSet.FROM_UNIXTIME.equalsIgnoreCase(childFnName)) {
ScalarOperator fromUnixTime = lookupChild(call,
x -> x instanceof CallOperator &&
((CallOperator) x).getFnName().equalsIgnoreCase(FunctionSet.FROM_UNIXTIME));
if (fromUnixTime != null) {
// Keep original behavior: only succeeds when argument list matches hour_from_unixtime signature
Type[] argTypes = childCall.getChildren().stream().map(ScalarOperator::getType).toArray(Type[]::new);
Type[] argTypes = fromUnixTime.getChildren().stream().map(ScalarOperator::getType).toArray(Type[]::new);
Function fn =
Expr.getBuiltinFunction(FunctionSet.HOUR_FROM_UNIXTIME, argTypes, Function.CompareMode.IS_IDENTICAL);
Expr.getBuiltinFunction(FunctionSet.HOUR_FROM_UNIXTIME, argTypes,
Function.CompareMode.IS_IDENTICAL);
if (fn == null) {
return call;
}
return new CallOperator(FunctionSet.HOUR_FROM_UNIXTIME, call.getType(), childCall.getChildren(), fn);
return new CallOperator(FunctionSet.HOUR_FROM_UNIXTIME, call.getType(), fromUnixTime.getChildren(), fn);
}
// Case 2: hour(to_datetime(ts)) or hour(to_datetime(ts, 0)) -> hour_from_unixtime(ts)
if (FunctionSet.TO_DATETIME.equalsIgnoreCase(childFnName)) {
List<ScalarOperator> args = childCall.getChildren();
ScalarOperator toDatetime = lookupChild(call,
x -> x instanceof CallOperator &&
((CallOperator) x).getFnName().equalsIgnoreCase(FunctionSet.TO_DATETIME));
if (toDatetime != null) {
List<ScalarOperator> args = toDatetime.getChildren();
ScalarOperator tsArg;
ScalarOperator unixtimeArgForHour;

View File

@ -16,22 +16,22 @@
package com.starrocks.sql.optimizer.rewrite.scalar;
import com.google.common.collect.Lists;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.ast.expression.BinaryType;
import com.starrocks.sql.optimizer.operator.OperatorType;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.LikePredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.plan.PlanTestBase;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class SimplifiedPredicateRuleTest {
public class SimplifiedPredicateRuleTest extends PlanTestBase {
private static final ConstantOperator OI_NULL = ConstantOperator.createNull(Type.INT);
private static final ConstantOperator OI_100 = ConstantOperator.createInt(100);
private static final ConstantOperator OI_200 = ConstantOperator.createInt(200);
@ -42,6 +42,20 @@ public class SimplifiedPredicateRuleTest {
private SimplifiedPredicateRule rule = new SimplifiedPredicateRule();
@BeforeAll
public static void beforeAll() throws Exception {
starRocksAssert.withTable("CREATE TABLE IF NOT EXISTS `test_timestamp` (\n" +
" `id` bigint NULL COMMENT \"\",\n" +
" `ts` bigint NULL COMMENT \"unix timestamp\"\n" +
") ENGINE=OLAP\n" +
"DUPLICATE KEY(`id`)\n" +
"DISTRIBUTED BY HASH(`id`) BUCKETS 3\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\",\n" +
"\"in_memory\" = \"false\"\n" +
");");
}
@Test
public void applyCaseWhen() {
CaseWhenOperator cwo1 = new CaseWhenOperator(Type.INT, new ColumnRefOperator(1, Type.INT, "id", true), null,
@ -105,113 +119,32 @@ public class SimplifiedPredicateRuleTest {
}
@Test
public void applyHourFromUnixTime() {
// Test hour(from_unixtime(ts)) -> hour_from_unixtime(ts)
ColumnRefOperator tsColumn = new ColumnRefOperator(1, Type.BIGINT, "ts", true);
public void applyHourFromUnixTime() throws Exception {
starRocksAssert.query("SELECT hour(from_unixtime(ts)) FROM test_timestamp")
.explainContains("hour_from_unixtime");
// Create from_unixtime(ts) call
CallOperator fromUnixTimeCall = new CallOperator(FunctionSet.FROM_UNIXTIME, Type.VARCHAR,
Lists.newArrayList(tsColumn), null);
starRocksAssert.query("SELECT hour(ts) FROM test_timestamp")
.explainWithout("hour_from_unixtime");
// Create hour(from_unixtime(ts)) call
CallOperator hourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(fromUnixTimeCall), null);
ScalarOperator result = rule.apply(hourCall, null);
// Verify the result is hour_from_unixtime(ts)
assertEquals(OperatorType.CALL, result.getOpType());
CallOperator resultCall = (CallOperator) result;
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall.getFnName());
assertEquals(1, resultCall.getChildren().size());
assertEquals(tsColumn, resultCall.getChild(0));
// Test that hour(ts) is not optimized (not from_unixtime)
CallOperator simpleHourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(tsColumn), null);
ScalarOperator simpleResult = rule.apply(simpleHourCall, null);
assertEquals(simpleHourCall, simpleResult);
// Test that hour(from_unixtime(ts, format)) is not optimized (multiple arguments)
CallOperator fromUnixTimeCall2 = new CallOperator(FunctionSet.FROM_UNIXTIME, Type.VARCHAR,
Lists.newArrayList(tsColumn, ConstantOperator.createVarchar("format")), null);
CallOperator hourCall2 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(fromUnixTimeCall2), null);
ScalarOperator result2 = rule.apply(hourCall2, null);
assertEquals(hourCall2, result2);
starRocksAssert.query("SELECT hour(from_unixtime(ts, '%Y-%m-%d %H:%i:%s')) FROM test_timestamp")
.explainWithout("hour_from_unixtime");
}
@Test
public void applyHourToDatetimeRewrite() {
// hour(to_datetime(ts)) -> hour_from_unixtime(ts)
ColumnRefOperator tsColumn = new ColumnRefOperator(2, Type.BIGINT, "ts2", true);
public void applyHourToDatetimeRewrite() throws Exception {
starRocksAssert.query("SELECT hour(to_datetime(ts)) FROM test_timestamp")
.explainContains("hour_from_unixtime");
CallOperator toDatetimeCall = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
Lists.newArrayList(tsColumn), null);
CallOperator hourCall = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(toDatetimeCall), null);
starRocksAssert.query("SELECT hour(to_datetime(ts, 0)) FROM test_timestamp")
.explainContains("hour_from_unixtime");
ScalarOperator result = rule.apply(hourCall, null);
assertEquals(OperatorType.CALL, result.getOpType());
CallOperator resultCall = (CallOperator) result;
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall.getFnName());
assertEquals(1, resultCall.getChildren().size());
assertEquals(tsColumn, resultCall.getChild(0));
starRocksAssert.query("SELECT hour(to_datetime(ts, 3)) FROM test_timestamp")
.explainContains("hour_from_unixtime", "/ 1000");
// hour(to_datetime(ts, 0)) -> hour_from_unixtime(ts)
CallOperator toDatetimeCallScale0 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
Lists.newArrayList(tsColumn, ConstantOperator.createInt(0)), null);
CallOperator hourCall2 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(toDatetimeCallScale0), null);
starRocksAssert.query("SELECT hour(to_datetime(ts, 6)) FROM test_timestamp")
.explainContains("hour_from_unixtime", "/ 1000000");
ScalarOperator result2 = rule.apply(hourCall2, null);
assertEquals(OperatorType.CALL, result2.getOpType());
CallOperator resultCall2 = (CallOperator) result2;
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall2.getFnName());
assertEquals(1, resultCall2.getChildren().size());
assertEquals(tsColumn, resultCall2.getChild(0));
// hour(to_datetime(ts, 3)) -> hour_from_unixtime(ts/1000)
CallOperator toDatetimeCallScale3 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
Lists.newArrayList(tsColumn, ConstantOperator.createInt(3)), null);
CallOperator hourCall3 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(toDatetimeCallScale3), null);
ScalarOperator result3 = rule.apply(hourCall3, null);
assertEquals(OperatorType.CALL, result3.getOpType());
CallOperator resultCall3 = (CallOperator) result3;
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall3.getFnName());
assertEquals(1, resultCall3.getChildren().size());
// Expect a divide(ts, 1000) as the argument
ScalarOperator arg3 = resultCall3.getChild(0);
assertEquals(OperatorType.CALL, arg3.getOpType());
CallOperator div3 = (CallOperator) arg3;
assertEquals(FunctionSet.DIVIDE, div3.getFnName());
assertEquals(tsColumn, div3.getChild(0));
assertEquals(ConstantOperator.createInt(1000), div3.getChild(1));
// hour(to_datetime(ts, 6)) -> hour_from_unixtime(ts/1000000)
CallOperator toDatetimeCallScale6 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
Lists.newArrayList(tsColumn, ConstantOperator.createInt(6)), null);
CallOperator hourCall6 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(toDatetimeCallScale6), null);
ScalarOperator result6 = rule.apply(hourCall6, null);
assertEquals(OperatorType.CALL, result6.getOpType());
CallOperator resultCall6 = (CallOperator) result6;
assertEquals(FunctionSet.HOUR_FROM_UNIXTIME, resultCall6.getFnName());
assertEquals(1, resultCall6.getChildren().size());
ScalarOperator arg6 = resultCall6.getChild(0);
assertEquals(OperatorType.CALL, arg6.getOpType());
CallOperator div6 = (CallOperator) arg6;
assertEquals(FunctionSet.DIVIDE, div6.getFnName());
assertEquals(tsColumn, div6.getChild(0));
assertEquals(ConstantOperator.createInt(1_000_000), div6.getChild(1));
// Unsupported scale like 4 should not be rewritten
CallOperator toDatetimeCallScale4 = new CallOperator(FunctionSet.TO_DATETIME, Type.DATETIME,
Lists.newArrayList(tsColumn, ConstantOperator.createInt(4)), null);
CallOperator hourCall4 = new CallOperator(FunctionSet.HOUR, Type.TINYINT,
Lists.newArrayList(toDatetimeCallScale4), null);
ScalarOperator result4 = rule.apply(hourCall4, null);
assertEquals(hourCall4, result4);
starRocksAssert.query("SELECT hour(to_datetime(ts, 4)) FROM test_timestamp")
.explainWithout("hour_from_unixtime");
}
}

View File

@ -608,7 +608,7 @@ vectorized_functions = [
# TODO: 50380 year_from_unixtime
# TODO: 50381 month_from_unixtime
# TODO: 50382 day_from_unixtime
[50383, 'hour_from_unixtime', True, False, 'INT', ['BIGINT'], 'TimeFunctions::hour_from_unixtime'],
[50383, 'hour_from_unixtime', True, False, 'TINYINT', ['BIGINT'], 'TimeFunctions::hour_from_unixtime'],
# TODO: 50384 minute_from_unixtime
# TODO: 50385 second_from_unixtime