[Enhancement] Support complex expressions in FILTER clause and add boolean type validation for aggregate functions (backport #62637) (#62665)
Signed-off-by: stephen <stephen5217@163.com> Co-authored-by: stephen <91597003+stephen-shelby@users.noreply.github.com>
This commit is contained in:
parent
ff547d6a45
commit
9b21d191af
|
|
@ -198,6 +198,17 @@ public class FunctionAnalyzer {
|
|||
FunctionName argFuncNameWithoutIf =
|
||||
new FunctionName(AggStateUtils.getAggFuncNameOfCombinator(fnName.getFunction()));
|
||||
FunctionParams params = functionCallExpr.getParams();
|
||||
|
||||
// Validate that the condition parameter (last parameter) is boolean type or can be cast to boolean
|
||||
if (!params.exprs().isEmpty()) {
|
||||
Expr conditionExpr = params.exprs().get(params.exprs().size() - 1);
|
||||
if (!Type.canCastTo(conditionExpr.getType(), Type.BOOLEAN)) {
|
||||
throw new SemanticException(String.format(
|
||||
"The condition expression in %s function must be boolean type or castable to boolean, but got %s",
|
||||
fnName.getFunction(), conditionExpr.getType().toSql()), functionCallExpr.getPos());
|
||||
}
|
||||
}
|
||||
|
||||
FunctionParams functionParamsWithOutIf =
|
||||
new FunctionParams(params.isStar(), params.exprs().subList(0, params.exprs().size() - 1),
|
||||
params.getExprsNames() == null ? null :
|
||||
|
|
|
|||
|
|
@ -7595,7 +7595,7 @@ public class AstBuilder extends StarRocksBaseVisitor<ParseNode> {
|
|||
if (isCountFunc && isDistinct) {
|
||||
throw new ParsingException("Aggregation filter does not support COUNT DISTINCT");
|
||||
}
|
||||
Expr booleanExpr = (Expr) visit(context.filter().booleanExpression());
|
||||
Expr booleanExpr = (Expr) visit(context.filter().expression());
|
||||
functionName = functionName + FunctionSet.AGG_STATE_IF_SUFFIX;
|
||||
exprs.add(booleanExpr);
|
||||
|
||||
|
|
|
|||
|
|
@ -2741,7 +2741,7 @@ whenClause
|
|||
;
|
||||
|
||||
filter
|
||||
: FILTER '(' WHERE booleanExpression ')'
|
||||
: FILTER '(' WHERE expression ')'
|
||||
;
|
||||
|
||||
over
|
||||
|
|
|
|||
|
|
@ -3058,4 +3058,68 @@ public class AggregateTest extends PlanTestBase {
|
|||
plan = getThriftPlan(sql);
|
||||
assertContains(plan, "group_by_min_max:[TExpr(");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAggregateFilterSyntax() throws Exception {
|
||||
// Test basic FILTER syntax with boolean expression
|
||||
String sql = "select count(*) filter (where v1 > 5) from t0";
|
||||
String plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "count_if");
|
||||
|
||||
// Test FILTER with complex boolean expression
|
||||
sql = "select sum(v2) filter (where v1 > 5 and v2 < 10) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "sum_if");
|
||||
|
||||
// Test FILTER with logical operators
|
||||
sql = "select avg(v3) filter (where v1 = 1 or v2 = 2) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "avg_if");
|
||||
|
||||
// Test FILTER with NOT operator
|
||||
sql = "select max(v1) filter (where not (v2 > 10)) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "max_if");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAggregateFilterBooleanTypeValidation() throws Exception {
|
||||
// Test that numeric expressions in FILTER are now allowed (can be cast to boolean)
|
||||
String sql = "select count(*) filter (where v1) from t0";
|
||||
String plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "count_if");
|
||||
|
||||
// Test that string expressions in FILTER are also allowed (can be cast to boolean)
|
||||
sql = "select sum(v2) filter (where 'true') from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "sum_if");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAggIfFunctionBooleanTypeValidation() throws Exception {
|
||||
// Test sum_if with correct boolean condition
|
||||
String sql = "select sum_if(v2, v1 > 5) from t0";
|
||||
String plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "sum_if");
|
||||
|
||||
// Test count_if with correct boolean condition
|
||||
sql = "select count_if(v1 > 0 and v2 < 100) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "count_if");
|
||||
|
||||
// Test that numeric conditions in sum_if are now allowed (can be cast to boolean)
|
||||
sql = "select sum_if(v2, v1) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "sum_if");
|
||||
|
||||
// Test that numeric conditions in count_if are now allowed (can be cast to boolean)
|
||||
sql = "select count_if(v2) from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "count_if");
|
||||
|
||||
// Test string conditions are also allowed
|
||||
sql = "select sum_if(v2, 'true') from t0";
|
||||
plan = getFragmentPlan(sql);
|
||||
assertContains(plan, "sum_if");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue