[Enhancement] support map type in UDAF (#60840)

Signed-off-by: yan zhang <dirtysalt1987@gmail.com>
This commit is contained in:
yan zhang 2025-07-18 09:49:51 +08:00 committed by GitHub
parent 96b2c68575
commit f5f8e9bc2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 400 additions and 11 deletions

View File

@ -452,6 +452,7 @@ public:
helper.get_result_from_boxed_array(ctx, type, output.get(), res, batch_size);
} else {
helper.get_result_from_boxed_array(ctx, type, to, res, batch_size);
(void)ColumnHelper::update_nested_has_null(to);
down_cast<NullableColumn*>(to)->update_has_null();
}
}

58
contrib/java-udf/pom.xml Normal file
View File

@ -0,0 +1,58 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.starrocks.udf</groupId>
<artifactId>java-udf</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>17</java.version>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -0,0 +1,12 @@
package com.starrocks.udf;
import org.apache.commons.math3.special.Erf;
public class NormalCdf {
public final Double evaluate(Double mean, Double standardDeviation, Double value) {
if (mean == null || standardDeviation == null || value == null) {
return null;
}
return 0.5 * (1 + Erf.erf((value - mean) / (standardDeviation * Math.sqrt(2))));
}
}

View File

@ -0,0 +1,89 @@
package com.starrocks.udf;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.Map;
import java.util.function.BinaryOperator;
public class SumMap<K, T> {
BinaryOperator<T> sumFunction;
public SumMap(BinaryOperator<T> sumFunction) {
this.sumFunction = sumFunction;
}
public static class State<K, T> {
Map<K, T> values = new java.util.HashMap<>();
public byte[] serialize() throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
for (Map.Entry<K, T> entry : values.entrySet()) {
oos.writeObject(entry.getKey());
oos.writeObject(entry.getValue());
}
oos.flush();
oos.close();
return baos.toByteArray();
}
public int serializeLength() throws IOException {
byte[] serializedData = serialize();
return serializedData.length;
}
public void serialize(java.nio.ByteBuffer buff) throws IOException {
byte[] serializedData = serialize();
buff.put(serializedData);
}
public void deserialize(java.nio.ByteBuffer buff) throws IOException, ClassNotFoundException {
int length = buff.remaining();
byte[] data = new byte[length];
buff.get(data);
try (java.io.ObjectInputStream ois = new java.io.ObjectInputStream(new java.io.ByteArrayInputStream(data))) {
while (true) {
try {
K key = (K) ois.readObject();
T value = (T) ois.readObject();
values.put(key, value);
} catch (java.io.EOFException eof) {
break;
}
}
}
}
}
public void destroy(State state) {
}
public final void update(State state, Map<K, T> val) {
if (val != null) {
for (Map.Entry<K, T> entry : val.entrySet()) {
K key = entry.getKey();
T value = entry.getValue();
state.values.merge(key, value, sumFunction);
}
}
}
public void serialize(State state, java.nio.ByteBuffer buff) throws IOException {
state.serialize(buff);
}
public void merge(State state, java.nio.ByteBuffer buffer) throws IOException, ClassNotFoundException {
State<K, T> oldState = new State();
oldState.deserialize(buffer);
for (Map.Entry<K, T> entry : oldState.values.entrySet()) {
K key = entry.getKey();
T value = entry.getValue();
state.values.merge(key, value, sumFunction);
}
}
public Map<Object, T> finalize(State state) {
return state.values;
}
}

View File

@ -0,0 +1,41 @@
package com.starrocks.udf;
import java.io.IOException;
import java.util.Map;
public class SumMapInt64 extends SumMap<Object, Long> {
public SumMapInt64() {
super(Long::sum);
}
public static class State extends SumMap.State<Object, Long> {
public int serializeLength() throws IOException {
return super.serializeLength();
}
}
public State create() {
return new State();
}
public void destroy(State state) {
super.destroy(state);
}
// only support scalar type!!!
public final void update(State state, Map<Object, Long> val) {
super.update(state, val);
}
public void serialize(State state, java.nio.ByteBuffer buff) throws IOException {
super.serialize(state, buff);
}
public void merge(State state, java.nio.ByteBuffer buffer) throws IOException, ClassNotFoundException {
super.merge(state, buffer);
}
public Map<Object, Long> finalize(State state) {
return super.finalize(state);
}
}

View File

@ -179,7 +179,7 @@ public class CreateFunctionAnalyzer {
} else {
createdFunction = analyzeStarrocksJarUdtf(stmt, checksum, handleClass);
}
stmt.setFunction(createdFunction);
}
@ -198,7 +198,7 @@ public class CreateFunctionAnalyzer {
}
private Function analyzeStarrocksJarUdf(CreateFunctionStmt stmt, String checksum,
JavaUDFInternalClass handleClass) {
JavaUDFInternalClass handleClass) {
checkStarrocksJarUdfClass(stmt, handleClass);
FunctionName functionName = stmt.getFunctionName();
@ -304,8 +304,8 @@ public class CreateFunctionAnalyzer {
}
private Function analyzeStarrocksJarUdaf(CreateFunctionStmt stmt, String checksum,
JavaUDFInternalClass mainClass,
JavaUDFInternalClass udafStateClass) {
JavaUDFInternalClass mainClass,
JavaUDFInternalClass udafStateClass) {
FunctionName functionName = stmt.getFunctionName();
FunctionArgsDef argsDef = stmt.getArgsDef();
TypeDef returnType = stmt.getReturnType();
@ -329,7 +329,7 @@ public class CreateFunctionAnalyzer {
}
private Function analyzeStarrocksJarUdtf(CreateFunctionStmt stmt, String checksum,
JavaUDFInternalClass mainClass) {
JavaUDFInternalClass mainClass) {
FunctionName functionName = stmt.getFunctionName();
FunctionArgsDef argsDef = stmt.getArgsDef();
TypeDef returnType = stmt.getReturnType();
@ -488,17 +488,23 @@ public class CreateFunctionAnalyzer {
}
private void checkUdfType(Method method, Type expType, Class<?> ptype, String pname) {
if (!(expType instanceof ScalarType)) {
Class<?> cls = null;
if (expType instanceof ScalarType) {
ScalarType scalarType = (ScalarType) expType;
cls = PRIMITIVE_TYPE_TO_JAVA_CLASS_TYPE.get(scalarType.getPrimitiveType());
} else if (expType instanceof MapType) {
cls = Map.class;
} else if (expType instanceof ArrayType) {
cls = List.class;
} else {
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,
String.format("UDF class '%s' method '%s' does not support non-scalar type '%s'",
String.format("UDF class '%s' method '%s' does not support type '%s'",
clazz.getCanonicalName(), method.getName(), expType));
}
ScalarType scalarType = (ScalarType) expType;
Class<?> cls = PRIMITIVE_TYPE_TO_JAVA_CLASS_TYPE.get(scalarType.getPrimitiveType());
if (cls == null) {
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,
String.format("UDF class '%s' method '%s' does not support type '%s'",
clazz.getCanonicalName(), method.getName(), scalarType));
clazz.getCanonicalName(), method.getName(), expType));
}
if (!cls.equals(ptype)) {
ErrorReport.reportSemanticException(ErrorCode.ERR_COMMON_ERROR,

View File

@ -57,6 +57,30 @@ public class CreateFunctionStmtAnalyzerTest {
createFunctionSql, 32).get(0);
}
private CreateFunctionStmt createMapStmt(String symbol, String type) {
String createFunctionSql = String.format("CREATE %s FUNCTION ABC.MY_UDAF_MAP(map<string,string>) \n"
+ "RETURNS map<string,string> \n"
+ "properties (\n"
+ " \"symbol\" = \"%s\",\n"
+ " \"type\" = \"StarrocksJar\",\n"
+ " \"file\" = \"http://localhost:8080/\"\n"
+ ");", type, symbol);
return (CreateFunctionStmt) com.starrocks.sql.parser.SqlParser.parse(
createFunctionSql, 32).get(0);
}
private CreateFunctionStmt createListStmt(String symbol, String type) {
String createFunctionSql = String.format("CREATE %s FUNCTION ABC.MY_UDAF_LIST(array<string>) \n"
+ "RETURNS array<string> \n"
+ "properties (\n"
+ " \"symbol\" = \"%s\",\n"
+ " \"type\" = \"StarrocksJar\",\n"
+ " \"file\" = \"http://localhost:8080/\"\n"
+ ");", type, symbol);
return (CreateFunctionStmt) com.starrocks.sql.parser.SqlParser.parse(
createFunctionSql, 32).get(0);
}
private CreateFunctionStmt createPyStmt(String symbol, String type, String target) {
Config.enable_udf = true;
String createFunctionSql = String.format("CREATE FUNCTION ABC.MY_UDF_JSON_GET(string, string) \n"
@ -168,6 +192,7 @@ public class CreateFunctionStmtAnalyzerTest {
+ ");", args, ret);
return sql;
}
void mockClazz(Class<?> clazz) {
new MockUp<CreateFunctionAnalyzer>() {
@Mock
@ -249,6 +274,66 @@ public class CreateFunctionStmtAnalyzerTest {
}
}
public static class EmptyAggMapEval {
public static class State {
public int serializeLength() {
return 0;
}
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public final void update(State state, Map<String, String> val) {
}
public void serialize(State state, java.nio.ByteBuffer buff) {
}
public void merge(State state, java.nio.ByteBuffer buffer) {
}
public Map<String, String> finalize(State state) {
return null;
}
}
public static class EmptyAggListEval {
public static class State {
public int serializeLength() {
return 0;
}
}
public State create() {
return new State();
}
public void destroy(State state) {
}
public final void update(State state, List<String> val) {
}
public void serialize(State state, java.nio.ByteBuffer buff) {
}
public void merge(State state, java.nio.ByteBuffer buffer) {
}
public List<String> finalize(State state) {
return null;
}
}
@Test
public void testJUDAF() {
try {
@ -277,6 +362,62 @@ public class CreateFunctionStmtAnalyzerTest {
}
}
@Test
public void testJUDAFMap() {
try {
Config.enable_udf = true;
new MockUp<CreateFunctionAnalyzer>() {
@Mock
public String computeMd5(CreateFunctionStmt stmt) {
return "0xff";
}
};
new MockUp<UDFInternalClassLoader>() {
@Mock
public final Class<?> loadClass(String name, boolean resolve)
throws ClassNotFoundException {
if (name.contains("$")) {
return EmptyAggMapEval.State.class;
}
return EmptyAggMapEval.class;
}
};
CreateFunctionStmt stmt = createMapStmt("symbol", "AGGREGATE");
new CreateFunctionAnalyzer().analyze(stmt, connectContext);
Assertions.assertEquals("0xff", stmt.getFunction().getChecksum());
} finally {
Config.enable_udf = false;
}
}
@Test
public void testJUDAFList() {
try {
Config.enable_udf = true;
new MockUp<CreateFunctionAnalyzer>() {
@Mock
public String computeMd5(CreateFunctionStmt stmt) {
return "0xff";
}
};
new MockUp<UDFInternalClassLoader>() {
@Mock
public final Class<?> loadClass(String name, boolean resolve)
throws ClassNotFoundException {
if (name.contains("$")) {
return EmptyAggListEval.State.class;
}
return EmptyAggListEval.class;
}
};
CreateFunctionStmt stmt = createListStmt("symbol", "AGGREGATE");
new CreateFunctionAnalyzer().analyze(stmt, connectContext);
Assertions.assertEquals("0xff", stmt.getFunction().getChecksum());
} finally {
Config.enable_udf = false;
}
}
public static class JUDTF {
public String[] process(String s, String s2) {
return null;

View File

@ -281,4 +281,31 @@ PROPERTIES
select echo_map2(map("","")), echo_map2(map(null, null));
-- result:
{null:""} {null:null}
-- !result
CREATE aggregate FUNCTION sum_map(map<string,bigint>)
RETURNS map<string,bigint>
properties (
"symbol" = "com.starrocks.udf.SumMapInt64",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/java-udf.jar?v=2"
);
-- result:
-- !result
create table map_table (id int, data map<string, bigint>);
-- result:
-- !result
insert into map_table values (1, map{"a": 10, "b": 20}), (1, map{"a": 20, "c": 20}), (2, map{"d": 20, "e": 30}), (2, map{null: 40, "d": 30});
-- result:
-- !result
select id, data from map_table order by id;
-- result:
1 {"a":10,"b":20}
1 {"a":20,"c":20}
2 {"d":20,"e":30}
2 {null:40,"d":30}
-- !result
select id, sum_map(data) from map_table group by id order by id;
-- result:
1 {"a":30,"b":20,"c":20}
2 {null:40,"d":50,"e":30}
-- !result

View File

@ -179,4 +179,18 @@ PROPERTIES
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/ArrayMap.jar?v=2"
);
select echo_map2(map("","")), echo_map2(map(null, null));
select echo_map2(map("","")), echo_map2(map(null, null));
CREATE aggregate FUNCTION sum_map(map<string,bigint>)
RETURNS map<string,bigint>
properties (
"symbol" = "com.starrocks.udf.SumMapInt64",
"type" = "StarrocksJar",
"file" = "${udf_url}/starrocks-jdbc/java-udf.jar?v=2"
);
create table map_table (id int, data map<string, bigint>);
insert into map_table values (1, map{"a": 10, "b": 20}), (1, map{"a": 20, "c": 20}), (2, map{"d": 20, "e": 30}), (2, map{null: 40, "d": 30});
select id, data from map_table order by id;
select id, sum_map(data) from map_table group by id order by id;