[Enhancement] Support complex expressions in FILTER clause and add boolean type validation for aggregate functions (backport #62637) (#62665)

Signed-off-by: stephen <stephen5217@163.com>
Co-authored-by: stephen <91597003+stephen-shelby@users.noreply.github.com>
This commit is contained in:
mergify[bot] 2025-09-03 10:27:31 +08:00 committed by GitHub
parent ff547d6a45
commit 9b21d191af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 2 deletions

View File

@ -198,6 +198,17 @@ public class FunctionAnalyzer {
FunctionName argFuncNameWithoutIf =
new FunctionName(AggStateUtils.getAggFuncNameOfCombinator(fnName.getFunction()));
FunctionParams params = functionCallExpr.getParams();
// Validate that the condition parameter (last parameter) is boolean type or can be cast to boolean
if (!params.exprs().isEmpty()) {
Expr conditionExpr = params.exprs().get(params.exprs().size() - 1);
if (!Type.canCastTo(conditionExpr.getType(), Type.BOOLEAN)) {
throw new SemanticException(String.format(
"The condition expression in %s function must be boolean type or castable to boolean, but got %s",
fnName.getFunction(), conditionExpr.getType().toSql()), functionCallExpr.getPos());
}
}
FunctionParams functionParamsWithOutIf =
new FunctionParams(params.isStar(), params.exprs().subList(0, params.exprs().size() - 1),
params.getExprsNames() == null ? null :

View File

@ -7595,7 +7595,7 @@ public class AstBuilder extends StarRocksBaseVisitor<ParseNode> {
if (isCountFunc && isDistinct) {
throw new ParsingException("Aggregation filter does not support COUNT DISTINCT");
}
Expr booleanExpr = (Expr) visit(context.filter().booleanExpression());
Expr booleanExpr = (Expr) visit(context.filter().expression());
functionName = functionName + FunctionSet.AGG_STATE_IF_SUFFIX;
exprs.add(booleanExpr);

View File

@ -2741,7 +2741,7 @@ whenClause
;
filter
: FILTER '(' WHERE booleanExpression ')'
: FILTER '(' WHERE expression ')'
;
over

View File

@ -3058,4 +3058,68 @@ public class AggregateTest extends PlanTestBase {
plan = getThriftPlan(sql);
assertContains(plan, "group_by_min_max:[TExpr(");
}
@Test
public void testAggregateFilterSyntax() throws Exception {
// Test basic FILTER syntax with boolean expression
String sql = "select count(*) filter (where v1 > 5) from t0";
String plan = getFragmentPlan(sql);
assertContains(plan, "count_if");
// Test FILTER with complex boolean expression
sql = "select sum(v2) filter (where v1 > 5 and v2 < 10) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "sum_if");
// Test FILTER with logical operators
sql = "select avg(v3) filter (where v1 = 1 or v2 = 2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "avg_if");
// Test FILTER with NOT operator
sql = "select max(v1) filter (where not (v2 > 10)) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "max_if");
}
@Test
public void testAggregateFilterBooleanTypeValidation() throws Exception {
// Test that numeric expressions in FILTER are now allowed (can be cast to boolean)
String sql = "select count(*) filter (where v1) from t0";
String plan = getFragmentPlan(sql);
assertContains(plan, "count_if");
// Test that string expressions in FILTER are also allowed (can be cast to boolean)
sql = "select sum(v2) filter (where 'true') from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "sum_if");
}
@Test
public void testAggIfFunctionBooleanTypeValidation() throws Exception {
// Test sum_if with correct boolean condition
String sql = "select sum_if(v2, v1 > 5) from t0";
String plan = getFragmentPlan(sql);
assertContains(plan, "sum_if");
// Test count_if with correct boolean condition
sql = "select count_if(v1 > 0 and v2 < 100) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "count_if");
// Test that numeric conditions in sum_if are now allowed (can be cast to boolean)
sql = "select sum_if(v2, v1) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "sum_if");
// Test that numeric conditions in count_if are now allowed (can be cast to boolean)
sql = "select count_if(v2) from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "count_if");
// Test string conditions are also allowed
sql = "select sum_if(v2, 'true') from t0";
plan = getFragmentPlan(sql);
assertContains(plan, "sum_if");
}
}